2D高斯泼溅(2d-gaussian-splatting)训练过程深度解析
2025-07-10 08:27:47作者:段琳惟
项目概述
2D高斯泼溅(2d-gaussian-splatting)是一种基于高斯分布的图像渲染技术,它通过优化一组2D高斯参数来重建目标图像。该项目中的train.py文件是训练过程的核心实现,本文将深入解析其工作原理和实现细节。
训练流程架构
训练过程主要包含以下几个关键组件:
- 场景初始化:创建包含高斯模型和相机参数的场景对象
- 优化器配置:设置学习率、正则化参数等优化选项
- 渲染循环:迭代优化高斯参数
- 密度调整:动态增加或减少高斯分布数量
- 评估与保存:定期评估模型性能并保存检查点
核心训练循环解析
1. 初始化阶段
训练开始时,系统会进行以下初始化操作:
gaussians = GaussianModel(dataset.sh_degree) # 创建高斯模型
scene = Scene(dataset, gaussians) # 创建场景
gaussians.training_setup(opt) # 配置优化器
高斯模型使用球谐函数(SH)来表示光照特性,初始化时指定了SH的最大度数。场景对象则负责管理训练数据和高斯分布。
2. 主训练循环
训练过程采用迭代优化的方式,每次迭代包含以下步骤:
- 学习率调整:根据当前迭代次数动态调整学习率
- 球谐度数提升:每1000次迭代增加SH的度数,增强光照表达能力
- 随机视角选择:从训练集中随机选择一个相机视角进行渲染
- 损失计算:计算渲染图像与真实图像之间的差异
render_pkg = render(viewpoint_cam, gaussians, pipe, background)
image = render_pkg["render"]
gt_image = viewpoint_cam.original_image.cuda()
Ll1 = l1_loss(image, gt_image)
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
损失函数结合了L1损失和结构相似性(SSIM)损失,通过λ参数平衡两者权重。
3. 正则化技术
为了防止过拟合和提升渲染质量,训练中引入了两种正则化项:
- 距离正则化:在迭代3000次后激活,控制高斯分布的空间分布
- 法线正则化:在迭代7000次后激活,约束表面法线的一致性
normal_error = (1 - (rend_normal * surf_normal).sum(dim=0))[None]
normal_loss = lambda_normal * (normal_error).mean()
dist_loss = lambda_dist * (rend_dist).mean()
total_loss = loss + dist_loss + normal_loss
4. 密度自适应调整
高斯泼溅的一个关键特性是能动态调整高斯分布的数量和位置:
- 密度统计:收集各高斯分布在屏幕空间的可见性和影响范围
- 分裂与修剪:根据梯度信息决定是否分裂大高斯或修剪小高斯
- 不透明度重置:定期重置高斯分布的不透明度参数
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
size_threshold = 20 if iteration > opt.opacity_reset_interval else None
gaussians.densify_and_prune(opt.densify_grad_threshold, opt.opacity_cull, scene.cameras_extent, size_threshold)
训练监控与评估
1. 进度监控
训练过程中会实时显示以下指标:
- 总损失值
- 距离正则化损失
- 法线正则化损失
- 当前高斯分布数量
loss_dict = {
"Loss": f"{ema_loss_for_log:.{5}f}",
"distort": f"{ema_dist_for_log:.{5}f}",
"normal": f"{ema_normal_for_log:.{5}f}",
"Points": f"{len(gaussians.get_xyz)}"
}
2. 定期评估
在指定的测试迭代次数时,系统会:
- 在测试集和部分训练集上评估模型
- 计算L1损失和PSNR指标
- 保存渲染结果和深度图等可视化信息
validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras()},
{'name': 'train', 'cameras': [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})
实用功能
1. 检查点机制
支持从检查点恢复训练,便于长时间训练任务的中断恢复:
if checkpoint:
(model_params, first_iter) = torch.load(checkpoint)
gaussians.restore(model_params, opt)
2. 可视化界面
集成了网络GUI功能,可以实时查看训练进度和渲染结果:
if network_gui.conn == None:
network_gui.try_connect(dataset.render_items)
训练参数配置
通过命令行参数可以灵活配置训练过程:
- 模型参数:控制高斯模型的初始设置
- 优化参数:学习率、正则化权重等
- 流水线参数:渲染管线的具体设置
- 测试/保存间隔:指定评估和保存模型的迭代次数
lp = ModelParams(parser)
op = OptimizationParams(parser)
pp = PipelineParams(parser)
技术要点总结
- 动态密度调整:根据渲染误差自适应调整高斯分布密度,平衡细节和效率
- 渐进式训练:分阶段引入不同的正则化约束,稳定训练过程
- 多损失组合:结合像素级和结构级损失函数,提升渲染质量
- 高效渲染:利用GPU加速的渲染管线实现快速迭代
通过这种训练方式,2D高斯泼溅能够高效地学习表示复杂图像内容,在保持渲染质量的同时优化计算资源的使用。