首页
/ TransUNet训练脚本解析与使用指南

TransUNet训练脚本解析与使用指南

2025-07-10 06:06:20作者:昌雅子Ethen

TransUNet是一种结合了Transformer和UNet架构的医学图像分割模型,在医学图像分割任务中表现出色。本文将深入解析TransUNet项目的训练脚本(train.py),帮助读者理解其实现原理和使用方法。

训练脚本核心功能

TransUNet的训练脚本主要负责以下几个核心功能:

  1. 参数配置与解析
  2. 模型初始化
  3. 训练过程管理
  4. 结果保存

参数配置详解

训练脚本使用argparse模块提供了丰富的可配置参数,这些参数可以分为几大类:

数据相关参数

  • root_path: 训练数据根目录
  • dataset: 使用的数据集名称(默认为'Synapse')
  • list_dir: 数据列表文件目录
  • num_classes: 网络输出通道数(类别数)

训练过程参数

  • max_iterations: 最大迭代次数
  • max_epochs: 最大epoch数
  • batch_size: 每个GPU的batch大小
  • base_lr: 基础学习率

模型架构参数

  • img_size: 网络输入图像大小
  • n_skip: 使用的跳跃连接数量
  • vit_name: 选择的ViT模型名称
  • vit_patches_size: ViT的patch大小

系统参数

  • n_gpu: 使用的GPU数量
  • deterministic: 是否使用确定性训练
  • seed: 随机种子

关键代码解析

随机种子设置

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

这段代码确保了实验的可重复性,通过设置Python、NumPy和PyTorch的随机种子,使得每次运行都能得到相同的结果。

模型初始化

config_vit = CONFIGS_ViT_seg[args.vit_name]
config_vit.n_classes = args.num_classes
config_vit.n_skip = args.n_skip
net = ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).cuda()
net.load_from(weights=np.load(config_vit.pretrained_path))
  1. 根据参数vit_name获取对应的ViT配置
  2. 设置类别数和跳跃连接数
  3. 初始化TransUNet模型并加载预训练权重

训练器选择

trainer = {'Synapse': trainer_synapse,}
trainer[dataset_name](args, net, snapshot_path)

根据数据集名称选择对应的训练器,目前支持Synapse数据集。

使用指南

基本训练命令

python train.py --dataset Synapse --root_path ./data/Synapse/train_npz --max_epochs 150

参数调优建议

  1. 学习率: 可以从默认的0.01开始,根据训练情况调整
  2. batch size: 根据GPU显存大小调整,较大的batch size通常更稳定
  3. 图像大小: 默认224x224,可根据任务需求调整
  4. epoch数: 观察验证集性能决定是否提前停止

模型保存路径

训练脚本会自动根据参数生成模型保存路径,格式如下:

../model/TU_Synapse224/TU_pretrain_R50-ViT-B_16_skip3_vitpatch16_epo150_bs24_lr0.01_224

路径中包含以下信息:

  • 数据集名称
  • 图像大小
  • 是否使用预训练
  • ViT模型名称
  • 跳跃连接数
  • patch大小
  • epoch数
  • batch size
  • 学习率
  • 随机种子

高级功能

确定性训练

设置--deterministic 1可以启用确定性训练,确保实验可重复,但可能会降低训练速度。

不同ViT模型选择

通过--vit_name参数可以选择不同的ViT变体,目前支持'R50-ViT-B_16'等。

常见问题解决

  1. 显存不足: 减小batch size或图像大小
  2. 训练不稳定: 降低学习率或使用更小的patch size
  3. 性能不佳: 尝试增加epoch数或调整跳跃连接数

总结

TransUNet的训练脚本设计灵活,提供了丰富的可配置参数,支持多种训练场景。通过理解这些参数的含义和作用,用户可以针对不同的医学图像分割任务进行有效的模型训练和调优。