在CentOS上部署PyTorch应用程序通常包括以下步骤:
准备环境:
sudo yum update -y
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh
conda create -n pytorch python=3.8
conda activate pytorch
安装PyTorch:
conda install pytorch torchvision torchaudio cudatoolkit=11.8 -c pytorch
pip install torch torchvision torchaudio
验证安装:
import torch
print(torch.__version__)
print(torch.cuda.is_available())
模型加载与推理:
model = SimpleCNN()
model.load_state_dict(torch.load('model.pth'))
model.eval()
transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
image = Image.open(image_path)
image_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
_, predicted = torch.max(output.data, 1)
print(f"Predicted class: {predicted.item()}")
使用TorchScript编译模型(可选):
traced_model = torch.jit.trace(model, input_tensor)
traced_model.save("traced_model.pt")
量化模型以提高性能(可选):
import torch.quantization as quantization
model.qconfig = quantization.get_default_qconfig('fbgemm')
quantized_model = quantization.prepare(model, inplace=False)
quantization.convert(quantized_model, inplace=True)