TransUNet训练脚本解析与使用指南
2025-07-10 06:06:20作者:昌雅子Ethen
TransUNet是一种结合了Transformer和UNet架构的医学图像分割模型,在医学图像分割任务中表现出色。本文将深入解析TransUNet项目的训练脚本(train.py),帮助读者理解其实现原理和使用方法。
训练脚本核心功能
TransUNet的训练脚本主要负责以下几个核心功能:
- 参数配置与解析
- 模型初始化
- 训练过程管理
- 结果保存
参数配置详解
训练脚本使用argparse模块提供了丰富的可配置参数,这些参数可以分为几大类:
数据相关参数
root_path
: 训练数据根目录dataset
: 使用的数据集名称(默认为'Synapse')list_dir
: 数据列表文件目录num_classes
: 网络输出通道数(类别数)
训练过程参数
max_iterations
: 最大迭代次数max_epochs
: 最大epoch数batch_size
: 每个GPU的batch大小base_lr
: 基础学习率
模型架构参数
img_size
: 网络输入图像大小n_skip
: 使用的跳跃连接数量vit_name
: 选择的ViT模型名称vit_patches_size
: ViT的patch大小
系统参数
n_gpu
: 使用的GPU数量deterministic
: 是否使用确定性训练seed
: 随机种子
关键代码解析
随机种子设置
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
这段代码确保了实验的可重复性,通过设置Python、NumPy和PyTorch的随机种子,使得每次运行都能得到相同的结果。
模型初始化
config_vit = CONFIGS_ViT_seg[args.vit_name]
config_vit.n_classes = args.num_classes
config_vit.n_skip = args.n_skip
net = ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).cuda()
net.load_from(weights=np.load(config_vit.pretrained_path))
- 根据参数
vit_name
获取对应的ViT配置 - 设置类别数和跳跃连接数
- 初始化TransUNet模型并加载预训练权重
训练器选择
trainer = {'Synapse': trainer_synapse,}
trainer[dataset_name](args, net, snapshot_path)
根据数据集名称选择对应的训练器,目前支持Synapse数据集。
使用指南
基本训练命令
python train.py --dataset Synapse --root_path ./data/Synapse/train_npz --max_epochs 150
参数调优建议
- 学习率: 可以从默认的0.01开始,根据训练情况调整
- batch size: 根据GPU显存大小调整,较大的batch size通常更稳定
- 图像大小: 默认224x224,可根据任务需求调整
- epoch数: 观察验证集性能决定是否提前停止
模型保存路径
训练脚本会自动根据参数生成模型保存路径,格式如下:
../model/TU_Synapse224/TU_pretrain_R50-ViT-B_16_skip3_vitpatch16_epo150_bs24_lr0.01_224
路径中包含以下信息:
- 数据集名称
- 图像大小
- 是否使用预训练
- ViT模型名称
- 跳跃连接数
- patch大小
- epoch数
- batch size
- 学习率
- 随机种子
高级功能
确定性训练
设置--deterministic 1
可以启用确定性训练,确保实验可重复,但可能会降低训练速度。
不同ViT模型选择
通过--vit_name
参数可以选择不同的ViT变体,目前支持'R50-ViT-B_16'等。
常见问题解决
- 显存不足: 减小batch size或图像大小
- 训练不稳定: 降低学习率或使用更小的patch size
- 性能不佳: 尝试增加epoch数或调整跳跃连接数
总结
TransUNet的训练脚本设计灵活,提供了丰富的可配置参数,支持多种训练场景。通过理解这些参数的含义和作用,用户可以针对不同的医学图像分割任务进行有效的模型训练和调优。