对抗生成网络(Generative Adversarial Network,GAN)是一种深度学习模型,由生成器(Generator)和判别器(Discriminator)组成,通过对抗训练的方式来生成逼真的数据。
import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.optim as optim from torchvision.datasets import MNIST import torchvision.transforms as transforms from torch.utils.data import DataLoader import torch.nn.functional as F from tqdm import tqdm device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 1. 线性层后的激活函数由 sigmoid 替换为 leaky_relu # 2. 在激活函数后增加 BN 层进行批量归一化,bn 换 ln # class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.linear1 = nn.Linear(784, 256) self.linear2 = nn.Linear(256, 1) # self.norm = nn.BatchNorm1d(num_features=256, affine=False) self.norm = nn.LayerNorm(256) def forward(self, inputs): inputs = nn.Flatten(start_dim=1)(inputs) inputs = self.linear1(inputs) inputs = F.leaky_relu(inputs, negative_slope=0.02) inputs = self.norm(inputs) output = self.linear2(inputs) return output class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.linear1 = nn.Linear(32, 256) self.linear2 = nn.Linear(256, 784) # self.norm = nn.BatchNorm1d(num_features=256, affine=False) self.norm = nn.LayerNorm(256) def forward(self, inputs): inputs = self.linear1(inputs) inputs = F.leaky_relu(inputs, negative_slope=0.02) inputs = self.norm(inputs) inputs = self.linear2(inputs) inputs = F.sigmoid(inputs) output = inputs.reshape([inputs.size(0), 1, 28, 28]) return output # 生成器输入 torch.manual_seed(100) def generator_random_inputs(*args): inputs = torch.randn(*args) noises = torch.rand(*args) return (inputs + noises).to(device) def train(): gen_model = Generator().to(device) dis_model = Discriminator().to(device) train_data = MNIST(root='data', download=False, train=True, transform=transforms.Compose([transforms.ToTensor(), ])) batch_size = 100 def collate_function(batch_data): # 正样本 positive_inputs = [] for single_input, single_label in batch_data: positive_inputs.append(single_input.unsqueeze(0)) positive_inputs = torch.cat(positive_inputs).to(device) # 负样本 negative_inputs = generator_random_inputs(1, 32) negative_inputs = gen_model(negative_inputs).detach() # 拼接样本 batch_inputs = torch.cat([positive_inputs, negative_inputs]).to(device) batch_labels = [1] * len(positive_inputs) + [0] * len(negative_inputs) batch_labels = torch.tensor(batch_labels, dtype=torch.float32).to(device) return batch_inputs, batch_labels dataloader = DataLoader(train_data, shuffle=True, batch_size=batch_size, collate_fn=collate_function) gen_optimizer = optim.Adam(gen_model.parameters(), lr=1e-4) dis_optimizer = optim.Adam(dis_model.parameters(), lr=1e-5) # criterion = nn.CrossEntropyLoss() criterion = nn.BCEWithLogitsLoss() gen_losses = [] dis_losses = [] for epoch in range(500): progress = tqdm(range(len(dataloader)), ncols=100) gen_sum_loss, dis_sum_loss = 0.0, 0.0 dis_sum_size, gen_sum_size = 0, 0 for dis_inputs, dis_labels in dataloader: # 1. 训练判别器 dis_outputs = dis_model(dis_inputs) dis_loss = criterion(dis_outputs.squeeze(), dis_labels) dis_optimizer.zero_grad() dis_loss.backward() dis_optimizer.step() dis_sum_loss += dis_loss.item() * len(dis_labels) dis_sum_size += len(dis_labels) # 2. 训练生成器 for _ in range(2): generate_numbers = 100 # 生成图像数量 generator_inputs = generator_random_inputs(generate_numbers, 32) gen_outputs = gen_model(generator_inputs) gen_labels = torch.tensor([1] * generate_numbers, dtype=torch.float32, device=device) dis_outputs = dis_model(gen_outputs) gen_loss = criterion(dis_outputs.squeeze(), gen_labels) gen_optimizer.zero_grad() gen_loss.backward() gen_optimizer.step() gen_sum_loss += gen_loss.item() * generate_numbers gen_sum_size += generate_numbers # 3. 更新进度条 progress.set_description('epoch %03d dis %.6f gen %.6f' % (epoch + 1, dis_sum_loss / dis_sum_size, gen_sum_loss / gen_sum_size)) progress.update() progress.close() dis_losses.append(dis_sum_loss / dis_sum_size) gen_losses.append(gen_sum_loss / gen_sum_size) torch.save(gen_model.state_dict(), 'model2/%d-gen.pth' % epoch) torch.save(dis_model.state_dict(), 'model2/%d-dis.pth' % epoch) # 生成数字图像 with torch.no_grad(): fig = plt.figure(figsize=(6, 3)) for index in range(3): gen_seed = generator_random_inputs(1, 32) outputs = gen_model(gen_seed) outputs = outputs.squeeze(0).permute(1, 2, 0) plt.subplot(1, 3, index + 1) plt.imshow(outputs.cpu().numpy()) # plt.title(f'seed {gen_seed.item():.5f}') plt.suptitle(f'EPOCH {epoch + 1}') plt.show() torch.save({'gen': gen_losses, 'dis': dis_losses}, 'model2/train_loss.pth') if __name__ == '__main__': train()