首页
/ wiseodd/generative-models中的F-GAN实现解析

wiseodd/generative-models中的F-GAN实现解析

2025-07-07 04:13:42作者:尤峻淳Whitney

概述

本文将深入分析wiseodd/generative-models项目中关于F-GAN的PyTorch实现。F-GAN是一种基于f-散度的生成对抗网络框架,它通过不同的f函数可以推导出多种GAN变体。这个实现展示了如何使用PyTorch构建一个基础的F-GAN模型,并在MNIST数据集上进行训练。

F-GAN理论基础

F-GAN的核心思想是利用f-散度(f-divergence)来衡量生成分布与真实分布之间的差异。f-散度是一类通用的散度度量,包括KL散度、JS散度等常见散度作为特例。F-GAN通过选择不同的凸函数f,可以推导出不同的GAN变体。

代码结构解析

1. 数据准备与参数设置

mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 32
z_dim = 10
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
lr = 1e-3

这部分代码完成了:

  • 加载MNIST数据集
  • 设置批处理大小(mb_size)为32
  • 潜在空间维度(z_dim)设为10
  • 输入维度(X_dim)由MNIST图像决定(784维)
  • 隐藏层维度(h_dim)设为128
  • 学习率(lr)设为0.001

2. 网络架构

生成器(G)和判别器(D)都采用简单的全连接网络:

G = torch.nn.Sequential(
    torch.nn.Linear(z_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, X_dim),
    torch.nn.Sigmoid()
)

D = torch.nn.Sequential(
    torch.nn.Linear(X_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, 1),
)

生成器将10维噪声通过两层全连接网络映射到784维(MNIST图像维度),使用ReLU激活函数和Sigmoid输出。判别器则将784维输入映射到1维输出(真假判断)。

3. 优化器设置

G_solver = optim.Adam(G.parameters(), lr=lr)
D_solver = optim.Adam(D.parameters(), lr=lr)

使用Adam优化器分别优化生成器和判别器,学习率相同。

F-GAN的核心实现

1. 散度选择

代码中实现了多种f-散度对应的损失函数,通过注释可以切换:

# 总变分散度(Total Variation)
# D_loss = -(torch.mean(0.5 * torch.tanh(D_real)) - torch.mean(0.5 * torch.tanh(D_fake)))

# 前向KL散度(Forward KL)
# D_loss = -(torch.mean(D_real) - torch.mean(torch.exp(D_fake - 1)))

# 反向KL散度(Reverse KL)
D_loss = -(torch.mean(-torch.exp(D_real)) - torch.mean(-1 - D_fake))

# Pearson χ²散度
# D_loss = -(torch.mean(D_real) - torch.mean(0.25*D_fake**2 + D_fake))

# Squared Hellinger散度
# D_loss = -(torch.mean(1 - torch.exp(D_real)) - torch.mean((1 - torch.exp(D_fake)) / (torch.exp(D_fake))))

每种散度对应不同的f函数,会导致模型有不同的优化目标。默认使用的是反向KL散度。

2. 训练循环

训练过程遵循标准的GAN训练流程,交替训练判别器和生成器:

  1. 采样真实数据和噪声数据
  2. 计算判别器损失并更新判别器
  3. 计算生成器损失并更新生成器
  4. 定期输出训练状态和生成样本
for it in range(1000000):
    # 采样数据
    z = Variable(torch.randn(mb_size, z_dim))
    X, _ = mnist.train.next_batch(mb_size)
    X = Variable(torch.from_numpy(X))
    
    # 判别器训练
    G_sample = G(z)
    D_real = D(X)
    D_fake = D(G_sample)
    D_loss = ...  # 选择一种散度计算损失
    D_loss.backward()
    D_solver.step()
    reset_grad()
    
    # 生成器训练
    G_sample = G(z)
    D_fake = D(G_sample)
    G_loss = ...  # 对应散度的生成器损失
    G_loss.backward()
    G_solver.step()
    reset_grad()
    
    # 定期输出
    if it % 1000 == 0:
        print('Iter-{}; D_loss: {:.4}; G_loss: {:.4}'.format(it, D_loss.data[0], G_loss.data[0]))
        # 保存生成样本图像...

关键点解析

  1. f-散度的实现:代码中展示了如何将不同的f-散度转化为GAN的损失函数,这是F-GAN的核心贡献。

  2. 生成器设计:使用简单的全连接网络,适合MNIST这种低分辨率图像生成。

  3. 判别器设计:同样使用全连接网络,输出单个标量表示输入的真实性。

  4. 训练稳定性:使用Adam优化器有助于稳定训练过程,这是GAN训练中的常见做法。

实际应用建议

  1. 散度选择:不同散度适合不同场景。反向KL倾向于生成"安全"样本,前向KL倾向于覆盖所有模式,Pearson χ²对异常值更鲁棒。

  2. 网络架构:对于更复杂的数据(如CIFAR、ImageNet),应考虑使用卷积网络。

  3. 训练技巧:可以尝试添加梯度惩罚、谱归一化等技术来稳定训练。

  4. 评估指标:实际应用中应引入FID、IS等量化指标评估生成质量。

总结

这个F-GAN实现清晰地展示了如何将f-散度理论转化为实际的GAN训练代码。通过修改f函数,可以方便地实现多种GAN变体,为研究不同散度对生成效果的影响提供了良好基础。代码结构清晰,适合作为学习F-GAN的入门示例。