最近研究在自编码器,放一个复现的代码,移除了工程相关的代码,只保留了核心,有多卡accelerate就设置为True,没有就关了。
Decode 和 Encode 参考了stable diffusion的设计,Decode最后一层改成了方差和均值(也就是纯血VAE)特征图通过采样产生,再使用VQ量化特征图。图片最后还是有些胡,感觉是因为有些图像被压缩过,插值成256*256,或者jpeg格式的有损压缩导致了数据有噪声被学会了。
数据源:
Konachan动漫头像数据集_数据集-飞桨AI Studio星河社区
效果图
epoch 0 step 100
epoch 6 step 10000
epoch 50 step 85000epoch 100 176700
模型代码
import mathimport numpy as np
import torch
from torch import nn
from torch.nn import functional as Fclass ConvBlock(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, groups=1):super(ConvBlock, self).__init__()self.conv_block = nn.Sequential(nn.GroupNorm(groups, in_channels),nn.SiLU(),nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),)def forward(self, x):return self.conv_block(x)class ResnetBlock(nn.Module):def __init__(self, in_channels, out_channels, groups=32):super(ResnetBlock, self).__init__()self.conv_block = nn.Sequential(ConvBlock(in_channels, out_channels, groups=groups),ConvBlock(out_channels, out_channels, groups=groups),)if in_channels != out_channels:self.skip_conn = ConvBlock(in_channels, out_channels, kernel_size=1, padding=0, groups=groups)else:self.skip_conn = nn.Identity()def forward(self, x):return self.conv_block(x) + self.skip_conn(x)class AttentionBlock(nn.Module):def __init__(self, in_channels, out_channels, groups=32):super(AttentionBlock, self).__init__()self.q_conv = ConvBlock(in_channels, out_channels, kernel_size=1, padding=0, groups=groups)self.k_conv = ConvBlock(in_channels, out_channels, kernel_size=1, padding=0, groups=groups)self.v_conv = ConvBlock(in_channels, out_channels, kernel_size=1, padding=0, groups=groups)self.out_conv = ConvBlock(out_channels, out_channels, kernel_size=1, padding=0, groups=groups)if in_channels != out_channels:self.skip_conn = ConvBlock(in_channels, out_channels, kernel_size=1, padding=0, groups=groups)else:self.skip_conn = nn.Identity()def forward(self, x):q = self.q_conv(x)k = self.k_conv(x)v = self.v_conv(x)attention = torch.einsum('bchw,bcHW->bhwHW', q, k)attention = attention / math.sqrt(q.shape[-1])attention = attention.softmax(dim=-1)out = torch.einsum('bhwHW,bcHW->bchw', attention, v)out = self.out_conv(out)return out + self.skip_conn(x)class MiddleBlock(nn.Module):def __init__(self, in_channels, out_channels, groups=32):super(MiddleBlock, self).__init__()self.conv_block = nn.Sequential(ResnetBlock(in_channels, out_channels, groups=groups),AttentionBlock(out_channels, out_channels, groups=groups),ResnetBlock(out_channels, out_channels, groups=groups),)def forward(self, x):return self.conv_block(x)class UpSample(nn.Module):def __init__(self, in_channels, out_channels):super(UpSample, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1)def forward(self, x):x = nn.functional.interpolate(x, scale_factor=2)x = self.conv(x)return xclass DownSample(nn.Module):def __init__(self, in_channels, out_channels):super(DownSample, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, 3, 2, 0)def forward(self, x):pad = (0, 1, 0, 1)x = F.pad(x, pad, mode='constant', value=0)x = self.conv(x)return xclass DownBlock(nn.Module):def __init__(self, in_channels, out_channels):super(DownBlock, self).__init__()self.down_block = nn.Sequential(ResnetBlock(in_channels, out_channels),ResnetBlock(out_channels, out_channels),)def forward(self, x):return self.down_block(x)class UpBlock(nn.Module):def __init__(self, in_channels, out_channels):super(UpBlock, self).__init__()self.up_block = nn.Sequential(ResnetBlock(in_channels, out_channels),ResnetBlock(out_channels, out_channels),)def forward(self, x):return self.up_block(x)class Encoder(nn.Module):def __init__(self, in_channels, z_channels, groups=32):super(Encoder, self).__init__()self.conv = nn.Conv2d(in_channels, 128, 3, 1, 1)self.res_block = self.create_resnet_block(128, 128, 2, groups=groups)self.res_block2 = self.create_resnet_block(128, 256, 2, groups=groups)self.res_block3 = self.create_resnet_block(256, 512, 2, groups=groups)self.down_block = DownBlock(512, 512)self.middle_block = MiddleBlock(512, 512, groups=groups)self.conv_block = ConvBlock(512, z_channels * 2, groups=groups)@staticmethoddef create_resnet_block(in_channels, out_channels, num_blocks, groups=32):res_blocks = []for _ in range(num_blocks):res_blocks.append(ResnetBlock(in_channels, in_channels, groups=groups))res_blocks.append(DownSample(in_channels, out_channels))return nn.Sequential(*res_blocks)def forward(self, x):x = self.conv(x)x = self.res_block(x)x = self.res_block2(x)x = self.res_block3(x)x = self.down_block(x)x = self.middle_block(x)x = self.conv_block(x)return xclass Decoder(nn.Module):def __init__(self, in_channels, groups=32):super(Decoder, self).__init__()self.conv = nn.Conv2d(in_channels, 512, 3, 1, 1)self.middle_block = MiddleBlock(512, 512, groups=groups)self.resnet_block = self.create_resnet_block(512, 512, 3, groups=groups)self.resnet_block2 = self.create_resnet_block(512, 256, 3, groups=groups)self.resnet_block3 = self.create_resnet_block(256, 128, 3, groups=groups)self.up_block = UpBlock(128, 128)self.conv_block = ConvBlock(128, 3, groups=groups)@staticmethoddef create_resnet_block(in_channels, out_channels, num_blocks, groups=32):res_blocks = []for _ in range(num_blocks):res_blocks.append(ResnetBlock(in_channels, in_channels, groups=groups))res_blocks.append(UpSample(in_channels, out_channels))return nn.Sequential(*res_blocks)def forward(self, x):x = self.conv(x)x = self.middle_block(x)x = self.resnet_block(x)x = self.resnet_block2(x)x = self.resnet_block3(x)x = self.up_block(x)x = self.conv_block(x)return xclass DiagonalGaussianDistribution(object):def __init__(self, parameters, deterministic=False):self.parameters = parametersself.mean, self.logvar = torch.chunk(parameters, 2, dim=1)self.logvar = torch.clamp(self.logvar, -30.0, 20.0)self.deterministic = deterministicself.std = torch.exp(0.5 * self.logvar)self.var = torch.exp(self.logvar)if self.deterministic:self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)def sample(self):x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)return xdef kl(self, other=None):if self.deterministic:return torch.Tensor([0.])else:if other is None:return 0.5 * torch.sum(torch.pow(self.mean, 2)+ self.var - 1.0 - self.logvar,dim=[1, 2, 3])else:return 0.5 * torch.sum(torch.pow(self.mean - other.mean, 2) / other.var+ self.var / other.var - 1.0 - self.logvar + other.logvar,dim=[1, 2, 3])def nll(self, sample, dims=None):if dims is None:dims = [1, 2, 3]if self.deterministic:return torch.Tensor([0.])log_two_pi = np.log(2.0 * np.pi)return 0.5 * torch.sum(log_two_pi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,dim=dims)def mode(self):return self.meanclass VectorQuantizer(nn.Module):"""带EMA更新的向量量化层"""def __init__(self, num_embeddings, embedding_dim, beta=0.25, decay=0.99, epsilon=1e-5, ema=False):super().__init__()self.embedding_dim = embedding_dimself.num_embeddings = num_embeddingsself.beta = betaself.decay = decayself.epsilon = epsilonself.ema = ema# 码本初始化self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)self.embedding.weight.data.normal_()# self.embedding.requires_grad_(False)# EMA统计量self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))self.register_buffer('_ema_w', self.embedding.weight.data.clone())def forward(self, z):# 形状变换z = z.permute(0, 2, 3, 1) # [B, D, H, W] -> [B, H, W, D]z_flattened = z.reshape(-1, self.embedding_dim)# 计算码本距离distances = torch.cdist(z_flattened, self.embedding.weight, p=2.0) ** 2# 获取最近邻编码encoding_indices = torch.argmin(distances, dim=1)quantized = self.embedding(encoding_indices).view(z.shape)quantized = quantized.permute(0, 3, 1, 2)vq_loss = self.beta * F.mse_loss(quantized.detach(), z.permute(0, 3, 1, 2))# EMA 更新if self.training and self.ema:with torch.no_grad():# 更新 EMA 统计量encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=z.device)encodings.scatter_(1, encoding_indices.view(-1, 1), 1)updated_ema_cluster_size = self._ema_cluster_size * self.decay + (1 - self.decay) * torch.sum(encodings,0)# Laplace平滑n = torch.sum(updated_ema_cluster_size)updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon)/ (n + self.num_embeddings * self.epsilon) * n)dw = torch.matmul(encodings.t(), z_flattened)updated_ema_w = self._ema_w * self.decay + (1 - self.decay) * dw# 更新码本self._ema_cluster_size.data.copy_(updated_ema_cluster_size)self.embedding.weight.data.copy_(updated_ema_w / updated_ema_cluster_size.unsqueeze(1))else:codebook_loss = F.mse_loss(quantized, z.permute(0, 3, 1, 2).detach())vq_loss = vq_loss + codebook_loss# 直通估计quantized = z.permute(0, 3, 1, 2) + (quantized - z.permute(0, 3, 1, 2)).detach()return quantized, encoding_indices, vq_lossclass VAE(nn.Module):def __init__(self, in_channels, groups=32, z_channels=4, embedding_dim=4):super(VAE, self).__init__()self.scale_factor = 0.18215self.encoder = Encoder(in_channels, z_channels, groups=groups)self.decoder = Decoder(z_channels, groups=groups)self.quant_conv = nn.Conv2d(z_channels * 2, embedding_dim * 2, 1, 1, 0)self.post_quant_conv = nn.Conv2d(embedding_dim, z_channels, 1, 1, 0)def encode(self, x):h = self.encoder(x)moments = self.quant_conv(h)posterior = DiagonalGaussianDistribution(moments)out = posterior.sample()out = self.scale_factor * outreturn outdef decode(self, z):z = 1. / self.scale_factor * zz = self.post_quant_conv(z)dec = self.decoder(z)return decdef forward(self, x):z = self.encode(x)dec = self.decode(z)return decdef generate(self, x):x = self.decoder(x)return xclass VQVAE(VAE):def __init__(self, in_channels=3, groups=8, z_channels=4, embedding_dim=4, num_embeddings=8196, beta=0.25,decay=0.99, epsilon=1e-5):super(VQVAE, self).__init__(in_channels, groups, z_channels, embedding_dim)self.quantize = VectorQuantizer(num_embeddings,embedding_dim,ema=True,beta=beta,decay=decay,epsilon=epsilon)def forward(self, x):z = self.encode(x)quantized, _, vq_loss = self.quantize(z)x_recon = self.decode(quantized)return x_recon, vq_lossdef calculate_balance_facter(self, perceptual_loss, gan_loss):last_layer = self.decoder.conv_block.conv_block[-1]last_layer_weight = last_layer.weightperceptual_loss_grads = torch.autograd.grad(perceptual_loss, last_layer_weight, retain_graph=True)[0]gan_loss_grads = torch.autograd.grad(gan_loss, last_layer_weight, retain_graph=True)[0]alpha = torch.norm(perceptual_loss_grads) / (torch.norm(gan_loss_grads) + 1e-4)alpha = torch.clamp(alpha, 0, 1e4).detach()return 0.8 * alpha
训练脚本
import osimport numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from accelerate import Accelerator, DistributedDataParallelKwargs
from lpips import LPIPS
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.models import VGG19_Weights
from tqdm import tqdmfrom vae import VQVAE# --------------------------
# 对抗组件
# --------------------------class Discriminator(nn.Module):"""多尺度判别器"""def __init__(self, in_channels=3, base_channels=4, num_layers=3):super().__init__()layers = [nn.Conv2d(in_channels, base_channels, 4, 2, 1), nn.LeakyReLU(0.2)]channels = base_channelsfor _ in range(1, num_layers):layers += [nn.Conv2d(channels, channels * 2, 4, 2, 1),nn.InstanceNorm2d(channels * 2),nn.LeakyReLU(0.2)]channels *= 2layers += [nn.Conv2d(channels, channels, 4, 1, 0),nn.InstanceNorm2d(channels),nn.LeakyReLU(0.2),nn.Conv2d(channels, 1, 1)]self.model = nn.Sequential(*layers)def forward(self, x):return self.model(x)class PerceptualLoss(nn.Module):def __init__(self, layers=None):super(PerceptualLoss, self).__init__()if layers is None:layers = ['1', '2', '4', '7']self.layers = layersself.vgg = torchvision.models.vgg19(weights=VGG19_Weights.DEFAULT).features.eval()self.vgg.requires_grad_(False)for name, module in self.vgg.named_modules():if name in layers:module.register_forward_hook(self.forward_hook)self.features = []def forward_hook(self, module, input, output):self.features.append(output)def forward(self, x, x_recon):x_and_x_recon = torch.cat((x, x_recon), dim=0)self.features = []self.vgg(x_and_x_recon)x_and_x_recon_features = self.featuresloss = torch.tensor(0.0, device=x.device)for i, layer in enumerate(self.layers):x_feature = x_and_x_recon_features[i][:x.shape[0]]x_norm_factor = torch.sqrt(torch.mean(x_feature ** 2, dim=1, keepdim=True))x_feature = x_feature / x_norm_factorx_recon_feature = x_and_x_recon_features[i][x.shape[0]:]x_recon_norm_factor = torch.sqrt(torch.mean(x_recon_feature ** 2, dim=1, keepdim=True))x_recon_feature = x_recon_feature / x_recon_norm_factorloss += F.l1_loss(x_feature, x_recon_feature)return loss# --------------------------
# 训练循环
# --------------------------def train_vqgan(dataloader, epochs=100, mixed_precision=False, accelerate=False, disc_start=10000, rec_factor=1,perceptual_factor=1, learning_rate=4.5e-6, in_channels=3, groups=8, z_channels=4, embedding_dim=4,num_embeddings=8196, beta=0.25, decay=0.99, epsilon=1e-5):os.makedirs('results', exist_ok=True)# 初始化模型model = VQVAE(in_channels, groups, z_channels, embedding_dim, num_embeddings, beta, decay, epsilon)discriminator = Discriminator()# perceptual_loss_fn = PerceptualLoss()perceptual_loss_fn = LPIPS().eval()# 优化器opt_ae = Adam(list(model.encoder.parameters()) + list(model.decoder.parameters())+ list(model.quantize.parameters()), lr=learning_rate, betas=(0.5, 0.9))opt_disc = Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.9))gradient_accumulation_steps = 4step = 0start_epoch = 0if os.path.exists("vqgan.pth"):state_dict = torch.load("vqgan.pth")step = state_dict.get("step", 0)start_epoch = state_dict.get("epoch", 0)model.load_state_dict(state_dict.get("model", {}))discriminator.load_state_dict(state_dict.get("discriminator", {}))opt_ae.load_state_dict(state_dict.get("opt_ae", {}))opt_disc.load_state_dict(state_dict.get("opt_disc", {}))ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)if accelerate:accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps,mixed_precision='fp16' if mixed_precision else 'no',kwargs_handlers=[ddp_kwargs])# 加速器model, discriminator, perceptual_loss_fn, opt_ae, opt_disc, dataloader = accelerator.prepare(model, discriminator, perceptual_loss_fn, opt_ae, opt_disc, dataloader)device = accelerator.deviceelse:accelerator = Nonedevice = "cuda" if torch.cuda.is_available() else "cpu"model = model.to(device)discriminator = discriminator.to(device)perceptual_loss_fn = perceptual_loss_fn.to(device)for epoch in range(start_epoch, epochs):with tqdm(range(len(dataloader))) as pbar:for _, batch in zip(pbar, dataloader):x, _ = batchx = x.to(device)if accelerator is not None:# 生成器更新with accelerator.autocast():disc_loss, g_loss, perceptual_loss, rec_loss, total_loss, vq_loss, x_recon = train_step(accelerator,disc_start,discriminator,model,perceptual_factor,perceptual_loss_fn,rec_factor,step,x)opt_ae.zero_grad()accelerator.backward(total_loss, retain_graph=True)opt_disc.zero_grad()accelerator.backward(disc_loss)opt_ae.step()opt_disc.step()else:# 生成器更新with torch.amp.autocast(device, enabled=mixed_precision):disc_loss, g_loss, perceptual_loss, rec_loss, total_loss, vq_loss, x_recon = train_step(accelerator,disc_start,discriminator,model,perceptual_factor,perceptual_loss_fn,rec_factor,step,x)opt_ae.zero_grad()total_loss.backward(retain_graph=True)opt_disc.zero_grad()disc_loss.backward()opt_ae.step()opt_disc.step()pbar.set_postfix(TotalLoss=np.round(total_loss.cpu().detach().numpy().item(), 5),DiscLoss=np.round(disc_loss.cpu().detach().numpy().item(), 3),PerceptualLoss=np.round(perceptual_loss.cpu().detach().numpy().item(), 5),RecLoss=np.round(rec_loss.cpu().detach().numpy().item(), 5),GenLoss=np.round(g_loss.cpu().detach().numpy().item(), 5),VqLoss=np.round(vq_loss.cpu().detach().numpy().item(), 5))pbar.update(0)# 日志记录if step % 100 == 0:if accelerator:if accelerator.is_main_process:with torch.no_grad():fake_image = x_recon[:4].permute(0, 2, 3, 1).contiguous()means = torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 1, 3).to(fake_image.device)stds = torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 1, 3).to(fake_image.device)fake_image = fake_image * stds + meansfake_image.clamp_(0, 1)fake_image = fake_image.permute(0, 3, 1, 2).contiguous()real_image = x[:4].permute(0, 2, 3, 1).contiguous()real_image = real_image * stds + meansreal_image.clamp_(0, 1)real_image = real_image.permute(0, 3, 1, 2).contiguous()real_fake_images = torch.cat((real_image, fake_image))torchvision.utils.save_image(real_fake_images,os.path.join("results", f"{epoch}_{step}.jpg"),nrow=4)else:with torch.no_grad():real_fake_images = torch.cat((x[:4], x_recon.add(1).mul(0.5)[:4]))torchvision.utils.save_image(real_fake_images,os.path.join("results", f"{epoch}_{step}.jpg"),nrow=4)step += 1if accelerate:if accelerate and accelerator.is_main_process:unwrapped_model = accelerator.unwrap_model(model)unwrapped_discriminator = accelerator.unwrap_model(discriminator)# 保存模型state_dict = {"model": unwrapped_model.state_dict(),"discriminator": unwrapped_discriminator.state_dict(),"opt_ae": opt_ae.state_dict(),"opt_disc": opt_disc.state_dict(),"step": step,"epoch": epoch}torch.save(state_dict, "vqgan.pth")else:# 保存模型state_dict = {"model": model.state_dict(),"discriminator": discriminator.state_dict(),"opt_ae": opt_ae.state_dict(),"opt_disc": opt_disc.state_dict(),"step": step,"epoch": epoch}torch.save(state_dict, "vqgan.pth")return model, discriminator, opt_ae, opt_discdef train_step(accelerator, disc_start, discriminator, model, perceptual_factor, perceptual_loss_fn, rec_factor, step,x):x_recon, vq_loss = model(x)disc_real = discriminator(x)disc_faker = discriminator(x_recon)disc_factor = 0 if disc_start > step else 1perceptual_loss = perceptual_loss_fn(x, x_recon).mean()rec_loss = F.l1_loss(x_recon, x)perceptual_rec_loss = perceptual_factor * perceptual_loss + rec_factor * rec_lossperceptual_rec_loss = perceptual_rec_loss.mean()g_loss = -torch.mean(disc_faker)if accelerator:balance_facter = model.module.calculate_balance_facter(perceptual_rec_loss, g_loss)else:balance_facter = model.calculate_balance_facter(perceptual_rec_loss, g_loss)total_loss = perceptual_rec_loss + vq_loss + disc_factor * balance_facter * g_lossd_real_loss = F.binary_cross_entropy_with_logits(disc_real, torch.ones_like(disc_real))d_fake_loss = F.binary_cross_entropy_with_logits(disc_faker, torch.zeros_like(disc_faker))disc_loss = disc_factor * 0.5 * (d_real_loss + d_fake_loss)return disc_loss, g_loss, perceptual_loss, rec_loss, total_loss, vq_loss, x_recondef get_imagenet_dataloader(batch_size=32, data_path="datasets/faces"):# 数据加载transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])train_dataset = ImageFolder(data_path, transform=transform)return DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)# --------------------------
# 使用示例
# --------------------------if __name__ == "__main__":# 数据加载(示例)train_loader = get_imagenet_dataloader(batch_size=12, data_path="faces")# 开始训练train_vqgan(train_loader, epochs=100, mixed_precision=True, accelerate=True, disc_start=10000, rec_factor=1,perceptual_factor=1, learning_rate=4.5e-6, in_channels=3, groups=8, z_channels=4, embedding_dim=4,num_embeddings=8196, beta=0.25, decay=0.99, epsilon=1e-5)