ubuntu

Ubuntu上PyTorch如何进行模型部署

小樊
41
2025-10-29 12:45:25
栏目: 智能运维

Ubuntu上PyTorch模型部署完整流程

1. 环境准备:安装PyTorch及依赖

在Ubuntu上部署PyTorch模型前,需先配置基础环境。建议使用虚拟环境(如condavenv)隔离依赖,避免冲突。

2. 模型准备:保存与加载

模型部署前需完成训练并保存,支持两种格式:完整模型(包含架构与权重)或状态字典(仅权重,推荐)。

3. 推理脚本编写:基础预测

编写Python脚本实现模型加载与推理,支持批处理预处理(如图像分类任务)。

import torch
from torchvision import transforms
from PIL import Image
from model import MyModel  # 替换为你的模型类

# 初始化模型
model = MyModel()
model.load_state_dict(torch.load('model_state_dict.pth'))
model.eval()

# 定义预处理(以ResNet为例)
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整图像尺寸
    transforms.ToTensor(),          # 转换为Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化
])

# 推理函数
def predict(image_path):
    # 加载图像并预处理
    image = Image.open(image_path).convert('RGB')
    input_tensor = preprocess(image).unsqueeze(0)  # 添加batch维度
    
    # 推理(禁用梯度计算,提升速度)
    with torch.no_grad():
        output = model(input_tensor)
    
    # 后处理(如分类任务取最大概率类别)
    _, predicted_class = torch.max(output, 1)
    return predicted_class.item()

# 示例调用
if __name__ == "__main__":
    result = predict('test_image.jpg')
    print(f"Predicted class: {result}")

4. 生产环境部署:Web服务化

若需将模型作为API服务对外提供,可使用Flask(轻量级)或FastAPI(高性能)框架。以下以Flask为例:

5. 生产环境部署:性能优化

为提升模型服务的性能与稳定性,需进行以下优化:

6. 高级部署:TorchServe(官方工具)

TorchServe是PyTorch官方提供的模型服务框架,支持REST APIgRPC模型热加载,适合大规模生产部署。

关键注意事项

0
看了该问题的人还看了