首页
/ PyTorch-GAN项目中的CycleGAN实现详解

PyTorch-GAN项目中的CycleGAN实现详解

2025-07-05 08:25:25作者:冯爽妲Honey

概述

CycleGAN是一种用于图像到图像转换的无监督学习框架,它能够在没有成对训练数据的情况下学习两个不同域之间的映射关系。本文将以PyTorch-GAN项目中的CycleGAN实现为例,深入解析其核心架构和训练过程。

核心组件

1. 生成器网络

CycleGAN使用了基于残差块的生成器架构:

G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
  • 采用ResNet风格的生成器,包含9个残差块(默认)
  • 输入输出维度相同,实现图像域A到B和B到A的双向转换
  • 使用instance normalization提升生成质量

2. 判别器网络

D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)
  • 使用PatchGAN判别器架构
  • 输出一个N×N的矩阵,每个元素对应输入图像的一个局部区域的真实性判断
  • 有助于生成高频细节

训练过程

1. 损失函数

CycleGAN使用了三种关键损失函数:

criterion_GAN = torch.nn.MSELoss()  # 对抗损失
criterion_cycle = torch.nn.L1Loss()  # 循环一致性损失
criterion_identity = torch.nn.L1Loss()  # 身份损失
  • 对抗损失:确保生成的图像在目标域中看起来真实
  • 循环一致性损失:保证转换后的图像能重建回原始图像
  • 身份损失:帮助保留输入图像的色彩分布

2. 优化器配置

optimizer_G = torch.optim.Adam(...)  # 生成器优化器
optimizer_D_A = torch.optim.Adam(...)  # 判别器A优化器
optimizer_D_B = torch.optim.Adam(...)  # 判别器B优化器

使用Adam优化器,初始学习率设为0.0002,动量参数β1=0.5,β2=0.999

3. 学习率调度

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(...)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(...)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(...)

实现线性学习率衰减,从decay_epoch(默认100)开始逐渐降低学习率

关键训练步骤

1. 生成器训练

# 身份损失
loss_id_A = criterion_identity(G_BA(real_A), real_A)
loss_id_B = criterion_identity(G_AB(real_B), real_B)

# GAN损失
fake_B = G_AB(real_A)
loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)

# 循环一致性损失
recov_A = G_BA(fake_B)
loss_cycle_A = criterion_cycle(recov_A, real_A)

生成器需要同时优化三种损失,总损失为加权和:

loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity

2. 判别器训练

# 真实样本损失
loss_real = criterion_GAN(D_A(real_A), valid)

# 生成样本损失
fake_A_ = fake_A_buffer.push_and_pop(fake_A)
loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)

# 总损失
loss_D_A = (loss_real + loss_fake) / 2

使用历史生成的图像缓冲区(fake_A_buffer)来稳定训练过程

数据预处理

transforms_ = [
    transforms.Resize(int(opt.img_height * 1.12), 
    transforms.RandomCrop((opt.img_height, opt.img_width)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
  • 随机缩放裁剪增强数据多样性
  • 随机水平翻转增加数据量
  • 像素值归一化到[-1,1]范围

模型保存与验证

# 保存模型检查点
torch.save(G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch))

# 生成验证图像
sample_images(batches_done)
  • 定期保存模型权重
  • 在验证集上生成样本图像,直观评估模型性能

训练技巧

  1. 使用历史缓冲区:存储之前生成的图像用于判别器训练,提高稳定性
  2. 学习率衰减:后期降低学习率帮助模型收敛
  3. 损失权重调整:通过λ_cyc和λ_id平衡不同损失项的影响
  4. 批量归一化:生成器中使用instance normalization而非batch normalization

总结

PyTorch-GAN项目中的CycleGAN实现提供了一个清晰、模块化的框架,展示了如何在没有成对数据的情况下实现跨域图像转换。通过精心设计的损失函数和训练策略,该实现能够产生高质量的图像转换结果,是学习无监督图像转换的优秀范例。