Microsoft UniLM中的BEATs:基于声学标记器的音频预训练模型解析
2025-07-05 07:18:07作者:宣聪麟
什么是BEATs?
BEATs(Bootstrapping Audio Pre-Training with Acoustic Tokenizers)是微软研究院开发的一种创新的音频预训练框架。该模型通过自监督学习方式,利用声学标记器(Acoustic Tokenizers)从原始音频数据中学习高质量的音频表示,在多项音频理解任务上达到了state-of-the-art的性能。
BEATs的核心技术原理
BEATs的核心创新在于其独特的预训练框架,主要包含两个关键组件:
-
声学标记器(Acoustic Tokenizer):将连续音频信号离散化为一系列标记(tokens),类似于NLP中的词标记化过程。
-
音频Transformer编码器:基于标记化后的音频数据进行预训练,学习通用的音频表示。
这种架构的优势在于:
- 通过离散化表示,模型可以更好地捕捉音频中的关键特征
- 标记器与编码器的联合训练实现了自举式学习(bootstrapping learning)
- 适用于各种下游音频任务,如音频分类、语音识别等
BEATs模型系列
BEATs提供了多个版本的预训练模型,每个版本在架构和训练数据上有所不同:
-
基础版本:包括Iter1、Iter2和Iter3三个迭代版本
- Iter1使用随机投影作为初始标记器
- Iter2和Iter3使用学习到的声学标记器
-
增强版本(Iter3+):使用更大规模音频数据集(AS20K和AS2M)训练
- AS20K:包含2万小时音频数据
- AS2M:包含200万小时音频数据
如何使用BEATs模型
1. 加载声学标记器
import torch
from Tokenizers import TokenizersConfig, Tokenizers
# 加载预训练标记器
checkpoint = torch.load('/path/to/tokenizer.pt')
# 初始化标记器
cfg = TokenizersConfig(checkpoint['cfg'])
BEATs_tokenizer = Tokenizers(cfg)
BEATs_tokenizer.load_state_dict(checkpoint['model'])
BEATs_tokenizer.eval()
# 使用标记器处理音频
audio_input = torch.randn(1, 10000) # 16kHz音频输入
padding_mask = torch.zeros(1, 10000).bool()
labels = BEATs_tokenizer.extract_labels(audio_input, padding_mask=padding_mask)
2. 加载预训练模型
import torch
from BEATs import BEATs, BEATsConfig
# 加载预训练模型
checkpoint = torch.load('/path/to/model.pt')
# 初始化模型
cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()
# 提取音频特征表示
audio_input = torch.randn(1, 10000) # 16kHz音频输入
padding_mask = torch.zeros(1, 10000).bool()
representation = BEATs_model.extract_features(audio_input, padding_mask=padding_mask)[0]
3. 加载微调模型(用于分类任务)
import torch
from BEATs import BEATs, BEATsConfig
# 加载微调模型
checkpoint = torch.load('/path/to/fine-tuned-model.pt')
# 初始化模型
cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()
# 进行音频分类预测
audio_input = torch.randn(3, 10000) # 3个16kHz音频样本
padding_mask = torch.zeros(3, 10000).bool()
probs = BEATs_model.extract_features(audio_input, padding_mask=padding_mask)[0]
# 输出top5预测结果
for i, (top5_prob, top5_idx) in enumerate(zip(*probs.topk(k=5))):
top5_labels = [checkpoint['label_dict'][idx.item()] for idx in top5_idx]
print(f'第{i}个音频的top5预测标签: {top5_labels}, 对应概率: {top5_prob}')
BEATs性能表现
BEATs在多项音频理解任务上展现了卓越的性能:
- 与单模型SOTA比较:BEATs超越了之前最好的单模型方法
- 与集成模型比较:BEATs单模型性能接近甚至超过某些集成模型
- 不同标记器比较:迭代改进的标记器版本带来了持续的性能提升
- 不同预训练目标比较:BEATs的预训练策略显著优于传统方法
应用场景
BEATs可应用于多种音频相关任务:
- 环境声音分类
- 语音情感识别
- 音频事件检测
- 音乐分类与标记
- 语音识别辅助任务
技术优势
- 高效预训练:通过声学标记器实现更高效的音频表示学习
- 强泛化能力:学到的特征可迁移到多种下游任务
- 可扩展性:框架支持使用更大规模数据进行训练
- 灵活性:既可用于特征提取,也可进行端到端微调
最佳实践建议
-
模型选择:根据任务需求选择合适的BEATs版本
- 基础任务:Iter3通常足够
- 高要求任务:建议使用Iter3+ AS2M版本
-
计算资源:
- 预训练需要较强GPU资源
- 微调和推理可在中等配置GPU上完成
-
数据准备:
- 确保音频采样率为16kHz
- 对于分类任务,建议准备平衡的数据集
-
微调技巧:
- 学习率通常设置较小(如1e-5)
- 可冻结部分底层参数防止过拟合
BEATs代表了音频预训练领域的重要进展,为音频理解任务提供了强大的基础模型。通过合理使用这些预训练模型,研究人员和开发者可以在各种音频应用中取得优异的表现。