首页
/ 深入解析Donut模型架构:基于Swin Transformer的文档理解系统

深入解析Donut模型架构:基于Swin Transformer的文档理解系统

2025-07-07 06:33:43作者:幸俭卉

Donut是一个端到端的OCR-free文档理解Transformer模型,它通过创新的架构设计实现了无需传统OCR步骤的文档理解能力。本文将深入解析Donut模型的核心组件和工作原理。

模型整体架构

Donut模型由两大核心组件构成:

  1. SwinEncoder:基于Swin Transformer的图像编码器,负责将文档图像转换为特征表示
  2. BARTDecoder:基于多语言BART的文本解码器,负责根据编码特征生成结构化输出

这种编码器-解码器架构使得Donut能够直接从图像输入生成结构化文本输出,跳过了传统OCR中间步骤。

SwinEncoder详解

核心特性

SwinEncoder是基于Swin Transformer架构设计的,具有以下特点:

  1. 长轴对齐处理:通过align_long_axis参数控制是否自动旋转长宽不匹配的图像
  2. 窗口注意力机制:使用可配置的window_size参数控制局部注意力的范围
  3. 分层特征提取:通过encoder_layer参数配置各层的深度

关键实现细节

class SwinEncoder(nn.Module):
    def __init__(self, input_size, align_long_axis, window_size, encoder_layer, name_or_path=None):
        # 初始化代码
        self.model = SwinTransformer(
            img_size=self.input_size,
            depths=self.encoder_layer,
            window_size=self.window_size,
            patch_size=4,
            embed_dim=128,
            num_heads=[4, 8, 16, 32],
            num_classes=0,
        )
  • 输入预处理:包含标准化和随机填充选项
  • 位置编码调整:动态调整相对位置偏置表以适应不同窗口大小
  • 权重初始化:默认使用预训练的Swin-Base模型权重

图像预处理流程

  1. 转换为RGB格式
  2. 根据长轴设置决定是否旋转
  3. 调整大小并保持纵横比
  4. 应用随机或中心填充
  5. 标准化处理

BARTDecoder详解

核心特性

BARTDecoder基于多语言BART模型,具有以下特点:

  1. 因果语言模型:配置为纯解码器模式(is_decoder=True)
  2. 交叉注意力机制:能够关注编码器输出的图像特征
  3. 特殊令牌支持:可动态添加如<sep/>等特殊令牌

关键实现细节

class BARTDecoder(nn.Module):
    def __init__(self, decoder_layer, max_position_embeddings, name_or_path=None):
        # 初始化代码
        self.model = MBartForCausalLM(
            config=MBartConfig(
                is_decoder=True,
                is_encoder_decoder=False,
                add_cross_attention=True,
                decoder_layers=self.decoder_layer,
                max_position_embeddings=self.max_position_embeddings,
                vocab_size=len(self.tokenizer),
        )
  • 位置嵌入调整:支持动态调整最大序列长度
  • 权重初始化:默认使用Asian-BART-ECJK预训练权重
  • 生成优化:重写了prepare_inputs_for_inference方法以支持生成过程

特殊功能实现

  1. 位置嵌入调整:通过resize_bart_abs_pos_emb方法动态调整位置编码
  2. 特殊令牌处理:支持动态添加新令牌并调整嵌入层大小
  3. 生成过程优化:实现了高效的缓存机制以加速自回归生成

Donut整合模型

配置类

DonutConfig类统一管理模型的所有配置参数,包括:

  • 图像处理相关:input_size, align_long_axis, window_size
  • 架构相关:encoder_layer, decoder_layer
  • 序列处理相关:max_position_embeddings, max_length

前向传播流程

  1. 图像通过SwinEncoder获取特征表示
  2. 文本提示与图像特征一起输入BARTDecoder
  3. 计算生成序列与目标序列的交叉熵损失
def forward(self, image_tensors, decoder_input_ids, decoder_labels):
    encoder_outputs = self.encoder(image_tensors)
    decoder_outputs = self.decoder(
        input_ids=decoder_input_ids,
        encoder_hidden_states=encoder_outputs,
        labels=decoder_labels,
    )
    return decoder_outputs

推理接口

提供便捷的inference方法支持端到端文档理解:

  1. 支持直接输入PIL图像或预处理后的张量
  2. 可配置返回JSON格式或原始输出
  3. 可选返回注意力权重用于分析

模型设计亮点

  1. 端到端训练:直接从图像到结构化输出,无需OCR中间步骤
  2. 灵活的图像处理:自动处理不同长宽比和尺寸的文档图像
  3. 高效的解码:优化的生成过程支持长序列输出
  4. 多语言支持:基于多语言BART的decoder天然支持多种语言

应用场景

Donut模型特别适合以下场景:

  • 文档信息提取(发票、收据等)
  • 表格数据识别
  • 文档问答系统
  • 任何需要从文档图像中提取结构化信息的任务

通过本文的解析,我们可以看到Donut模型如何巧妙地将视觉编码器和文本解码器结合起来,实现了真正意义上的端到端文档理解。这种架构设计不仅简化了传统OCR流程,还通过联合训练提高了整体性能。