首页
/ Microsoft UniLM中的BEATs:基于声学标记器的音频预训练模型解析

Microsoft UniLM中的BEATs:基于声学标记器的音频预训练模型解析

2025-07-05 07:18:07作者:宣聪麟

什么是BEATs?

BEATs(Bootstrapping Audio Pre-Training with Acoustic Tokenizers)是微软研究院开发的一种创新的音频预训练框架。该模型通过自监督学习方式,利用声学标记器(Acoustic Tokenizers)从原始音频数据中学习高质量的音频表示,在多项音频理解任务上达到了state-of-the-art的性能。

BEATs的核心技术原理

BEATs的核心创新在于其独特的预训练框架,主要包含两个关键组件:

  1. 声学标记器(Acoustic Tokenizer):将连续音频信号离散化为一系列标记(tokens),类似于NLP中的词标记化过程。

  2. 音频Transformer编码器:基于标记化后的音频数据进行预训练,学习通用的音频表示。

这种架构的优势在于:

  • 通过离散化表示,模型可以更好地捕捉音频中的关键特征
  • 标记器与编码器的联合训练实现了自举式学习(bootstrapping learning)
  • 适用于各种下游音频任务,如音频分类、语音识别等

BEATs模型系列

BEATs提供了多个版本的预训练模型,每个版本在架构和训练数据上有所不同:

  1. 基础版本:包括Iter1、Iter2和Iter3三个迭代版本

    • Iter1使用随机投影作为初始标记器
    • Iter2和Iter3使用学习到的声学标记器
  2. 增强版本(Iter3+):使用更大规模音频数据集(AS20K和AS2M)训练

    • AS20K:包含2万小时音频数据
    • AS2M:包含200万小时音频数据

如何使用BEATs模型

1. 加载声学标记器

import torch
from Tokenizers import TokenizersConfig, Tokenizers

# 加载预训练标记器
checkpoint = torch.load('/path/to/tokenizer.pt')

# 初始化标记器
cfg = TokenizersConfig(checkpoint['cfg'])
BEATs_tokenizer = Tokenizers(cfg)
BEATs_tokenizer.load_state_dict(checkpoint['model'])
BEATs_tokenizer.eval()

# 使用标记器处理音频
audio_input = torch.randn(1, 10000)  # 16kHz音频输入
padding_mask = torch.zeros(1, 10000).bool()
labels = BEATs_tokenizer.extract_labels(audio_input, padding_mask=padding_mask)

2. 加载预训练模型

import torch
from BEATs import BEATs, BEATsConfig

# 加载预训练模型
checkpoint = torch.load('/path/to/model.pt')

# 初始化模型
cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()

# 提取音频特征表示
audio_input = torch.randn(1, 10000)  # 16kHz音频输入
padding_mask = torch.zeros(1, 10000).bool()
representation = BEATs_model.extract_features(audio_input, padding_mask=padding_mask)[0]

3. 加载微调模型(用于分类任务)

import torch
from BEATs import BEATs, BEATsConfig

# 加载微调模型
checkpoint = torch.load('/path/to/fine-tuned-model.pt')

# 初始化模型
cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()

# 进行音频分类预测
audio_input = torch.randn(3, 10000)  # 3个16kHz音频样本
padding_mask = torch.zeros(3, 10000).bool()
probs = BEATs_model.extract_features(audio_input, padding_mask=padding_mask)[0]

# 输出top5预测结果
for i, (top5_prob, top5_idx) in enumerate(zip(*probs.topk(k=5))):
    top5_labels = [checkpoint['label_dict'][idx.item()] for idx in top5_idx]
    print(f'第{i}个音频的top5预测标签: {top5_labels}, 对应概率: {top5_prob}')

BEATs性能表现

BEATs在多项音频理解任务上展现了卓越的性能:

  1. 与单模型SOTA比较:BEATs超越了之前最好的单模型方法
  2. 与集成模型比较:BEATs单模型性能接近甚至超过某些集成模型
  3. 不同标记器比较:迭代改进的标记器版本带来了持续的性能提升
  4. 不同预训练目标比较:BEATs的预训练策略显著优于传统方法

应用场景

BEATs可应用于多种音频相关任务:

  • 环境声音分类
  • 语音情感识别
  • 音频事件检测
  • 音乐分类与标记
  • 语音识别辅助任务

技术优势

  1. 高效预训练:通过声学标记器实现更高效的音频表示学习
  2. 强泛化能力:学到的特征可迁移到多种下游任务
  3. 可扩展性:框架支持使用更大规模数据进行训练
  4. 灵活性:既可用于特征提取,也可进行端到端微调

最佳实践建议

  1. 模型选择:根据任务需求选择合适的BEATs版本

    • 基础任务:Iter3通常足够
    • 高要求任务:建议使用Iter3+ AS2M版本
  2. 计算资源

    • 预训练需要较强GPU资源
    • 微调和推理可在中等配置GPU上完成
  3. 数据准备

    • 确保音频采样率为16kHz
    • 对于分类任务,建议准备平衡的数据集
  4. 微调技巧

    • 学习率通常设置较小(如1e-5)
    • 可冻结部分底层参数防止过拟合

BEATs代表了音频预训练领域的重要进展,为音频理解任务提供了强大的基础模型。通过合理使用这些预训练模型,研究人员和开发者可以在各种音频应用中取得优异的表现。