Marigold项目深度估计模型训练全解析
2025-07-10 06:02:05作者:卓艾滢Kingsley
概述
Marigold是一个基于深度学习的单目深度估计项目,其训练脚本(train.py)提供了完整的模型训练流程。本文将深入解析该训练脚本的技术实现细节,帮助读者理解如何高效训练一个高质量的深度估计模型。
训练流程架构
Marigold的训练流程采用了模块化设计,主要包含以下几个关键组件:
- 配置管理:使用OmegaConf加载和管理YAML配置文件
- 数据准备:支持多种数据集和混合采样策略
- 模型构建:基于预训练模型初始化
- 训练循环:支持梯度累积和多种训练策略
- 评估与可视化:训练过程中的验证和结果可视化
核心功能详解
1. 配置与初始化
训练脚本首先处理命令行参数和配置文件:
parser = argparse.ArgumentParser(description="Train your cute model!")
parser.add_argument("--config", type=str, default="config/train_marigold.yaml")
parser.add_argument("--resume_run", action="store", default=None)
parser.add_argument("--output_dir", type=str, default=None)
...
支持的功能包括:
- 从配置文件启动新训练
- 从检查点恢复训练
- 自定义输出目录
- 定时自动退出
- 禁用WandB日志等
2. 数据加载与处理
数据加载系统设计精巧,支持多种配置:
train_dataset: BaseDepthDataset = get_dataset(
cfg_data.train,
base_data_dir=base_data_dir,
mode=DatasetMode.TRAIN,
augmentation_args=cfg.augmentation,
depth_transform=depth_transform,
)
关键特性:
- 混合数据集支持:可以组合多个数据集,按概率采样
- 数据增强:通过配置文件灵活配置
- 深度值归一化:统一不同数据集的深度范围
- 高效加载:支持SLURM集群的本地临时存储加速
3. 模型初始化
MarigoldPipeline是核心模型类:
model = MarigoldPipeline.from_pretrained(
os.path.join(base_ckpt_dir, cfg.model.pretrained_path),
**_pipeline_kwargs
)
模型特点:
- 基于预训练权重初始化
- 支持自定义参数扩展
- 兼容多种骨干网络
4. 训练策略
训练器采用抽象工厂模式:
trainer_cls = get_trainer_cls(cfg.trainer.name)
trainer = trainer_cls(
cfg=cfg,
model=model,
train_dataloader=train_loader,
...
)
训练特性:
- 梯度累积:支持大batch size训练
- 学习率调度:可恢复的调度状态
- 多阶段验证:支持多个验证集
- 可视化监控:训练过程可视化
关键技术点
混合数据集采样
Marigold实现了创新的MixedBatchSampler:
mixed_sampler = MixedBatchSampler(
src_dataset_ls=dataset_ls,
batch_size=cfg.dataloader.max_train_batch_size,
drop_last=True,
prob=cfg_data.train.prob_ls,
shuffle=True,
generator=loader_generator,
)
这种采样器可以:
- 按指定概率从不同数据集采样
- 保证每个batch来自同一数据集
- 支持随机种子复现
深度归一化处理
深度估计任务中,不同数据集的深度范围差异很大,Marigold通过DepthNormalizerBase实现统一处理:
depth_transform: DepthNormalizerBase = get_depth_normalizer(
cfg_normalizer=cfg.depth_normalization
)
支持多种归一化策略,如:
- 线性缩放
- 对数变换
- 分位数归一化
分布式训练支持
脚本内置对SLURM集群的支持:
- 自动检测SLURM环境
- 数据复制到本地临时存储加速IO
- 作业ID记录
if is_on_slurm() and (not args.do_not_copy_data):
base_data_dir = os.path.join(get_local_scratch_dir(), "Marigold_data")
...
训练监控与调试
Marigold集成了多种监控工具:
- TensorBoard日志:记录训练指标
- WandB集成:实验跟踪与管理
- 代码快照:保存训练时的完整代码状态
- 可视化输出:定期保存预测结果图像
# 初始化WandB
wandb_cfg_dic = {
"config": dict(cfg),
"name": job_name,
"mode": "online",
**cfg.wandb,
}
wandb_run = init_wandb(enable=True, **wandb_cfg_dic)
最佳实践建议
- 配置管理:合理组织YAML配置文件,区分不同实验设置
- 数据准备:确保各数据集的深度值范围合理,必要时调整归一化策略
- 混合训练:平衡不同数据集采样概率,防止模型偏向特定数据分布
- 监控调整:密切关注验证集指标,及时调整学习率等参数
- 资源利用:在集群环境中充分利用本地临时存储加速数据加载
总结
Marigold的训练脚本提供了一个高度可配置、模块化的深度估计模型训练框架。通过本文的解析,我们可以看到其在数据加载、模型训练、实验管理等方面的精心设计。这些设计不仅保证了训练过程的稳定性,也为研究者的实验创新提供了充分的灵活性。理解这些实现细节,将有助于开发者更好地使用和扩展Marigold项目。