Learning-to-See-in-the-Dark项目中的Sony相机数据训练解析
2025-07-08 01:19:46作者:魏侃纯Zoe
项目概述
Learning-to-See-in-the-Dark是一个专注于低光环境下图像增强的深度学习项目,旨在通过神经网络将低光条件下拍摄的原始图像转换为高质量的明亮图像。本文重点解析项目中针对Sony相机数据的训练脚本(train_Sony.py)的实现细节和技术要点。
数据准备与预处理
数据目录结构
脚本中定义了三个关键目录:
input_dir
: 存储短曝光(低光)的Sony RAW格式图像gt_dir
: 存储长曝光(正常光)的对应图像作为ground truthcheckpoint_dir
和result_dir
: 分别保存训练过程中的模型检查点和结果输出
RAW图像处理
项目使用Sony相机拍摄的.ARW格式RAW文件,通过rawpy
库进行读取和处理。关键预处理步骤包括:
- 黑电平校正:通过减去512的黑电平值并归一化到0-1范围
- Bayer模式转换:将原始Bayer阵列图像转换为4通道格式,分别对应RGGB四个颜色分量
- 曝光补偿:根据短曝光和长曝光时间的比值进行亮度补偿
def pack_raw(raw):
im = raw.raw_image_visible.astype(np.float32)
im = np.maximum(im - 512, 0) / (16383 - 512) # 黑电平校正
# Bayer模式到4通道转换
out = np.concatenate((im[0:H:2, 0:W:2, :], # R
im[0:H:2, 1:W:2, :], # G1
im[1:H:2, 1:W:2, :], # G2
im[1:H:2, 0:W:2, :]), axis=2) # B
return out
网络架构设计
项目采用了一个基于U-Net结构的全卷积网络,主要特点包括:
- 编码器-解码器结构:通过4层下采样和上采样构建多尺度特征
- 跳跃连接:将编码器各层的特征与解码器对应层连接,保留空间细节
- 激活函数:使用Leaky ReLU(α=0.2)缓解梯度消失问题
- 输出处理:最后使用depth_to_space操作将12通道特征图转换为3通道RGB图像
网络结构概览:
- 编码器部分:conv1-conv5,逐步下采样提取特征
- 解码器部分:up6-up9,逐步上采样重建图像
- 跳跃连接:将编码器特征与解码器特征拼接
def network(input):
# 编码器部分
conv1 = slim.conv2d(input, 32, [3,3], activation_fn=lrelu) # 第1层
pool1 = slim.max_pool2d(conv1, [2,2]) # 下采样
# ... 中间层省略 ...
conv5 = slim.conv2d(pool4, 512, [3,3], activation_fn=lrelu) # 最深层
# 解码器部分
up6 = upsample_and_concat(conv5, conv4, 256, 512) # 上采样+跳跃连接
conv6 = slim.conv2d(up6, 256, [3,3], activation_fn=lrelu)
# ... 中间层省略 ...
# 输出层
conv10 = slim.conv2d(conv9, 12, [1,1], activation_fn=None)
out = tf.depth_to_space(conv10, 2) # 通道到空间转换
return out
训练策略
损失函数
采用简单的L1损失(平均绝对误差)作为优化目标:
G_loss = tf.reduce_mean(tf.abs(out_image - gt_image))
优化器配置
使用Adam优化器,初始学习率设为1e-4,在训练超过2000轮后降至1e-5:
G_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss)
数据增强
在训练过程中实施了多种数据增强策略:
- 随机裁剪512x512的patch进行训练
- 随机水平/垂直翻转
- 随机转置(行列交换)
- 输入值裁剪到[0,1]范围
# 随机裁剪
xx = np.random.randint(0, W - ps)
yy = np.random.randint(0, H - ps)
input_patch = input_images[...][:, yy:yy+ps, xx:xx+ps, :]
# 随机翻转和转置
if np.random.randint(2) == 1:
input_patch = np.flip(input_patch, axis=1) # 垂直翻转
if np.random.randint(2) == 1:
input_patch = np.transpose(input_patch, (0,2,1,3)) # 转置
训练过程管理
- 每500轮保存一次模型和示例结果
- 支持从最近的检查点恢复训练
- 在内存中缓存预处理后的图像数据加速训练
技术亮点
- RAW图像直接处理:直接在RAW域进行处理,充分利用传感器原始数据
- 多曝光处理:能处理不同曝光比(100x,250x,300x)的输入图像
- 内存优化:预处理后数据缓存在内存中,大幅减少IO时间
- 轻量级网络:使用全卷积结构,参数量适中,适合处理高分辨率图像
实际应用建议
- 数据准备:确保短曝光和长曝光图像严格对齐,最好使用三脚架固定相机拍摄
- 参数调整:可根据显存大小调整patch size(ps),平衡训练速度和效果
- 监控训练:定期检查保存的结果图像,观察模型收敛情况
- 硬件要求:训练需要较大显存,建议使用至少8GB显存的GPU
通过这个训练脚本,项目实现了从极低光RAW图像到高质量RGB图像的端到端转换,为低光环境下的计算机视觉应用提供了有效解决方案。