首页
/ ResNeSt 模型训练脚本解析与使用指南

ResNeSt 模型训练脚本解析与使用指南

2025-07-10 01:26:43作者:虞亚竹Luna

概述

ResNeSt 是一个基于 ResNet 架构改进的深度卷积神经网络,通过引入 Split-Attention 机制显著提升了模型性能。本文将详细解析 ResNeSt 项目的训练脚本(train.py),帮助读者理解其实现原理和使用方法。

环境配置与初始化

训练脚本首先进行了一系列环境变量设置,这些设置优化了 MXNet 框架在 GPU 上的运行效率:

os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '1'
os.environ['MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF'] = '26'
os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD'] = '999'
os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD'] = '25'
os.environ['MXNET_GPU_COPY_NTHREADS'] = '1'
os.environ['MXNET_OPTIMIZER_AGGREGATION_SIZE'] = '54'

这些配置主要控制:

  • 禁用 cuDNN 自动调优
  • 设置 GPU 内存池参数
  • 优化前向/反向传播的批量执行
  • 调整 GPU 数据拷贝线程数
  • 设置优化器聚合大小

参数解析与设置

脚本提供了丰富的训练参数配置选项:

parser = argparse.ArgumentParser(description='MXNet ImageNet Example')
# 数据相关参数
parser.add_argument('--use-rec', action='store_true', default=False)
parser.add_argument('--data-nthreads', type=int, default=8)
# 训练超参数
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--lr', type=float, default=0.05)
parser.add_argument('--momentum', type=float, default=0.9)
# 模型相关参数
parser.add_argument('--model', type=str, default='resnet50_v1')
parser.add_argument('--last-gamma', action='store_true', default=False)
# 数据增强选项
parser.add_argument('--auto_aug', action='store_true')
parser.add_argument('--dropblock-prob', type=float, default=0)
args = parser.parse_args()

关键参数说明:

  • --last-gamma: 是否将最后一个 BN 层的 gamma 初始化为 0
  • --auto_aug: 使用自动数据增强策略
  • --dropblock-prob: DropBlock 正则化的概率

分布式训练初始化

脚本使用 Horovod 进行分布式训练:

hvd.init()
num_workers = hvd.size()
rank = hvd.rank()
local_rank = hvd.local_rank()

数据加载与预处理

训练数据增强

脚本提供了两种数据增强策略,根据输入尺寸自动选择:

if input_size >= 320:
    train_transforms.extend([
        ERandomCrop(input_size),
        pth_transforms.Resize((input_size, input_size)),
        pth_transforms.RandomHorizontalFlip(),
        pth_transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
        transforms.RandomLighting(lighting_param)
    ])
else:
    train_transforms.extend([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomFlipLeftRight(),
        transforms.RandomColorJitter(brightness=jitter_param, 
                                   contrast=jitter_param,
                                   saturation=jitter_param)
    ])

分布式数据采样

使用 SplitSampler 实现数据分片,每个 worker 处理不同的数据子集:

class SplitSampler(mx.gluon.data.sampler.Sampler):
    def __init__(self, length, num_parts=1, part_index=0, random=True):
        self.part_len = length // num_parts
        self.start = self.part_len * part_index
        self.end = self.start + self.part_len

模型初始化与配置

模型加载支持多种配置选项:

kwargs = {
    'ctx': context,
    'pretrained': args.use_pretrained,
    'classes': num_classes,
    'input_size': args.input_size
}

if args.last_gamma:
    kwargs['last_gamma'] = True

if args.dropblock_prob > 0:
    kwargs['dropblock_prob'] = args.dropblock_prob

net = get_model(args.model, **kwargs)

训练流程

学习率调度

使用余弦退火学习率调度器,并支持 warmup:

lr_sched = lr_scheduler.CosineScheduler(
    args.num_epochs * epoch_size,
    base_lr=(args.lr * num_workers),
    warmup_steps=(args.warmup_epochs * epoch_size),
    warmup_begin_lr=args.warmup_lr
)

混合训练策略

支持 MixUp 和标签平滑两种训练策略:

if args.mixup:
    lam = np.random.beta(args.mixup_alpha, args.mixup_alpha)
    data = [lam*X + (1-lam)*X[::-1] for X in data]
    label = mixup_transform(label, num_classes, lam, eta)
elif args.label_smoothing:
    label = smooth(label, num_classes)

DropBlock 调度

动态调整 DropBlock 概率:

drop_scheduler = DropBlockScheduler(net, 0, 0.1, args.num_epochs)

模型评估

验证阶段使用中心裁剪和标准化:

transform_test = transforms.Compose([
    pth_transforms.ToPIL(),
    ECenterCrop(input_size),
    pth_transforms.Resize((input_size, input_size)),
    pth_transforms.ToNDArray(),
    transforms.ToTensor(),
    normalize
])

使用建议

  1. 大型模型训练:对于输入尺寸大于 320 的模型,建议使用 --auto_aug 参数启用自动数据增强
  2. 正则化配置:通过 --dropblock-prob 控制 DropBlock 强度,通常设置为 0.1 左右
  3. 学习率设置:分布式训练时,基础学习率会自动乘以 worker 数量
  4. 混合精度训练:使用 --dtype float16 启用混合精度训练

总结

ResNeSt 训练脚本提供了丰富的配置选项和训练策略,支持分布式训练、多种数据增强方法和正则化技术。通过合理配置这些参数,用户可以高效地训练出高性能的 ResNeSt 模型。