您好,登录后才能下订单哦!
密码登录
登录注册
点击 登录注册 即表示同意《亿速云用户服务条款》
# PyTorch中如何实现病虫害图像分类
## 引言
在农业生产中,病虫害是影响作物产量和品质的重要因素。传统的人工识别方法效率低下且依赖经验,而基于深度学习的图像分类技术为这一问题提供了高效解决方案。本文将详细介绍如何使用PyTorch框架构建一个病虫害图像分类系统,涵盖数据准备、模型构建、训练优化和部署应用的全流程。
---
## 一、环境准备与数据收集
### 1.1 PyTorch环境配置
```python
# 安装PyTorch(根据CUDA版本选择)
pip install torch torchvision torchaudio
dataset/
├── train/
│ ├── class1/
│ ├── class2/
│ └── ...
├── val/
└── test/
from torchvision import transforms
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
from torch.utils.data import Dataset
class PestDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.classes = os.listdir(root_dir)
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
self.images = []
self.transform = transform
for cls in self.classes:
cls_path = os.path.join(root_dir, cls)
for img_name in os.listdir(cls_path):
self.images.append((os.path.join(cls_path, img_name),
self.class_to_idx[cls]))
def __getitem__(self, idx):
img_path, label = self.images[idx]
img = Image.open(img_path).convert('RGB')
if self.transform:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.images)
import torchvision.models as models
# 加载预训练ResNet50
model = models.resnet50(pretrained=True)
# 冻结所有卷积层
for param in model.parameters():
param.requires_grad = False
# 替换全连接层
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
nn.Linear(num_ftrs, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
class SimpleCNN(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.classifier = nn.Sequential(
nn.Linear(128 * 28 * 28, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
# 初始化
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
# 训练循环
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for inputs, labels in train_loader:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
scheduler.step()
epoch_loss = running_loss / len(train_loader)
print(f'Epoch {epoch+1}, Loss: {epoch_loss:.4f}')
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
outputs = model(inputs)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# 计算类别权重
class_counts = [len(os.listdir(f"train/{cls}")) for cls in classes]
weights = 1. / torch.tensor(class_counts, dtype=torch.float)
criterion = nn.CrossEntropyLoss(weight=weights.to(device))
from sklearn.metrics import classification_report
def evaluate(model, dataloader):
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for inputs, labels in dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
print(classification_report(all_labels, all_preds, target_names=classes))
return all_preds, all_labels
# 实现Grad-CAM类
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
target_layer.register_forward_hook(self.save_activations)
target_layer.register_backward_hook(self.save_gradients)
def save_activations(self, module, input, output):
self.activations = output
def save_gradients(self, module, grad_input, grad_output):
self.gradients = grad_output[0]
def __call__(self, x, class_idx=None):
output = self.model(x)
if class_idx is None:
class_idx = torch.argmax(output)
self.model.zero_grad()
one_hot = torch.zeros_like(output)
one_hot[0][class_idx] = 1
output.backward(gradient=one_hot)
pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
cam = torch.sum(self.activations * pooled_gradients[None, :, None, None], dim=1)
cam = F.relu(cam)
cam = F.interpolate(cam.unsqueeze(0), size=x.shape[2:], mode='bilinear')
cam = cam - cam.min()
cam = cam / cam.max()
return cam.squeeze().cpu().numpy()
# 导出为TorchScript
scripted_model = torch.jit.script(model)
scripted_model.save("pest_classifier.pt")
# ONNX格式导出
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "model.onnx",
input_names=["input"], output_names=["output"])
from flask import Flask, request, jsonify
import torchvision.transforms as T
app = Flask(__name__)
model = torch.jit.load('pest_classifier.pt')
model.eval()
transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
@app.route('/predict', methods=['POST'])
def predict():
file = request.files['image']
img = Image.open(file.stream).convert('RGB')
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
output = model(img_tensor)
_, pred = torch.max(output, 1)
return jsonify({'class': classes[pred.item()]})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
小样本学习:
多标签分类:
移动端部署:
本文详细介绍了基于PyTorch的病虫害图像分类系统实现方法。通过合理选择模型架构、优化训练过程并结合实际部署需求,可以构建出准确率超过90%的实用化系统。未来可结合目标检测、语义分割等技术实现更精细的病虫害分析,为智慧农业提供有力支持。 “`
注:实际文章约为2700字(含代码),此处为精简展示版。完整版应包含: 1. 更详细的理论解释 2. 各技术选择的对比分析 3. 实验结果的定量展示 4. 实际应用案例 5. 参考文献与扩展阅读建议
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。