DenseNet训练流程详解:从理论到实现
2025-07-08 05:05:05作者:管翌锬
概述
本文深入解析DenseNet项目中的训练流程实现,重点分析train.lua文件的核心逻辑。DenseNet是一种密集连接卷积网络,通过特征重用机制显著提升了模型性能并减少了参数数量。训练脚本作为模型优化的核心部分,其实现细节直接影响最终模型效果。
训练器初始化
训练器(Trainer)类是整个训练过程的核心控制器,其初始化过程包含以下关键组件:
- 模型与损失函数:接收DenseNet模型实例和损失函数(criterion)
- 优化器状态:配置SGD优化器的超参数,包括:
- 基础学习率(learningRate)
- 动量(momentum)
- Nesterov动量(nesterov)
- 权重衰减(weightDecay)
function Trainer:__init(model, criterion, opt, optimState)
self.model = model
self.criterion = criterion
self.optimState = optimState or {
learningRate = opt.LR,
learningRateDecay = 0.0,
momentum = opt.momentum,
nesterov = true,
dampening = 0.0,
weightDecay = opt.weightDecay,
}
self.opt = opt
self.params, self.gradParams = model:getParameters()
end
训练过程详解
1. 学习率调度策略
DenseNet实现了两种学习率调度方式:
- 多步衰减(MultiStep):在特定epoch区间进行阶梯式衰减
- 余弦退火(Cosine):遵循余弦曲线平滑调整学习率
-- 多步衰减策略
function Trainer:learningRate(epoch)
local decay = 0
if self.opt.dataset == 'imagenet' then
decay = math.floor((epoch - 1) / 30)
elseif self.opt.dataset == 'cifar10' then
decay = epoch >= 0.75*self.opt.nEpochs and 2 or epoch >= 0.5*self.opt.nEpochs and 1 or 0
end
return self.opt.LR * math.pow(0.1, decay)
end
-- 余弦退火策略
function Trainer:learningRateCosine(epoch, iter, nBatches)
local nEpochs = self.opt.nEpochs
local T_total = nEpochs * nBatches
local T_cur = ((epoch-1) % nEpochs) * nBatches + iter
return 0.5 * self.opt.LR * (1 + torch.cos(math.pi * T_cur / T_total))
end
2. 单epoch训练流程
每个训练epoch包含以下关键步骤:
- 模型模式设置:将模型切换至训练模式(启用BatchNorm和Dropout)
- 数据加载:从dataloader获取批量数据
- 前向传播:计算模型输出和损失
- 反向传播:计算梯度
- 参数更新:使用SGD优化器更新权重
- 性能评估:计算top1和top5准确率
function Trainer:train(epoch, dataloader)
self.model:training()
for n, sample in dataloader:run() do
-- 数据加载
self:copyInputs(sample)
-- 前向传播
local output = self.model:forward(self.input):float()
local loss = self.criterion:forward(self.model.output, self.target)
-- 反向传播
self.model:zeroGradParameters()
self.criterion:backward(self.model.output, self.target)
self.model:backward(self.input, self.criterion.gradInput)
-- 参数更新
optim.sgd(feval, self.params, self.optimState)
-- 性能评估
local top1, top5 = self:computeScore(output, sample.target, 1)
end
end
3. 验证过程
验证流程与训练类似,但有几点关键区别:
- 模型模式:切换为评估模式(禁用BatchNorm和Dropout)
- 数据增强:支持tenCrop(10-crop)测试增强
- 无梯度计算:不进行反向传播和参数更新
function Trainer:test(epoch, dataloader)
self.model:evaluate()
for n, sample in dataloader:run() do
-- 仅前向传播
local output = self.model:forward(self.input):float()
local top1, top5 = self:computeScore(output, sample.target, nCrops)
end
self.model:training()
end
关键技术点解析
1. 准确率计算
准确率计算支持标准单图和10-crop测试两种情况:
function Trainer:computeScore(output, target, nCrops)
if nCrops > 1 then
-- 对10-crop结果取平均
output = output:view(output:size(1) / nCrops, nCrops, output:size(2))
:sum(2):squeeze(2)
end
-- 计算top1和top5准确率
local _ , predictions = output:float():topk(5, 2, true, true)
local correct = predictions:eq(target:long():view(batchSize, 1):expandAs(predictions))
local top1 = 1.0 - (correct:narrow(2, 1, 1):sum() / batchSize)
local top5 = 1.0 - (correct:narrow(2, 1, len):sum() / batchSize)
return top1 * 100, top5 * 100
end
2. 数据加载优化
针对不同硬件配置优化数据加载:
function Trainer:copyInputs(sample)
-- 单GPU使用普通CUDA张量
-- 多GPU使用pinned memory提升传输效率
self.input = self.input or (self.opt.nGPU == 1
and torch[self.opt.tensorType:match('torch.(%a+)')]()
or getCudaTensorType(self.opt.tensorType))
self.target = self.target or torch.CudaLongTensor()
self.input:resize(sample.input:size()):copy(sample.input)
self.target:resize(sample.target:size()):copy(sample.target)
end
训练实践建议
-
学习率策略选择:
- 对于小数据集(如CIFAR),多步衰减通常足够
- 对于大数据集(如ImageNet),余弦退火可能带来更好效果
-
批量大小调整:
- 根据GPU内存调整batch size
- 较大batch size可配合学习率warmup
-
正则化配置:
- 适当调整weightDecay防止过拟合
- DenseNet本身具有较强的正则化能力
-
训练监控:
- 定期验证集评估
- 监控训练/验证损失曲线
通过深入理解DenseNet训练流程的实现细节,开发者可以更好地调整模型参数,优化训练过程,从而获得性能更优的密集连接网络模型。