首页
/ Donut项目训练脚本解析与使用指南

Donut项目训练脚本解析与使用指南

2025-07-07 06:34:31作者:侯霆垣

项目概述

Donut是一个基于Transformer的文档理解模型,能够处理各种文档图像理解任务,如文档分类、文档视觉问答(DocVQA)等。该项目的训练脚本(train.py)提供了完整的模型训练流程实现,本文将深入解析其核心功能和使用方法。

训练脚本核心组件

1. 配置系统

训练脚本使用了sconf库来管理配置,支持从YAML文件加载配置并通过命令行参数覆盖。这种设计使得实验管理更加灵活:

config = Config(args.config)  # 从YAML加载配置
config.argv_update(left_argv)  # 用命令行参数更新配置

2. 自定义检查点处理

脚本实现了CustomCheckpointIO类来处理模型检查点,这种设计可以灵活控制保存和加载的内容:

class CustomCheckpointIO(CheckpointIO):
    def save_checkpoint(self, checkpoint, path, storage_options=None):
        del checkpoint["state_dict"]  # 不保存state_dict
        torch.save(checkpoint, path)
    
    def load_checkpoint(self, path, storage_options=None):
        # 分别加载模型参数和其他检查点信息
        checkpoint = torch.load(path + "artifacts.ckpt")
        state_dict = torch.load(path + "pytorch_model.bin")
        checkpoint["state_dict"] = {"model." + key: value for key, value in state_dict.items()}
        return checkpoint

3. 训练流程控制

主训练函数train()包含了完整的训练流程:

  1. 设置随机种子保证可重复性
  2. 初始化模型和数据模块
  3. 准备训练和验证数据集
  4. 配置日志记录器和回调函数
  5. 启动训练过程

关键功能解析

数据集处理

Donut支持多任务训练,可以同时加载多个数据集。对于不同任务,会添加特定的特殊标记:

# 为不同任务添加特殊标记
if task_name == "rvlcdip":
    model_module.model.decoder.add_special_tokens([...])  # 文档分类标签
if task_name == "docvqa":
    model_module.model.decoder.add_special_tokens(["<yes/>", "<no/>"])  # 问答任务标记

训练器配置

脚本使用了PyTorch Lightning的Trainer,并进行了多项配置:

trainer = pl.Trainer(
    devices=torch.cuda.device_count(),  # 使用所有可用GPU
    strategy="ddp",  # 分布式数据并行
    max_epochs=config.max_epochs,  # 最大训练轮数
    precision=16,  # 混合精度训练
    callbacks=[lr_callback, checkpoint_callback, bar]  # 学习率监控、模型保存、进度条
)

进度条定制

自定义的ProgressBar类增强了训练过程中的信息展示:

class ProgressBar(pl.callbacks.TQDMProgressBar):
    def get_metrics(self, trainer, model):
        items = super().get_metrics(trainer, model)
        items["exp_name"] = f"{self.config.get('exp_name', '')}"  # 显示实验名称
        items["exp_version"] = f"{self.config.get('exp_version', '')}"  # 显示实验版本
        return items

使用指南

1. 准备配置文件

创建一个YAML配置文件,指定模型参数、数据路径和训练超参数。

2. 启动训练

使用以下命令启动训练:

python train.py --config path/to/config.yaml

3. 可选参数

  • --exp_version: 指定实验版本标识,默认为当前时间戳
  • 其他配置参数可以通过命令行覆盖

4. 恢复训练

如果需要从检查点恢复训练,可以在配置文件中设置resume_from_checkpoint_path参数。

最佳实践

  1. 多任务训练:利用Donut支持多任务的特点,可以同时训练文档分类和问答任务
  2. 特殊标记:根据任务需求添加适当的特殊标记,如文档类型或答案选项
  3. 混合精度:默认使用FP16训练可以节省显存并加速训练
  4. 分布式训练:脚本自动支持多GPU训练,充分利用硬件资源

总结

Donut的训练脚本提供了灵活且强大的训练流程,支持多任务学习、分布式训练和细粒度的训练控制。通过理解其内部机制,用户可以更好地定制自己的文档理解模型训练过程。