centos

CentOS上PyTorch模型部署方法

小樊
39
2025-05-09 05:51:41
栏目: 智能运维

在CentOS上部署PyTorch模型可以通过多种方法实现,以下是一些常见的步骤和方法:

环境准备

  1. 安装PyTorch和其他依赖库
pip install torch torchvision torchaudio
  1. 创建虚拟环境(推荐):
python -m venv myenv
source myenv/bin/activate

模型转换与保存

  1. 使用TorchScript编译模型
import torch
import torchvision.models as models

model = models.resnet18()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("traced_model.pt")
import torch

class MyModule(torch.nn.Module):
    def __init__(self, n, m):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(n, m))

    def forward(self, input):
        if input.sum() > 0:
            output = self.weight.mv(input)
        else:
            output = self.weight @ input
        return output

my_module = MyModule(10, 20)
scripted_model = torch.jit.script(my_module)
scripted_model.save("scripted_model.pt")
  1. 模型量化:为了提高模型性能,可以进行模型量化。
import torch.quantization as quantization

model.qconfig = quantization.get_default_qconfig('fbgemm')
quantized_model = quantization.prepare(model, inplace=False)
quantization.convert(quantized_model, inplace=True)
quantized_model.save("quantized_model.pt")
  1. 使用ONNX进行部署
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Linear(16 * 16 * 16, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

model = SimpleCNN()
model.load_state_dict(torch.load('model.pth'))
model.eval()

def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image = Image.open(image_path)
    image_tensor = transform(image).unsqueeze(0)
    return image_tensor

# 推理流程
image_path = 'test_image.jpg'
input_tensor = preprocess_image(image_path)
with torch.no_grad():
    output = model(input_tensor)
    _, predicted = torch.max(output.data, 1)
    print(f"Predicted class: {predicted.item()}")

# 转换为ONNX
dummy_input = input_tensor.clone().detach()
torch.onnx.export(model, dummy_input, "simple_cnn.onnx", verbose=True)

# 使用ONNX Runtime进行推理
import onnx
import onnxruntime as ort

onnx_model = onnx.load("simple_cnn.onnx")
onnx.checker.check_model(onnx_model)

ort_session = ort.InferenceSession("simple_cnn.onnx")
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
ort_inputs = {input_name: input_tensor.numpy()}
ort_outs = ort_session.run([output_name], ort_inputs)
print(f"Predicted class: {ort_outs[0][0]}")

使用TensorRT进行部署(可选)

  1. 安装TensorRT:
pip install tensorrt

以上是在CentOS上部署PyTorch模型的常见方法,具体选择哪种方法取决于你的需求和硬件配置。

0
看了该问题的人还看了