VQ-VAE(未完成)

conda create -n vqvae-env python=3.10
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126
pip install matplotlib==3.5.3 -i https://pypi.tuna.tsinghua.edu.cn/simple/

windows11 + python 3.10 + pycharm 2021.1.3

1. 数据处理

import pickle
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import transforms


def demo():
    file_path = 'stl10_binary/train_X.bin'
    with open(file_path, 'rb') as f:
        numpy_images = np.fromfile(f, dtype=np.uint8)
        numpy_images = numpy_images.reshape(-1, 3, 96, 96).transpose(0, 3, 2, 1)

    # 查看图像
    for idx in range(9):
        plt.subplot(3, 3, idx + 1)
        plt.imshow(numpy_images[idx])
        plt.axis('off')
    plt.show()

    # 图像处理
    # transforms.ToTensor() 的具体行为
    # 输入:一个 NumPy 数组或 PIL 图像,形状为 (height, width, channels)。
    # 输出:一个 PyTorch 张量,形状为 (channels, height, width)。
    # 数据类型:将像素值从 [0, 255](整数)归一化到 [0.0, 1.0](浮点数)。

    # 逐通道对图像数据进行标准化处理。减去均值并除以标准差调整每个通道的像素值
    # mean=[0.485, 0.456, 0.406]:分别对应 RGB 通道的均值。
    # std=[0.229, 0.224, 0.225]:分别对应 RGB 通道的标准差。
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # 归一化到 [-1, 1]
    ])

    torch_images = []
    for numpy_image in numpy_images:
        torch_image = transform(numpy_image)
        torch_images.append(torch_image)

    torch_images = torch.stack(torch_images)
    print(torch_images.shape)

    # 存储图像
    print('数据形状:', torch_images.shape)
    pickle.dump(torch_images, open('data/images.pkl', 'wb'))


if __name__ == '__main__':
    demo()

2.

import pickle
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# from encoder import Encoder
# from decoder import Decoder
# from vector_quantizer import VectorQuantizer


class Encoder(nn.Module):
    def __init__(self, latent_dim=512):
        super(Encoder, self).__init__()
        self.model = nn.Sequential(
            # 输入形状: torch.Size([batch_size, 3, 96, 96]) 输出形状: torch.Size([batch_size, 32, 48, 48])
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
            # nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1, inplace=True),

            # 输入形状: torch.Size([batch_size, 32, 48, 48]) 输出形状: torch.Size([batch_size, 64, 24, 24])
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            # nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1, inplace=True),

            # 输入形状: torch.Size([batch_size, 64, 24, 24]) 输出形状: torch.Size([batch_size, 128, 12, 12])
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            # nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1, inplace=True),

            # 输入形状: torch.Size([batch_size, 128, 12, 12]) 输出形状: torch.Size([batch_size, 256, 6, 6])
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            # nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1, inplace=True),

            # 输入形状: torch.Size([2, 9216]) 输出形状: torch.Size([2, 256])
            nn.Flatten(start_dim=1),
            nn.Linear(256 * 6 * 6, 256)
        )

    def forward(self, inputs):
        return self.model(inputs)


class Decoder(nn.Module):
    def __init__(self, latent_dim = 256):
        super(Decoder, self).__init__()

        self.model = nn.Sequential(

            # 输入形状: torch.Size([2, 256]) 输出形状: torch.Size([2, 9216])
            nn.Linear(latent_dim, 256 * 6 * 6),
            nn.Unflatten(1, (256, 6, 6)),

            # 输入形状: torch.Size([2, 256, 6, 6]) 输出形状: torch.Size([2, 128, 12, 12])
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            # nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1, inplace=True),

            # 输入形状: torch.Size([2, 128, 12, 12]) 输出形状: torch.Size([2, 64, 24, 24])
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            # nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1, inplace=True),

            # 输入形状: torch.Size([2, 64, 24, 24]) 输出形状: torch.Size([2, 32, 48, 48])
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            # nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1, inplace=True),

            # 输入形状: torch.Size([2, 32, 48, 48]) 输出形状: torch.Size([2, 3, 96, 96])
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, inputs):
        return self.model(inputs)


