# 生成对抗网络解析:基本原理、架构变体与实际应用
## GANs核心原理与数学基础
生成对抗网络(GANs)由生成器(Generator)和判别器(Discriminator)两个神经网络构成,通过对抗训练实现数据生成。生成器试图创建逼真的假数据,而判别器则努力区分真实数据和生成数据。这种对抗过程形成一个最小最大优化问题,其价值函数可以表示为:
```
V(D, G) = E_{x∼p_data(x)}[log D(x)] + E_{z∼p_z(z)}[log(1 - D(G(z)))]
```
其中生成器G试图最小化该函数,而判别器D试图最大化它。这种对抗训练最终达到纳什均衡点,此时生成器产生的数据分布p_g(x)无限接近真实数据分布p_data(x)。
## 基础GAN实现代码
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# 生成器网络定义
class Generator(nn.Module):
def __init__(self, latent_dim=100, img_channels=1, feature_map_size=64):
super(Generator, self).__init__()
self.main = nn.Sequential(
# 输入:latent_dim维噪声
nn.ConvTranspose2d(latent_dim, feature_map_size * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(feature_map_size * 8),
nn.ReLU(True),
# 上采样层
nn.ConvTranspose2d(feature_map_size * 8, feature_map_size * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_map_size * 4),
nn.ReLU(True),
nn.ConvTranspose2d(feature_map_size * 4, feature_map_size * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_map_size * 2),
nn.ReLU(True),
nn.ConvTranspose2d(feature_map_size * 2, feature_map_size, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_map_size),
nn.ReLU(True),
# 输出层
nn.ConvTranspose2d(feature_map_size, img_channels, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, input):
return self.main(input)
# 判别器网络定义
class Discriminator(nn.Module):
def __init__(self, img_channels=1, feature_map_size=64):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
# 输入:img_channels x 64 x 64
nn.Conv2d(img_channels, feature_map_size, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 下采样层
nn.Conv2d(feature_map_size, feature_map_size * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_map_size * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(feature_map_size * 2, feature_map_size * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_map_size * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(feature_map_size * 4, feature_map_size * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_map_size * 8),
nn.LeakyReLU(0.2, inplace=True),
# 输出层
nn.Conv2d(feature_map_size * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input).view(-1, 1)
# GAN训练类
class GANTrainer:
def __init__(self, latent_dim=100, device='cuda' if torch.cuda.is_available() else 'cpu'):
self.latent_dim = latent_dim
self.device = device
self.generator = Generator(latent_dim).to(device)
self.discriminator = Discriminator().to(device)
# 优化器
self.optimizer_G = optim.Adam(self.generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
self.optimizer_D = optim.Adam(self.discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 损失函数
self.criterion = nn.BCELoss()
# 固定噪声用于可视化
self.fixed_noise = torch.randn(64, latent_dim, 1, 1, device=device)
def train_step(self, real_imgs):
batch_size = real_imgs.size(0)
# 真实和假标签
real_labels = torch.ones(batch_size, 1, device=self.device)
fake_labels = torch.zeros(batch_size, 1, device=self.device)
# ---------------------
# 训练判别器
# ---------------------
self.optimizer_D.zero_grad()
# 真实图像损失
real_output = self.discriminator(real_imgs)
d_real_loss = self.criterion(real_output, real_labels)
# 假图像损失
noise = torch.randn(batch_size, self.latent_dim, 1, 1, device=self.device)
fake_imgs = self.generator(noise)
fake_output = self.discriminator(fake_imgs.detach())
d_fake_loss = self.criterion(fake_output, fake_labels)
# 判别器总损失
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
self.optimizer_D.step()
# ---------------------
# 训练生成器
# ---------------------
self.optimizer_G.zero_grad()
# 生成器试图欺骗判别器
noise = torch.randn(batch_size, self.latent_dim, 1, 1, device=self.device)
fake_imgs = self.generator(noise)
output = self.discriminator(fake_imgs)
# 生成器希望判别器将假图像判断为真
g_loss = self.criterion(output, real_labels)
g_loss.backward()
self.optimizer_G.step()
return d_loss.item(), g_loss.item()
def generate_samples(self, num_samples=16):
with torch.no_grad():
noise = torch.randn(num_samples, self.latent_dim, 1, 1, device=self.device)
samples = self.generator(noise)
return samples
```
## GAN变体架构演进
### DCGAN(深度卷积GAN)
```python
class DCGANGenerator(nn.Module):
"""深度卷积生成器,使用转置卷积"""
def __init__(self, nz=100, ngf=64, nc=3):
super(DCGANGenerator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, input):
return self.main(input)
```
### WGAN与WGAN-GP
```python
class WGAN_Critic(nn.Module):
"""WGAN判别器(称为Critic)"""
def __init__(self, nc=3, ndf=64):
super(WGAN_Critic, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.InstanceNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.InstanceNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.InstanceNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)
)
def forward(self, input):
return self.main(input).view(-1)
def compute_gradient_penalty(critic, real_samples, fake_samples, device):
"""计算WGAN-GP的梯度惩罚项"""
batch_size = real_samples.size(0)
# 在真实和生成样本之间随机插值
alpha = torch.rand(batch_size, 1, 1, 1, device=device)
interpolated = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
# 计算Critic对插值样本的输出
critic_interpolated = critic(interpolated)
# 计算梯度
gradients = torch.autograd.grad(
outputs=critic_interpolated,
inputs=interpolated,
grad_outputs=torch.ones(critic_interpolated.size(), device=device),
create_graph=True,
retain_graph=True,
>
)[0]
# 计算梯度惩罚
gradients = gradients.view(batch_size, -1)
gradient_norm = gradients.norm(2, dim=1)
gradient_penalty = ((gradient_norm - 1) ** 2).mean()
return gradient_penalty
```
### Conditional GAN(条件GAN)
```python
class ConditionalGenerator(nn.Module):
"""条件生成器,接收类别标签作为额外输入"""
def __init__(self, n_classes=10, latent_dim=100, img_channels=3):
super(ConditionalGenerator, self).__init__()
self.label_embedding = nn.Embedding(n_classes, n_classes)
self.main = nn.Sequential(
nn.ConvTranspose2d(latent_dim + n_classes, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, img_channels, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, noise, labels):
# 嵌入标签并与噪声拼接
label_embedded = self.label_embedding(labels).unsqueeze(2).unsqueeze(3)
combined = torch.cat([noise, label_embedded], dim=1)
return self.main(combined)
class ConditionalDiscriminator(nn.Module):
"""条件判别器"""
def __init__(self, n_classes=10, img_channels=3):
super(ConditionalDiscriminator, self).__init__()
self.label_embedding = nn.Embedding(n_classes, n_classes)
self.main = nn.Sequential(
nn.Conv2d(img_channels + n_classes, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, 4, 1, 0, bias=False)
)
def forward(self, img, labels):
# 将标签嵌入扩展到图像空间维度
label_embedded = self.label_embedding(labels)
label_embedded = label_embedded.view(label_embedded.size(0), -1, 1, 1)
label_embedded = label_embedded.expand(-1, -1, img.size(2), img.size(3))
# 将标签信息与图像拼接
combined = torch.cat([img, label_embedded], dim=1)
return self.main(combined).view(-1)
```
### StyleGAN系列
```python
class StyleGANMappingNetwork(nn.Module):
"""StyleGAN的映射网络,将潜在向量z转换为风格向量w"""
def __init__(self, latent_dim=512, style_dim=512, n_layers=8):
super(StyleGANMappingNetwork, self).__init__()
layers = []
for i in range(n_layers):
in_dim = latent_dim if i == 0 else style_dim
out_dim = style_dim
layers.extend([
nn.Linear(in_dim, out_dim),
nn.LeakyReLU(0.2)
])
self.mapping = nn.Sequential(*layers)
def forward(self, z):
return self.mapping(z)
class AdaIN(nn.Module):
"""自适应实例归一化"""
def __init__(self, channels, style_dim):
super(AdaIN, self).__init__()
self.norm = nn.InstanceNorm2d(channels, affine=False)
self.style_scale = nn.Linear(style_dim, channels)
self.style_bias = nn.Linear(style_dim, channels)
def forward(self, x, style):
normalized = self.norm(x)
style_scale = self.style_scale(style).unsqueeze(2).unsqueeze(3)
style_bias = self.style_bias(style).unsqueeze(2).unsqueeze(3)
return style_scale * normalized + style_bias
class StyleGANGeneratorBlock(nn.Module):
"""StyleGAN生成器块"""
def __init__(self, in_channels, out_channels, style_dim):
super(StyleGANGeneratorBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
self.adain1 = AdaIN(out_channels, style_dim)
self.lrelu1 = nn.LeakyReLU(0.2)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
self.adain2 = AdaIN(out_channels, style_dim)
self.lrelu2 = nn.LeakyReLU(0.2)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
def forward(self, x, style):
x = self.upsample(x)
x = self.conv1(x)
x = self.adain1(x, style)
x = self.lrelu1(x)
x = self.conv2(x)
x = self.adain2(x, style)
x = self.lrelu2(x)
return x
```
## 训练策略与稳定化技巧
```python
class GANStabilization:
"""GAN训练稳定化技巧集合"""
@staticmethod
def spectral_norm(module):
"""谱归一化"""
return nn.utils.spectral_norm(module)
@staticmethod
def r1_regularization(discriminator, real_samples, gamma=10):
<"9y.csxthr.com"><"3h.zhaiLimao.com"><"6z.yunruiwater.cn">
"""R1正则化"""
real_samples.requires_grad_(True)
real_output = discriminator(real_samples)
gradients = torch.autograd.grad(
outputs=real_output.sum(),
inputs=real_samples,
create_graph=True
)[0]
gradient_penalty = gradients.square().sum([1, 2, 3])
return gamma * 0.5 * gradient_penalty.mean()
@staticmethod
def ema_model_update(model, ema_model, decay=0.999):
"""指数移动平均模型更新"""
with torch.no_grad():
for param, ema_param in zip(model.parameters(), ema_model.parameters()):
ema_param.data.mul_(decay).add_(param.data, alpha=1 - decay)
@staticmethod
def progressive_growing(generator, discriminator, current_resolution, target_resolution):
"""渐进式增长训练"""
# 逐步增加分辨率
if current_resolution < target_resolution:
# 添加新层
pass
return generator, discriminator
```
## GAN评价指标实现
```python
class GANMetrics:
"""GAN生成质量评价指标"""
@staticmethod
def inception_score(generated_images, classifier, n_splits=10):
"""计算Inception Score"""
n_samples = len(generated_images)
scores = []
for i in range(n_splits):
subset = generated_images[i * n_samples // n_splits: (i + 1) * n_samples // n_splits]
with torch.no_grad():
preds = classifier(subset)
preds = torch.nn.functional.softmax(preds, dim=1)
# 计算条件概率的KL散度
py = preds.mean(dim=0)
scores.append(torch.sum(preds * (torch.log(preds) - torch.log(py)), dim=1).mean().exp())
return torch.stack(scores).mean().item(), torch.stack(scores).std().item()
@staticmethod
def fid_score(real_features, generated_features):
"""计算Fréchet Inception Distance"""
mu_real = real_features.mean(dim=0)
mu_gen = generated_features.mean(dim=0)
sigma_real = torch.cov(real_features.t())
sigma_gen = torch.cov(generated_features.t())
# 计算FID
diff = mu_real - mu_gen
cov_mean = torch.sqrt(sigma_real @ sigma_gen)
if torch.is_complex(cov_mean):
cov_mean = cov_mean.real
fid = diff.dot(diff) + torch.trace(sigma_real + sigma_gen - 2 * cov_mean)
return fid.item()
```
## 实际应用案例
### 图像超分辨率
```python
class SRGANGenerator(nn.Module):
"""超分辨率GAN生成器"""
def __init__(self, scale_factor=4):
super(SRGANGenerator, self).__init__()
# 残差块
self.residual_blocks = nn.Sequential(*[
ResidualBlock(64) for _ in range(16)
])
# 上采样层
self.upsample_layers = nn.Sequential(
nn.Conv2d(64, 256, 3, 1, 1),
nn.PixelShuffle(2),
nn.ReLU(True),
nn.Conv2d(64, 256, 3, 1, 1),
nn.PixelShuffle(2),
nn.ReLU(True)
)
self.final_conv = nn.Conv2d(64, 3, 9, 1, 4)
def forward(self, lr_img):
# 提取特征
features = self.residual_blocks(lr_img)
# 上采样
upscaled = self.upsample_layers(features)
# 生成高分辨率图像
hr_img = self.final_conv(upscaled)
return torch.tanh(hr_img)
class ResidualBlock(nn.Module):
"""残差块"""
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(channels)
self.prelu = nn.PReLU()
self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.prelu(out)
out = self.conv2(out)
out = self.bn2(out)
return out + residual
```
### 图像到图像转换
```python
class Pix2PixGenerator(nn.Module):
"""Pix2Pix生成器(U-Net架构)"""
def __init__(self, input_channels=3, output_channels=3, ngf=64):
<"0a.sxyicheng.cn"><"7e.jsnjz.cn"><"1j.csxthr.com">
super(Pix2PixGenerator, self).__init__()
# 编码器
self.encoder1 = self.downsample_block(input_channels, ngf, batch_norm=False)
self.encoder2 = self.downsample_block(ngf, ngf * 2)
self.encoder3 = self.downsample_block(ngf * 2, ngf * 4)
self.encoder4 = self.downsample_block(ngf * 4, ngf * 8)
self.encoder5 = self.downsample_block(ngf * 8, ngf * 8)
self.encoder6 = self.downsample_block(ngf * 8, ngf * 8)
self.encoder7 = self.downsample_block(ngf * 8, ngf * 8)
self.encoder8 = self.downsample_block(ngf * 8, ngf * 8, batch_norm=False)
# 解码器
self.decoder1 = self.upsample_block(ngf * 8, ngf * 8, dropout=True)
self.decoder2 = self.upsample_block(ngf * 16, ngf * 8, dropout=True)
self.decoder3 = self.upsample_block(ngf * 16, ngf * 8, dropout=True)
self.decoder4 = self.upsample_block(ngf * 16, ngf * 8)
self.decoder5 = self.upsample_block(ngf * 16, ngf * 4)
self.decoder6 = self.upsample_block(ngf * 8, ngf * 2)
self.decoder7 = self.upsample_block(ngf * 4, ngf)
self.decoder8 = self.upsample_block(ngf * 2, output_channels,
activation='tanh', batch_norm=False)
def downsample_block(self, in_channels, out_channels, batch_norm=True):
layers = [nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)]
if batch_norm:
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.LeakyReLU(0.2, True))
return nn.Sequential(*layers)
def upsample_block(self, in_channels, out_channels, dropout=False,
activation='relu', batch_norm=True):
layers = [nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)]
if batch_norm:
layers.append(nn.BatchNorm2d(out_channels))
if dropout:
layers.append(nn.Dropout(0.5))
if activation == 'relu':
layers.append(nn.ReLU(True))
elif activation == 'tanh':
layers.append(nn.Tanh())
return nn.Sequential(*layers)
def forward(self, input):
# 编码
e1 = self.encoder1(input)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
e4 = self.encoder4(e3)
e5 = self.encoder5(e4)
e6 = self.encoder6(e5)
e7 = self.encoder7(e6)
e8 = self.encoder8(e7)
# 解码(带跳跃连接)
d1 = self.decoder1(e8)
d1 = torch.cat([d1, e7], dim=1)
d2 = self.decoder2(d1)
d2 = torch.cat([d2, e6], dim=1)
d3 = self.decoder3(d2)
d3 = torch.cat([d3, e5], dim=1)
d4 = self.decoder4(d3)
d4 = torch.cat([d4, e4], dim=1)
d5 = self.decoder5(d4)
d5 = torch.cat([d5, e3], dim=1)
d6 = self.decoder6(d5)
d6 = torch.cat([d6, e2], dim=1)
d7 = self.decoder7(d6)
d7 = torch.cat([d7, e1], dim=1)
output = self.decoder8(d7)
return output
```
## 训练监控与可视化
```python
class GANTrainingMonitor:
"""GAN训练过程监控器"""
def __init__(self, log_dir='logs'):
self.log_dir = log_dir
os.makedirs(log_dir, exist_ok=True)
self.losses = {
'd_loss': [],
'g_loss': [],
'd_real': [],
'd_fake': []
}
# TensorBoard记录器
self.writer = SummaryWriter(log_dir)
def update(self, d_loss, g_loss, d_real_acc, d_fake_acc, epoch, iteration):
"""更新训练记录"""
self.losses['d_loss'].append(d_loss)
self.losses['g_loss'].append(g_loss)
self.losses['d_real'].append(d_real_acc)
self.losses['d_fake'].append(d_fake_acc)
# 记录到TensorBoard
global_step = epoch * 1000 + iteration
self.writer.add_scalar('Loss/Discriminator', d_loss, global_step)
self.writer.add_scalar('Loss/Generator', g_loss, global_step)
self.writer.add_scalar('Accuracy/D_Real', d_real_acc, global_step)
self.writer.add_scalar('Accuracy/D_Fake', d_fake_acc, global_step)
def visualize_samples(self, generator, epoch, n_samples=16):
"""可视化生成样本"""
with torch.no_grad():
noise = torch.randn(n_samples, 100, 1, 1, device='cuda')
samples = generator(noise)
# 保存样本图像
grid = torchvision.utils.make_grid(samples, nrow=4, normalize=True)
torchvision.utils.save_image(grid,
f'{self.log_dir}/samples_epoch_{epoch:04d}.png')
# 添加到TensorBoard
self.writer.add_image('Generated_Samples', grid, epoch)
def plot_losses(self):
"""绘制损失曲线"""
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(self.losses['d_loss'], label='Discriminator Loss')
plt.plot(self.losses['g_loss'], label='Generator Loss')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.legend()
plt.title('GAN Training Losses')
plt.subplot(1, 2, 2)
plt.plot(self.losses['d_real'], label='D Real Accuracy')
plt.plot(self.losses['d_fake'], label='D Fake Accuracy')
plt.xlabel('Iteration')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Discriminator Accuracy')
plt.tight_layout()
plt.savefig(f'{self.log_dir}/loss_curves.png')
plt.close()
```
从基础GAN到StyleGAN的演进,反映了生成对抗网络在理论深度和应用广度上的持续发展。这些技术突破不仅推动了计算机视觉领域的前进,也为跨模态生成、数据增强和创意应用开辟了新的可能性。