首页
/ Marigold项目深度估计模型训练全解析

Marigold项目深度估计模型训练全解析

2025-07-10 06:02:05作者:卓艾滢Kingsley

概述

Marigold是一个基于深度学习的单目深度估计项目,其训练脚本(train.py)提供了完整的模型训练流程。本文将深入解析该训练脚本的技术实现细节,帮助读者理解如何高效训练一个高质量的深度估计模型。

训练流程架构

Marigold的训练流程采用了模块化设计,主要包含以下几个关键组件:

  1. 配置管理:使用OmegaConf加载和管理YAML配置文件
  2. 数据准备:支持多种数据集和混合采样策略
  3. 模型构建:基于预训练模型初始化
  4. 训练循环:支持梯度累积和多种训练策略
  5. 评估与可视化:训练过程中的验证和结果可视化

核心功能详解

1. 配置与初始化

训练脚本首先处理命令行参数和配置文件:

parser = argparse.ArgumentParser(description="Train your cute model!")
parser.add_argument("--config", type=str, default="config/train_marigold.yaml")
parser.add_argument("--resume_run", action="store", default=None)
parser.add_argument("--output_dir", type=str, default=None)
...

支持的功能包括:

  • 从配置文件启动新训练
  • 从检查点恢复训练
  • 自定义输出目录
  • 定时自动退出
  • 禁用WandB日志等

2. 数据加载与处理

数据加载系统设计精巧,支持多种配置:

train_dataset: BaseDepthDataset = get_dataset(
    cfg_data.train,
    base_data_dir=base_data_dir,
    mode=DatasetMode.TRAIN,
    augmentation_args=cfg.augmentation,
    depth_transform=depth_transform,
)

关键特性:

  • 混合数据集支持:可以组合多个数据集,按概率采样
  • 数据增强:通过配置文件灵活配置
  • 深度值归一化:统一不同数据集的深度范围
  • 高效加载:支持SLURM集群的本地临时存储加速

3. 模型初始化

MarigoldPipeline是核心模型类:

model = MarigoldPipeline.from_pretrained(
    os.path.join(base_ckpt_dir, cfg.model.pretrained_path), 
    **_pipeline_kwargs
)

模型特点:

  • 基于预训练权重初始化
  • 支持自定义参数扩展
  • 兼容多种骨干网络

4. 训练策略

训练器采用抽象工厂模式:

trainer_cls = get_trainer_cls(cfg.trainer.name)
trainer = trainer_cls(
    cfg=cfg,
    model=model,
    train_dataloader=train_loader,
    ...
)

训练特性:

  • 梯度累积:支持大batch size训练
  • 学习率调度:可恢复的调度状态
  • 多阶段验证:支持多个验证集
  • 可视化监控:训练过程可视化

关键技术点

混合数据集采样

Marigold实现了创新的MixedBatchSampler:

mixed_sampler = MixedBatchSampler(
    src_dataset_ls=dataset_ls,
    batch_size=cfg.dataloader.max_train_batch_size,
    drop_last=True,
    prob=cfg_data.train.prob_ls,
    shuffle=True,
    generator=loader_generator,
)

这种采样器可以:

  1. 按指定概率从不同数据集采样
  2. 保证每个batch来自同一数据集
  3. 支持随机种子复现

深度归一化处理

深度估计任务中,不同数据集的深度范围差异很大,Marigold通过DepthNormalizerBase实现统一处理:

depth_transform: DepthNormalizerBase = get_depth_normalizer(
    cfg_normalizer=cfg.depth_normalization
)

支持多种归一化策略,如:

  • 线性缩放
  • 对数变换
  • 分位数归一化

分布式训练支持

脚本内置对SLURM集群的支持:

  • 自动检测SLURM环境
  • 数据复制到本地临时存储加速IO
  • 作业ID记录
if is_on_slurm() and (not args.do_not_copy_data):
    base_data_dir = os.path.join(get_local_scratch_dir(), "Marigold_data")
    ...

训练监控与调试

Marigold集成了多种监控工具:

  1. TensorBoard日志:记录训练指标
  2. WandB集成:实验跟踪与管理
  3. 代码快照:保存训练时的完整代码状态
  4. 可视化输出:定期保存预测结果图像
# 初始化WandB
wandb_cfg_dic = {
    "config": dict(cfg),
    "name": job_name,
    "mode": "online",
    **cfg.wandb,
}
wandb_run = init_wandb(enable=True, **wandb_cfg_dic)

最佳实践建议

  1. 配置管理:合理组织YAML配置文件,区分不同实验设置
  2. 数据准备:确保各数据集的深度值范围合理,必要时调整归一化策略
  3. 混合训练:平衡不同数据集采样概率,防止模型偏向特定数据分布
  4. 监控调整:密切关注验证集指标,及时调整学习率等参数
  5. 资源利用:在集群环境中充分利用本地临时存储加速数据加载

总结

Marigold的训练脚本提供了一个高度可配置、模块化的深度估计模型训练框架。通过本文的解析,我们可以看到其在数据加载、模型训练、实验管理等方面的精心设计。这些设计不仅保证了训练过程的稳定性,也为研究者的实验创新提供了充分的灵活性。理解这些实现细节,将有助于开发者更好地使用和扩展Marigold项目。