生成对抗网络(GAN)详解(代码实现)
GANs 的基本概念
This framework can yield specific training algorithms for many kinds of model and optimization algorithm. In this article, we explore the special case when the generative model generates samples by passing random noise through a multilayer perceptron, and the discriminative model is also a multilayer perceptron. We refer to this special case as adversarial nets.
该框架能够为多种模型和优化算法生成特定的训练算法。在本文中,我们探讨了生成模型通过多层感知机传递随机噪声来生成样本,而判别模型也是多层感知机的这种特殊情况。我们将这种特殊情况称为对抗网络。
GANs 由两个核心部分组成:
- 生成器(Generator):从随机噪声生成逼真的数据。比如,它可以从一堆随机数字生成类似真实手写数字的图像。
- 判别器(Discriminator):判断输入的数据是真实的(来自数据集)还是生成的(来自生成器)。它的输出是一个概率值,表示“这个数据看起来是真的”的可能性。
这两个网络是对抗训练的:生成器努力“骗过”判别器,生成更真实的数据;而判别器努力提高自己的分辨能力。这种对抗过程最终让生成器学会生成非常逼真的图像。
实现步骤
以生成 MNIST 数据集中的手写数字为例,逐步实现
1. 数据准备
首先,我们需要准备训练数据。MNIST 数据集包含 6 万张 28x28 像素的灰度手写数字图像。我们会用这些真实图像来训练判别器,并作为生成器的学习目标。
2. 网络架构
生成器
生成器的任务是从随机噪声生成图像。假设输入是一个 100 维的随机噪声向量,输出是一个 28x28 的图像(展开为 784 维向量)。一个简单的生成器结构可能是:
- 输入层:100 维噪声向量
- 隐藏层 1:256 个神经元,ReLU 激活函数
- 隐藏层 2:512 个神经元,ReLU 激活函数
- 隐藏层 3:1024 个神经元,ReLU 激活函数
- 输出层:784 个神经元(28x28),Tanh 激活函数(将像素值限制在 -1 到 1 之间)
判别器
判别器的任务是判断图像的真假。输入是 28x28 的图像(展开为 784 维向量),输出是一个概率值。一个简单的判别器结构可能是:
- 输入层:784 个神经元(28x28 的图像)
- 隐藏层 1:512 个神经元,Leaky ReLU 激活函数
- 隐藏层 2:256 个神经元,Leaky ReLU 激活函数
- 输出层:1 个神经元,Sigmoid 激活函数(输出 0 到 1 之间的概率)
3. 训练过程
GANs 的训练是一个博弈过程,包含以下步骤:
- 从真实数据中采样:取一批真实图像(比如 64 张 MNIST 图像)。
- 生成假图像:从随机噪声中生成一批假图像。
- 训练判别器:
- 用真实图像(标记为 1)和假图像(标记为 0)训练判别器。
- 计算损失并更新判别器参数。
- 训练生成器:
- 生成一批假图像,让判别器判断,但这次目标是让判别器认为它们是真的(标记为 1)。
- 计算损失并更新生成器参数。
这个过程不断重复,直到生成器生成的图像足够逼真。
代码示例
下面是用 PyTorch 实现的一个简单 GAN 示例。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 设置超参数
batch_size = 64 # 每批次处理的图像数量
lr = 0.0002 # 学习率
num_epochs = 100 # 训练轮数
noise_dim = 100 # 噪声向量的维度# 数据加载和预处理
transform = transforms.Compose([transforms.ToTensor(), # 转换为张量transforms.Normalize((0.5,), (0.5,)) # 标准化到 [-1, 1]
])
mnist = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(mnist, batch_size=batch_size, shuffle=True)# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = nn.Sequential(nn.Linear(noise_dim, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 1024),nn.ReLU(),nn.Linear(1024, 784),nn.Tanh() # 输出范围 [-1, 1])def forward(self, z):return self.model(z)# 定义判别器
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(784, 512),nn.LeakyReLU(0.2), # Leaky ReLU 防止梯度消失nn.Linear(512, 256),nn.LeakyReLU(0.2),nn.Linear(256, 1),nn.Sigmoid() # 输出概率 [0, 1])def forward(self, img):img_flat = img.view(img.size(0), -1) # 将图像展平为 784 维向量return self.model(img_flat)# 实例化模型
generator = Generator()
discriminator = Discriminator()# 定义损失函数和优化器
criterion = nn.BCELoss() # 二元交叉熵损失
optimizer_g = optim.Adam(generator.parameters(), lr=lr)
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr)# 训练循环
for epoch in range(num_epochs):for i, (real_imgs, _) in enumerate(data_loader):batch_size = real_imgs.size(0)# 训练判别器optimizer_d.zero_grad()real_labels = torch.ones(batch_size, 1) # 真实图像标签为 1fake_labels = torch.zeros(batch_size, 1) # 假图像标签为 0# 用真实图像训练判别器real_output = discriminator(real_imgs)d_loss_real = criterion(real_output, real_labels)# 用生成图像训练判别器z = torch.randn(batch_size, noise_dim) # 随机噪声fake_imgs = generator(z)fake_output = discriminator(fake_imgs.detach()) # detach 防止梯度传回生成器d_loss_fake = criterion(fake_output, fake_labels)# 总判别器损失d_loss = d_loss_real + d_loss_faked_loss.backward()optimizer_d.step()# 训练生成器optimizer_g.zero_grad()z = torch.randn(batch_size, noise_dim)fake_imgs = generator(z)fake_output = discriminator(fake_imgs)g_loss = criterion(fake_output, real_labels) # 目标是让判别器认为是真的g_loss.backward()optimizer_g.step()# 每轮打印损失print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')
代码解释
-
数据加载:
- 使用
torchvision
加载 MNIST 数据集。 transforms.Normalize((0.5,), (0.5,))
将像素值从 [0, 1] 标准化到 [-1, 1],与生成器的 Tanh 输出一致。
- 使用
-
生成器:
- 输入 100 维噪声,经过多层全连接网络,输出 784 维向量(28x28 图像)。
- 使用
Tanh
激活函数,确保输出范围与数据匹配。
-
判别器:
- 输入 784 维图像向量,输出一个概率值。
- 使用
LeakyReLU
防止梯度消失,Sigmoid
输出概率。
-
训练过程:
- 判别器:用真实图像和生成图像分别计算损失,然后更新参数。
- 生成器:生成假图像,让判别器判断,并希望判别器输出接近 1,更新生成器参数。
- 使用
Adam
优化器和二元交叉熵损失(BCELoss
)。
-
输出:
- 每轮训练后,打印判别器损失(
d_loss
)和生成器损失(g_loss
),观察训练进展。
- 每轮训练后,打印判别器损失(
总结
通过这个例子,可以看到 GANs 的具体实现其实是将生成器和判别器的对抗思想转化为代码。生成器从噪声生成图像,判别器判断真假,两者交替训练,最终生成器能生成逼真的手写数字。如果运行这段代码,经过足够多的轮次(比如 100 轮),会发现生成的图像越来越接近真实的 MNIST 数字。
附:论文中提到了马尔可夫链(Markov Chain)和近似推理网络(Unrolled Approximate Inference Networks),它们是什么?
1. 马尔可夫链(Markov Chain)
什么是马尔可夫链?
马尔可夫链是一种随机过程,用来描述系统在不同状态之间转换的规律。它的核心特点是无记忆性,也就是说,未来的状态只依赖于当前状态,而与更早的状态历史无关。这种性质也被称为“马尔可夫性质”。
生活中的例子
想象一个简单的天气预测场景:
- 今天是晴天,明天是晴天的概率是 70%,下雨的概率是 30%。
- 今天是雨天,明天是晴天的概率是 50%,下雨的概率也是 50%。
- 明天天气如何,只取决于今天是什么天气,而不管昨天或更早的天气。
在这个例子中,天气的状态(晴天或雨天)形成了一个马尔可夫链。
在生成模型中的应用
在生成模型中,马尔可夫链常被用来从复杂的概率分布中抽取样本。例如,马尔可夫链蒙特卡罗(MCMC)方法利用马尔可夫链逐步调整样本,使其逐渐逼近目标分布(比如真实图像的分布)。不过,这种方法需要运行很多步才能稳定,计算成本较高,尤其是在高维数据(如图像)上,收敛速度较慢。
代码示例:模拟天气变化
下面是一个简单的Python代码,模拟基于马尔可夫链的天气变化:
import numpy as np# 状态转移矩阵
# 行是当前状态,列是下一状态
P = np.array([[0.7, 0.3], # 晴天 -> 晴天: 0.7, 晴天 -> 雨天: 0.3[0.5, 0.5]]) # 雨天 -> 晴天: 0.5, 雨天 -> 雨天: 0.5# 初始状态:0 = 晴天
current_state = 0# 模拟 10 天的天气
for day in range(1, 11):next_state = np.random.choice([0, 1], p=P[current_state])print(f"第 {day} 天: {'晴天' if next_state == 0 else '雨天'}")current_state = next_state
输出示例(结果随机):
第 1 天: 晴天
第 2 天: 晴天
第 3 天: 雨天
第 4 天: 晴天
...
代码解释:
P
是状态转移矩阵,定义了从当前状态到下一状态的概率。np.random.choice
根据概率随机选择下一状态。- 每一天的天气只取决于前一天,体现了马尔可夫链的无记忆性。
2. 展开的近似推理网络(Unrolled Approximate Inference Networks)
什么是推理网络?
在生成模型中,例如变分自编码器(VAE),我们需要推断隐藏变量(也叫潜在变量)的分布。推理网络是一个神经网络,用来近似这个后验分布。简单来说,给定观测数据(比如一张图片),推理网络预测潜在变量的可能分布(比如图片的特征)。
“展开的”是什么意思?
“展开的”推理网络指的是在训练或推理过程中,网络通过多步迭代来更精确地逼近真实的后验分布,而不是只进行一次预测。这种方法类似于反复优化答案,每次都比上一次更接近正确结果。
生活中的例子
假设你在猜一个谜语:
- 普通推理网络:听完问题后,直接给出一个答案。
- 展开的推理网络:先给一个初步答案,然后再思考、调整,重复几次,直到答案更准确。
在生成模型中的应用
在VAE中,普通的推理网络通过单次前向传播预测潜在变量的分布。而展开的推理网络会进行多步优化(比如通过梯度下降或更复杂的网络结构),提高对后验分布的近似精度。这样可以生成更高质量的样本,但也会增加计算复杂度和训练时间。
代码示例:VAE中的推理网络
下面是一个简单的VAE推理网络实现(基于PyTorch),并示意如何“展开”它:
import torch
import torch.nn as nn# 定义推理网络(编码器)
class InferenceNetwork(nn.Module):def __init__(self, input_dim, hidden_dim, latent_dim):super(InferenceNetwork, self).__init__()self.fc1 = nn.Linear(input_dim, hidden_dim)self.fc2_mean = nn.Linear(hidden_dim, latent_dim)self.fc2_logvar = nn.Linear(hidden_dim, latent_dim)def forward(self, x):h = torch.relu(self.fc1(x))mean = self.fc2_mean(h)logvar = self.fc2_logvar(h)return mean, logvar# 定义生成网络(解码器)
generative_net = nn.Sequential(nn.Linear(10, 256),nn.ReLU(),nn.Linear(256, 784),nn.Sigmoid()
)# 简单的VAE模型
class VAE(nn.Module):def __init__(self, inference_net, generative_net):super(VAE, self).__init__()self.inference_net = inference_netself.generative_net = generative_netdef forward(self, x):mean, logvar = self.inference_net(x)z = self.reparameterize(mean, logvar)recon_x = self.generative_net(z)return recon_x, mean, logvardef reparameterize(self, mean, logvar):std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mean + eps * std# 实例化模型
inference_net = InferenceNetwork(input_dim=784, hidden_dim=256, latent_dim=10)
vae = VAE(inference_net, generative_net)
代码解释:
InferenceNetwork
是一个神经网络,输入是数据(比如展平的28×28图像,维度为784),输出是潜在变量的均值和方差。reparameterize
是VAE中的重参数化技巧,用于从分布中采样潜在变量z
。generative_net
将潜在变量z
转换回重构数据。
示意“展开的”推理网络
普通推理网络只运行一次 forward
,而展开的推理网络可以通过多步迭代优化潜在变量。例如:
def unrolled_inference(x, inference_net, generative_net, num_steps=5, learning_rate=0.01):mean, logvar = inference_net(x)z = mean # 初始估计for _ in range(num_steps):recon_x = generative_net(z)loss = ((recon_x - x) ** 2).sum() # 重构损失grad = torch.autograd.grad(loss, z, retain_graph=True)[0]z = z - learning_rate * grad # 梯度下降更新 zreturn z# 示例调用
x = torch.randn(1, 784) # 假设输入数据
z = unrolled_inference(x, inference_net, generative_net)
代码解释:
- 从推理网络得到初始的
mean
和logvar
,并以mean
作为潜在变量z
的起点。 - 通过多步梯度下降优化
z
,使重构数据recon_x
更接近输入x
。 - 这种多步迭代体现了“展开”的思想,提高了对潜在变量分布的估计精度。
总结
-
马尔可夫链
- 是一种基于状态转移的随机过程,具有无记忆性。
- 在生成模型中,用于从复杂分布中抽取样本(如MCMC),但计算效率较低。
- 示例:天气变化模拟。
-
展开的近似推理网络
- 是一种通过神经网络近似后验分布的方法,“展开”指的是多步迭代优化。
- 在VAE等模型中提高潜在变量估计的精度,生成更高质量样本,但计算成本较高。
- 示例:VAE中通过梯度下降优化潜在变量。
这两个概念在生成模型中很重要,但侧重点不同:马尔可夫链关注采样,展开的推理网络关注推理精度。
附:后验分布与先验分布
什么是先验分布?
先验分布(Prior Distribution)是贝叶斯统计中的一个核心概念,它是指在没有看到任何数据之前,我们对未知参数的初步假设或信念。这个信念用概率分布的形式来表达,反映了我们在数据到来之前对参数可能取值的看法。
生活中的例子
为了更好地理解先验分布,我们可以用一个简单的例子来说明。假设你在猜测一枚硬币是否公平:
- 在没有抛硬币之前,你可能会假设这枚硬币是公平的(正面和反面概率各为50%),或者你可能怀疑它有偏差(比如正面概率更高)。
- 这种初步的猜测就是先验分布。例如:
- 如果你认为硬币正面概率 p p p 在0到1之间完全等可能,这就是一个均匀分布。
- 如果你根据经验觉得硬币更可能是公平的( p p p 接近0.5),可以用一个特定的分布(比如Beta分布)来表示这种偏好。
数学上的定义
在贝叶斯统计中,先验分布通常用 p ( θ ) p(\theta) p(θ) 表示,其中 θ \theta θ 是我们想要估计的未知参数。
- 比如,在硬币的例子中, θ \theta θ 是硬币正面朝上的概率 p p p。如果我们假设 p p p 在0到1之间等可能,那么先验分布可以写成:
p ( θ ) = 1 , 0 ≤ θ ≤ 1 p(\theta) = 1, \quad 0 \leq \theta \leq 1 p(θ)=1,0≤θ≤1
这表示在没有数据之前,我们对 θ \theta θ 的任何可能值都没有特别的偏好。
先验分布在贝叶斯定理中的作用
先验分布在贝叶斯统计中之所以重要,是因为它通过贝叶斯定理与数据结合,更新我们的信念。贝叶斯定理的公式是:
p ( θ ∣ x ) = p ( x ∣ θ ) ⋅ p ( θ ) p ( x ) p(\theta | x) = \frac{p(x | \theta) \cdot p(\theta)}{p(x)} p(θ∣x)=p(x)p(x∣θ)⋅p(θ)
其中:
- p ( θ ∣ x ) p(\theta | x) p(θ∣x):后验分布,表示在看到数据 x x x 后,我们对 θ \theta θ 的更新信念。
- p ( x ∣ θ ) p(x | \theta) p(x∣θ):似然函数,表示在给定 θ \theta θ 的情况下,数据 x x x 出现的概率。
- p ( θ ) p(\theta) p(θ):先验分布,数据之前的信念。
- p ( x ) p(x) p(x):归一化常数,确保概率总和为1。
从这个公式可以看出,先验分布 p ( θ ) p(\theta) p(θ) 是计算后验分布的一个关键部分,它直接影响我们对参数的最终估计。
先验分布的类型
先验分布可以根据我们对参数的了解程度分为两种:
- 无信息先验(Uninformative Prior):当我们对参数几乎一无所知时,可以选择一个均匀分布,表示对所有可能值没有偏好。
- 信息先验(Informative Prior):如果我们有某些领域的知识或历史数据,可以选择一个具体的分布来反映这些信息,比如基于之前的实验结果假设硬币正面概率更可能接近0.5。
先验分布的重要性
先验分布在贝叶斯统计中有以下几个关键作用:
- 它允许我们将主观信念或先验知识融入统计模型中。
- 当数据量较少时,先验分布对后验分布的影响较大;当数据量足够多时,似然函数(数据本身)的影响会更显著。
- 选择一个合适的先验分布是贝叶斯分析中的重要步骤,因为它会影响最终的结论。
简单来说,先验分布是贝叶斯统计中在看到数据之前对未知参数的初步假设,用概率分布来表达我们的信念。通过贝叶斯定理,先验分布与数据(似然函数)结合,更新为更准确的后验分布。它既体现了主观性(基于我们的知识或猜测),又为统计推断提供了一个起点。
什么是后验分布?
后验分布是贝叶斯统计中的一个核心概念。简单来说,它是在给定观测数据的情况下,对未知参数的概率分布进行更新后的结果。它结合了我们事先的信念(先验分布)和新的数据信息,反映了参数更真实的可能性。
基本概念
在贝叶斯统计中,我们通常会从以下步骤理解后验分布:
- 先验分布(Prior Distribution)
在看到任何数据之前,我们对未知参数会有一个初步的假设或信念,这个假设用概率分布表示,称为先验分布。比如,我们可能假设一个硬币是公平的(正反面概率各50%),或者认为它可能有某种偏差。 - 观测数据(Data)
当我们收集到新的数据后,比如抛硬币10次得到7次正面和3次反面,这些数据会提供额外的信息。 - 后验分布(Posterior Distribution)
通过贝叶斯定理,我们结合先验分布和观测数据,更新对参数的信念,得到后验分布。后验分布比先验分布更贴近实际情况,因为它融合了数据带来的新证据。
一个生活中的例子
假设你在猜测一个硬币是否公平:
- 先验分布:在抛硬币之前,你假设硬币正面朝上的概率 p p p 是均匀分布的(即 p p p 在0到1之间等可能)。
- 观测数据:你抛了10次硬币,结果是7次正面,3次反面。
- 后验分布:利用贝叶斯定理更新信念后,你可能会发现 p p p 更可能接近0.7,而不是原来的均匀分布。
这里的后验分布就是:在知道抛硬币结果后,硬币正面概率 p p p 的新概率分布。
数学上的解释
贝叶斯定理告诉我们如何计算后验分布:
p ( θ ∣ x ) = p ( x ∣ θ ) ⋅ p ( θ ) p ( x ) p(\theta | x) = \frac{p(x | \theta) \cdot p(\theta)}{p(x)} p(θ∣x)=p(x)p(x∣θ)⋅p(θ)
- p ( θ ∣ x ) p(\theta | x) p(θ∣x):后验分布(参数 θ \theta θ 在给定数据 x x x 下的分布)。
- p ( x ∣ θ ) p(x | \theta) p(x∣θ):似然函数(数据 x x x 在给定参数 θ \theta θ 下的概率)。
- p ( θ ) p(\theta) p(θ):先验分布(数据之前对 θ \theta θ 的信念)。
- p ( x ) p(x) p(x):归一化常数(确保概率总和为1)。
简单来说,后验分布是先验分布和似然函数的"结合",反映了数据对我们信念的更新。
在实际中的用途
后验分布非常重要,因为它可以用来:
- 估计参数:比如硬币正面概率的具体值。
- 预测未来:根据后验分布预测下一次抛硬币的结果。
- 决策:基于后验分布选择最优策略。
在生成模型中的例子
在像变分自编码器(VAE)这样的模型中,后验分布也有类似的应用。假设我们用VAE生成手写数字图像:
- 潜在变量 z z z:代表图像的抽象特征,比如笔画粗细或倾斜角度。
- 后验分布 p ( z ∣ x ) p(z|x) p(z∣x):给定一张具体图像 x x x,潜在变量 z z z 的分布。
由于直接计算 p ( z ∣ x ) p(z|x) p(z∣x) 很困难,我们用神经网络(推理网络)来近似它。
总结
- 后验分布是贝叶斯统计中基于观测数据更新参数分布的结果。
- 它结合了先验信念和数据信息,比先验更准确。
- 在生活中,它就像根据新证据调整你的猜测;在模型中,它帮助我们理解数据背后的隐藏变量。
条件GAN的完整代码实现
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt# --------------------------
# 超参数设置
# --------------------------
batch_size = 64
lr = 0.0002
num_epochs = 50
noise_dim = 100 # 噪声向量维度
num_classes = 10 # MNIST类别数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# --------------------------
# 数据加载和预处理
# --------------------------
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
mnist = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(mnist, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True, drop_last=True)# --------------------------
# 生成器(条件GAN的Generator)
# --------------------------
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()# 嵌入层:将类别标签映射到一个与噪声维度相同的向量self.label_emb = nn.Embedding(num_classes, noise_dim)# 主体网络:输入为噪声+标签拼接向量,形状:(B, 2*noise_dim, 1, 1)self.main = nn.Sequential(nn.ConvTranspose2d(noise_dim * 2, 256, kernel_size=7, stride=1, padding=0, bias=False), # -> (B, 256, 7, 7)nn.BatchNorm2d(256),nn.ReLU(True),nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False), # -> (B, 128, 14, 14)nn.BatchNorm2d(128),nn.ReLU(True),nn.ConvTranspose2d(128, 1, kernel_size=4, stride=2, padding=1, bias=False), # -> (B, 1, 28, 28)nn.Tanh() # 输出范围[-1,1])def forward(self, z, labels):""":param z: (B, noise_dim) 随机噪声:param labels: (B,) 样本类别"""label_embedding = self.label_emb(labels) # (B, noise_dim)x = torch.cat([z, label_embedding], dim=1) # (B, 2*noise_dim)x = x.view(-1, 2 * noise_dim, 1, 1) # 重塑为 (B, 2*noise_dim, 1, 1)return self.main(x) # 输出 (B, 1, 28, 28)# --------------------------
# 判别器(条件GAN的Discriminator)
# --------------------------
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()# 嵌入层:将标签映射到28x28形状的向量(1通道)self.label_emb = nn.Embedding(num_classes, 28 * 28)# 主体网络:输入图像与标签拼接,通道数=1+1=2self.main = nn.Sequential(nn.Conv2d(2, 64, kernel_size=4, stride=2, padding=1, bias=False), # -> (B, 64, 14, 14)nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False), # -> (B, 128, 7, 7)nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False), # -> (B, 256, 4, 4)nn.BatchNorm2d(256),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=0, bias=False), # -> (B, 1, 1, 1)nn.Sigmoid())def forward(self, img, labels):""":param img: (B, 1, 28, 28) 输入图像:param labels: (B,) 样本类别"""label_embedding = self.label_emb(labels).view(-1, 1, 28, 28) # 将标签扩展为 (B, 1, 28, 28)x = torch.cat([img, label_embedding], dim=1) # 拼接后形状 (B, 2, 28, 28)return self.main(x).view(-1, 1) # 输出 (B, 1)# --------------------------
# 权重初始化函数
# --------------------------
def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find('BatchNorm') != -1:nn.init.normal_(m.weight.data, 1.0, 0.02)nn.init.constant_(m.bias.data, 0)# --------------------------
# 初始化网络、损失函数、优化器
# --------------------------
generator = Generator().to(device)
discriminator = Discriminator().to(device)
generator.apply(weights_init)
discriminator.apply(weights_init)
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))# 可选:使用学习率调度器,逐步降低学习率(如ReduceLROnPlateau或StepLR)
# scheduler_g = optim.lr_scheduler.StepLR(optimizer_g, step_size=20, gamma=0.5)
# scheduler_d = optim.lr_scheduler.StepLR(optimizer_d, step_size=20, gamma=0.5)# --------------------------
# 开始训练 cGAN
# --------------------------
print(f"开始训练条件GAN (cGAN) - 设备: {device}")for epoch in range(num_epochs):for real_imgs, labels in data_loader:real_imgs, labels = real_imgs.to(device), labels.to(device)batch_size_cur = real_imgs.size(0)# 为真实图像构造标签(使用标签平滑)real_target = torch.full((batch_size_cur, 1), 0.9, device=device)# 假图像标签加上轻微随机扰动fake_target = torch.zeros(batch_size_cur, 1, device=device) + 0.1 * torch.rand(batch_size_cur, 1, device=device)# --- 训练判别器 ---discriminator.zero_grad()# 判别器判断真实图像real_output = discriminator(real_imgs, labels)d_loss_real = criterion(real_output, real_target)# 生成对应标签的假图像z = torch.randn(batch_size_cur, noise_dim, device=device)fake_imgs = generator(z, labels)fake_output = discriminator(fake_imgs.detach(), labels)d_loss_fake = criterion(fake_output, fake_target)d_loss = d_loss_real + d_loss_faked_loss.backward()optimizer_d.step()# --- 训练生成器 ---# 为平衡训练,每个批次对生成器更新两次for _ in range(2):generator.zero_grad()# 使用相同标签生成假图像z = torch.randn(batch_size_cur, noise_dim, device=device)fake_imgs = generator(z, labels)# 生成器目标:使判别器将假图像判定为真实g_output = discriminator(fake_imgs, labels)g_loss = criterion(g_output, real_target)g_loss.backward()optimizer_g.step()# 可选:更新学习率调度器# scheduler_g.step()# scheduler_d.step()print(f"Epoch [{epoch+1}/{num_epochs}] Loss_D: {d_loss.item():.4f} Loss_G: {g_loss.item():.4f}")# --------------------------
# 保存比较图:使用同一批次标签生成与真实对比
# --------------------------
generator.eval()
real_batch, labels = next(iter(data_loader))
real_batch, labels = real_batch[:8].to(device), labels[:8].to(device)
with torch.no_grad():z = torch.randn(8, noise_dim, device=device)fake_batch = generator(z, labels)# 保存比较图,每行一对,左侧真实,右侧生成
fig, axes = plt.subplots(8, 2, figsize=(4, 16))
for i in range(8):axes[i, 0].imshow(real_batch[i][0].cpu().numpy(), cmap='gray', vmin=-1, vmax=1)axes[i, 0].set_title(f"Real {labels[i].item()}")axes[i, 0].axis('off')axes[i, 1].imshow(fake_batch[i][0].cpu().numpy(), cmap='gray', vmin=-1, vmax=1)axes[i, 1].set_title(f"Fake {labels[i].item()}")axes[i, 1].axis('off')plt.savefig('mnist_cgan_comparison.png', dpi=300, bbox_inches='tight')
plt.close()# --------------------------
# 保存生成样本网格:随机标签下生成16张图像
# --------------------------
z = torch.randn(16, noise_dim, device=device)
rand_labels = torch.randint(0, num_classes, (16,), device=device)
with torch.no_grad():fake_samples = generator(z, rand_labels)
save_image(fake_samples, 'mnist_cgan_samples.png', nrow=4, normalize=True)print("训练完成,已保存 'mnist_cgan_comparison.png' 和 'mnist_cgan_samples.png'")