wiseodd/generative-models中的F-GAN实现解析
概述
本文将深入分析wiseodd/generative-models项目中关于F-GAN的PyTorch实现。F-GAN是一种基于f-散度的生成对抗网络框架,它通过不同的f函数可以推导出多种GAN变体。这个实现展示了如何使用PyTorch构建一个基础的F-GAN模型,并在MNIST数据集上进行训练。
F-GAN理论基础
F-GAN的核心思想是利用f-散度(f-divergence)来衡量生成分布与真实分布之间的差异。f-散度是一类通用的散度度量,包括KL散度、JS散度等常见散度作为特例。F-GAN通过选择不同的凸函数f,可以推导出不同的GAN变体。
代码结构解析
1. 数据准备与参数设置
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 32
z_dim = 10
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
lr = 1e-3
这部分代码完成了:
- 加载MNIST数据集
- 设置批处理大小(mb_size)为32
- 潜在空间维度(z_dim)设为10
- 输入维度(X_dim)由MNIST图像决定(784维)
- 隐藏层维度(h_dim)设为128
- 学习率(lr)设为0.001
2. 网络架构
生成器(G)和判别器(D)都采用简单的全连接网络:
G = torch.nn.Sequential(
torch.nn.Linear(z_dim, h_dim),
torch.nn.ReLU(),
torch.nn.Linear(h_dim, X_dim),
torch.nn.Sigmoid()
)
D = torch.nn.Sequential(
torch.nn.Linear(X_dim, h_dim),
torch.nn.ReLU(),
torch.nn.Linear(h_dim, 1),
)
生成器将10维噪声通过两层全连接网络映射到784维(MNIST图像维度),使用ReLU激活函数和Sigmoid输出。判别器则将784维输入映射到1维输出(真假判断)。
3. 优化器设置
G_solver = optim.Adam(G.parameters(), lr=lr)
D_solver = optim.Adam(D.parameters(), lr=lr)
使用Adam优化器分别优化生成器和判别器,学习率相同。
F-GAN的核心实现
1. 散度选择
代码中实现了多种f-散度对应的损失函数,通过注释可以切换:
# 总变分散度(Total Variation)
# D_loss = -(torch.mean(0.5 * torch.tanh(D_real)) - torch.mean(0.5 * torch.tanh(D_fake)))
# 前向KL散度(Forward KL)
# D_loss = -(torch.mean(D_real) - torch.mean(torch.exp(D_fake - 1)))
# 反向KL散度(Reverse KL)
D_loss = -(torch.mean(-torch.exp(D_real)) - torch.mean(-1 - D_fake))
# Pearson χ²散度
# D_loss = -(torch.mean(D_real) - torch.mean(0.25*D_fake**2 + D_fake))
# Squared Hellinger散度
# D_loss = -(torch.mean(1 - torch.exp(D_real)) - torch.mean((1 - torch.exp(D_fake)) / (torch.exp(D_fake))))
每种散度对应不同的f函数,会导致模型有不同的优化目标。默认使用的是反向KL散度。
2. 训练循环
训练过程遵循标准的GAN训练流程,交替训练判别器和生成器:
- 采样真实数据和噪声数据
- 计算判别器损失并更新判别器
- 计算生成器损失并更新生成器
- 定期输出训练状态和生成样本
for it in range(1000000):
# 采样数据
z = Variable(torch.randn(mb_size, z_dim))
X, _ = mnist.train.next_batch(mb_size)
X = Variable(torch.from_numpy(X))
# 判别器训练
G_sample = G(z)
D_real = D(X)
D_fake = D(G_sample)
D_loss = ... # 选择一种散度计算损失
D_loss.backward()
D_solver.step()
reset_grad()
# 生成器训练
G_sample = G(z)
D_fake = D(G_sample)
G_loss = ... # 对应散度的生成器损失
G_loss.backward()
G_solver.step()
reset_grad()
# 定期输出
if it % 1000 == 0:
print('Iter-{}; D_loss: {:.4}; G_loss: {:.4}'.format(it, D_loss.data[0], G_loss.data[0]))
# 保存生成样本图像...
关键点解析
-
f-散度的实现:代码中展示了如何将不同的f-散度转化为GAN的损失函数,这是F-GAN的核心贡献。
-
生成器设计:使用简单的全连接网络,适合MNIST这种低分辨率图像生成。
-
判别器设计:同样使用全连接网络,输出单个标量表示输入的真实性。
-
训练稳定性:使用Adam优化器有助于稳定训练过程,这是GAN训练中的常见做法。
实际应用建议
-
散度选择:不同散度适合不同场景。反向KL倾向于生成"安全"样本,前向KL倾向于覆盖所有模式,Pearson χ²对异常值更鲁棒。
-
网络架构:对于更复杂的数据(如CIFAR、ImageNet),应考虑使用卷积网络。
-
训练技巧:可以尝试添加梯度惩罚、谱归一化等技术来稳定训练。
-
评估指标:实际应用中应引入FID、IS等量化指标评估生成质量。
总结
这个F-GAN实现清晰地展示了如何将f-散度理论转化为实际的GAN训练代码。通过修改f函数,可以方便地实现多种GAN变体,为研究不同散度对生成效果的影响提供了良好基础。代码结构清晰,适合作为学习F-GAN的入门示例。