在Ubuntu下利用PyTorch进行深度学习,可按以下步骤操作:
build-essential、python3、pip等。sudo apt update
sudo apt install python3 python3-pip build-essential
venv或conda。python3 -m venv pytorch_env # venv方式
source pytorch_env/bin/activate
或使用conda:conda create -n pytorch_env python=3.9
conda activate pytorch_env
pip install torch torchvision torchaudio # pip方式
# 或
conda install pytorch torchvision torchaudio cpuonly -c pytorch # conda方式
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 # 以CUDA 11.7为例
# 或
conda install pytorch torchvision torchaudio cudatoolkit=11.7 -c pytorch # conda方式
import torch
print(torch.__version__)
print(torch.cuda.is_available()) # 若为True则CUDA可用
torchvision加载数据集(如CIFAR-10),并进行预处理(归一化、裁剪等)。import torchvision.transforms as transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(6 * 14 * 14, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(nn.functional.relu(self.conv1(x)))
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
net = Net()
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for epoch in range(2): # 训练2轮
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if i % 2000 == 1999:
print(f'Epoch {epoch+1}, Batch {i+1}, Loss: {loss.item():.3f}')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)
inputs, labels = inputs.to(device), labels.to(device)
# 保存模型
torch.save(net.state_dict(), 'model.pth')
# 加载模型
net.load_state_dict(torch.load('model.pth'))
安装及基础用法可参考PyTorch官网教程,具体项目可结合torchvision、torch.nn等模块扩展。