Linux下PyTorch图形界面操作指南
一 工具总览与定位
二 训练监控与实验追踪
pip install tensorboardfrom torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir="runs/exp1")
for epoch in range(num_epochs):
# ... 训练步骤 ...
writer.add_scalar('Loss/train', loss, epoch)
writer.add_scalar('Acc/train', acc, epoch)
writer.close()
tensorboard --logdir=runs,浏览器访问 http://localhost:6006。python -m visdom.server(默认端口8097)from visdom import Visdom
vis = Visdom()
vis.line([0], [0], win='loss', opts=dict(title='Training Loss'))
for epoch in range(100):
loss = ... # 计算损失
vis.line([loss], [epoch], win='loss', update='append')
vis.image、vis.images、vis.text、vis.matplot 等。三 模型结构与参数查看
pip install netronimport torch, torchvision
model = torchvision.models.resnet18(pretrained=True)
torch.save(model.state_dict(), "resnet18.pth")
# 终端启动
netron resnet18.pth
pip install torchvizimport torch, torchvision
from torchviz import make_dot
model = torchvision.models.resnet18(pretrained=True)
x = torch.randn(1, 3, 224, 224)
y = model(x)
dot = make_dot(y, params=dict(model.named_parameters()))
dot.render("resnet18_graph", format="png") # 生成 PNG
pip install torchinfofrom torchinfo import summary
from torchvision.models import resnet18
model = resnet18()
summary(model, input_size=(1, 3, 224, 224))
四 结果绘图与本地分析
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
epochs = list(range(1, num_epochs+1))
df = pd.DataFrame({'Epoch': epochs,
'Train Loss': train_losses,
'Val Loss': val_losses})
# Matplotlib
plt.plot(epochs, train_losses, label='Train')
plt.plot(epochs, val_losses, label='Val')
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.show()
# Seaborn
sns.lineplot(data=df, x='Epoch', y='Train Loss'); plt.show()
sns.histplot(train_losses, kde=True); plt.show()
五 部署交互式Web界面
pip install gradioimport gradio as gr
def enhance_image(img):
# img: PIL.Image 或 ndarray,做你的增强逻辑
return img # 这里直接回显示例
inputs = gr.Image(type="pil", label="输入图像")
outputs = gr.Image(type="pil", label="增强结果")
demo = gr.Interface(fn=enhance_image, inputs=inputs, outputs=outputs,
title="图像增强演示")
demo.launch(share=True) # 生成可分享链接
pip install streamlitimport streamlit as st
import torch, torch.nn as nn
from torchvision import transforms
from PIL import Image
# 假设已训练好并保存 state_dict: mnist_cnn.pth
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.fc1 = nn.Linear(32*7*7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = nn.functional.relu(self.conv1(x))
x = nn.functional.max_pool2d(x, 2)
x = nn.functional.relu(self.conv2(x))
x = nn.functional.max_pool2d(x, 2)
x = x.view(x.size(0), -1)
x = nn.functional.relu(self.fc1(x))
return self.fc2(x)
@st.cache_resource
def load_model():
m = CNN()
m.load_state_dict(torch.load("mnist_cnn.pth", map_location="cpu"))
m.eval()
return m
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
st.title("MNIST 手写体识别")
uploaded = st.file_uploader("上传灰度图", type=["png","jpg","jpeg"])
if uploaded:
img = Image.open(uploaded).convert("L").resize((28,28))
st.image(img, caption="输入", width=150)
x = transform(img).unsqueeze(0)
model = load_model()
with torch.no_grad():
pred = model(x).argmax(1).item()
st.write(f"预测数字:{pred}")
streamlit run app.py,浏览器自动打开本地页面。