Linux上PyTorch可视化工具使用指南
一、常用工具与适用场景
二、TensorBoard快速上手
pip install tensorboard(PyTorch自带torch.utils.tensorboard,无需额外安装tensorboardx)。tensorboard --logdir=runs --port=6006(默认端口6006,可更换端口与host)。from torch.utils.tensorboard import SummaryWriter
import torch, torchvision
writer = SummaryWriter(log_dir="runs/exp1")
# 模型与数据
model = torchvision.models.resnet18(False).cuda()
x = torch.randn(4, 3, 224, 224).cuda()
# 标量:损失与准确率
for epoch in range(5):
loss = torch.rand(1).cuda()
acc = torch.rand(1).cuda()
writer.add_scalar("Loss/train", loss, epoch)
writer.add_scalar("Acc/train", acc, epoch)
# 直方图:参数分布
for name, param in model.named_parameters():
writer.add_histogram(f"param/{name}", param, epoch)
# 图像:输入样例网格
if epoch == 0:
grid = torchvision.utils.make_grid(x, nrow=2)
writer.add_image("input", grid, epoch)
writer.close()
三、Visdom实时可视化
pip install visdompython -m visdom.server -port 8097(默认端口8097)。import visdom, numpy as np, time
viz = visdom.Visdom(env="demo")
# 单条曲线
for step in range(100):
loss = np.random.randn()
viz.line([loss], [step], win="loss", update="append", opts=dict(title="Train Loss"))
# 多条曲线:loss与acc
for step in range(100):
loss = abs(np.random.randn()) + 1
acc = abs(np.random.randn())
viz.line([[loss, acc]], [step], win="metrics", update="append",
opts=dict(title="Loss & Acc", legend=["loss", "acc"]))
# 图像(注意通道顺序)
img_bgr = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
img_rgb = img_bgr[:, :, ::-1] # BGR -> RGB
img = np.transpose(img_rgb, (2, 0, 1)) # HWC -> CHW
viz.image(img, win="sample", opts=dict(title="Sample Image"))
# 文本日志
viz.text("Hello, Visdom!", win="log", opts=dict(title="Log"))
四、模型结构与数据流可视化
pip install netronnetron model.pth 或 netron.start("model.pth")(会打开浏览器)。netron model.pb --port 8080,访问 http://localhost:8080。pip install torchinfofrom torchinfo import summary
import torchvision
model = torchvision.models.resnet18()
summary(model, input_size=(1, 3, 224, 224), device="cpu")
pip install torchvizimport torch, torchvision
from torchviz import make_dot
model = torchvision.models.resnet18()
x = torch.randn(1, 3, 224, 224)
y = model(x)
dot = make_dot(y, params=dict(model.named_parameters()))
dot.render("resnet18_graph", format="pdf") # 生成PDF
五、远程服务器与常见问题
ssh -L 6006:127.0.0.1:6006 user@serverssh -L 8097:127.0.0.1:8097 user@serverimg_rgb = img_bgr[:, :, ::-1]。runs/exp1/、runs/exp2/,便于TensorBoard对比。