首页
/ 基于wiseodd/generative-models的Wasserstein GAN梯度惩罚实现解析

基于wiseodd/generative-models的Wasserstein GAN梯度惩罚实现解析

2025-07-07 04:16:58作者:羿妍玫Ivan

引言

生成对抗网络(GAN)是近年来深度学习领域最具突破性的技术之一,而Wasserstein GAN(WGAN)则是GAN系列中非常重要的改进版本。本文将深入解析wiseodd/generative-models项目中WGAN-GP(带梯度惩罚的Wasserstein GAN)的TensorFlow实现,帮助读者理解其核心思想和实现细节。

WGAN-GP的核心思想

传统GAN使用JS散度作为衡量生成分布与真实分布差异的指标,而WGAN则创新性地使用Wasserstein距离(又称Earth-Mover距离)。WGAN-GP在WGAN基础上进一步改进,通过梯度惩罚(Gradient Penalty)替代权重裁剪,解决了WGAN训练不稳定的问题。

代码结构解析

1. 参数设置与数据准备

mb_size = 32        # 批处理大小
X_dim = 784         # 输入维度(MNIST图像展平后)
z_dim = 10          # 噪声维度
h_dim = 128         # 隐藏层维度
lam = 10            # 梯度惩罚系数
n_disc = 5          # 判别器训练次数/生成器训练次数比
lr = 1e-4           # 学习率

mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)

这部分设置了模型的基本参数,并加载了MNIST数据集。值得注意的是,WGAN-GP通常需要比传统GAN更小的学习率(这里设为1e-4)。

2. 网络架构

生成器(G)架构

def G(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)
    return G_prob

生成器采用简单的两层全连接网络:

  1. 第一层:ReLU激活函数
  2. 第二层:Sigmoid激活函数(将输出限制在[0,1]区间)

判别器(D)架构

def D(X):
    D_h1 = tf.nn.relu(tf.matmul(X, D_W1) + D_b1)
    out = tf.matmul(D_h1, D_W2) + D_b2
    return out

判别器同样采用两层全连接网络,但特别注意:

  • 最后一层不使用Sigmoid激活函数
  • 输出为实数(不限制范围),这是WGAN与普通GAN的重要区别

3. 梯度惩罚实现

这是WGAN-GP最核心的创新点:

eps = tf.random_uniform([mb_size, 1], minval=0., maxval=1.)
X_inter = eps*X + (1. - eps)*G_sample
grad = tf.gradients(D(X_inter), [X_inter])[0]
grad_norm = tf.sqrt(tf.reduce_sum((grad)**2, axis=1))
grad_pen = lam * tf.reduce_mean((grad_norm - 1)**2)

这段代码实现了:

  1. 在真实样本和生成样本之间随机插值
  2. 计算判别器对这些插值点的梯度
  3. 惩罚梯度范数偏离1的情况(Lipschitz约束)

4. 损失函数与优化

D_loss = tf.reduce_mean(D_fake) - tf.reduce_mean(D_real) + grad_pen
G_loss = -tf.reduce_mean(D_fake)

D_solver = (tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5)
            .minimize(D_loss, var_list=theta_D))
G_solver = (tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5)
            .minimize(G_loss, var_list=theta_G))

WGAN的损失函数特点:

  • 判别器试图最大化真实样本与生成样本的判别器输出差距
  • 生成器试图最大化生成样本的判别器输出
  • 加入了梯度惩罚项

训练过程分析

训练循环展示了WGAN-GP的关键训练技巧:

for it in range(1000000):
    for _ in range(n_disc):  # 判别器训练多次
        # 训练判别器...
    
    # 训练生成器...
    
    # 定期保存生成样本

值得注意的是:

  1. 判别器训练次数(n_disc=5)多于生成器,这是WGAN训练的常见策略
  2. 使用Adam优化器但降低beta1(动量参数)为0.5,有助于训练稳定性
  3. 定期可视化生成结果,方便监控训练过程

实现细节与技巧

  1. 权重初始化:使用Xavier初始化,有助于网络训练的稳定性

    def xavier_init(size):
        in_dim = size[0]
        xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
        return tf.random_normal(shape=size, stddev=xavier_stddev)
    
  2. 噪声采样:从均匀分布中采样噪声向量

    def sample_z(m, n):
        return np.random.uniform(-1., 1., size=[m, n])
    
  3. 结果可视化:使用matplotlib的gridspec创建规整的图像网格

    def plot(samples):
        # 创建4x4的图像网格
        # ...
    

总结

通过wiseodd/generative-models中的WGAN-GP实现,我们可以学习到:

  1. WGAN-GP通过梯度惩罚优雅地实现了Lipschitz约束,比权重裁剪更稳定
  2. 判别器输出不需要经过激活函数,直接作为Wasserstein距离的估计
  3. 训练时需要平衡判别器和生成器的更新频率
  4. 适当调整优化器参数(如降低Adam的beta1)有助于训练稳定性

这个实现虽然简洁,但包含了WGAN-GP的所有关键要素,是学习Wasserstein GAN的优秀参考代码。读者可以在此基础上进行扩展,如尝试更复杂的网络结构、应用到其他数据集等。