深入解析CSAILVision语义分割模型的训练流程
2025-07-08 03:51:02作者:殷蕙予
概述
本文将深入分析语义分割模型训练脚本的实现细节,帮助读者理解PyTorch框架下语义分割模型的完整训练流程。该训练脚本实现了从数据加载、模型构建到优化策略等一系列关键功能。
核心组件解析
1. 模型构建
训练脚本使用ModelBuilder
类构建编码器(encoder)和解码器(decoder):
net_encoder = ModelBuilder.build_encoder(
arch=cfg.MODEL.arch_encoder.lower(),
fc_dim=cfg.MODEL.fc_dim,
weights=cfg.MODEL.weights_encoder)
net_decoder = ModelBuilder.build_decoder(
arch=cfg.MODEL.arch_decoder.lower(),
fc_dim=cfg.MODEL.fc_dim,
num_class=cfg.DATASET.num_class,
weights=cfg.MODEL.weights_decoder)
编码器通常采用ResNet等骨干网络提取特征,解码器则负责将低分辨率特征图上采样到输入图像尺寸并预测每个像素的类别。
2. 损失函数
使用负对数似然损失(NLLLoss)作为损失函数,并忽略特定索引(通常用于标记无效区域):
crit = nn.NLLLoss(ignore_index=-1)
3. 数据加载
训练数据通过TrainDataset
类加载,支持多GPU并行训练:
dataset_train = TrainDataset(
cfg.DATASET.root_dataset,
cfg.DATASET.list_train,
cfg.DATASET,
batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)
训练流程详解
1. 单epoch训练
train()
函数实现了一个完整epoch的训练过程:
- 数据加载:从迭代器中获取一个batch的数据
- 学习率调整:根据当前迭代次数调整学习率
- 前向传播:计算损失和准确率
- 反向传播:更新模型参数
- 日志记录:定期输出训练状态
def train(segmentation_module, iterator, optimizers, history, epoch, cfg):
# 初始化指标记录器
batch_time = AverageMeter()
data_time = AverageMeter()
ave_total_loss = AverageMeter()
ave_acc = AverageMeter()
# 主训练循环
for i in range(cfg.TRAIN.epoch_iters):
# 加载数据
batch_data = next(iterator)
# 前向传播
loss, acc = segmentation_module(batch_data)
# 反向传播
loss.backward()
for optimizer in optimizers:
optimizer.step()
# 记录指标
# ...
2. 学习率调整策略
采用多项式衰减策略调整学习率:
def adjust_learning_rate(optimizers, cur_iter, cfg):
scale_running_lr = ((1. - float(cur_iter) / cfg.TRAIN.max_iters) ** cfg.TRAIN.lr_pow)
cfg.TRAIN.running_lr_encoder = cfg.TRAIN.lr_encoder * scale_running_lr
cfg.TRAIN.running_lr_decoder = cfg.TRAIN.lr_decoder * scale_running_lr
# 更新优化器学习率
# ...
这种策略在训练初期使用较大学习率,随着训练进行逐渐衰减,有助于模型收敛。
3. 参数分组优化
对模型参数进行分组,不同组采用不同的权重衰减策略:
def group_weight(module):
group_decay = [] # 需要权重衰减的参数
group_no_decay = [] # 不需要权重衰减的参数
# 分类参数
# ...
return [dict(params=group_decay),
dict(params=group_no_decay, weight_decay=.0)]
这种分组优化策略在实践中能提高模型性能。
多GPU训练支持
脚本支持多GPU并行训练,通过UserScatteredDataParallel
实现数据并行:
if len(gpus) > 1:
segmentation_module = UserScatteredDataParallel(
segmentation_module,
device_ids=gpus)
# 同步批归一化
patch_replication_callback(segmentation_module)
模型保存与恢复
训练过程中定期保存模型检查点:
def checkpoint(nets, history, cfg, epoch):
torch.save(history, f'{cfg.DIR}/history_epoch_{epoch}.pth')
torch.save(dict_encoder, f'{cfg.DIR}/encoder_epoch_{epoch}.pth')
torch.save(dict_decoder, f'{cfg.DIR}/decoder_epoch_{epoch}.pth')
支持从指定epoch恢复训练:
if cfg.TRAIN.start_epoch > 0:
cfg.MODEL.weights_encoder = f'{cfg.DIR}/encoder_epoch_{cfg.TRAIN.start_epoch}.pth'
cfg.MODEL.weights_decoder = f'{cfg.DIR}/decoder_epoch_{cfg.TRAIN.start_epoch}.pth'
配置系统
训练脚本使用灵活的配置系统,支持通过YAML文件和命令行参数配置训练过程:
cfg.merge_from_file(args.cfg) # 从YAML文件加载配置
cfg.merge_from_list(args.opts) # 从命令行参数更新配置
总结
该训练脚本实现了语义分割模型训练的全流程,具有以下特点:
- 模块化设计,各组件职责清晰
- 支持多GPU并行训练
- 灵活的学习率调整策略
- 完善的检查点机制
- 可配置的训练参数
通过深入理解这个训练脚本的实现,开发者可以更好地应用于自己的语义分割任务,或基于此进行二次开发。