ubuntu

Ubuntu如何利用PyTorch进行图像识别

小樊
41
2025-08-03 18:28:08
栏目: 智能运维

在Ubuntu上利用PyTorch进行图像识别通常涉及以下步骤:

  1. 安装必要的软件和库
sudo apt update
sudo apt install python3 python3-pip
python3 -m venv pytorch-env
source pytorch-env/bin/activate
pip install torch torchvision torchaudio
  1. 数据预处理
import torch
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split
import numpy as np
from PIL import Image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
data = pd.read_csv('faces_dataset.csv')
images = data['image_path'].apply(transform)
labels = data['label']
train_images, test_images, train_labels, test_labels = train_test_split(images, labels, test_size=0.2, random_state=42)
train_data = torch.utils.data.TensorDataset(torch.stack(train_images), torch.tensor(train_labels))
test_data = torch.utils.data.TensorDataset(torch.stack(test_images), torch.tensor(test_labels))
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=32)
  1. 建立模型
model = models.resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, len(train_dataset.classes))
  1. 训练模型
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(num_epochs):
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
  1. 测试模型
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the test images: {} %'.format(100 * correct / total))
  1. 部署模型

以上步骤提供了一个基本的框架,具体的实现可能会根据项目的具体需求有所不同。

0
看了该问题的人还看了