基于 VAE 生成图像

变分自编码器(Variational Autoencoder, VAE)是一种概率生成模型,属于深度生成模型的范畴。它能够学习数据的潜在分布,并通过采样潜在空间中的点来生成新的数据。例如,在图像生成任务中,VAE 可以学习到一组图像的特征,并基于这些特征生成新的图像。生成图像过程:

  • 训练 VAE,使其学习输入图像的潜在表示。
  • 训练完成后,从标准正态分布 \( N \) 中随机采样一个潜在向量 \( z \)。
  • 通过解码器,将 \( z \) 转换为图像,从而生成新图像。

1. 数据处理

import torch
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt


class CustomDataset(Dataset):

    def __init__(self):
        with open('stl10_binary/train_X.bin', 'rb') as f:
            numpy_images = np.fromfile(f, dtype=np.uint8)
            data1 = numpy_images.reshape(-1, 3, 96, 96).transpose(0, 3, 2, 1)

        with open('stl10_binary/test_X.bin', 'rb') as f:
            numpy_images = np.fromfile(f, dtype=np.uint8)
            data2 = numpy_images.reshape(-1, 3, 96, 96).transpose(0, 3, 2, 1)

        with open('stl10_binary/unlabeled_X.bin', 'rb') as f:
            numpy_images = np.fromfile(f, dtype=np.uint8)
            data3 = numpy_images.reshape(-1, 3, 96, 96).transpose(0, 3, 2, 1)

        self.data = np.concatenate((data1, data2, data3), axis=0)[:10000]

        print('数据量:', len(self.data))

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def show(self):
        indices = np.arange(0, len(self.data))
        indexes = np.random.choice(indices, 9)
        for idx, index in enumerate(indexes):
            plt.subplot(3, 3, idx + 1)
            plt.imshow(self.data[index])
            plt.axis('off')
        plt.show()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample


if __name__ == '__main__':
    images = CustomDataset()
    images.show()

    # for idx, image in enumerate(images):
    #     print(image.shape, torch.min(image).item(), torch.max(image).item())

2. 模型搭建

import torch.nn as nn
import torch
import matplotlib.pyplot as plt


