基于wiseodd/generative-models的Wasserstein GAN梯度惩罚实现解析
2025-07-07 04:16:58作者:羿妍玫Ivan
引言
生成对抗网络(GAN)是近年来深度学习领域最具突破性的技术之一,而Wasserstein GAN(WGAN)则是GAN系列中非常重要的改进版本。本文将深入解析wiseodd/generative-models项目中WGAN-GP(带梯度惩罚的Wasserstein GAN)的TensorFlow实现,帮助读者理解其核心思想和实现细节。
WGAN-GP的核心思想
传统GAN使用JS散度作为衡量生成分布与真实分布差异的指标,而WGAN则创新性地使用Wasserstein距离(又称Earth-Mover距离)。WGAN-GP在WGAN基础上进一步改进,通过梯度惩罚(Gradient Penalty)替代权重裁剪,解决了WGAN训练不稳定的问题。
代码结构解析
1. 参数设置与数据准备
mb_size = 32 # 批处理大小
X_dim = 784 # 输入维度(MNIST图像展平后)
z_dim = 10 # 噪声维度
h_dim = 128 # 隐藏层维度
lam = 10 # 梯度惩罚系数
n_disc = 5 # 判别器训练次数/生成器训练次数比
lr = 1e-4 # 学习率
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
这部分设置了模型的基本参数,并加载了MNIST数据集。值得注意的是,WGAN-GP通常需要比传统GAN更小的学习率(这里设为1e-4)。
2. 网络架构
生成器(G)架构
def G(z):
G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
G_prob = tf.nn.sigmoid(G_log_prob)
return G_prob
生成器采用简单的两层全连接网络:
- 第一层:ReLU激活函数
- 第二层:Sigmoid激活函数(将输出限制在[0,1]区间)
判别器(D)架构
def D(X):
D_h1 = tf.nn.relu(tf.matmul(X, D_W1) + D_b1)
out = tf.matmul(D_h1, D_W2) + D_b2
return out
判别器同样采用两层全连接网络,但特别注意:
- 最后一层不使用Sigmoid激活函数
- 输出为实数(不限制范围),这是WGAN与普通GAN的重要区别
3. 梯度惩罚实现
这是WGAN-GP最核心的创新点:
eps = tf.random_uniform([mb_size, 1], minval=0., maxval=1.)
X_inter = eps*X + (1. - eps)*G_sample
grad = tf.gradients(D(X_inter), [X_inter])[0]
grad_norm = tf.sqrt(tf.reduce_sum((grad)**2, axis=1))
grad_pen = lam * tf.reduce_mean((grad_norm - 1)**2)
这段代码实现了:
- 在真实样本和生成样本之间随机插值
- 计算判别器对这些插值点的梯度
- 惩罚梯度范数偏离1的情况(Lipschitz约束)
4. 损失函数与优化
D_loss = tf.reduce_mean(D_fake) - tf.reduce_mean(D_real) + grad_pen
G_loss = -tf.reduce_mean(D_fake)
D_solver = (tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5)
.minimize(D_loss, var_list=theta_D))
G_solver = (tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5)
.minimize(G_loss, var_list=theta_G))
WGAN的损失函数特点:
- 判别器试图最大化真实样本与生成样本的判别器输出差距
- 生成器试图最大化生成样本的判别器输出
- 加入了梯度惩罚项
训练过程分析
训练循环展示了WGAN-GP的关键训练技巧:
for it in range(1000000):
for _ in range(n_disc): # 判别器训练多次
# 训练判别器...
# 训练生成器...
# 定期保存生成样本
值得注意的是:
- 判别器训练次数(n_disc=5)多于生成器,这是WGAN训练的常见策略
- 使用Adam优化器但降低beta1(动量参数)为0.5,有助于训练稳定性
- 定期可视化生成结果,方便监控训练过程
实现细节与技巧
-
权重初始化:使用Xavier初始化,有助于网络训练的稳定性
def xavier_init(size): in_dim = size[0] xavier_stddev = 1. / tf.sqrt(in_dim / 2.) return tf.random_normal(shape=size, stddev=xavier_stddev)
-
噪声采样:从均匀分布中采样噪声向量
def sample_z(m, n): return np.random.uniform(-1., 1., size=[m, n])
-
结果可视化:使用matplotlib的gridspec创建规整的图像网格
def plot(samples): # 创建4x4的图像网格 # ...
总结
通过wiseodd/generative-models中的WGAN-GP实现,我们可以学习到:
- WGAN-GP通过梯度惩罚优雅地实现了Lipschitz约束,比权重裁剪更稳定
- 判别器输出不需要经过激活函数,直接作为Wasserstein距离的估计
- 训练时需要平衡判别器和生成器的更新频率
- 适当调整优化器参数(如降低Adam的beta1)有助于训练稳定性
这个实现虽然简洁,但包含了WGAN-GP的所有关键要素,是学习Wasserstein GAN的优秀参考代码。读者可以在此基础上进行扩展,如尝试更复杂的网络结构、应用到其他数据集等。