首页
/ 深入解析AntixK/PyTorch-VAE中的WAE-MMD模型实现

深入解析AntixK/PyTorch-VAE中的WAE-MMD模型实现

2025-07-07 04:27:53作者:余洋婵Anita

本文将对基于PyTorch实现的Wasserstein自编码器(MMD变体)进行详细解析,帮助读者理解其核心思想和实现细节。

1. WAE-MMD模型概述

WAE-MMD(Wasserstein Auto-Encoder with Maximum Mean Discrepancy)是一种基于Wasserstein距离的自编码器变体,它使用最大均值差异(MMD)作为正则化项来匹配潜在空间分布与先验分布。相比传统VAE,WAE-MMD具有以下优势:

  1. 避免了VAE中需要计算KL散度的限制
  2. 通过MMD直接度量分布差异,训练更稳定
  3. 生成质量通常优于标准VAE

2. 模型架构解析

2.1 编码器结构

编码器由多个卷积块组成,每个块包含:

  • 卷积层(kernel_size=3, stride=2)
  • 批归一化层
  • LeakyReLU激活函数
modules.append(
    nn.Sequential(
        nn.Conv2d(in_channels, out_channels=h_dim,
                  kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(h_dim),
        nn.LeakyReLU())
)

最终通过全连接层将特征映射到潜在空间。

2.2 解码器结构

解码器采用转置卷积结构,基本模块包含:

  • 转置卷积层(kernel_size=3, stride=2)
  • 批归一化层
  • LeakyReLU激活函数
modules.append(
    nn.Sequential(
        nn.ConvTranspose2d(hidden_dims[i],
                           hidden_dims[i + 1],
                           kernel_size=3,
                           stride=2,
                           padding=1,
                           output_padding=1),
        nn.BatchNorm2d(hidden_dims[i + 1]),
        nn.LeakyReLU())
)

最终输出层使用Tanh激活函数将像素值限制在[-1,1]范围内。

3. 核心算法实现

3.1 MMD计算

MMD(最大均值差异)是WAE-MMD的核心,用于度量潜在变量分布与先验分布(Gaussian)的差异:

def compute_mmd(self, z: Tensor, reg_weight: float) -> Tensor:
    prior_z = torch.randn_like(z)  # 从先验分布采样
    
    # 计算三种核矩阵
    prior_z__kernel = self.compute_kernel(prior_z, prior_z)
    z__kernel = self.compute_kernel(z, z)
    priorz_z__kernel = self.compute_kernel(prior_z, z)
    
    # MMD计算公式
    mmd = reg_weight * prior_z__kernel.mean() + \
          reg_weight * z__kernel.mean() - \
          2 * reg_weight * priorz_z__kernel.mean()
    return mmd

3.2 核函数实现

代码实现了两种核函数:

  1. RBF核(径向基函数核)
  2. IMQ核(逆多元二次核)
def compute_inv_mult_quad(self, x1: Tensor, x2: Tensor, eps: float = 1e-7) -> Tensor:
    z_dim = x2.size(-1)
    C = 2 * z_dim * self.z_var
    kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim=-1))
    result = kernel.sum() - kernel.diag().sum()
    return result

IMQ核通常在实践中表现更好,避免了RBF核可能导致的梯度消失问题。

4. 损失函数

WAE-MMD的损失函数由两部分组成:

  1. 重构损失:输入与输出之间的均方误差
  2. MMD损失:潜在分布与先验分布的差异
recons_loss = F.mse_loss(recons, input)
mmd_loss = self.compute_mmd(z, reg_weight)
loss = recons_loss + mmd_loss

5. 模型使用方法

5.1 采样生成

def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor:
    z = torch.randn(num_samples, self.latent_dim)  # 从先验分布采样
    z = z.to(current_device)
    samples = self.decode(z)  # 解码生成样本
    return samples

5.2 图像重建

def generate(self, x: Tensor, **kwargs) -> Tensor:
    return self.forward(x)[0]  # 返回重建结果

6. 关键参数说明

  • latent_dim: 潜在空间维度
  • reg_weight: MMD正则化权重
  • kernel_type: 核函数类型(rbf/imq)
  • latent_var: 潜在空间方差

7. 总结

WAE-MMD通过结合自编码器的重构能力和MMD分布匹配,实现了高质量的生成模型。相比VAE,它不需要对潜在变量的后验分布进行严格假设,训练过程更加稳定。本文分析的实现采用了标准的卷积自编码器架构,读者可以根据需要调整网络结构或尝试不同的核函数。