PyTorch在CentOS中的常用可视化工具及使用指南
TensorBoard是PyTorch官方推荐的可视化工具,主要用于监控训练过程中的指标变化(如损失、准确率),也可展示模型结构。
pip install tensorboard torchvision
SummaryWriter记录训练指标(如损失、准确率),并可选添加模型计算图。from torch.utils.tensorboard import SummaryWriter
import torch
# 初始化SummaryWriter,指定日志保存目录
writer = SummaryWriter(log_dir='./runs/experiment1')
# 模拟训练过程(替换为实际训练代码)
num_epochs = 10
for epoch in range(num_epochs):
train_loss = 0.5 + 0.1 * epoch # 示例损失值
val_accuracy = 0.7 + 0.05 * epoch # 示例准确率
# 记录标量数据
writer.add_scalar('Loss/train', train_loss, epoch)
writer.add_scalar('Accuracy/val', val_accuracy, epoch)
# 可选:添加模型计算图(需提供输入张量)
# input_tensor = torch.randn(1, 3, 224, 224) # 示例输入
# writer.add_graph(model, input_tensor)
# 关闭writer
writer.close()
http://localhost:6006查看可视化结果。tensorboard --logdir=./runs
PyTorchViz基于Graphviz库,用于生成PyTorch模型的计算图,直观展示张量操作与数据流向。
torchviz。pip install torchviz
make_dot函数生成模型计算图,需提供模型输入张量及参数。import torch
from torchviz import make_dot
import torch.nn as nn
# 定义示例模型(替换为实际模型)
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(1, 16, 3, 1, 1), nn.ReLU())
self.fc = nn.Linear(16 * 26 * 26, 10) # 假设输入为28x28图像
def forward(self, x):
x = self.conv1(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
model = ConvNet()
input_tensor = torch.randn(1, 1, 28, 28) # 示例输入
output = model(input_tensor)
# 生成计算图
dot = make_dot(output, params=dict(model.named_parameters()))
# 保存为PDF文件(可选格式:png、svg)
dot.render("model_structure", format="pdf")
Matplotlib是Python基础绘图库,适合绘制损失曲线、准确率曲线等;Seaborn基于Matplotlib,提供更高级的统计可视化(如直方图、热力图)。
pip install matplotlib seaborn
import matplotlib.pyplot as plt
# 示例数据(替换为实际数据)
epochs = range(1, 11)
train_losses = [0.5, 0.45, 0.4, 0.38, 0.36, 0.35, 0.34, 0.33, 0.32, 0.31]
val_losses = [0.55, 0.5, 0.48, 0.46, 0.45, 0.44, 0.43, 0.42, 0.41, 0.4]
# 绘制曲线
plt.plot(epochs, train_losses, 'bo-', label='Training Loss')
plt.plot(epochs, val_losses, 'ro-', label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
import seaborn as sns
import pandas as pd
# 示例数据(替换为实际数据)
data = pd.DataFrame({'Loss': train_losses + val_losses})
# 绘制直方图
sns.histplot(data['Loss'], kde=True, bins=10)
plt.title('Loss Distribution')
plt.xlabel('Loss')
plt.ylabel('Frequency')
plt.show()
Netron是一款跨框架的模型可视化工具,支持PyTorch(.pt/.pth格式)、ONNX等多种模型格式,可直接查看模型层结构、参数及连接。
pip install netron
netron model.pth --port 8080
启动后,在浏览器访问http://localhost:8080即可查看模型结构。torchinfo(原名torch-summary)用于输出模型的层结构、参数数量、输出形状等信息,帮助理解模型架构。
torchinfo。pip install torchinfo
summary函数,传入模型及输入尺寸。from torchinfo import summary
import torchvision.models as models
# 加载示例模型(替换为实际模型)
model = models.resnet18()
# 打印模型结构信息(输入尺寸需匹配模型要求)
summary(model, input_size=(1, 3, 224, 224))
以上工具覆盖了PyTorch在CentOS中的主要可视化需求,可根据具体场景选择使用(如训练监控用TensorBoard、模型结构用Netron/PyTorchViz、数据分布用Matplotlib/Seaborn)。