class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings=2048, embedding_dim=256, commitment_cost=1.0):
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost

        # 初始化代码本
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1 / num_embeddings, 1 / num_embeddings)

    def forward(self, z):
        # 计算输入 z 与代码本中所有向量的距离
        distances = torch.cdist(z, self.embedding.weight, p=2).pow(2)
        # distances = (torch.sum(z ** 2, dim=1, keepdim=True) +
        #              torch.sum(self.embedding.weight ** 2, dim=1) -
        #              2 * torch.matmul(z, self.embedding.weight.t()))

        # 找到最近邻的代码向量
        encoding_indices = torch.argmin(distances, dim=1)
        quantized = self.embedding(encoding_indices)

        # 计算损失
        e_latent_loss = torch.mean((quantized.detach() - z) ** 2)
        q_latent_loss = torch.mean((quantized - z.detach()) ** 2)

        # print('e_latent_loss:', e_latent_loss)
        # print('q_latent_loss:', q_latent_loss)

        # e_latent_loss = nn.MSELoss()(quantized.detach(), z)
        # q_latent_loss = nn.MSELoss()(quantized, z.detach())

        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        # 重参数化技巧
        quantized = z + (quantized - z).detach()
        return quantized, loss


def train():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    images = pickle.load(open('data/images.pkl', 'rb'))
    dataloader = DataLoader(images, batch_size=256, shuffle=True)
    encoder = Encoder().to(device)
    decoder = Decoder().to(device)
    vectors = VectorQuantizer().to(device)

    optimizer = optim.AdamW(list(encoder.parameters()) + list(vectors.parameters()) + list(decoder.parameters()), lr=1e-5)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

    num_epochs = 300
    for epoch in range(num_epochs):
        total_loss = 0.0
        for batch_idx, encoder_images in enumerate(dataloader):

            # 1. 将输入图像压缩到潜在分布空间
            encoder_images = encoder_images.to(device)
            encoder_outputs = encoder(encoder_images)

            # 2. 从潜在分布空间进行采样
            vector, quantizer_loss = vectors(encoder_outputs)

            # 3. 输入潜在分布空间采样重构图像
            decoder_outputs = decoder(vector)

            optimizer.zero_grad()

            # recon_loss = nn.MSELoss()(decoder_outputs, encoder_images)
            recon_loss = torch.mean((decoder_outputs - encoder_images) ** 2)
            # print('recon_loss:', recon_loss)
            # print('-' * 50)
            loss = recon_loss + quantizer_loss

            total_loss += loss.item()
            # 5. 梯度计算和参数更新
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(list(encoder.parameters()) + list(vectors.parameters()) + list(decoder.parameters()), max_norm=1.0)

            optimizer.step()
        scheduler.step()
        print(f'Epoch {epoch + 1} Lr: {scheduler.get_last_lr()[0]:.10f} Loss: {total_loss:.6f} ')

        if (epoch + 1) % 10 == 0:
            with torch.no_grad():
                codebook = vectors.embedding.weight  # 获取代码本
                sampled_indices = torch.randint(0, codebook.size(0), (3,))  # 随机采样索引
                sampled_codes = codebook[sampled_indices]  # 获取对应的代码向量

                generated_images = decoder(sampled_codes)
                for idx, image in enumerate(generated_images):
                    plt.subplot(1, 3, idx + 1)
                    image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
                    image = (image + 1) / 2
                    plt.imshow(image)
                    plt.axis('off')
                plt.show()


if __name__ == '__main__':
    train()

未经允许不得转载:一亩三分地 » VQ-VAE(未完成)
评论 (0)

1 + 5 =