PyTorch Image Models 训练脚本使用指南
2025-07-05 03:50:54作者:宣海椒Queenly
概述
PyTorch Image Models (timm) 提供了一套完整的训练、验证和推理脚本,这些脚本源自早期的 PyTorch ImageNet 示例,并在此基础上进行了大量功能增强和性能优化。本文将详细介绍如何使用这些脚本进行高效的图像模型训练。
训练脚本详解
基本训练命令
训练脚本支持多种模型和训练配置。以下是一个典型的 SE-ResNet34 训练示例:
./distributed_train.sh 4 --data-dir /data/imagenet \
--model seresnet34 \
--sched cosine \
--epochs 150 \
--warmup-epochs 5 \
--lr 0.4 \
--reprob 0.5 \
--remode pixel \
--batch-size 256 \
--amp -j 4
关键参数说明:
distributed_train.sh 4
:使用4个GPU进行分布式训练--data-dir
:指定包含train和validation子目录的数据集路径--model
:指定模型架构--sched
:学习率调度策略(cosine表示余弦退火)--amp
:启用自动混合精度训练
训练技巧与建议
- 混合精度训练:推荐使用PyTorch 1.9+的原生AMP而非APEX AMP
- 学习率预热:
--warmup-epochs
可帮助模型在训练初期稳定收敛 - 数据增强:通过
--reprob
和--remode
控制随机擦除增强的概率和模式
验证与推理脚本
验证脚本
验证脚本用于评估模型在验证集上的表现:
python validate.py \
--data-dir /imagenet/validation/ \
--model seresnext26_32x4d \
--pretrained
推理脚本
推理脚本可生成预测结果:
python inference.py \
--data-dir /imagenet/validation/ \
--model mobilenetv3_large_100 \
--checkpoint ./output/train/model_best.pth.tar
典型训练配置示例
EfficientNet-B2训练配置
./distributed_train.sh 2 --data-dir /imagenet/ \
--model efficientnet_b2 -b 128 \
--sched step --epochs 450 \
--decay-epochs 2.4 --decay-rate .97 \
--opt rmsproptf --opt-eps .001 -j 8 \
--warmup-lr 1e-6 --weight-decay 1e-5 \
--drop 0.3 --drop-path 0.2 \
--model-ema --model-ema-decay 0.9999 \
--aa rand-m9-mstd0.5 \
--remode pixel --reprob 0.2 \
--amp --lr .016
特点:
- 使用RandAugment数据增强
- 采用EMA模型权重平均
- 阶梯式学习率衰减
ResNet50训练配置
./distributed_train.sh 2 --data-dir /imagenet \
-b 64 --model resnet50 \
--sched cosine --epochs 200 \
--lr 0.05 --amp \
--remode pixel --reprob 0.6 \
--aug-splits 3 \
--aa rand-m9-mstd0.5-inc1 \
--resplit --split-bn --jsd \
--dist-bn reduce
特点:
- 使用JSD损失函数
- 三路数据增强(clean+2xRA)
- 分离的BatchNorm层
训练优化建议
- 学习率调整:根据batch size线性调整学习率
- 训练时长:较大模型需要更多训练epoch
- 正则化策略:适当调整dropout和drop-path率
- 数据增强:RandAugment强度需根据模型容量调整
- 分布式训练:使用
--dist-bn reduce
优化BatchNorm同步
常见问题处理
- 训练不稳定:尝试降低学习率或增加warmup周期
- 验证精度低:检查数据增强是否过度或不足
- 内存不足:减小batch size或使用梯度累积
- 收敛速度慢:调整学习率调度策略或优化器参数
通过合理配置这些参数,您可以在PyTorch Image Models上获得优异的模型性能。