class VAE(nn.Module):

    def __init__(self):
        super(VAE, self).__init__()

        # 潜在变量的维度
        self.latent_dim = 256

        # 用于将输入图像压缩到低维表示
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),  # [,64,96,96]
            nn.ReLU(),
            nn.BatchNorm2d(64),

            nn.Conv2d(64, 64, 3, 1, 1),  # [,64,96,96]
            nn.ReLU(),
            nn.BatchNorm2d(64),

            nn.Conv2d(64, 64, 3, 1, 1),  # [,64,96,96]
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # [,64,48,48]
            nn.BatchNorm2d(64),

            nn.Conv2d(64, 128, 3, 1, 1),  # [,128,48,48]
            nn.ReLU(),
            nn.BatchNorm2d(128),

            nn.Conv2d(128, 128, 3, 1, 1),  # [,128,48,48]
            nn.ReLU(),
            nn.BatchNorm2d(128),

            nn.Conv2d(128, 128, 3, 1, 1),  # [,256,48,48]
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # [,256,24,24]
            nn.BatchNorm2d(128),

            nn.Flatten(start_dim=1),
        )

        # 计算潜在分布均值和对数方差
        self.mean = nn.Linear(128 * 24 * 24, self.latent_dim)  # 均值
        self.lvar = nn.Linear(128 * 24 * 24, self.latent_dim)  # 对数方差

        # 用于根据潜在变量重构图像
        self.decoder = nn.Sequential(

            nn.Linear(self.latent_dim, 64 * 24 * 24),
            nn.Unflatten(1, (64, 24, 24)),

            nn.ConvTranspose2d(64, 128, 3, 1, 1),
            nn.ReLU(),
            nn.BatchNorm2d(128),

            nn.ConvTranspose2d(128, 128, 3, 2, 1, 1),
            nn.ReLU(),
            nn.BatchNorm2d(128),

            nn.ConvTranspose2d(128, 64, 3, 1, 1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 32, 3, 1, 1),
            nn.ReLU(),
            nn.BatchNorm2d(32),

            nn.ConvTranspose2d(32, 32, 3, 1, 1),
            nn.ConvTranspose2d(32, 16, 3, 2, 1, 1),
            nn.ReLU(),
            nn.BatchNorm2d(16),

            nn.ConvTranspose2d(16, 3, 3, 1, 1),
            nn.Sigmoid(),
        )


    def reparameterization(self, mean, lvar):
        """重参数化,用于解决梯度传递问题"""
        std = torch.exp(0.5 * lvar)
        eps = torch.randn_like(std)
        z = mean + eps * std
        return z


    def loss_function(self, pred_images, true_images, mean, lvar, beta=3):
        """MSE + KLD 损失计算"""
        MSE = nn.MSELoss(reduction='sum')(pred_images, true_images)
        KLD = -0.5 * torch.sum(1 + lvar - mean.pow(2) - lvar.exp())
        batch_loss = MSE + beta * KLD
        return batch_loss


    @torch.no_grad()
    def generate(self):
        """随机生成图像"""
        # 1. 从正态分布随机采样
        z = torch.randn(3, self.latent_dim).to(next(self.parameters()).device)
        # 2. 从随机向量重构图像
        generated_images = self.decoder(z)
        # 3. 显示生成图像
        plt.suptitle('generate')
        for idx, image in enumerate(generated_images):
            plt.subplot(1, 3, idx + 1)
            image = (image + 1) / 2
            plt.imshow(image.permute(1, 2, 0).cpu().numpy())
            plt.axis('off')
        plt.show()


    @torch.no_grad()
    def reconstruct(self, input_images):
        """根据输入图像重构图像"""
        input_images = input_images.to(next(self.parameters()).device)
        # 1. 压缩输入图像信息
        latent_vectors = self.encoder(input_images)
        # 2. 计算输入图像分布
        mean = self.mean(latent_vectors)
        lvar = self.lvar(latent_vectors)
        # 3. 重参数化,采样潜在变量
        z = self.reparameterization(mean, lvar)
        # 4. 根据输入重构图像
        generated_images = self.decoder(z)
        # 3. 显示重建图像图像
        plt.suptitle('reconstruct')
        for idx, image in enumerate(generated_images):
            plt.subplot(1, 3, idx + 1)
            image = (image + 1) / 2
            plt.imshow(image.permute(1, 2, 0).cpu().numpy())
            plt.axis('off')
        plt.show()


    def forward(self, true_images):
        """前向计算损失"""
        # 1. 将输入图像压缩到潜在分布空间
        true_images = true_images.to(next(self.parameters()).device)
        # 2. 压缩输入图像信息
        latent_vectors = self.encoder(true_images)
        # 3. 计算潜在向量分布
        mean = self.mean(latent_vectors)
        lvar = self.lvar(latent_vectors)
        # 4. 重参数化
        z = self.reparameterization(mean, lvar)
        # 5. 根据潜在变量重构图像
        pred_images = self.decoder(z)
        # 4. 计算重构损失和 KLD 损失
        batch_loss = self.loss_function(pred_images, true_images, mean, lvar)
        return batch_loss

3. 模型训练

import pickle
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import shutil
from model import VAE
from custom_dataset import CustomDataset


def train():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = CustomDataset()
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
    estimator = VAE().to(device)
    optimizer = optim.Adam(estimator.parameters(), lr=5e-6)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

    num_epochs = 500
    for epoch in range(num_epochs):
        total_loss = 0.0
        total_size = 0
        progress = tqdm(range(len(dataloader)))
        for batch_idx, input_images in enumerate(dataloader):
            # 计算损失
            loss = estimator(input_images)
            # 统计损失
            batch_size = input_images.size(0)
            total_loss += loss.item()
            total_size += batch_size
            # 参数更新
            optimizer.zero_grad()
            (loss / batch_size).backward()
            optimizer.step()

            progress.set_description(f'Epoch {epoch + 1} Loss: {total_loss / total_size:.5f}')
            progress.update()
        progress.close()

        # scheduler.step()

        if (epoch + 1) % 2 == 0:
            save_path = f'model/{epoch + 1}/'
            if os.path.exists(save_path) or os.path.isdir(save_path):
                shutil.rmtree(save_path)
            os.mkdir(save_path)
            pickle.dump(estimator, open(save_path + 'estimator.pkl', 'wb'))
            estimator.generate()
            estimator.reconstruct(input_images[:3])


if __name__ == '__main__':
    train()
未经允许不得转载:一亩三分地 » 基于 VAE 生成图像
评论 (0)

4 + 6 =