对抗生成网络(GAN)图像生成

对抗生成网络(Generative Adversarial Network,GAN)是一种深度学习模型,由生成器(Generator)和判别器(Discriminator)组成,通过对抗训练的方式来生成逼真的数据。

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. 线性层后的激活函数由 sigmoid 替换为 leaky_relu
# 2. 在激活函数后增加 BN 层进行批量归一化,bn 换 ln
#

class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()
        self.linear1 = nn.Linear(784, 256)
        self.linear2 = nn.Linear(256, 1)
        # self.norm = nn.BatchNorm1d(num_features=256, affine=False)
        self.norm = nn.LayerNorm(256)

    def forward(self, inputs):
        inputs = nn.Flatten(start_dim=1)(inputs)
        inputs = self.linear1(inputs)
        inputs = F.leaky_relu(inputs, negative_slope=0.02)
        inputs = self.norm(inputs)
        output = self.linear2(inputs)

        return output


class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()
        self.linear1 = nn.Linear(32, 256)
        self.linear2 = nn.Linear(256, 784)
        # self.norm = nn.BatchNorm1d(num_features=256, affine=False)
        self.norm = nn.LayerNorm(256)

    def forward(self, inputs):
        inputs = self.linear1(inputs)
        inputs = F.leaky_relu(inputs, negative_slope=0.02)
        inputs = self.norm(inputs)
        inputs = self.linear2(inputs)
        inputs = F.sigmoid(inputs)
        output = inputs.reshape([inputs.size(0), 1, 28, 28])

        return output


# 生成器输入
torch.manual_seed(100)
def generator_random_inputs(*args):
    inputs = torch.randn(*args)
    noises = torch.rand(*args)

    return (inputs + noises).to(device)


def train():

    gen_model = Generator().to(device)
    dis_model = Discriminator().to(device)
    train_data = MNIST(root='data',
                       download=False,
                       train=True,
                       transform=transforms.Compose([transforms.ToTensor(), ]))
    batch_size = 100

    def collate_function(batch_data):
        # 正样本
        positive_inputs = []
        for single_input, single_label in batch_data:
            positive_inputs.append(single_input.unsqueeze(0))
        positive_inputs = torch.cat(positive_inputs).to(device)

        # 负样本
        negative_inputs = generator_random_inputs(1, 32)
        negative_inputs = gen_model(negative_inputs).detach()

        # 拼接样本
        batch_inputs = torch.cat([positive_inputs, negative_inputs]).to(device)
        batch_labels = [1] * len(positive_inputs) + [0] * len(negative_inputs)
        batch_labels = torch.tensor(batch_labels, dtype=torch.float32).to(device)

        return batch_inputs, batch_labels

    dataloader = DataLoader(train_data, shuffle=True, batch_size=batch_size, collate_fn=collate_function)
    gen_optimizer = optim.Adam(gen_model.parameters(), lr=1e-4)
    dis_optimizer = optim.Adam(dis_model.parameters(), lr=1e-5)
    # criterion = nn.CrossEntropyLoss()
    criterion = nn.BCEWithLogitsLoss()

    gen_losses = []
    dis_losses = []

    for epoch in range(500):
        progress = tqdm(range(len(dataloader)), ncols=100)
        gen_sum_loss, dis_sum_loss = 0.0, 0.0
        dis_sum_size, gen_sum_size = 0, 0

        for dis_inputs, dis_labels in dataloader:
            # 1. 训练判别器
            dis_outputs = dis_model(dis_inputs)
            dis_loss = criterion(dis_outputs.squeeze(), dis_labels)
            dis_optimizer.zero_grad()
            dis_loss.backward()
            dis_optimizer.step()
            dis_sum_loss += dis_loss.item() * len(dis_labels)
            dis_sum_size += len(dis_labels)

            # 2. 训练生成器
            for _ in range(2):

                generate_numbers = 100  # 生成图像数量
                generator_inputs = generator_random_inputs(generate_numbers, 32)

                gen_outputs = gen_model(generator_inputs)
                gen_labels = torch.tensor([1] * generate_numbers, dtype=torch.float32, device=device)
                dis_outputs = dis_model(gen_outputs)
                gen_loss = criterion(dis_outputs.squeeze(), gen_labels)

                gen_optimizer.zero_grad()
                gen_loss.backward()
                gen_optimizer.step()
                gen_sum_loss += gen_loss.item() * generate_numbers
                gen_sum_size += generate_numbers

            # 3. 更新进度条
            progress.set_description('epoch %03d dis %.6f gen %.6f' % (epoch + 1, dis_sum_loss / dis_sum_size, gen_sum_loss / gen_sum_size))
            progress.update()

        progress.close()
        dis_losses.append(dis_sum_loss / dis_sum_size)
        gen_losses.append(gen_sum_loss / gen_sum_size)

        torch.save(gen_model.state_dict(), 'model2/%d-gen.pth' % epoch)
        torch.save(dis_model.state_dict(), 'model2/%d-dis.pth' % epoch)

        # 生成数字图像
        with torch.no_grad():
            fig = plt.figure(figsize=(6, 3))
            for index in range(3):
                gen_seed = generator_random_inputs(1, 32)
                outputs = gen_model(gen_seed)
                outputs = outputs.squeeze(0).permute(1, 2, 0)
                plt.subplot(1, 3, index + 1)
                plt.imshow(outputs.cpu().numpy())
                # plt.title(f'seed {gen_seed.item():.5f}')
            plt.suptitle(f'EPOCH {epoch + 1}')
            plt.show()

    torch.save({'gen': gen_losses, 'dis': dis_losses}, 'model2/train_loss.pth')


if __name__ == '__main__':
    train()
未经允许不得转载:一亩三分地 » 对抗生成网络(GAN)图像生成
评论 (0)

8 + 6 =