前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >条件变分自编码器 CVAE

条件变分自编码器 CVAE

作者头像
为为为什么
发布2023-05-20 14:54:37
发布2023-05-20 14:54:37
1.8K00
代码可运行
举报
文章被收录于专栏:又见苍岚又见苍岚
运行总次数:0
代码可运行

之前学习了 VAE 相关内容,本文记录 VAE 的条件版 CVAE(Conditional VAE)。

CVAE原理

  • 在条件变分自编码器(CVAE)中,模型的输出就不是 \mathbf{x}_j了,而是对应于输入\mathbf{x}_i的任务相关数据\mathbf{y}_i 也就是说输入的是条件,输出是在条件约束下的数据样本; 比如手写数字生成任务中:输入 x 可以是想输出的数字,比如 6,输出 y 则是数字 6 的手写图片。
  • 因此,我们采样的时候,不再是从 P(Z) 中直接采样,而是从 P(Z|X) 中进行采样,因此假设变成了, P(Z \mid X)=N(Z \mid \mu(X), I)
  • 套路和VAE是一样的,这次的最大似然估计变成了 \log p_{\theta}(\mathbf{Y}\mid\mathbf{X}) ,即:

  • 则ELBO(Empirical Lower Bound)为 \ell(p_{\theta}, q_{\phi}),进一步:

其中 (x,y) 为一般有监督学习中的数据对。可以看出CVAE相当于一个有监督版本的VAE,它重构/生成的是 y \mid x (VAE重构/生成的是 x )。举个例子,若令 x 表示手写数字的类别标签, y 表示手写数字图像,就可以通过采样 z 生成指定的数字 x 对应的图像 y 。值得一提的是,VAE 中的关于 z 的先验项是 p_{\theta}(z) ,而 CVAE 中的先验项 p_{\theta}(z \mid x) x 有关,在网络实现上就会有一个从 x z 的 “先验网络”。

网络结构

网络结构包含三个部分:

  1. 先验网络 p_{\theta}(\mathbf{z}\mid\mathbf{X}),如下图(b)所示
  2. Recognition 网络 q_{\phi}(\mathbf{z}\mid\mathbf{X},\mathbf{Y}), 如下图©所示
  3. Decoder网络 p_{\theta}(\mathbf{Y}\mid\mathbf{X},\mathbf{Z}),如下图(b)所示

通过条件改变隐变量的均值,从而控制了隐变量采样的位置,控制最后的输出结果。

先看图 (b),代表了整个从 x 推断到 y 的过程,如果理解的话其实这是一个生成的过程 (生成 y \mid x) :先从输入 x 经过一个先验网络到 z (重参数采样),再由 xz 生成 y 。然而,这篇文章后面的实验都采用了图 (d) 的架构。也就是 x 先通过一个 baseline CNN得到一个 \hat{y} ,再 由 x \hat{y} 共同得到 z 的先验。个人认为这个操作就是为了得到效果保证而启发式地设计的,理论上不太漂亮。

对比 VAE

  • VAE的变分下界为:
\mathcal{L}(\phi, \theta ; x)=-K L\left(q_{\phi}(z \mid x) | p_{\theta}(z)\right)+\mathbb{E}{q{\phi}(z \mid x)}\left[\log p_{\theta}(z \mid x)\right] \leq \log p_{\theta}(x)
  • CVAE的变分下界为:
\mathcal{L}(\phi, \theta ; x, y)=-K L\left(q_{\phi}(z \mid x, y) | p_{\theta}(z \mid x)\right)+\mathbb{E}{q{\phi}(z \mid x, y)}\left[\log p_{\theta}(y \mid x, z)\right] \leq \log p_{\theta}(y \mid x)

示例代码

代码语言:javascript
代码运行次数:0
运行
复制
class CVAE(nn.Module):
    """Implementation of CVAE(Conditional Variational Auto-Encoder)"""
    def __init__(self, feature_size, class_size, latent_size):
        super(CVAE, self).__init__()
        self.fc1 = nn.Linear(feature_size + class_size, 200)
        self.fc2_mu = nn.Linear(200, latent_size)
        self.fc2_log_std = nn.Linear(200, latent_size)
        self.fc3 = nn.Linear(latent_size + class_size, 200)
        self.fc4 = nn.Linear(200, feature_size)
    def encode(self, x, y):
        h1 = F.relu(self.fc1(torch.cat([x, y], dim=1)))  # concat features and labels
        mu = self.fc2_mu(h1)
        log_std = self.fc2_log_std(h1)
        return mu, log_std
    def decode(self, z, y):
        h3 = F.relu(self.fc3(torch.cat([z, y], dim=1)))  # concat latents and labels
        recon = torch.sigmoid(self.fc4(h3))  # use sigmoid because the input image's pixel is between 0-1
        return recon
    def reparametrize(self, mu, log_std):
        std = torch.exp(log_std)
        eps = torch.randn_like(std)  # simple from standard normal distribution
        z = mu + eps * std
        return z
    def forward(self, x, y):
        mu, log_std = self.encode(x, y)
        z = self.reparametrize(mu, log_std)
        recon = self.decode(z, y)
        return recon, mu, log_std
    def loss_function(self, recon, x, mu, log_std) -> torch.Tensor:
        recon_loss = F.mse_loss(recon, x, reduction="sum")  # use "mean" may have a bad effect on gradients
        kl_loss = -0.5 * (1 + 2*log_std - mu.pow(2) - torch.exp(2*log_std))
        kl_loss = torch.sum(kl_loss)
        loss = recon_loss + kl_loss
        return loss

原始论文

https://101.43.39.125/HexoFiles/js/vvd_js/pdfjs/web/viewer.html?file=https://101.43.39.125/HexoFiles/vvd_file_mt/202305181756155.pdf

参考资料

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2023年5月18日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • CVAE原理
  • 网络结构
  • 对比 VAE
  • 示例代码
  • 原始论文
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档