Simple-Faster-RCNN-Pytorch 训练流程详解
项目概述
Simple-Faster-RCNN-Pytorch 是一个基于 PyTorch 实现的简化版 Faster R-CNN 目标检测框架。该项目保留了 Faster R-CNN 的核心思想,同时简化了部分实现细节,使其更易于理解和学习。本文将重点解析其训练脚本 train.py 的实现逻辑和技术细节。
训练脚本核心架构
1. 初始化设置
脚本首先进行了一些必要的初始化设置:
# 解决文件描述符限制问题
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (20480, rlimit[1]))
# 设置matplotlib后端
matplotlib.use('agg')
这部分代码主要解决了 Linux 系统下文件描述符数量的限制问题,这对于处理大量图像数据非常重要。同时设置了 matplotlib 的后端为 'agg',这是非交互式后端,适合在服务器环境下运行。
2. 评估函数实现
eval()
函数负责模型在验证集上的性能评估:
def eval(dataloader, faster_rcnn, test_num=10000):
# 初始化预测和真实值的存储列表
pred_bboxes, pred_labels, pred_scores = list(), list(), list()
gt_bboxes, gt_labels, gt_difficults = list(), list(), list()
# 遍历验证数据集
for ii, (imgs, sizes, gt_bboxes_, gt_labels_, gt_difficults_) in tqdm(enumerate(dataloader)):
sizes = [sizes[0][0].item(), sizes[1][0].item()]
# 获取模型预测结果
pred_bboxes_, pred_labels_, pred_scores_ = faster_rcnn.predict(imgs, [sizes])
# 收集结果用于评估
gt_bboxes += list(gt_bboxes_.numpy())
gt_labels += list(gt_labels_.numpy())
gt_difficults += list(gt_difficults_.numpy())
pred_bboxes += pred_bboxes_
pred_labels += pred_labels_
pred_scores += pred_scores_
if ii == test_num: break
# 使用PASCAL VOC评估指标
result = eval_detection_voc(
pred_bboxes, pred_labels, pred_scores,
gt_bboxes, gt_labels, gt_difficults,
use_07_metric=True)
return result
该函数实现了标准的 PASCAL VOC 评估流程,包括预测框的收集、真实标注的收集以及最终 mAP 指标的计算。
3. 主训练流程
train()
函数是整个训练过程的核心:
3.1 数据准备
dataset = Dataset(opt)
dataloader = data_.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.num_workers)
testset = TestDataset(opt)
test_dataloader = data_.DataLoader(testset, batch_size=1, num_workers=opt.test_num_workers, shuffle=False, pin_memory=True)
这里创建了训练集和测试集的 DataLoader,注意训练集使用了 shuffle=True 进行数据打乱,而测试集则保持顺序不变。batch_size 设置为 1 是 Faster R-CNN 的常见做法,因为不同图像的物体数量可能不同。
3.2 模型初始化
faster_rcnn = FasterRCNNVGG16()
trainer = FasterRCNNTrainer(faster_rcnn).cuda()
模型基于 VGG16 作为 backbone 构建 Faster R-CNN,并使用自定义的 Trainer 类来管理训练过程。Trainer 封装了损失计算、反向传播等训练细节。
3.3 训练循环
训练过程采用标准的 epoch 循环:
for epoch in range(opt.epoch):
trainer.reset_meters()
for ii, (img, bbox_, label_, scale) in tqdm(enumerate(dataloader)):
scale = at.scalar(scale)
img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda()
trainer.train_step(img, bbox, label, scale)
# 可视化相关操作
if (ii + 1) % opt.plot_every == 0:
# 绘制损失曲线、真实框和预测框
trainer.vis.plot_many(trainer.get_meter_data())
ori_img_ = inverse_normalize(at.tonumpy(img[0]))
gt_img = visdom_bbox(ori_img_, at.tonumpy(bbox_[0]), at.tonumpy(label_[0]))
trainer.vis.img('gt_img', gt_img)
_bboxes, _labels, _scores = trainer.faster_rcnn.predict([ori_img_], visualize=True)
pred_img = visdom_bbox(ori_img_, at.tonumpy(_bboxes[0]), at.tonumpy(_labels[0]).reshape(-1), at.tonumpy(_scores[0]))
trainer.vis.img('pred_img', pred_img)
每个 epoch 中,训练器会遍历整个数据集,执行以下操作:
- 将数据转移到 GPU
- 调用
train_step
执行单步训练 - 定期进行可视化展示,包括损失曲线、真实标注框和预测框
3.4 验证与模型保存
eval_result = eval(test_dataloader, faster_rcnn, test_num=opt.test_num)
trainer.vis.plot('test_map', eval_result['map'])
if eval_result['map'] > best_map:
best_map = eval_result['map']
best_path = trainer.save(best_map=best_map)
每个 epoch 结束后会在验证集上评估模型性能,并保存表现最好的模型。这里使用 mAP (mean Average Precision) 作为主要评估指标。
3.5 学习率调整
if epoch == 9:
trainer.load(best_path)
trainer.faster_rcnn.scale_lr(opt.lr_decay)
lr_ = lr_ * opt.lr_decay
在第 10 个 epoch 时,会重新加载最佳模型并降低学习率,这是一种常见的学习率衰减策略。
关键技术点解析
-
多任务损失:Faster R-CNN 同时优化 RPN 网络和 Fast R-CNN 网络的损失,包括分类损失和回归损失。
-
Anchor机制:RPN 网络使用不同尺度和长宽比的 anchor 作为候选区域的基础。
-
ROI Pooling:将不同大小的候选区域转换为固定大小的特征图,便于后续分类和回归。
-
NMS处理:在预测阶段使用非极大值抑制去除冗余的检测框。
-
Visdom可视化:项目使用 Visdom 进行训练过程的可视化监控,包括损失曲线、预测结果等。
训练技巧与建议
-
学习率设置:初始学习率为 0.001,在第 10 个 epoch 后衰减为原来的 0.1 倍。
-
数据增强:虽然代码中没有显示的数据增强操作,但可以通过修改 Dataset 类来增加随机翻转等增强策略。
-
早停机制:代码在第 14 个 epoch 停止训练,可以根据实际需求调整。
-
模型保存:只保存验证集上表现最好的模型,避免存储空间浪费。
-
调试技巧:设置 opt.debug_file 可以在特定条件下触发 ipdb 调试器。
总结
Simple-Faster-RCNN-Pytorch 的训练脚本实现了一个完整的目标检测模型训练流程,包括数据加载、模型训练、验证评估和可视化等关键环节。通过分析这个实现,我们可以深入理解 Faster R-CNN 的工作原理和训练细节。该实现简洁明了,非常适合学习和研究 Faster R-CNN 算法。