ClearerVoice-Studio项目语音增强模型训练全解析
2025-07-10 04:04:59作者:劳婵绚Shirley
项目概述
ClearerVoice-Studio是一个专注于语音增强技术的开源项目,其中的train/speech_enhancement/train.py
文件是该项目的核心训练脚本。该脚本支持多种先进的语音增强模型训练,包括FRCRN、MossFormer2和MossFormerGAN等模型架构。
训练流程解析
1. 初始化设置
训练脚本首先进行一系列初始化设置,确保实验的可重复性:
random.seed(args.seed)
np.random.seed(args.seed)
os.environ['PYTORCH_SEED'] = str(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(args.seed)
这些设置包括随机种子固定、CUDA确定性模式等,确保每次运行结果一致。
2. 设备与分布式训练配置
脚本自动检测并配置训练设备(CPU或GPU),并支持分布式训练:
device = torch.device('cuda') if args.use_cuda else torch.device('cpu')
args.device = device
if args.distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', rank=args.local_rank,
init_method='env://', world_size=args.world_size)
3. 模型架构选择
项目支持多种语音增强模型架构:
- FRCRN_SE_16K:基于全频带卷积循环网络的语音增强模型
- MossFormer2_SE_48K:改进的Transformer架构语音增强模型
- MossFormerGAN_SE_16K:结合GAN的语音增强模型
模型通过network_wrapper
工厂函数动态创建:
model = network_wrapper(args).se_network
model = model.to(device)
对于GAN架构,还会额外创建判别器模型:
if args.network=='MossFormerGAN_SE_16K':
discriminator = network_wrapper(args).discriminator
discriminator = discriminator.to(device)
4. 优化器配置
根据不同的模型架构,脚本配置不同的优化策略:
- FRCRN_SE_16K:使用Adam优化器,支持权重衰减
- MossFormer2_SE_48K:标准Adam优化器
- MossFormerGAN_SE_16K:使用AdamW优化器,为生成器和判别器分别配置
if args.network=='FRCRN_SE_16K':
params = model.get_params(args.weight_decay)
optimizer = torch.optim.Adam(params, lr=args.init_learning_rate)
elif args.network=='MossFormer2_SE_48K':
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_learning_rate)
elif args.network=='MossFormerGAN_SE_16K':
optimizer = torch.optim.AdamW(model.parameters(), lr=args.init_learning_rate)
optimizer_disc = torch.optim.AdamW(discriminator.parameters(), lr=args.init_learning_rate*2)
5. 数据加载与采样
脚本使用自定义的get_dataloader
函数加载训练、验证和测试数据:
train_sampler, train_generator = get_dataloader(args,'train')
_, val_generator = get_dataloader(args, 'val')
if args.tt_list is not None:
_, test_generator = get_dataloader(args, 'test')
数据加载器支持多种配置参数,包括:
- 采样率(16K/48K)
- FFT参数(窗口长度、步长、FFT长度)
- 梅尔滤波器数量
- 最大音频长度
- 批处理大小等
6. 训练过程管理
训练过程由Solver
类管理,它封装了完整的训练逻辑:
solver = Solver(args=args,
model = model,
optimizer = optimizer,
discriminator = discriminator,
optimizer_disc = optimizer_disc,
train_data = train_generator,
validation_data = val_generator,
test_data = test_generator
)
solver.train()
关键参数解析
训练脚本支持丰富的配置参数,主要分为以下几类:
1. 实验设置
--mode
:运行模式(train/inference)--use-cuda
:是否使用GPU--network
:模型架构选择--checkpoint_dir
:模型保存路径--train_from_last_checkpoint
:是否从检查点继续训练
2. 数据配置
--tr-list
/--cv-list
/--tt-list
:训练/验证/测试数据列表--sampling-rate
:音频采样率--max_length
:训练音频最大长度--num_workers
:数据加载线程数
3. 特征提取参数
--window-len
/--window-inc
:STFT窗口参数--fft-len
:FFT长度--num-mels
:梅尔滤波器数量
4. 训练优化参数
--batch_size
:批处理大小--max-epoch
:最大训练轮次--init_learning_rate
:初始学习率--weight-decay
:权重衰减系数--clip-grad-norm
:梯度裁剪阈值
训练技巧与最佳实践
-
学习率设置:不同模型架构需要不同的学习率策略,GAN模型通常需要更小的学习率
-
批处理大小:根据GPU显存合理设置,可使用梯度累积(
accu_grad
)模拟更大批次 -
特征提取:合理设置STFT参数,平衡时频分辨率
-
正则化:适当使用权重衰减防止过拟合
-
混合精度训练:可考虑添加AMP支持加速训练
总结
ClearerVoice-Studio的训练脚本设计灵活,支持多种先进的语音增强模型架构。通过合理的参数配置,用户可以训练出高质量的语音增强模型,适用于各种噪声环境下的语音处理任务。脚本的模块化设计也便于扩展新的模型架构和训练策略。