在CentOS上部署PyTorch模型可以通过多种方法实现,以下是一些常见的步骤和方法:
pip install torch torchvision torchaudio
python -m venv myenv
source myenv/bin/activate
torch.jit.script
模块。import torch
import torchvision.models as models
model = models.resnet18()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("traced_model.pt")
torch.jit.script
编译。import torch
class MyModule(torch.nn.Module):
def __init__(self, n, m):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(n, m))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight @ input
return output
my_module = MyModule(10, 20)
scripted_model = torch.jit.script(my_module)
scripted_model.save("scripted_model.pt")
import torch.quantization as quantization
model.qconfig = quantization.get_default_qconfig('fbgemm')
quantized_model = quantization.prepare(model, inplace=False)
quantization.convert(quantized_model, inplace=True)
quantized_model.save("quantized_model.pt")
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc = nn.Linear(16 * 16 * 16, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.maxpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
model = SimpleCNN()
model.load_state_dict(torch.load('model.pth'))
model.eval()
def preprocess_image(image_path):
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open(image_path)
image_tensor = transform(image).unsqueeze(0)
return image_tensor
# 推理流程
image_path = 'test_image.jpg'
input_tensor = preprocess_image(image_path)
with torch.no_grad():
output = model(input_tensor)
_, predicted = torch.max(output.data, 1)
print(f"Predicted class: {predicted.item()}")
# 转换为ONNX
dummy_input = input_tensor.clone().detach()
torch.onnx.export(model, dummy_input, "simple_cnn.onnx", verbose=True)
# 使用ONNX Runtime进行推理
import onnx
import onnxruntime as ort
onnx_model = onnx.load("simple_cnn.onnx")
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession("simple_cnn.onnx")
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
ort_inputs = {input_name: input_tensor.numpy()}
ort_outs = ort_session.run([output_name], ort_inputs)
print(f"Predicted class: {ort_outs[0][0]}")
pip install tensorrt
以上是在CentOS上部署PyTorch模型的常见方法,具体选择哪种方法取决于你的需求和硬件配置。