DeepMind Research项目中的对象注意力推理模型解析
2025-07-06 02:50:35作者:范垣楠Rhoda
模型概述
本文要解析的是DeepMind Research项目中一个基于对象注意力的推理模型,该模型主要用于处理CLEVRER数据集中的视觉推理任务。CLEVRER是一个视频问答数据集,要求模型理解视频中的物理交互并回答相关问题。
模型架构核心组件
1. 嵌入层(Embedding Layer)
模型使用Sonnet框架的snt.Embed
层来处理问题和选项的文本输入:
- 问题词汇表大小(QUESTION_VOCAB_SIZE): 82
- 答案词汇表大小(ANSWER_VOCAB_SIZE): 22
- 嵌入维度(EMBED_DIM): 16(实际使用时会额外拼接2维ID向量)
self._embed = snt.Embed(QUESTION_VOCAB_SIZE, embed_dim - 2)
2. Transformer模块
模型的核心是一个多层Transformer结构:
- 层数(transformer_layers): 28层
- 注意力头数(num_heads): 10个
- 头尺寸(head_size): 128
- 支持相对位置编码(use_relative_positions)
self._memory_transformer = transformer.TransformerTower(
value_size=embed_dim + 2,
num_heads=num_heads,
num_layers=transformer_layers,
use_relative_positions=use_relative_positions,
causal=False)
3. 输出层
模型针对不同类型的问答任务设计了不同的输出层:
- 多项选择题(Multiple Choice): 使用线性层+ReLU+单输出单元
- 描述性问题(Descriptive): 使用线性层+ReLU+答案词汇表大小的输出
self._final_layer_mc = snt.Sequential(
[snt.Linear(head_size), tf.nn.relu, snt.Linear(1)])
self._final_layer_descriptive = snt.Sequential(
[snt.Linear(head_size), tf.nn.relu,
snt.Linear(ANSWER_VOCAB_SIZE)])
关键设计特点
1. 对象注意力机制
模型通过以下方式实现对象级别的注意力:
- 视觉输入(monet_latents)被处理为[batch, frames, num_objects, embed_dim]的张量
- 支持对象随机打乱(shuffle_objects),增强模型对对象顺序的鲁棒性
if self._shuffle_objects:
vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
vision_embedding = tf.random.shuffle(vision_embedding)
vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
2. 多模态融合
模型通过拼接ID向量来区分不同模态的信息:
- 语言嵌入拼接[1,0]表示问题文本
- 视觉嵌入拼接[0,1]表示视觉特征
- 这种设计让模型能区分不同来源的信息
def append_ids(tensor, id_vector, axis):
id_vector = tf.constant(id_vector, tf.float32)
for a in range(len(tensor.shape)):
if a != axis:
id_vector = tf.expand_dims(id_vector, axis=a)
tiling_vector = [s if i != axis else 1 for i, s in enumerate(tensor.shape)]
id_tensor = tf.tile(id_vector, tiling_vector)
return tf.concat([tensor, id_tensor], axis=axis)
3. 双任务支持
模型设计支持两种问答任务:
- 描述性问题:直接预测答案词汇的分布
- 多项选择题:为每个选项计算分数并选择最高分选项
数据处理流程
1. 描述性问题处理流程
- 嵌入问题文本
- 添加ID标识
- 处理视觉特征(可选打乱对象顺序)
- 通过Transformer融合多模态信息
- 输出答案词汇分布
def apply_model_descriptive(self, inputs):
question_embedding = self._embed(question)
question_embedding = append_ids(question_embedding, [0, 1], 2)
# ...处理视觉特征...
output = self._apply_transformers(lang_embedding, vision_embedding)
output = self._final_layer_descriptive(output)
2. 多项选择题处理流程
- 分别嵌入问题和每个选项
- 为每个选项构建问题-选项对
- 处理视觉特征(可选打乱对象顺序)
- 对每个选项单独通过Transformer处理
- 计算每个选项的得分并输出
def apply_model_mc(self, inputs):
question_embedding = self._embed(question)
choices_embedding = snt.BatchApply(self._embed)(choices)
# ...为每个选项构建输入...
for c in range(NUM_CHOICES):
output = self._apply_transformers(
lang_embedding[:, c, :, :], vision_embedding)
output_per_choice.append(output)
# ...计算选项得分...
模型训练与应用
虽然代码中主要展示了模型的前向传播部分,但从设计可以看出:
-
训练目标:
- 描述性问题:交叉熵损失(分类任务)
- 多项选择题:选择正确选项(排序任务)
-
推理特点:
- 支持批处理
- 对视觉对象的顺序具有鲁棒性
- 能够同时处理语言和视觉信息
总结
这个对象注意力推理模型展示了如何利用Transformer架构处理多模态推理任务。其核心创新点在于:
- 通过对象级别的注意力机制处理视觉输入
- 使用ID向量区分不同模态信息
- 统一架构支持多种问答形式
- 对对象顺序的鲁棒性设计
这种设计在需要理解场景中对象交互的视觉推理任务中表现出色,为多模态推理提供了可借鉴的架构范式。