在Linux中部署PyTorch模型通常涉及以下几个步骤:
.pth
或 .pt
格式。以下是一个简单的例子,展示如何使用Flask和PyTorch创建一个简单的模型推理服务:
from flask import Flask, request, jsonify
import torch
from torchvision import transforms
from PIL import Image
# 加载模型
model = torch.load('model.pth', map_location=torch.device('cpu'))
model.eval()
# 定义图像预处理
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
app = Flask(__name__)
@app.route('/predict', methods=['POST'])
def predict():
# 获取上传的图片
file = request.files['image']
image = Image.open(file.stream)
# 预处理图片
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0) # 创建一个mini-batch作为模型的输入
# 进行推理
with torch.no_grad():
output = model(input_batch)
# 处理输出结果
probabilities = torch.nn.functional.softmax(output[0], dim=-1)
predicted_class = probabilities.argmax().item()
# 返回预测结果
return jsonify({'class': predicted_class})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=80)
在上面的代码中,我们创建了一个Flask应用,它有一个 /predict
端点,用于接收图片文件并返回预测的类别。在部署之前,确保你的模型文件 model.pth
位于同一目录下。
请注意,这只是一个简单的例子,实际部署时可能需要考虑更多的因素,比如错误处理、安全性、性能优化等。此外,如果你的模型很大或者推理速度要求很高,可能需要考虑使用GPU加速,并相应地调整代码以支持CUDA。