首页
/ 2D高斯泼溅(2d-gaussian-splatting)训练过程深度解析

2D高斯泼溅(2d-gaussian-splatting)训练过程深度解析

2025-07-10 08:27:47作者:段琳惟

项目概述

2D高斯泼溅(2d-gaussian-splatting)是一种基于高斯分布的图像渲染技术,它通过优化一组2D高斯参数来重建目标图像。该项目中的train.py文件是训练过程的核心实现,本文将深入解析其工作原理和实现细节。

训练流程架构

训练过程主要包含以下几个关键组件:

  1. 场景初始化:创建包含高斯模型和相机参数的场景对象
  2. 优化器配置:设置学习率、正则化参数等优化选项
  3. 渲染循环:迭代优化高斯参数
  4. 密度调整:动态增加或减少高斯分布数量
  5. 评估与保存:定期评估模型性能并保存检查点

核心训练循环解析

1. 初始化阶段

训练开始时,系统会进行以下初始化操作:

gaussians = GaussianModel(dataset.sh_degree)  # 创建高斯模型
scene = Scene(dataset, gaussians)  # 创建场景
gaussians.training_setup(opt)  # 配置优化器

高斯模型使用球谐函数(SH)来表示光照特性,初始化时指定了SH的最大度数。场景对象则负责管理训练数据和高斯分布。

2. 主训练循环

训练过程采用迭代优化的方式,每次迭代包含以下步骤:

  1. 学习率调整:根据当前迭代次数动态调整学习率
  2. 球谐度数提升:每1000次迭代增加SH的度数,增强光照表达能力
  3. 随机视角选择:从训练集中随机选择一个相机视角进行渲染
  4. 损失计算:计算渲染图像与真实图像之间的差异
render_pkg = render(viewpoint_cam, gaussians, pipe, background)
image = render_pkg["render"]
gt_image = viewpoint_cam.original_image.cuda()
Ll1 = l1_loss(image, gt_image)
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))

损失函数结合了L1损失和结构相似性(SSIM)损失,通过λ参数平衡两者权重。

3. 正则化技术

为了防止过拟合和提升渲染质量,训练中引入了两种正则化项:

  1. 距离正则化:在迭代3000次后激活,控制高斯分布的空间分布
  2. 法线正则化:在迭代7000次后激活,约束表面法线的一致性
normal_error = (1 - (rend_normal * surf_normal).sum(dim=0))[None]
normal_loss = lambda_normal * (normal_error).mean()
dist_loss = lambda_dist * (rend_dist).mean()
total_loss = loss + dist_loss + normal_loss

4. 密度自适应调整

高斯泼溅的一个关键特性是能动态调整高斯分布的数量和位置:

  1. 密度统计:收集各高斯分布在屏幕空间的可见性和影响范围
  2. 分裂与修剪:根据梯度信息决定是否分裂大高斯或修剪小高斯
  3. 不透明度重置:定期重置高斯分布的不透明度参数
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
    size_threshold = 20 if iteration > opt.opacity_reset_interval else None
    gaussians.densify_and_prune(opt.densify_grad_threshold, opt.opacity_cull, scene.cameras_extent, size_threshold)

训练监控与评估

1. 进度监控

训练过程中会实时显示以下指标:

  • 总损失值
  • 距离正则化损失
  • 法线正则化损失
  • 当前高斯分布数量
loss_dict = {
    "Loss": f"{ema_loss_for_log:.{5}f}",
    "distort": f"{ema_dist_for_log:.{5}f}",
    "normal": f"{ema_normal_for_log:.{5}f}",
    "Points": f"{len(gaussians.get_xyz)}"
}

2. 定期评估

在指定的测试迭代次数时,系统会:

  1. 在测试集和部分训练集上评估模型
  2. 计算L1损失和PSNR指标
  3. 保存渲染结果和深度图等可视化信息
validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras()}, 
                     {'name': 'train', 'cameras': [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})

实用功能

1. 检查点机制

支持从检查点恢复训练,便于长时间训练任务的中断恢复:

if checkpoint:
    (model_params, first_iter) = torch.load(checkpoint)
    gaussians.restore(model_params, opt)

2. 可视化界面

集成了网络GUI功能,可以实时查看训练进度和渲染结果:

if network_gui.conn == None:
    network_gui.try_connect(dataset.render_items)

训练参数配置

通过命令行参数可以灵活配置训练过程:

  • 模型参数:控制高斯模型的初始设置
  • 优化参数:学习率、正则化权重等
  • 流水线参数:渲染管线的具体设置
  • 测试/保存间隔:指定评估和保存模型的迭代次数
lp = ModelParams(parser)
op = OptimizationParams(parser)
pp = PipelineParams(parser)

技术要点总结

  1. 动态密度调整:根据渲染误差自适应调整高斯分布密度,平衡细节和效率
  2. 渐进式训练:分阶段引入不同的正则化约束,稳定训练过程
  3. 多损失组合:结合像素级和结构级损失函数,提升渲染质量
  4. 高效渲染:利用GPU加速的渲染管线实现快速迭代

通过这种训练方式,2D高斯泼溅能够高效地学习表示复杂图像内容,在保持渲染质量的同时优化计算资源的使用。