首页
/ DeepMind Research项目中的对象注意力推理模型解析

DeepMind Research项目中的对象注意力推理模型解析

2025-07-06 02:50:35作者:范垣楠Rhoda

模型概述

本文要解析的是DeepMind Research项目中一个基于对象注意力的推理模型,该模型主要用于处理CLEVRER数据集中的视觉推理任务。CLEVRER是一个视频问答数据集,要求模型理解视频中的物理交互并回答相关问题。

模型架构核心组件

1. 嵌入层(Embedding Layer)

模型使用Sonnet框架的snt.Embed层来处理问题和选项的文本输入:

  • 问题词汇表大小(QUESTION_VOCAB_SIZE): 82
  • 答案词汇表大小(ANSWER_VOCAB_SIZE): 22
  • 嵌入维度(EMBED_DIM): 16(实际使用时会额外拼接2维ID向量)
self._embed = snt.Embed(QUESTION_VOCAB_SIZE, embed_dim - 2)

2. Transformer模块

模型的核心是一个多层Transformer结构:

  • 层数(transformer_layers): 28层
  • 注意力头数(num_heads): 10个
  • 头尺寸(head_size): 128
  • 支持相对位置编码(use_relative_positions)
self._memory_transformer = transformer.TransformerTower(
    value_size=embed_dim + 2,
    num_heads=num_heads,
    num_layers=transformer_layers,
    use_relative_positions=use_relative_positions,
    causal=False)

3. 输出层

模型针对不同类型的问答任务设计了不同的输出层:

  • 多项选择题(Multiple Choice): 使用线性层+ReLU+单输出单元
  • 描述性问题(Descriptive): 使用线性层+ReLU+答案词汇表大小的输出
self._final_layer_mc = snt.Sequential(
    [snt.Linear(head_size), tf.nn.relu, snt.Linear(1)])
self._final_layer_descriptive = snt.Sequential(
    [snt.Linear(head_size), tf.nn.relu,
     snt.Linear(ANSWER_VOCAB_SIZE)])

关键设计特点

1. 对象注意力机制

模型通过以下方式实现对象级别的注意力:

  • 视觉输入(monet_latents)被处理为[batch, frames, num_objects, embed_dim]的张量
  • 支持对象随机打乱(shuffle_objects),增强模型对对象顺序的鲁棒性
if self._shuffle_objects:
    vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
    vision_embedding = tf.random.shuffle(vision_embedding)
    vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])

2. 多模态融合

模型通过拼接ID向量来区分不同模态的信息:

  • 语言嵌入拼接[1,0]表示问题文本
  • 视觉嵌入拼接[0,1]表示视觉特征
  • 这种设计让模型能区分不同来源的信息
def append_ids(tensor, id_vector, axis):
    id_vector = tf.constant(id_vector, tf.float32)
    for a in range(len(tensor.shape)):
        if a != axis:
            id_vector = tf.expand_dims(id_vector, axis=a)
    tiling_vector = [s if i != axis else 1 for i, s in enumerate(tensor.shape)]
    id_tensor = tf.tile(id_vector, tiling_vector)
    return tf.concat([tensor, id_tensor], axis=axis)

3. 双任务支持

模型设计支持两种问答任务:

  1. 描述性问题:直接预测答案词汇的分布
  2. 多项选择题:为每个选项计算分数并选择最高分选项

数据处理流程

1. 描述性问题处理流程

  1. 嵌入问题文本
  2. 添加ID标识
  3. 处理视觉特征(可选打乱对象顺序)
  4. 通过Transformer融合多模态信息
  5. 输出答案词汇分布
def apply_model_descriptive(self, inputs):
    question_embedding = self._embed(question)
    question_embedding = append_ids(question_embedding, [0, 1], 2)
    # ...处理视觉特征...
    output = self._apply_transformers(lang_embedding, vision_embedding)
    output = self._final_layer_descriptive(output)

2. 多项选择题处理流程

  1. 分别嵌入问题和每个选项
  2. 为每个选项构建问题-选项对
  3. 处理视觉特征(可选打乱对象顺序)
  4. 对每个选项单独通过Transformer处理
  5. 计算每个选项的得分并输出
def apply_model_mc(self, inputs):
    question_embedding = self._embed(question)
    choices_embedding = snt.BatchApply(self._embed)(choices)
    # ...为每个选项构建输入...
    for c in range(NUM_CHOICES):
        output = self._apply_transformers(
            lang_embedding[:, c, :, :], vision_embedding)
        output_per_choice.append(output)
    # ...计算选项得分...

模型训练与应用

虽然代码中主要展示了模型的前向传播部分,但从设计可以看出:

  1. 训练目标

    • 描述性问题:交叉熵损失(分类任务)
    • 多项选择题:选择正确选项(排序任务)
  2. 推理特点

    • 支持批处理
    • 对视觉对象的顺序具有鲁棒性
    • 能够同时处理语言和视觉信息

总结

这个对象注意力推理模型展示了如何利用Transformer架构处理多模态推理任务。其核心创新点在于:

  1. 通过对象级别的注意力机制处理视觉输入
  2. 使用ID向量区分不同模态信息
  3. 统一架构支持多种问答形式
  4. 对对象顺序的鲁棒性设计

这种设计在需要理解场景中对象交互的视觉推理任务中表现出色,为多模态推理提供了可借鉴的架构范式。