Apple ML-FERRET 项目训练脚本解析与实现原理
2025-07-07 01:02:21作者:尤辰城Agatha
项目概述
Apple ML-FERRET 是一个基于视觉-语言多模态交互的项目,专注于区域级别的视觉理解和对话能力。本文主要分析其核心训练脚本 train.py
的实现原理和技术细节。
核心架构设计
1. 模型参数配置
训练脚本使用 dataclass
定义了三种主要配置类:
@dataclass
class ModelArguments:
# 模型相关参数
model_name_or_path: str = "facebook/opt-125m" # 基础模型路径
version: str = "v0" # 模型版本
freeze_backbone: bool = False # 是否冻结主干网络
vision_tower: Optional[str] = None # 视觉塔模型路径
mm_vision_select_layer: int = -1 # 选择视觉特征的层
mm_use_im_start_end: bool = False # 是否使用图像起止标记
add_region_feature: bool = False # 是否添加区域特征
2. 数据处理参数
@dataclass
class DataArguments:
data_path: List[str] # 训练数据路径列表
image_folder: List[str] # 图像文件夹路径
image_aspect_ratio: str = 'square_nocrop' # 图像比例处理方式
resized_image_h: int = 336 # 图像高度
resized_image_w: int = 336 # 图像宽度
point_input_sample: str = 'segment_mask|uniform' # 点采样策略
3. 训练参数
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = None # 缓存目录
optim: str = "adamw_torch" # 优化器
model_max_length: int = 512 # 最大序列长度
lora_enable: bool = False # 是否启用LoRA
lora_r: int = 64 # LoRA秩
关键技术实现
1. 多模态数据处理
预处理函数 preprocess_multimodal
负责处理包含图像标记的对话数据:
def preprocess_multimodal(sources, data_args):
for source in sources:
for sentence in source:
if DEFAULT_IMAGE_TOKEN in sentence['value']:
# 处理图像标记
sentence['value'] = sentence['value'].replace(
DEFAULT_IMAGE_TOKEN, '').strip()
sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
# 添加图像起止标记
if data_args.mm_use_im_start_end:
replace_token = (DEFAULT_IM_START_TOKEN +
DEFAULT_IMAGE_TOKEN +
DEFAULT_IM_END_TOKEN)
sentence["value"] = sentence["value"].replace(
DEFAULT_IMAGE_TOKEN, replace_token)
return sources
2. 对话模板处理
支持多种对话模板风格,如LLaMA-2格式:
def preprocess_llama_2(sources, tokenizer, has_image=False):
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# 构建对话提示
conversations = []
for source in sources:
conv.messages = []
for sentence in source:
role = roles[sentence["from"]]
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# 特殊处理图像token
if has_image:
input_ids = torch.stack([
tokenizer_image_token(prompt, tokenizer, return_tensors='pt')
for prompt in conversations], dim=0)
return dict(input_ids=input_ids, labels=targets)
3. 参数高效微调
实现了LoRA和适配器的高效保存与加载:
def safe_save_model_for_hf_trainer(trainer, output_dir, save_vision_tower):
if getattr(trainer.args, "tune_mm_mlp_adapter", False):
# 仅保存适配器部分
keys_to_match = ['mm_projector']
weight_to_save = get_mm_adapter_state_maybe_zero_3(
trainer.model.named_parameters(), keys_to_match)
torch.save(weight_to_save, os.path.join(output_dir, 'mm_projector.bin'))
# 保存视觉塔模型
if save_vision_tower:
vision_tower_folder = os.path.join(output_dir, "vision_tower")
trainer.model.model.get_vision_tower().vision_tower.save_pretrained(
vision_tower_folder)
训练流程解析
-
数据预处理阶段:
- 处理多模态对话数据
- 添加图像特殊标记
- 应用对话模板
-
模型初始化阶段:
- 加载基础语言模型
- 初始化视觉编码器
- 配置LoRA等参数高效微调组件
-
训练循环:
- 处理多批次数据
- 计算损失并反向传播
- 应用梯度裁剪和优化器步骤
-
模型保存:
- 选择性保存适配器参数
- 保存视觉编码器
- 处理分布式训练场景
区域特征处理特色
FERRET项目特别关注区域级别的视觉理解:
DEFAULT_REGION_FEA_TOKEN = "<region_fea>"
VOCAB_IMAGE_W = 1000
VOCAB_IMAGE_H = 1000
@dataclass
class ModelArguments:
add_region_feature: bool = False # 启用区域特征
region_geo_sampler: bool = False # 区域几何采样器
sampler_pooler_mode: str = 'mean' # 池化方式: mean/max
最佳实践建议
-
数据准备:
- 确保图像路径配置正确
- 对话数据需符合指定格式
- 合理设置图像分辨率
-
训练配置:
- 小规模实验可先冻结主干网络
- 逐步启用区域特征和LoRA
- 注意调整学习率和批次大小
-
调试技巧:
- 使用
rank0_print
调试主进程 - 检查tokenizer处理后的输入格式
- 验证图像标记是否正确插入
- 使用
通过深入理解FERRET的训练脚本实现,开发者可以更好地定制自己的多模态训练流程,或基于此架构开发新的视觉-语言交互功能。