conda create -n vqvae-env python=3.10 pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126 pip install matplotlib==3.5.3 -i https://pypi.tuna.tsinghua.edu.cn/simple/
windows11 + python 3.10 + pycharm 2021.1.3
1. 数据处理
import pickle import numpy as np import matplotlib.pyplot as plt import torch from torchvision import transforms def demo(): file_path = 'stl10_binary/train_X.bin' with open(file_path, 'rb') as f: numpy_images = np.fromfile(f, dtype=np.uint8) numpy_images = numpy_images.reshape(-1, 3, 96, 96).transpose(0, 3, 2, 1) # 查看图像 for idx in range(9): plt.subplot(3, 3, idx + 1) plt.imshow(numpy_images[idx]) plt.axis('off') plt.show() # 图像处理 # transforms.ToTensor() 的具体行为 # 输入:一个 NumPy 数组或 PIL 图像,形状为 (height, width, channels)。 # 输出:一个 PyTorch 张量,形状为 (channels, height, width)。 # 数据类型:将像素值从 [0, 255](整数)归一化到 [0.0, 1.0](浮点数)。 # 逐通道对图像数据进行标准化处理。减去均值并除以标准差调整每个通道的像素值 # mean=[0.485, 0.456, 0.406]:分别对应 RGB 通道的均值。 # std=[0.229, 0.224, 0.225]:分别对应 RGB 通道的标准差。 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化到 [-1, 1] ]) torch_images = [] for numpy_image in numpy_images: torch_image = transform(numpy_image) torch_images.append(torch_image) torch_images = torch.stack(torch_images) print(torch_images.shape) # 存储图像 print('数据形状:', torch_images.shape) pickle.dump(torch_images, open('data/images.pkl', 'wb')) if __name__ == '__main__': demo()
2.
import pickle import warnings warnings.filterwarnings('ignore') import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader import matplotlib.pyplot as plt # from encoder import Encoder # from decoder import Decoder # from vector_quantizer import VectorQuantizer class Encoder(nn.Module): def __init__(self, latent_dim=512): super(Encoder, self).__init__() self.model = nn.Sequential( # 输入形状: torch.Size([batch_size, 3, 96, 96]) 输出形状: torch.Size([batch_size, 32, 48, 48]) nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1), # nn.BatchNorm2d(32), nn.LeakyReLU(0.1, inplace=True), # 输入形状: torch.Size([batch_size, 32, 48, 48]) 输出形状: torch.Size([batch_size, 64, 24, 24]) nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), # nn.BatchNorm2d(64), nn.LeakyReLU(0.1, inplace=True), # 输入形状: torch.Size([batch_size, 64, 24, 24]) 输出形状: torch.Size([batch_size, 128, 12, 12]) nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # nn.BatchNorm2d(128), nn.LeakyReLU(0.1, inplace=True), # 输入形状: torch.Size([batch_size, 128, 12, 12]) 输出形状: torch.Size([batch_size, 256, 6, 6]) nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # nn.BatchNorm2d(256), nn.LeakyReLU(0.1, inplace=True), # 输入形状: torch.Size([2, 9216]) 输出形状: torch.Size([2, 256]) nn.Flatten(start_dim=1), nn.Linear(256 * 6 * 6, 256) ) def forward(self, inputs): return self.model(inputs) class Decoder(nn.Module): def __init__(self, latent_dim = 256): super(Decoder, self).__init__() self.model = nn.Sequential( # 输入形状: torch.Size([2, 256]) 输出形状: torch.Size([2, 9216]) nn.Linear(latent_dim, 256 * 6 * 6), nn.Unflatten(1, (256, 6, 6)), # 输入形状: torch.Size([2, 256, 6, 6]) 输出形状: torch.Size([2, 128, 12, 12]) nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # nn.BatchNorm2d(128), nn.LeakyReLU(0.1, inplace=True), # 输入形状: torch.Size([2, 128, 12, 12]) 输出形状: torch.Size([2, 64, 24, 24]) nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # nn.BatchNorm2d(64), nn.LeakyReLU(0.1, inplace=True), # 输入形状: torch.Size([2, 64, 24, 24]) 输出形状: torch.Size([2, 32, 48, 48]) nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # nn.BatchNorm2d(32), nn.LeakyReLU(0.1, inplace=True), # 输入形状: torch.Size([2, 32, 48, 48]) 输出形状: torch.Size([2, 3, 96, 96]) nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1), nn.Tanh(), ) def forward(self, inputs): return self.model(inputs) class VectorQuantizer(nn.Module): def __init__(self, num_embeddings=2048, embedding_dim=256, commitment_cost=1.0): super(VectorQuantizer, self).__init__() self.embedding_dim = embedding_dim self.num_embeddings = num_embeddings self.commitment_cost = commitment_cost # 初始化代码本 self.embedding = nn.Embedding(num_embeddings, embedding_dim) self.embedding.weight.data.uniform_(-1 / num_embeddings, 1 / num_embeddings) def forward(self, z): # 计算输入 z 与代码本中所有向量的距离 distances = torch.cdist(z, self.embedding.weight, p=2).pow(2) # distances = (torch.sum(z ** 2, dim=1, keepdim=True) + # torch.sum(self.embedding.weight ** 2, dim=1) - # 2 * torch.matmul(z, self.embedding.weight.t())) # 找到最近邻的代码向量 encoding_indices = torch.argmin(distances, dim=1) quantized = self.embedding(encoding_indices) # 计算损失 e_latent_loss = torch.mean((quantized.detach() - z) ** 2) q_latent_loss = torch.mean((quantized - z.detach()) ** 2) # print('e_latent_loss:', e_latent_loss) # print('q_latent_loss:', q_latent_loss) # e_latent_loss = nn.MSELoss()(quantized.detach(), z) # q_latent_loss = nn.MSELoss()(quantized, z.detach()) loss = q_latent_loss + self.commitment_cost * e_latent_loss # 重参数化技巧 quantized = z + (quantized - z).detach() return quantized, loss def train(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') images = pickle.load(open('data/images.pkl', 'rb')) dataloader = DataLoader(images, batch_size=256, shuffle=True) encoder = Encoder().to(device) decoder = Decoder().to(device) vectors = VectorQuantizer().to(device) optimizer = optim.AdamW(list(encoder.parameters()) + list(vectors.parameters()) + list(decoder.parameters()), lr=1e-5) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) num_epochs = 300 for epoch in range(num_epochs): total_loss = 0.0 for batch_idx, encoder_images in enumerate(dataloader): # 1. 将输入图像压缩到潜在分布空间 encoder_images = encoder_images.to(device) encoder_outputs = encoder(encoder_images) # 2. 从潜在分布空间进行采样 vector, quantizer_loss = vectors(encoder_outputs) # 3. 输入潜在分布空间采样重构图像 decoder_outputs = decoder(vector) optimizer.zero_grad() # recon_loss = nn.MSELoss()(decoder_outputs, encoder_images) recon_loss = torch.mean((decoder_outputs - encoder_images) ** 2) # print('recon_loss:', recon_loss) # print('-' * 50) loss = recon_loss + quantizer_loss total_loss += loss.item() # 5. 梯度计算和参数更新 loss.backward() # torch.nn.utils.clip_grad_norm_(list(encoder.parameters()) + list(vectors.parameters()) + list(decoder.parameters()), max_norm=1.0) optimizer.step() scheduler.step() print(f'Epoch {epoch + 1} Lr: {scheduler.get_last_lr()[0]:.10f} Loss: {total_loss:.6f} ') if (epoch + 1) % 10 == 0: with torch.no_grad(): codebook = vectors.embedding.weight # 获取代码本 sampled_indices = torch.randint(0, codebook.size(0), (3,)) # 随机采样索引 sampled_codes = codebook[sampled_indices] # 获取对应的代码向量 generated_images = decoder(sampled_codes) for idx, image in enumerate(generated_images): plt.subplot(1, 3, idx + 1) image = image.squeeze(0).permute(1, 2, 0).cpu().numpy() image = (image + 1) / 2 plt.imshow(image) plt.axis('off') plt.show() if __name__ == '__main__': train()