变分自编码器(Variational Autoencoder, VAE)是一种概率生成模型,属于深度生成模型的范畴。它能够学习数据的潜在分布,并通过采样潜在空间中的点来生成新的数据。例如,在图像生成任务中,VAE 可以学习到一组图像的特征,并基于这些特征生成新的图像。生成图像过程:
- 训练 VAE,使其学习输入图像的潜在表示。
- 训练完成后,从标准正态分布 \( N \) 中随机采样一个潜在向量 \( z \)。
- 通过解码器,将 \( z \) 转换为图像,从而生成新图像。
1. 数据处理
import torch from torch.utils.data import Dataset from torchvision import transforms import numpy as np import matplotlib.pyplot as plt class CustomDataset(Dataset): def __init__(self): with open('stl10_binary/train_X.bin', 'rb') as f: numpy_images = np.fromfile(f, dtype=np.uint8) data1 = numpy_images.reshape(-1, 3, 96, 96).transpose(0, 3, 2, 1) with open('stl10_binary/test_X.bin', 'rb') as f: numpy_images = np.fromfile(f, dtype=np.uint8) data2 = numpy_images.reshape(-1, 3, 96, 96).transpose(0, 3, 2, 1) with open('stl10_binary/unlabeled_X.bin', 'rb') as f: numpy_images = np.fromfile(f, dtype=np.uint8) data3 = numpy_images.reshape(-1, 3, 96, 96).transpose(0, 3, 2, 1) self.data = np.concatenate((data1, data2, data3), axis=0)[:10000] print('数据量:', len(self.data)) self.transform = transforms.Compose([ transforms.ToTensor(), # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) def show(self): indices = np.arange(0, len(self.data)) indexes = np.random.choice(indices, 9) for idx, index in enumerate(indexes): plt.subplot(3, 3, idx + 1) plt.imshow(self.data[index]) plt.axis('off') plt.show() def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] if self.transform: sample = self.transform(sample) return sample if __name__ == '__main__': images = CustomDataset() images.show() # for idx, image in enumerate(images): # print(image.shape, torch.min(image).item(), torch.max(image).item())
2. 模型搭建
import torch.nn as nn import torch import matplotlib.pyplot as plt class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() # 潜在变量的维度 self.latent_dim = 256 # 用于将输入图像压缩到低维表示 self.encoder = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), # [,64,96,96] nn.ReLU(), nn.BatchNorm2d(64), nn.Conv2d(64, 64, 3, 1, 1), # [,64,96,96] nn.ReLU(), nn.BatchNorm2d(64), nn.Conv2d(64, 64, 3, 1, 1), # [,64,96,96] nn.ReLU(), nn.MaxPool2d(2, 2), # [,64,48,48] nn.BatchNorm2d(64), nn.Conv2d(64, 128, 3, 1, 1), # [,128,48,48] nn.ReLU(), nn.BatchNorm2d(128), nn.Conv2d(128, 128, 3, 1, 1), # [,128,48,48] nn.ReLU(), nn.BatchNorm2d(128), nn.Conv2d(128, 128, 3, 1, 1), # [,256,48,48] nn.ReLU(), nn.MaxPool2d(2, 2), # [,256,24,24] nn.BatchNorm2d(128), nn.Flatten(start_dim=1), ) # 计算潜在分布均值和对数方差 self.mean = nn.Linear(128 * 24 * 24, self.latent_dim) # 均值 self.lvar = nn.Linear(128 * 24 * 24, self.latent_dim) # 对数方差 # 用于根据潜在变量重构图像 self.decoder = nn.Sequential( nn.Linear(self.latent_dim, 64 * 24 * 24), nn.Unflatten(1, (64, 24, 24)), nn.ConvTranspose2d(64, 128, 3, 1, 1), nn.ReLU(), nn.BatchNorm2d(128), nn.ConvTranspose2d(128, 128, 3, 2, 1, 1), nn.ReLU(), nn.BatchNorm2d(128), nn.ConvTranspose2d(128, 64, 3, 1, 1), nn.ReLU(), nn.BatchNorm2d(64), nn.ConvTranspose2d(64, 32, 3, 1, 1), nn.ReLU(), nn.BatchNorm2d(32), nn.ConvTranspose2d(32, 32, 3, 1, 1), nn.ConvTranspose2d(32, 16, 3, 2, 1, 1), nn.ReLU(), nn.BatchNorm2d(16), nn.ConvTranspose2d(16, 3, 3, 1, 1), nn.Sigmoid(), ) def reparameterization(self, mean, lvar): """重参数化,用于解决梯度传递问题""" std = torch.exp(0.5 * lvar) eps = torch.randn_like(std) z = mean + eps * std return z def loss_function(self, pred_images, true_images, mean, lvar, beta=3): """MSE + KLD 损失计算""" MSE = nn.MSELoss(reduction='sum')(pred_images, true_images) KLD = -0.5 * torch.sum(1 + lvar - mean.pow(2) - lvar.exp()) batch_loss = MSE + beta * KLD return batch_loss @torch.no_grad() def generate(self): """随机生成图像""" # 1. 从正态分布随机采样 z = torch.randn(3, self.latent_dim).to(next(self.parameters()).device) # 2. 从随机向量重构图像 generated_images = self.decoder(z) # 3. 显示生成图像 plt.suptitle('generate') for idx, image in enumerate(generated_images): plt.subplot(1, 3, idx + 1) image = (image + 1) / 2 plt.imshow(image.permute(1, 2, 0).cpu().numpy()) plt.axis('off') plt.show() @torch.no_grad() def reconstruct(self, input_images): """根据输入图像重构图像""" input_images = input_images.to(next(self.parameters()).device) # 1. 压缩输入图像信息 latent_vectors = self.encoder(input_images) # 2. 计算输入图像分布 mean = self.mean(latent_vectors) lvar = self.lvar(latent_vectors) # 3. 重参数化,采样潜在变量 z = self.reparameterization(mean, lvar) # 4. 根据输入重构图像 generated_images = self.decoder(z) # 3. 显示重建图像图像 plt.suptitle('reconstruct') for idx, image in enumerate(generated_images): plt.subplot(1, 3, idx + 1) image = (image + 1) / 2 plt.imshow(image.permute(1, 2, 0).cpu().numpy()) plt.axis('off') plt.show() def forward(self, true_images): """前向计算损失""" # 1. 将输入图像压缩到潜在分布空间 true_images = true_images.to(next(self.parameters()).device) # 2. 压缩输入图像信息 latent_vectors = self.encoder(true_images) # 3. 计算潜在向量分布 mean = self.mean(latent_vectors) lvar = self.lvar(latent_vectors) # 4. 重参数化 z = self.reparameterization(mean, lvar) # 5. 根据潜在变量重构图像 pred_images = self.decoder(z) # 4. 计算重构损失和 KLD 损失 batch_loss = self.loss_function(pred_images, true_images, mean, lvar) return batch_loss
3. 模型训练
import pickle import warnings warnings.filterwarnings('ignore') import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from tqdm import tqdm import os import shutil from model import VAE from custom_dataset import CustomDataset def train(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dataset = CustomDataset() dataloader = DataLoader(dataset, batch_size=64, shuffle=True) estimator = VAE().to(device) optimizer = optim.Adam(estimator.parameters(), lr=5e-6) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) num_epochs = 500 for epoch in range(num_epochs): total_loss = 0.0 total_size = 0 progress = tqdm(range(len(dataloader))) for batch_idx, input_images in enumerate(dataloader): # 计算损失 loss = estimator(input_images) # 统计损失 batch_size = input_images.size(0) total_loss += loss.item() total_size += batch_size # 参数更新 optimizer.zero_grad() (loss / batch_size).backward() optimizer.step() progress.set_description(f'Epoch {epoch + 1} Loss: {total_loss / total_size:.5f}') progress.update() progress.close() # scheduler.step() if (epoch + 1) % 2 == 0: save_path = f'model/{epoch + 1}/' if os.path.exists(save_path) or os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) pickle.dump(estimator, open(save_path + 'estimator.pkl', 'wb')) estimator.generate() estimator.reconstruct(input_images[:3]) if __name__ == '__main__': train()