Donut项目训练脚本解析与使用指南
2025-07-07 06:34:31作者:侯霆垣
项目概述
Donut是一个基于Transformer的文档理解模型,能够处理各种文档图像理解任务,如文档分类、文档视觉问答(DocVQA)等。该项目的训练脚本(train.py)提供了完整的模型训练流程实现,本文将深入解析其核心功能和使用方法。
训练脚本核心组件
1. 配置系统
训练脚本使用了sconf
库来管理配置,支持从YAML文件加载配置并通过命令行参数覆盖。这种设计使得实验管理更加灵活:
config = Config(args.config) # 从YAML加载配置
config.argv_update(left_argv) # 用命令行参数更新配置
2. 自定义检查点处理
脚本实现了CustomCheckpointIO
类来处理模型检查点,这种设计可以灵活控制保存和加载的内容:
class CustomCheckpointIO(CheckpointIO):
def save_checkpoint(self, checkpoint, path, storage_options=None):
del checkpoint["state_dict"] # 不保存state_dict
torch.save(checkpoint, path)
def load_checkpoint(self, path, storage_options=None):
# 分别加载模型参数和其他检查点信息
checkpoint = torch.load(path + "artifacts.ckpt")
state_dict = torch.load(path + "pytorch_model.bin")
checkpoint["state_dict"] = {"model." + key: value for key, value in state_dict.items()}
return checkpoint
3. 训练流程控制
主训练函数train()
包含了完整的训练流程:
- 设置随机种子保证可重复性
- 初始化模型和数据模块
- 准备训练和验证数据集
- 配置日志记录器和回调函数
- 启动训练过程
关键功能解析
数据集处理
Donut支持多任务训练,可以同时加载多个数据集。对于不同任务,会添加特定的特殊标记:
# 为不同任务添加特殊标记
if task_name == "rvlcdip":
model_module.model.decoder.add_special_tokens([...]) # 文档分类标签
if task_name == "docvqa":
model_module.model.decoder.add_special_tokens(["<yes/>", "<no/>"]) # 问答任务标记
训练器配置
脚本使用了PyTorch Lightning的Trainer,并进行了多项配置:
trainer = pl.Trainer(
devices=torch.cuda.device_count(), # 使用所有可用GPU
strategy="ddp", # 分布式数据并行
max_epochs=config.max_epochs, # 最大训练轮数
precision=16, # 混合精度训练
callbacks=[lr_callback, checkpoint_callback, bar] # 学习率监控、模型保存、进度条
)
进度条定制
自定义的ProgressBar
类增强了训练过程中的信息展示:
class ProgressBar(pl.callbacks.TQDMProgressBar):
def get_metrics(self, trainer, model):
items = super().get_metrics(trainer, model)
items["exp_name"] = f"{self.config.get('exp_name', '')}" # 显示实验名称
items["exp_version"] = f"{self.config.get('exp_version', '')}" # 显示实验版本
return items
使用指南
1. 准备配置文件
创建一个YAML配置文件,指定模型参数、数据路径和训练超参数。
2. 启动训练
使用以下命令启动训练:
python train.py --config path/to/config.yaml
3. 可选参数
--exp_version
: 指定实验版本标识,默认为当前时间戳- 其他配置参数可以通过命令行覆盖
4. 恢复训练
如果需要从检查点恢复训练,可以在配置文件中设置resume_from_checkpoint_path
参数。
最佳实践
- 多任务训练:利用Donut支持多任务的特点,可以同时训练文档分类和问答任务
- 特殊标记:根据任务需求添加适当的特殊标记,如文档类型或答案选项
- 混合精度:默认使用FP16训练可以节省显存并加速训练
- 分布式训练:脚本自动支持多GPU训练,充分利用硬件资源
总结
Donut的训练脚本提供了灵活且强大的训练流程,支持多任务学习、分布式训练和细粒度的训练控制。通过理解其内部机制,用户可以更好地定制自己的文档理解模型训练过程。