PyTorch在Ubuntu上的调试方法
在调试前需确保环境配置正确,避免因环境问题导致调试困难:
Miniconda管理Python环境(如conda create -n pytorch python=3.8 && conda activate pytorch),避免依赖冲突;安装CUDA Toolkit(如11.3)和cuDNN(需与PyTorch版本兼容),并通过nvcc --version(检查CUDA版本)、nvidia-smi(检查显卡驱动与GPU状态)验证安装。python -c "import torch; print(torch.cuda.is_available())",确认PyTorch能正确识别GPU(输出True表示GPU可用)。最直接的调试方式,在代码关键位置插入print语句输出变量值(如print("Input shape:", input_data.shape))或执行流程(如print("Entering training loop")),快速定位变量异常或逻辑错误。
Python自带交互式调试工具,适合命令行环境:
import pdb; pdb.set_trace(),程序运行到该行会暂停,可通过n(下一步)、s(进入函数)、c(继续)、p variable(打印变量)等命令调试。ipdb.set_trace()),提升调试体验。比print更灵活的日志记录工具,适合生产环境或长期运行的程序:
import logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logging.debug("Debug message: %s", variable) # 输出DEBUG级别日志
logging.info("Info message") # 输出INFO级别日志
可通过level参数控制日志详细程度(如DEBUG显示详细信息,INFO显示关键流程)。
Shift+F9),程序会在断点处暂停。F5)。使用torch.autograd.set_detect_anomaly(True)启用梯度异常检测,若模型训练中出现梯度爆炸/消失(如NaN值),程序会抛出详细错误信息(包括计算图路径),帮助快速定位梯度问题。
使用torch.autograd.profiler分析模型性能瓶颈(如计算耗时、内存占用):
from torch.autograd import profiler
with profiler.profile(record_shapes=True, profile_memory=True) as prof:
output = model(input_data)
print(prof.key_averages().table(sort_by="cpu_time_total")) # 打印性能报告
通过报告可识别耗时操作(如某层卷积的计算时间),针对性优化。
使用torch.utils.tensorboard记录训练指标(如损失、准确率),可视化模型训练过程:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/experiment-1')
for epoch in range(epochs):
loss = train_model()
writer.add_scalar('Loss/train', loss, epoch) # 记录训练损失
writer.close()
终端运行tensorboard --logdir=runs,通过浏览器访问http://localhost:6006查看可视化界面(如损失曲线、直方图)。
使用unittest或pytest框架编写单元测试,验证模型各模块(如层、函数、损失函数)的正确性:
import unittest
class TestModel(unittest.TestCase):
def test_linear_layer(self):
layer = torch.nn.Linear(10, 1)
input = torch.randn(5, 10)
output = layer(input)
self.assertEqual(output.shape, (5, 1)) # 验证输出形状
if __name__ == '__main__':
unittest.main()
运行python -m unittest test_module.py执行测试,确保模块功能符合预期。
IPython或Jupyter Notebook逐步执行代码(如%run script.py),实时检查变量值,适合探索性调试。assert语句插入检查点(如assert variable > 0, "Variable must be positive"),若条件不成立则抛出AssertionError,快速捕获非法状态。