首页
/ wiseodd/generative-models中的F-GAN实现解析

wiseodd/generative-models中的F-GAN实现解析

2025-07-07 04:14:47作者:蔡怀权

概述

本文主要分析wiseodd/generative-models项目中关于F-GAN(f-divergence GAN)的TensorFlow实现。F-GAN是一种基于f-散度的生成对抗网络框架,它通过不同的f-散度选择可以衍生出多种GAN变体。该实现展示了如何使用TensorFlow构建一个基础的F-GAN模型,并在MNIST数据集上进行训练。

F-GAN理论基础

F-GAN的核心思想是利用f-散度来衡量生成分布和真实分布之间的差异。f-散度是一类更广泛的散度度量,包括KL散度、JS散度等作为特例。在F-GAN中,判别器实际上是在估计f-散度的变分下界。

代码结构解析

1. 数据准备与参数设置

mb_size = 32       # 批处理大小
X_dim = 784        # 输入维度(MNIST图像展平后)
z_dim = 64         # 隐变量维度
h_dim = 128        # 隐藏层维度
lr = 1e-3          # 学习率
d_steps = 3        # 判别器训练步数

mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)

这部分代码设置了模型的基本超参数,并加载了MNIST数据集。值得注意的是,这里使用了较小的批处理大小(32)和相对简单的网络结构。

2. 网络架构

生成器网络

def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)
    return G_prob

生成器采用了两层全连接网络:

  1. 第一层:隐变量z → 隐藏层(ReLU激活)
  2. 第二层:隐藏层 → 输出层(Sigmoid激活)

判别器网络

def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    out = tf.matmul(D_h1, D_W2) + D_b2
    return out

判别器同样采用了两层全连接网络,但没有在输出层使用激活函数,因为不同的f-散度需要不同的输出处理方式。

3. 损失函数实现

该实现提供了多种f-散度的选择,通过注释不同的损失函数部分可以切换不同的GAN变体:

""" Total Variation """
# D_loss = -(tf.reduce_mean(0.5 * tf.nn.tanh(D_real)) -
#            tf.reduce_mean(0.5 * tf.nn.tanh(D_fake)))
# G_loss = -tf.reduce_mean(0.5 * tf.nn.tanh(D_fake))

""" Forward KL """
# D_loss = -(tf.reduce_mean(D_real) - tf.reduce_mean(tf.exp(D_fake - 1)))
# G_loss = -tf.reduce_mean(tf.exp(D_fake - 1))

""" Reverse KL """
# D_loss = -(tf.reduce_mean(-tf.exp(D_real)) - tf.reduce_mean(-1 - D_fake))
# G_loss = -tf.reduce_mean(-1 - D_fake)

""" Pearson Chi-squared """
D_loss = -(tf.reduce_mean(D_real) - tf.reduce_mean(0.25*D_fake**2 + D_fake))
G_loss = -tf.reduce_mean(0.25*D_fake**2 + D_fake)

""" Squared Hellinger """
# D_loss = -(tf.reduce_mean(1 - tf.exp(D_real)) -
#            tf.reduce_mean((1 - tf.exp(D_fake)) / (tf.exp(D_fake))))
# G_loss = -tf.reduce_mean((1 - tf.exp(D_fake)) / (tf.exp(D_fake)))

默认使用的是Pearson Chi-squared散度,这种选择会影响模型的训练动态和生成质量。

4. 训练过程

for it in range(1000000):
    X_mb, _ = mnist.train.next_batch(mb_size)
    z_mb = sample_z(mb_size, z_dim)

    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, z: z_mb})
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={z: z_mb})

训练过程采用标准的GAN训练方式,交替训练判别器和生成器。每1000次迭代会保存生成的样本图像。

关键实现细节

  1. 参数初始化:使用了Xavier初始化方法,这对于深度神经网络的训练稳定性很重要。

  2. 隐变量采样:从均匀分布中采样隐变量,范围在[-1, 1]之间。

  3. 优化器:使用Adam优化器,学习率设置为0.001。

  4. 生成样本可视化:使用matplotlib将生成的样本保存为图像,方便观察训练进展。

实际应用建议

  1. 尝试不同f-散度:可以取消注释不同的损失函数部分,比较不同f-散度对生成质量的影响。

  2. 调整网络结构:可以尝试增加网络深度或宽度,观察对生成效果的影响。

  3. 监控训练过程:除了损失函数值,还应该定期检查生成的样本质量。

  4. 超参数调优:可以尝试调整学习率、批处理大小等超参数,找到最佳配置。

总结

这个F-GAN实现展示了如何使用TensorFlow构建一个灵活的GAN框架,通过选择不同的f-散度可以得到不同的GAN变体。代码结构清晰,适合作为学习GAN原理和实践的起点。对于想要深入理解GAN内部机制的开发者来说,这是一个很好的参考实现。