conda create -n vqvae-env python=3.10 pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126 pip install matplotlib==3.5.3 -i https://pypi.tuna.tsinghua.edu.cn/simple/
windows11 + python 3.10 + pycharm 2021.1.3
1. 数据处理
import pickle
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
def demo():
file_path = 'stl10_binary/train_X.bin'
with open(file_path, 'rb') as f:
numpy_images = np.fromfile(f, dtype=np.uint8)
numpy_images = numpy_images.reshape(-1, 3, 96, 96).transpose(0, 3, 2, 1)
# 查看图像
for idx in range(9):
plt.subplot(3, 3, idx + 1)
plt.imshow(numpy_images[idx])
plt.axis('off')
plt.show()
# 图像处理
# transforms.ToTensor() 的具体行为
# 输入:一个 NumPy 数组或 PIL 图像,形状为 (height, width, channels)。
# 输出:一个 PyTorch 张量,形状为 (channels, height, width)。
# 数据类型:将像素值从 [0, 255](整数)归一化到 [0.0, 1.0](浮点数)。
# 逐通道对图像数据进行标准化处理。减去均值并除以标准差调整每个通道的像素值
# mean=[0.485, 0.456, 0.406]:分别对应 RGB 通道的均值。
# std=[0.229, 0.224, 0.225]:分别对应 RGB 通道的标准差。
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化到 [-1, 1]
])
torch_images = []
for numpy_image in numpy_images:
torch_image = transform(numpy_image)
torch_images.append(torch_image)
torch_images = torch.stack(torch_images)
print(torch_images.shape)
# 存储图像
print('数据形状:', torch_images.shape)
pickle.dump(torch_images, open('data/images.pkl', 'wb'))
if __name__ == '__main__':
demo()
2.
import pickle
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# from encoder import Encoder
# from decoder import Decoder
# from vector_quantizer import VectorQuantizer
class Encoder(nn.Module):
def __init__(self, latent_dim=512):
super(Encoder, self).__init__()
self.model = nn.Sequential(
# 输入形状: torch.Size([batch_size, 3, 96, 96]) 输出形状: torch.Size([batch_size, 32, 48, 48])
nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
# nn.BatchNorm2d(32),
nn.LeakyReLU(0.1, inplace=True),
# 输入形状: torch.Size([batch_size, 32, 48, 48]) 输出形状: torch.Size([batch_size, 64, 24, 24])
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
# nn.BatchNorm2d(64),
nn.LeakyReLU(0.1, inplace=True),
# 输入形状: torch.Size([batch_size, 64, 24, 24]) 输出形状: torch.Size([batch_size, 128, 12, 12])
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
# nn.BatchNorm2d(128),
nn.LeakyReLU(0.1, inplace=True),
# 输入形状: torch.Size([batch_size, 128, 12, 12]) 输出形状: torch.Size([batch_size, 256, 6, 6])
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
# nn.BatchNorm2d(256),
nn.LeakyReLU(0.1, inplace=True),
# 输入形状: torch.Size([2, 9216]) 输出形状: torch.Size([2, 256])
nn.Flatten(start_dim=1),
nn.Linear(256 * 6 * 6, 256)
)
def forward(self, inputs):
return self.model(inputs)
class Decoder(nn.Module):
def __init__(self, latent_dim = 256):
super(Decoder, self).__init__()
self.model = nn.Sequential(
# 输入形状: torch.Size([2, 256]) 输出形状: torch.Size([2, 9216])
nn.Linear(latent_dim, 256 * 6 * 6),
nn.Unflatten(1, (256, 6, 6)),
# 输入形状: torch.Size([2, 256, 6, 6]) 输出形状: torch.Size([2, 128, 12, 12])
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
# nn.BatchNorm2d(128),
nn.LeakyReLU(0.1, inplace=True),
# 输入形状: torch.Size([2, 128, 12, 12]) 输出形状: torch.Size([2, 64, 24, 24])
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
# nn.BatchNorm2d(64),
nn.LeakyReLU(0.1, inplace=True),
# 输入形状: torch.Size([2, 64, 24, 24]) 输出形状: torch.Size([2, 32, 48, 48])
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
# nn.BatchNorm2d(32),
nn.LeakyReLU(0.1, inplace=True),
# 输入形状: torch.Size([2, 32, 48, 48]) 输出形状: torch.Size([2, 3, 96, 96])
nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
nn.Tanh(),
)
def forward(self, inputs):
return self.model(inputs)
class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings=2048, embedding_dim=256, commitment_cost=1.0):
super(VectorQuantizer, self).__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.commitment_cost = commitment_cost
# 初始化代码本
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.embedding.weight.data.uniform_(-1 / num_embeddings, 1 / num_embeddings)
def forward(self, z):
# 计算输入 z 与代码本中所有向量的距离
distances = torch.cdist(z, self.embedding.weight, p=2).pow(2)
# distances = (torch.sum(z ** 2, dim=1, keepdim=True) +
# torch.sum(self.embedding.weight ** 2, dim=1) -
# 2 * torch.matmul(z, self.embedding.weight.t()))
# 找到最近邻的代码向量
encoding_indices = torch.argmin(distances, dim=1)
quantized = self.embedding(encoding_indices)
# 计算损失
e_latent_loss = torch.mean((quantized.detach() - z) ** 2)
q_latent_loss = torch.mean((quantized - z.detach()) ** 2)
# print('e_latent_loss:', e_latent_loss)
# print('q_latent_loss:', q_latent_loss)
# e_latent_loss = nn.MSELoss()(quantized.detach(), z)
# q_latent_loss = nn.MSELoss()(quantized, z.detach())
loss = q_latent_loss + self.commitment_cost * e_latent_loss
# 重参数化技巧
quantized = z + (quantized - z).detach()
return quantized, loss
def train():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
images = pickle.load(open('data/images.pkl', 'rb'))
dataloader = DataLoader(images, batch_size=256, shuffle=True)
encoder = Encoder().to(device)
decoder = Decoder().to(device)
vectors = VectorQuantizer().to(device)
optimizer = optim.AdamW(list(encoder.parameters()) + list(vectors.parameters()) + list(decoder.parameters()), lr=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
num_epochs = 300
for epoch in range(num_epochs):
total_loss = 0.0
for batch_idx, encoder_images in enumerate(dataloader):
# 1. 将输入图像压缩到潜在分布空间
encoder_images = encoder_images.to(device)
encoder_outputs = encoder(encoder_images)
# 2. 从潜在分布空间进行采样
vector, quantizer_loss = vectors(encoder_outputs)
# 3. 输入潜在分布空间采样重构图像
decoder_outputs = decoder(vector)
optimizer.zero_grad()
# recon_loss = nn.MSELoss()(decoder_outputs, encoder_images)
recon_loss = torch.mean((decoder_outputs - encoder_images) ** 2)
# print('recon_loss:', recon_loss)
# print('-' * 50)
loss = recon_loss + quantizer_loss
total_loss += loss.item()
# 5. 梯度计算和参数更新
loss.backward()
# torch.nn.utils.clip_grad_norm_(list(encoder.parameters()) + list(vectors.parameters()) + list(decoder.parameters()), max_norm=1.0)
optimizer.step()
scheduler.step()
print(f'Epoch {epoch + 1} Lr: {scheduler.get_last_lr()[0]:.10f} Loss: {total_loss:.6f} ')
if (epoch + 1) % 10 == 0:
with torch.no_grad():
codebook = vectors.embedding.weight # 获取代码本
sampled_indices = torch.randint(0, codebook.size(0), (3,)) # 随机采样索引
sampled_codes = codebook[sampled_indices] # 获取对应的代码向量
generated_images = decoder(sampled_codes)
for idx, image in enumerate(generated_images):
plt.subplot(1, 3, idx + 1)
image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
image = (image + 1) / 2
plt.imshow(image)
plt.axis('off')
plt.show()
if __name__ == '__main__':
train()

冀公网安备13050302001966号