之前学习了 VAE 相关内容,本文记录 VAE 的条件版 CVAE(Conditional VAE)。
其中 (x,y) 为一般有监督学习中的数据对。可以看出CVAE相当于一个有监督版本的VAE,它重构/生成的是 y \mid x (VAE重构/生成的是 x )。举个例子,若令 x 表示手写数字的类别标签, y 表示手写数字图像,就可以通过采样 z 生成指定的数字 x 对应的图像 y 。值得一提的是,VAE 中的关于 z 的先验项是 p_{\theta}(z) ,而 CVAE 中的先验项 p_{\theta}(z \mid x) 与 x 有关,在网络实现上就会有一个从 x 到 z 的 “先验网络”。
网络结构包含三个部分:
通过条件改变隐变量的均值,从而控制了隐变量采样的位置,控制最后的输出结果。
先看图 (b),代表了整个从 x 推断到 y 的过程,如果理解的话其实这是一个生成的过程 (生成 y \mid x) :先从输入 x 经过一个先验网络到 z (重参数采样),再由 x 和 z 生成 y 。然而,这篇文章后面的实验都采用了图 (d) 的架构。也就是 x 先通过一个 baseline CNN得到一个 \hat{y} ,再 由 x 和 \hat{y} 共同得到 z 的先验。个人认为这个操作就是为了得到效果保证而启发式地设计的,理论上不太漂亮。
class CVAE(nn.Module):
"""Implementation of CVAE(Conditional Variational Auto-Encoder)"""
def __init__(self, feature_size, class_size, latent_size):
super(CVAE, self).__init__()
self.fc1 = nn.Linear(feature_size + class_size, 200)
self.fc2_mu = nn.Linear(200, latent_size)
self.fc2_log_std = nn.Linear(200, latent_size)
self.fc3 = nn.Linear(latent_size + class_size, 200)
self.fc4 = nn.Linear(200, feature_size)
def encode(self, x, y):
h1 = F.relu(self.fc1(torch.cat([x, y], dim=1))) # concat features and labels
mu = self.fc2_mu(h1)
log_std = self.fc2_log_std(h1)
return mu, log_std
def decode(self, z, y):
h3 = F.relu(self.fc3(torch.cat([z, y], dim=1))) # concat latents and labels
recon = torch.sigmoid(self.fc4(h3)) # use sigmoid because the input image's pixel is between 0-1
return recon
def reparametrize(self, mu, log_std):
std = torch.exp(log_std)
eps = torch.randn_like(std) # simple from standard normal distribution
z = mu + eps * std
return z
def forward(self, x, y):
mu, log_std = self.encode(x, y)
z = self.reparametrize(mu, log_std)
recon = self.decode(z, y)
return recon, mu, log_std
def loss_function(self, recon, x, mu, log_std) -> torch.Tensor:
recon_loss = F.mse_loss(recon, x, reduction="sum") # use "mean" may have a bad effect on gradients
kl_loss = -0.5 * (1 + 2*log_std - mu.pow(2) - torch.exp(2*log_std))
kl_loss = torch.sum(kl_loss)
loss = recon_loss + kl_loss
return loss
参考资料