ubuntu

Ubuntu下PyTorch训练脚本怎么写

小樊
40
2025-10-04 05:58:14
栏目: 智能运维

在Ubuntu下编写PyTorch训练脚本涉及几个步骤,包括环境设置、数据准备、模型定义、损失函数和优化器的选择,以及训练循环的实现。以下是一个简单的示例,展示了如何编写一个基本的PyTorch训练脚本。

步骤1:安装PyTorch

首先,确保你已经安装了PyTorch。你可以参考PyTorch官网的安装指南来安装适合你系统的版本。

步骤2:编写训练脚本

创建一个新的Python文件,例如train.py,并编写以下代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义一个简单的神经网络
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 784)  # Flatten the input
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# 初始化模型、损失函数和优化器
model = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# 训练模型
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

print('Finished Training')

代码解释

  1. 导入必要的库

    • torchtorch.nn 用于构建和训练神经网络。
    • torch.optim 用于定义优化器。
    • torchvision 用于加载和预处理数据。
    • DataLoader 用于批量加载数据。
  2. 定义模型

    • SimpleNet 是一个简单的全连接神经网络,包含三个全连接层和一个ReLU激活函数。
  3. 数据预处理

    • 使用 transforms.Compose 组合多个数据预处理操作。
    • ToTensor 将图像转换为PyTorch张量。
    • Normalize 对图像进行标准化。
  4. 加载数据集

    • 使用 datasets.MNIST 加载MNIST数据集,并应用预处理操作。
    • 使用 DataLoader 批量加载数据,并设置 shuffle=True 以打乱数据顺序。
  5. 初始化模型、损失函数和优化器

    • SimpleNet 实例化模型。
    • CrossEntropyLoss 作为损失函数。
    • SGD 作为优化器,设置学习率和动量。
  6. 训练模型

    • 使用嵌套的 for 循环进行训练。
    • 外层循环遍历每个epoch。
    • 内层循环遍历每个batch的数据,计算损失并进行反向传播和参数更新。
    • 打印每个epoch的损失。

运行脚本

在终端中运行以下命令来执行训练脚本:

python train.py

这个示例展示了如何在Ubuntu下编写一个基本的PyTorch训练脚本。你可以根据具体需求扩展和修改这个脚本,例如添加验证步骤、保存模型、使用GPU加速等。

0
看了该问题的人还看了