1. TensorBoard(官方训练过程可视化工具)
TensorBoard是PyTorch官方集成的可视化工具,主要用于监控模型训练过程、可视化计算图及分析数据分布。它支持跟踪损失、准确率等指标的变化趋势,绘制直方图展示参数分布,还能可视化模型结构。安装方式为pip install tensorboard
,使用时通过SummaryWriter
记录数据,训练结束后启动tensorboard --logdir=runs
,在浏览器中访问localhost:6006
即可查看。
2. Torchinfo(模型结构与参数统计工具)
Torchinfo(原名torch-summary)用于输出PyTorch模型的详细结构信息,包括各层的类型、参数数量、输入输出形状等。安装命令为pip install torchinfo
,使用示例:from torchinfo import summary; summary(model, input_size=(batch_size, channels, height, width))
,能快速了解模型架构的全貌。
3. PyTorchviz(计算图可视化工具)
PyTorchviz基于Graphviz库,可将PyTorch模型的计算图可视化为图形文件(如PNG、PDF)。它展示了张量操作、模块调用及参数依赖关系,适合理解模型的计算流程。安装方式为pip install torchviz
,使用示例:from torchviz import make_dot; dot = make_dot(model(input_tensor), params=dict(model.named_parameters())); dot.render("model", format="png")
。
4. Netron(模型架构可视化工具)
Netron是一款跨平台的深度学习模型可视化工具,支持PyTorch的.pt
、.onnx
等格式模型文件。它能直观显示模型的层结构、参数配置及连接关系,无需编写代码即可查看模型细节。安装方式为pip install netron
,启动后通过命令netron model.pt --port 8080
在浏览器中访问localhost:8080
查看。
5. TorchView(增强型模型可视化工具)
TorchView是PyTorch的图形化模型可视化库,支持显示张量、模块、操作及输入输出形状,还能处理递归模块(如RNN)的滚动/展开。它适用于大多数PyTorch模型(包括Hugging Face模型),安装前需先安装graphviz
(sudo apt-get install graphviz
),然后通过pip install torchview
安装。使用示例:from torchview import draw_graph; draw_graph(model, input_size=(batch_size, channels), device='meta').visual_graph
。
6. Visdom(实时动态可视化工具)
Visdom由Facebook开发,支持实时数据可视化(如折线图、热力图、3D点云等),适合动态更新训练过程中的指标或图像结果。它是基于Web的轻量级工具,支持远程访问。安装方式为pip install visdom
,使用时通过visdom.Visdom()
创建服务器,调用viz.line()
等方法更新数据。
7. Captum(模型可解释性可视化工具)
Captum是PyTorch官方提供的可解释性库,用于分析模型的特征重要性和注意力机制。它支持积分梯度(Integrated Gradients)、显著性图(Saliency Maps)等方法,帮助理解模型决策过程。安装方式为pip install captum
,使用示例:from captum.attr import IntegratedGradients; ig = IntegratedGradients(model); attr = ig.attribute(input, target=label)
。
8. Matplotlib/Seaborn(基础数据可视化工具)
Matplotlib和Seaborn是Python常用的数据可视化库,可用于绘制训练过程中的损失/准确率曲线、特征分布直方图、相关性矩阵等。Matplotlib提供基础的绘图功能,Seaborn基于Matplotlib构建,提供更美观的主题和高级接口。安装方式为pip install matplotlib seaborn
,使用示例:import matplotlib.pyplot as plt; plt.plot(epochs, train_losses, label='Train Loss'); plt.show()
。