首页
/ Simple-Faster-RCNN-Pytorch 训练流程详解

Simple-Faster-RCNN-Pytorch 训练流程详解

2025-07-09 02:04:45作者:瞿蔚英Wynne

项目概述

Simple-Faster-RCNN-Pytorch 是一个基于 PyTorch 实现的简化版 Faster R-CNN 目标检测框架。该项目保留了 Faster R-CNN 的核心思想,同时简化了部分实现细节,使其更易于理解和学习。本文将重点解析其训练脚本 train.py 的实现逻辑和技术细节。

训练脚本核心架构

1. 初始化设置

脚本首先进行了一些必要的初始化设置:

# 解决文件描述符限制问题
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (20480, rlimit[1]))

# 设置matplotlib后端
matplotlib.use('agg')

这部分代码主要解决了 Linux 系统下文件描述符数量的限制问题,这对于处理大量图像数据非常重要。同时设置了 matplotlib 的后端为 'agg',这是非交互式后端,适合在服务器环境下运行。

2. 评估函数实现

eval() 函数负责模型在验证集上的性能评估:

def eval(dataloader, faster_rcnn, test_num=10000):
    # 初始化预测和真实值的存储列表
    pred_bboxes, pred_labels, pred_scores = list(), list(), list()
    gt_bboxes, gt_labels, gt_difficults = list(), list(), list()
    
    # 遍历验证数据集
    for ii, (imgs, sizes, gt_bboxes_, gt_labels_, gt_difficults_) in tqdm(enumerate(dataloader)):
        sizes = [sizes[0][0].item(), sizes[1][0].item()]
        # 获取模型预测结果
        pred_bboxes_, pred_labels_, pred_scores_ = faster_rcnn.predict(imgs, [sizes])
        # 收集结果用于评估
        gt_bboxes += list(gt_bboxes_.numpy())
        gt_labels += list(gt_labels_.numpy())
        gt_difficults += list(gt_difficults_.numpy())
        pred_bboxes += pred_bboxes_
        pred_labels += pred_labels_
        pred_scores += pred_scores_
        if ii == test_num: break

    # 使用PASCAL VOC评估指标
    result = eval_detection_voc(
        pred_bboxes, pred_labels, pred_scores,
        gt_bboxes, gt_labels, gt_difficults,
        use_07_metric=True)
    return result

该函数实现了标准的 PASCAL VOC 评估流程,包括预测框的收集、真实标注的收集以及最终 mAP 指标的计算。

3. 主训练流程

train() 函数是整个训练过程的核心:

3.1 数据准备

dataset = Dataset(opt)
dataloader = data_.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.num_workers)
testset = TestDataset(opt)
test_dataloader = data_.DataLoader(testset, batch_size=1, num_workers=opt.test_num_workers, shuffle=False, pin_memory=True)

这里创建了训练集和测试集的 DataLoader,注意训练集使用了 shuffle=True 进行数据打乱,而测试集则保持顺序不变。batch_size 设置为 1 是 Faster R-CNN 的常见做法,因为不同图像的物体数量可能不同。

3.2 模型初始化

faster_rcnn = FasterRCNNVGG16()
trainer = FasterRCNNTrainer(faster_rcnn).cuda()

模型基于 VGG16 作为 backbone 构建 Faster R-CNN,并使用自定义的 Trainer 类来管理训练过程。Trainer 封装了损失计算、反向传播等训练细节。

3.3 训练循环

训练过程采用标准的 epoch 循环:

for epoch in range(opt.epoch):
    trainer.reset_meters()
    for ii, (img, bbox_, label_, scale) in tqdm(enumerate(dataloader)):
        scale = at.scalar(scale)
        img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda()
        trainer.train_step(img, bbox, label, scale)
        
        # 可视化相关操作
        if (ii + 1) % opt.plot_every == 0:
            # 绘制损失曲线、真实框和预测框
            trainer.vis.plot_many(trainer.get_meter_data())
            ori_img_ = inverse_normalize(at.tonumpy(img[0]))
            gt_img = visdom_bbox(ori_img_, at.tonumpy(bbox_[0]), at.tonumpy(label_[0]))
            trainer.vis.img('gt_img', gt_img)
            
            _bboxes, _labels, _scores = trainer.faster_rcnn.predict([ori_img_], visualize=True)
            pred_img = visdom_bbox(ori_img_, at.tonumpy(_bboxes[0]), at.tonumpy(_labels[0]).reshape(-1), at.tonumpy(_scores[0]))
            trainer.vis.img('pred_img', pred_img)

每个 epoch 中,训练器会遍历整个数据集,执行以下操作:

  1. 将数据转移到 GPU
  2. 调用 train_step 执行单步训练
  3. 定期进行可视化展示,包括损失曲线、真实标注框和预测框

3.4 验证与模型保存

eval_result = eval(test_dataloader, faster_rcnn, test_num=opt.test_num)
trainer.vis.plot('test_map', eval_result['map'])

if eval_result['map'] > best_map:
    best_map = eval_result['map']
    best_path = trainer.save(best_map=best_map)

每个 epoch 结束后会在验证集上评估模型性能,并保存表现最好的模型。这里使用 mAP (mean Average Precision) 作为主要评估指标。

3.5 学习率调整

if epoch == 9:
    trainer.load(best_path)
    trainer.faster_rcnn.scale_lr(opt.lr_decay)
    lr_ = lr_ * opt.lr_decay

在第 10 个 epoch 时,会重新加载最佳模型并降低学习率,这是一种常见的学习率衰减策略。

关键技术点解析

  1. 多任务损失:Faster R-CNN 同时优化 RPN 网络和 Fast R-CNN 网络的损失,包括分类损失和回归损失。

  2. Anchor机制:RPN 网络使用不同尺度和长宽比的 anchor 作为候选区域的基础。

  3. ROI Pooling:将不同大小的候选区域转换为固定大小的特征图,便于后续分类和回归。

  4. NMS处理:在预测阶段使用非极大值抑制去除冗余的检测框。

  5. Visdom可视化:项目使用 Visdom 进行训练过程的可视化监控,包括损失曲线、预测结果等。

训练技巧与建议

  1. 学习率设置:初始学习率为 0.001,在第 10 个 epoch 后衰减为原来的 0.1 倍。

  2. 数据增强:虽然代码中没有显示的数据增强操作,但可以通过修改 Dataset 类来增加随机翻转等增强策略。

  3. 早停机制:代码在第 14 个 epoch 停止训练,可以根据实际需求调整。

  4. 模型保存:只保存验证集上表现最好的模型,避免存储空间浪费。

  5. 调试技巧:设置 opt.debug_file 可以在特定条件下触发 ipdb 调试器。

总结

Simple-Faster-RCNN-Pytorch 的训练脚本实现了一个完整的目标检测模型训练流程,包括数据加载、模型训练、验证评估和可视化等关键环节。通过分析这个实现,我们可以深入理解 Faster R-CNN 的工作原理和训练细节。该实现简洁明了,非常适合学习和研究 Faster R-CNN 算法。