首页
/ FacebookResearch/MoCo 自监督学习训练脚本深度解析

FacebookResearch/MoCo 自监督学习训练脚本深度解析

2025-07-08 03:59:56作者:钟日瑜

概述

本文将对 Facebook Research 团队提出的 MoCo (Momentum Contrast) 自监督学习框架的训练脚本 main_moco.py 进行深入解析。MoCo 是一种用于无监督视觉表示学习的对比学习方法,通过构建动态字典来实现对比学习,在多个视觉任务上取得了优异的表现。

脚本核心功能

该训练脚本主要实现了以下功能:

  1. 模型构建与初始化
  2. 数据加载与增强
  3. 分布式训练支持
  4. 训练流程控制
  5. 学习率调度
  6. 模型保存与恢复

关键组件解析

1. 参数配置系统

脚本使用 argparse 模块提供了丰富的训练参数配置:

# 基础训练参数
parser.add_argument("--epochs", default=200, type=int, metavar="N", help="number of total epochs to run")
parser.add_argument("-b", "--batch-size", default=256, type=int, metavar="N", help="mini-batch size")

# MoCo 特有参数
parser.add_argument("--moco-dim", default=128, type=int, help="feature dimension")
parser.add_argument("--moco-k", default=65536, type=int, help="queue size")
parser.add_argument("--moco-m", default=0.999, type=float, help="momentum of updating key encoder")
parser.add_argument("--moco-t", default=0.07, type=float, help="softmax temperature")

# MoCo v2 增强选项
parser.add_argument("--mlp", action="store_true", help="use mlp head")
parser.add_argument("--aug-plus", action="store_true", help="use moco v2 data augmentation")
parser.add_argument("--cos", action="store_true", help="use cosine lr schedule")

2. 数据增强策略

MoCo 提供了两种数据增强方案:

if args.aug_plus:
    # MoCo v2 增强策略
    augmentation = [
        transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
        transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]
else:
    # MoCo v1 增强策略
    augmentation = [
        transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
        transforms.RandomGrayscale(p=0.2),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]

3. 分布式训练支持

脚本提供了完善的分布式训练支持:

if args.multiprocessing_distributed:
    args.world_size = ngpus_per_node * args.world_size
    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
else:
    main_worker(args.gpu, ngpus_per_node, args)

4. 模型构建

MoCo 模型的核心构建逻辑:

model = deeplearning.cross_image_ssl.moco.builder.MoCo(
    models.__dict__[args.arch],  # 基础网络架构
    args.moco_dim,              # 特征维度
    args.moco_k,                # 队列大小
    args.moco_m,                # 动量系数
    args.moco_t,                # 温度参数
    args.mlp                    # 是否使用MLP头
)

5. 训练流程

训练过程的核心循环:

for epoch in range(args.start_epoch, args.epochs):
    if args.distributed:
        train_sampler.set_epoch(epoch)
    adjust_learning_rate(optimizer, epoch, args)
    
    # 训练一个epoch
    train(train_loader, model, criterion, optimizer, epoch, args)
    
    # 保存检查点
    save_checkpoint(...)

关键技术点

1. 动量更新机制

MoCo 的核心创新之一是使用动量更新机制来维护键编码器(key encoder)的参数:

# 动量更新公式
θ_k ← m * θ_k + (1 - m) * θ_q

其中 θ_k 是键编码器的参数,θ_q 是查询编码器的参数,m 是动量系数(默认0.999)。

2. 动态字典实现

MoCo 通过队列(queue)实现动态字典:

  • 队列大小由 --moco-k 参数控制(默认65536)
  • 新批次的特征入队,最旧的特征出队
  • 这种设计允许模型接触到大量负样本,而不会增加计算负担

3. 损失函数

MoCo 使用 InfoNCE 损失函数:

L_q = -log[exp(q·k_+/τ) / (exp(q·k_+/τ) + Σ exp(q·k_i/τ))]

其中 τ 是温度参数(由 --moco-t 控制,默认0.07)。

训练优化技巧

1. 学习率调度

脚本提供了两种学习率调度方式:

def adjust_learning_rate(optimizer, epoch, args):
    if args.cos:  # 余弦退火
        lr *= 0.5 * (1.0 + math.cos(math.pi * epoch / args.epochs))
    else:         # 阶梯式下降
        for milestone in args.schedule:
            lr *= 0.1 if epoch >= milestone else 1.0

2. 混合精度训练

虽然脚本中没有显式使用混合精度训练,但可以通过以下方式启用:

scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
    output, target = model(im_q=images[0], im_k=images[1])
    loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

使用建议

  1. 数据准备:确保数据集按照 ImageFolder 格式组织
  2. 硬件配置:建议使用多GPU训练以获得最佳效果
  3. 参数调优
    • 对于小数据集,可以减小 --moco-k
    • 调整 --moco-t 可以影响对比学习的难易程度
    • 使用 --aug-plus--mlp 启用MoCo v2的改进
  4. 监控训练:关注 top1/top5 准确率指标的变化趋势

总结

main_moco.py 脚本实现了 MoCo 自监督学习框架的完整训练流程,包含了模型构建、数据加载、分布式训练、学习率调度等关键组件。通过深入理解该脚本的实现细节,研究人员可以更好地应用 MoCo 方法,或在其基础上进行创新和改进。