首页
/ PyTorch Image Models 训练脚本使用指南

PyTorch Image Models 训练脚本使用指南

2025-07-05 03:50:54作者:宣海椒Queenly

概述

PyTorch Image Models (timm) 提供了一套完整的训练、验证和推理脚本,这些脚本源自早期的 PyTorch ImageNet 示例,并在此基础上进行了大量功能增强和性能优化。本文将详细介绍如何使用这些脚本进行高效的图像模型训练。

训练脚本详解

基本训练命令

训练脚本支持多种模型和训练配置。以下是一个典型的 SE-ResNet34 训练示例:

./distributed_train.sh 4 --data-dir /data/imagenet \
  --model seresnet34 \
  --sched cosine \
  --epochs 150 \
  --warmup-epochs 5 \
  --lr 0.4 \
  --reprob 0.5 \
  --remode pixel \
  --batch-size 256 \
  --amp -j 4

关键参数说明:

  • distributed_train.sh 4:使用4个GPU进行分布式训练
  • --data-dir:指定包含train和validation子目录的数据集路径
  • --model:指定模型架构
  • --sched:学习率调度策略(cosine表示余弦退火)
  • --amp:启用自动混合精度训练

训练技巧与建议

  1. 混合精度训练:推荐使用PyTorch 1.9+的原生AMP而非APEX AMP
  2. 学习率预热--warmup-epochs可帮助模型在训练初期稳定收敛
  3. 数据增强:通过--reprob--remode控制随机擦除增强的概率和模式

验证与推理脚本

验证脚本

验证脚本用于评估模型在验证集上的表现:

python validate.py \
  --data-dir /imagenet/validation/ \
  --model seresnext26_32x4d \
  --pretrained

推理脚本

推理脚本可生成预测结果:

python inference.py \
  --data-dir /imagenet/validation/ \
  --model mobilenetv3_large_100 \
  --checkpoint ./output/train/model_best.pth.tar

典型训练配置示例

EfficientNet-B2训练配置

./distributed_train.sh 2 --data-dir /imagenet/ \
  --model efficientnet_b2 -b 128 \
  --sched step --epochs 450 \
  --decay-epochs 2.4 --decay-rate .97 \
  --opt rmsproptf --opt-eps .001 -j 8 \
  --warmup-lr 1e-6 --weight-decay 1e-5 \
  --drop 0.3 --drop-path 0.2 \
  --model-ema --model-ema-decay 0.9999 \
  --aa rand-m9-mstd0.5 \
  --remode pixel --reprob 0.2 \
  --amp --lr .016

特点:

  • 使用RandAugment数据增强
  • 采用EMA模型权重平均
  • 阶梯式学习率衰减

ResNet50训练配置

./distributed_train.sh 2 --data-dir /imagenet \
  -b 64 --model resnet50 \
  --sched cosine --epochs 200 \
  --lr 0.05 --amp \
  --remode pixel --reprob 0.6 \
  --aug-splits 3 \
  --aa rand-m9-mstd0.5-inc1 \
  --resplit --split-bn --jsd \
  --dist-bn reduce

特点:

  • 使用JSD损失函数
  • 三路数据增强(clean+2xRA)
  • 分离的BatchNorm层

训练优化建议

  1. 学习率调整:根据batch size线性调整学习率
  2. 训练时长:较大模型需要更多训练epoch
  3. 正则化策略:适当调整dropout和drop-path率
  4. 数据增强:RandAugment强度需根据模型容量调整
  5. 分布式训练:使用--dist-bn reduce优化BatchNorm同步

常见问题处理

  1. 训练不稳定:尝试降低学习率或增加warmup周期
  2. 验证精度低:检查数据增强是否过度或不足
  3. 内存不足:减小batch size或使用梯度累积
  4. 收敛速度慢:调整学习率调度策略或优化器参数

通过合理配置这些参数,您可以在PyTorch Image Models上获得优异的模型性能。