ResNeSt 模型训练脚本解析与使用指南
2025-07-10 01:26:43作者:虞亚竹Luna
概述
ResNeSt 是一个基于 ResNet 架构改进的深度卷积神经网络,通过引入 Split-Attention 机制显著提升了模型性能。本文将详细解析 ResNeSt 项目的训练脚本(train.py),帮助读者理解其实现原理和使用方法。
环境配置与初始化
训练脚本首先进行了一系列环境变量设置,这些设置优化了 MXNet 框架在 GPU 上的运行效率:
os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '1'
os.environ['MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF'] = '26'
os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD'] = '999'
os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD'] = '25'
os.environ['MXNET_GPU_COPY_NTHREADS'] = '1'
os.environ['MXNET_OPTIMIZER_AGGREGATION_SIZE'] = '54'
这些配置主要控制:
- 禁用 cuDNN 自动调优
- 设置 GPU 内存池参数
- 优化前向/反向传播的批量执行
- 调整 GPU 数据拷贝线程数
- 设置优化器聚合大小
参数解析与设置
脚本提供了丰富的训练参数配置选项:
parser = argparse.ArgumentParser(description='MXNet ImageNet Example')
# 数据相关参数
parser.add_argument('--use-rec', action='store_true', default=False)
parser.add_argument('--data-nthreads', type=int, default=8)
# 训练超参数
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--lr', type=float, default=0.05)
parser.add_argument('--momentum', type=float, default=0.9)
# 模型相关参数
parser.add_argument('--model', type=str, default='resnet50_v1')
parser.add_argument('--last-gamma', action='store_true', default=False)
# 数据增强选项
parser.add_argument('--auto_aug', action='store_true')
parser.add_argument('--dropblock-prob', type=float, default=0)
args = parser.parse_args()
关键参数说明:
--last-gamma
: 是否将最后一个 BN 层的 gamma 初始化为 0--auto_aug
: 使用自动数据增强策略--dropblock-prob
: DropBlock 正则化的概率
分布式训练初始化
脚本使用 Horovod 进行分布式训练:
hvd.init()
num_workers = hvd.size()
rank = hvd.rank()
local_rank = hvd.local_rank()
数据加载与预处理
训练数据增强
脚本提供了两种数据增强策略,根据输入尺寸自动选择:
if input_size >= 320:
train_transforms.extend([
ERandomCrop(input_size),
pth_transforms.Resize((input_size, input_size)),
pth_transforms.RandomHorizontalFlip(),
pth_transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
transforms.RandomLighting(lighting_param)
])
else:
train_transforms.extend([
transforms.RandomResizedCrop(input_size),
transforms.RandomFlipLeftRight(),
transforms.RandomColorJitter(brightness=jitter_param,
contrast=jitter_param,
saturation=jitter_param)
])
分布式数据采样
使用 SplitSampler
实现数据分片,每个 worker 处理不同的数据子集:
class SplitSampler(mx.gluon.data.sampler.Sampler):
def __init__(self, length, num_parts=1, part_index=0, random=True):
self.part_len = length // num_parts
self.start = self.part_len * part_index
self.end = self.start + self.part_len
模型初始化与配置
模型加载支持多种配置选项:
kwargs = {
'ctx': context,
'pretrained': args.use_pretrained,
'classes': num_classes,
'input_size': args.input_size
}
if args.last_gamma:
kwargs['last_gamma'] = True
if args.dropblock_prob > 0:
kwargs['dropblock_prob'] = args.dropblock_prob
net = get_model(args.model, **kwargs)
训练流程
学习率调度
使用余弦退火学习率调度器,并支持 warmup:
lr_sched = lr_scheduler.CosineScheduler(
args.num_epochs * epoch_size,
base_lr=(args.lr * num_workers),
warmup_steps=(args.warmup_epochs * epoch_size),
warmup_begin_lr=args.warmup_lr
)
混合训练策略
支持 MixUp 和标签平滑两种训练策略:
if args.mixup:
lam = np.random.beta(args.mixup_alpha, args.mixup_alpha)
data = [lam*X + (1-lam)*X[::-1] for X in data]
label = mixup_transform(label, num_classes, lam, eta)
elif args.label_smoothing:
label = smooth(label, num_classes)
DropBlock 调度
动态调整 DropBlock 概率:
drop_scheduler = DropBlockScheduler(net, 0, 0.1, args.num_epochs)
模型评估
验证阶段使用中心裁剪和标准化:
transform_test = transforms.Compose([
pth_transforms.ToPIL(),
ECenterCrop(input_size),
pth_transforms.Resize((input_size, input_size)),
pth_transforms.ToNDArray(),
transforms.ToTensor(),
normalize
])
使用建议
- 大型模型训练:对于输入尺寸大于 320 的模型,建议使用
--auto_aug
参数启用自动数据增强 - 正则化配置:通过
--dropblock-prob
控制 DropBlock 强度,通常设置为 0.1 左右 - 学习率设置:分布式训练时,基础学习率会自动乘以 worker 数量
- 混合精度训练:使用
--dtype float16
启用混合精度训练
总结
ResNeSt 训练脚本提供了丰富的配置选项和训练策略,支持分布式训练、多种数据增强方法和正则化技术。通过合理配置这些参数,用户可以高效地训练出高性能的 ResNeSt 模型。