首页
/ DenseNet训练流程详解:从理论到实现

DenseNet训练流程详解:从理论到实现

2025-07-08 05:05:05作者:管翌锬

概述

本文深入解析DenseNet项目中的训练流程实现,重点分析train.lua文件的核心逻辑。DenseNet是一种密集连接卷积网络,通过特征重用机制显著提升了模型性能并减少了参数数量。训练脚本作为模型优化的核心部分,其实现细节直接影响最终模型效果。

训练器初始化

训练器(Trainer)类是整个训练过程的核心控制器,其初始化过程包含以下关键组件:

  1. 模型与损失函数:接收DenseNet模型实例和损失函数(criterion)
  2. 优化器状态:配置SGD优化器的超参数,包括:
    • 基础学习率(learningRate)
    • 动量(momentum)
    • Nesterov动量(nesterov)
    • 权重衰减(weightDecay)
function Trainer:__init(model, criterion, opt, optimState)
   self.model = model
   self.criterion = criterion
   self.optimState = optimState or {
      learningRate = opt.LR,
      learningRateDecay = 0.0,
      momentum = opt.momentum,
      nesterov = true,
      dampening = 0.0,
      weightDecay = opt.weightDecay,
   }
   self.opt = opt
   self.params, self.gradParams = model:getParameters()
end

训练过程详解

1. 学习率调度策略

DenseNet实现了两种学习率调度方式:

  1. 多步衰减(MultiStep):在特定epoch区间进行阶梯式衰减
  2. 余弦退火(Cosine):遵循余弦曲线平滑调整学习率
-- 多步衰减策略
function Trainer:learningRate(epoch)
   local decay = 0
   if self.opt.dataset == 'imagenet' then
      decay = math.floor((epoch - 1) / 30)
   elseif self.opt.dataset == 'cifar10' then
      decay = epoch >= 0.75*self.opt.nEpochs and 2 or epoch >= 0.5*self.opt.nEpochs and 1 or 0
   end
   return self.opt.LR * math.pow(0.1, decay)
end

-- 余弦退火策略
function Trainer:learningRateCosine(epoch, iter, nBatches)
   local nEpochs = self.opt.nEpochs
   local T_total = nEpochs * nBatches
   local T_cur = ((epoch-1) % nEpochs) * nBatches + iter
   return 0.5 * self.opt.LR * (1 + torch.cos(math.pi * T_cur / T_total))
end

2. 单epoch训练流程

每个训练epoch包含以下关键步骤:

  1. 模型模式设置:将模型切换至训练模式(启用BatchNorm和Dropout)
  2. 数据加载:从dataloader获取批量数据
  3. 前向传播:计算模型输出和损失
  4. 反向传播:计算梯度
  5. 参数更新:使用SGD优化器更新权重
  6. 性能评估:计算top1和top5准确率
function Trainer:train(epoch, dataloader)
   self.model:training()
   for n, sample in dataloader:run() do
      -- 数据加载
      self:copyInputs(sample)
      
      -- 前向传播
      local output = self.model:forward(self.input):float()
      local loss = self.criterion:forward(self.model.output, self.target)
      
      -- 反向传播
      self.model:zeroGradParameters()
      self.criterion:backward(self.model.output, self.target)
      self.model:backward(self.input, self.criterion.gradInput)
      
      -- 参数更新
      optim.sgd(feval, self.params, self.optimState)
      
      -- 性能评估
      local top1, top5 = self:computeScore(output, sample.target, 1)
   end
end

3. 验证过程

验证流程与训练类似,但有几点关键区别:

  1. 模型模式:切换为评估模式(禁用BatchNorm和Dropout)
  2. 数据增强:支持tenCrop(10-crop)测试增强
  3. 无梯度计算:不进行反向传播和参数更新
function Trainer:test(epoch, dataloader)
   self.model:evaluate()
   for n, sample in dataloader:run() do
      -- 仅前向传播
      local output = self.model:forward(self.input):float()
      local top1, top5 = self:computeScore(output, sample.target, nCrops)
   end
   self.model:training()
end

关键技术点解析

1. 准确率计算

准确率计算支持标准单图和10-crop测试两种情况:

function Trainer:computeScore(output, target, nCrops)
   if nCrops > 1 then
      -- 对10-crop结果取平均
      output = output:view(output:size(1) / nCrops, nCrops, output:size(2))
         :sum(2):squeeze(2)
   end
   
   -- 计算top1和top5准确率
   local _ , predictions = output:float():topk(5, 2, true, true)
   local correct = predictions:eq(target:long():view(batchSize, 1):expandAs(predictions))
   
   local top1 = 1.0 - (correct:narrow(2, 1, 1):sum() / batchSize)
   local top5 = 1.0 - (correct:narrow(2, 1, len):sum() / batchSize)
   
   return top1 * 100, top5 * 100
end

2. 数据加载优化

针对不同硬件配置优化数据加载:

function Trainer:copyInputs(sample)
   -- 单GPU使用普通CUDA张量
   -- 多GPU使用pinned memory提升传输效率
   self.input = self.input or (self.opt.nGPU == 1
      and torch[self.opt.tensorType:match('torch.(%a+)')]()
      or getCudaTensorType(self.opt.tensorType))
   self.target = self.target or torch.CudaLongTensor()
   
   self.input:resize(sample.input:size()):copy(sample.input)
   self.target:resize(sample.target:size()):copy(sample.target)
end

训练实践建议

  1. 学习率策略选择

    • 对于小数据集(如CIFAR),多步衰减通常足够
    • 对于大数据集(如ImageNet),余弦退火可能带来更好效果
  2. 批量大小调整

    • 根据GPU内存调整batch size
    • 较大batch size可配合学习率warmup
  3. 正则化配置

    • 适当调整weightDecay防止过拟合
    • DenseNet本身具有较强的正则化能力
  4. 训练监控

    • 定期验证集评估
    • 监控训练/验证损失曲线

通过深入理解DenseNet训练流程的实现细节,开发者可以更好地调整模型参数,优化训练过程,从而获得性能更优的密集连接网络模型。