首页
/ Airbnb Aerosolve多分类训练配置详解

Airbnb Aerosolve多分类训练配置详解

2025-07-08 05:17:47作者:余洋婵Anita

本文深入解析Aerosolve项目中用于多分类任务的训练配置文件demo_multiclass_train.conf,帮助读者理解如何配置和使用Aerosolve进行多类别分类任务。

配置文件概述

该配置文件展示了如何使用Aerosolve的通用管道(generic pipeline)进行多类别分类任务。示例使用了20 Newsgroups数据集,目标是训练一个能够预测消息所属新闻组的模型。

核心配置解析

基础设置

job_name : "Generic Pipeline Multiclass"
training_data_version : 1
model_version: "a"
model_type : "full_rank_linear"
  • job_name:定义任务名称
  • training_data_version:训练数据版本号
  • model_version:模型版本标识
  • model_type:指定使用全秩线性模型(full_rank_linear)

数据路径配置

prefix : "hdfs://airfs-silver/user/"${USER}"/multiclass_demo_pipeline"
training_data : ${prefix}"/training_data"${training_data_version}
eval_data : ${prefix}"/eval_data"${training_data_version}

这些配置定义了HDFS上存储训练数据、评估数据和模型输出的路径模板。实际使用时应根据环境调整prefix

数据查询配置

generic_hive_query : """
  select
    concat(group_name, ":1.0") as LABEL,
    message as S_RAW
"""

这个Hive查询定义了如何从原始数据中提取特征:

  • LABEL:类别标签,格式为"[class1]:[weight1]"
  • S_RAW:原始消息文本

训练流程配置

配置文件定义了几个关键阶段:

1. 调试阶段

debug_example {
  hive_query : ${generic_hive_query}" from "${demo_table}
  is_multiclass : true
  count : 10
}

这个阶段用于检查原始数据样本,count参数限制只查看10条记录。

2. 数据准备阶段

make_training {
  training_hive_query : ${generic_hive_query}" from "${demo_table}
  training_output : ${training_data}
  eval_hive_query : ${generic_hive_query}" from "${demo_table_eval}
  eval_output : ${eval_data}
  is_multiclass : true
  num_shards : 20
}
  • 从Hive表中提取训练和评估数据
  • num_shards:指定数据分片数量,影响并行处理效率

3. 模型训练阶段

train_model {
  input : ${training_data}
  subsample : ${train_subsample}
  model_config : ${model_config}
}

使用准备好的训练数据训练模型,subsample可用于控制采样比例。

4. 模型评估阶段

eval_model {
  input : ${eval_data}
  subsample : ${eval_subsample}
  bins : 11
  model_config : ${model_config}
  is_probability : false
  is_multiclass : true
  metric_to_maximize : "!HOLD_F1"
  model_name : ${model_name}
}
  • bins:指定评估时使用的分箱数
  • metric_to_maximize:指定优化指标,这里使用F1分数

特征转换配置

配置文件定义了几个重要的特征转换步骤:

tokenize_string {
  transform : default_string_tokenizer
  field1: "S"
  regex : """[\s\p{Punct}]"""
  output : "TOKENS"
  generate_bigrams : false
}

normalize_tokens {
  transform : normalize_float
  field1: "TOKENS"
}
  1. 字符串分词:使用空白和标点符号作为分隔符
  2. 词元归一化:对分词结果进行归一化处理
  3. 删除原始字符串:处理完成后删除原始特征

这些转换通过combined_transform组合在一起。

模型配置详解

full_rank_linear_model_config {
  trainer : "full_rank_linear"
  model_output : ${model_name}
  rank_key : "LABEL"
  loss : "hinge"
  cache : "memory"
  iterations : 20
  lambda : 1.0
  min_count : 3
}
  • trainer:指定使用全秩线性模型
  • loss:损失函数,支持"softmax"或"hinge"
  • iterations:训练迭代次数
  • lambda:正则化参数
  • min_count:特征最小出现次数阈值

实际应用建议

  1. 数据准备:确保输入数据格式符合要求,特别是LABEL字段的格式
  2. 路径配置:根据实际环境调整HDFS路径
  3. 参数调优:根据数据规模调整num_shards,根据模型性能调整iterationslambda
  4. 特征工程:可以修改tokenize_string中的正则表达式来调整分词策略

通过理解这个配置文件,用户可以快速上手Aerosolve的多分类任务,并根据实际需求调整配置参数。