掩码自动编码器MAE是一种可扩展的计算机视觉自监督学习器。MAE的基本思路是:屏蔽输入图像的随机补丁,并重建丢失的像素,它基于两个核心设计。
本文所涉及的所有资源的获取方式:这里
MAE的掩码自编码器是一种简单地自编码方法,它在给定原始信号的部分观测值的情况下重建原始信号。和所有的自编码器一样,MAE有一个将观察到的信号映射到潜在表示的编码器,以及一个从潜在表示重建原始信号的解码器。与经典的自编码器不同,MAE采用了一种非对称设计,允许编码器仅对部分观察到的信号进行操作(没有掩码标记),并采用了一个轻量级解码器,该解码器根据潜在表示和掩码标记重建全部信号。
掩码 与ViT相同,MAE将图像划分为规则的非重叠块,之后,MAE对补丁的子集进行采样,并屏蔽(即移除)剩余的补丁。MAE的采样策略很简单,对补丁随机采样,不进行替换,遵循均匀分布。具有高掩码比(去除的补丁所占的比值)的随机采样在很大程度上消除了冗余,从而产生了一个无法通过从可见的相邻补丁外推来解决的任务,均匀分布防止了潜在的中心偏差(即图像中心附近的掩码补丁越多),最后一个高度稀疏的输入为设计一个有效地编码器创建了机会。
MAE编码器 MAE的编码器是一个ViT,但只应用与可见的、未屏蔽的补丁。就像在标准的ViT中一样,MAE的编码器通过添加了位置嵌入的线性投影来嵌入补丁,然后通过一系列Transformer块来处理结果集。然而,MAE的编码器只对全集的一小部分(例如25%)进行操作。这使MAE能够仅使用一小部分计算和内存来训练非常大的编码器。
MAE解码器 MAE解码器的输入是由编码器的可见补丁和掩码令牌组成的完整令牌集。每个掩码标记是一个共享的、学习的向量,指示要预测的丢失补丁的存在。MAE将位置嵌入添加到该全集中的所有令牌中,如果没有这一点,掩码令牌将没有关于其在图像中的位置信息。
MAE解码器仅在预训练期间用于执行图像重建任务(只有编码器用于产生用于识别的图像表示。)因此,可以以独立于编码器设计的方式灵活地设计解码器架构。
重建目标 MAE通过预测每个掩码补丁的像素值来重建输入,解码器输出中的每个元素是表示补丁的像素值的矢量。解码器的最后一层是线性投影,其输出通道的数量等于块中像素值的数量。对解码器的输出进行重构以形成重构图像。MAE的损失函数在像素空间中计算重建图像和原始图像之间的均方误差(MSE),与BERT相同,MAE只计算掩码补丁上的损失。 MAE还研究了一种变体,其重建目标是每个被屏蔽补丁的归一化像素。具体来说,MAE计算一个Patch中所有像素的均值和标准差,并使用它们对该patch进行归一化。使用归一化像素作为重建的目标提高了表示质量。
简单地实现 首先,MAE为每个输入补丁生成一个标记(通过添加位置嵌入的线性投影),接下来,MAE随机打乱令牌列表,并根据屏蔽比率删除列表的最后一部分。这个过程为编码器生成一小部分标记,相当于采样补丁而不进行替换。编码后,MAE将一个掩码令牌列表添加到编码补丁列表中,并对这个完整列表纪念性unshuffle(反转随机混洗操作),以将所有标记与其目标对齐。编码器应用于该完整列表(添加了位置嵌入)。如前所述,不需要稀疏运算,这种简单地实现引入了可忽略不计的开销,因为混洗和取消混洗操作很快。
MAE随机掩码图像
具有GAN损失的情况
随机掩码
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
# 确定需要保存多少个patch
len_keep = int(L * (1 - mask_ratio))
# [1,196] 用batch此时输入的图片可能不止一个,196表示patch的个数
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
# 默认按升序排序,此时返回的是序号,首先获取从低到高排列的序号
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
# 获取ids_shuffle从低到高排列的序号,这样就能还原原始的noise的情况
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep] # 保存数据少的情况
# [1,49,1024] dim=0 按列进行索引,dim=1按行进行索引,获取x的取值
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask 为0表示没有被掩码,1表示被掩码
# 将是否被掩码通过mask表示出来
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
编码器
def forward_encoder(self, x, mask_ratio):
# embed patches [1,3,224,224]->[1,196,1024]
x = self.patch_embed(x)
# add pos embed w/o cls token 除了全局特征,全部加上了位置信息
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, mask, ids_restore = self.random_masking(x, mask_ratio)
# id_restore保存的是原来的位置
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1) # [1,1,1024]
# [1,50,1024] 要包含一个class的情况
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore
解码器
def forward_decoder(self, x, ids_restore):
# embed tokens [1,50,1024]->[1,50,512]
x = self.decoder_embed(x)
# append mask tokens to sequence 获取被掩码的token [1,147,512]
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
# 将经过编码的数据和原始的初始化为0的数据编码在一起。
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
# 将编码的和为编码的重新转变为原始的patch大小,其实本质上只需要考虑编码的位置,因为其余都是随机初始化的
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection 将其转换为所有像素
x = self.decoder_pred(x)
# remove cls token
x = x[:, 1:, :]
return x
# linux系统下python=3.7
conda create -n mae python=3.7
conda activate mae
# 下载torch
wget https://download.pytorch.org/whl/cu116/torch-1.13.0%2Bcu116-cp37-cp37m-linux_x86_64.whl
pip install 'torch的下载地址'
# 下载torchvision
wget https://download.pytorch.org/whl/cu116/torchvision-0.14.0%2Bcu116-cp37-cp37m-linux_x86_64.whl
pip install 'torchvision的下载地址'
pip install timm==0.4.5
pip install ipykernel
pip install matplotlib
pip install tensorboard
扫码关注腾讯云开发者
领取腾讯云代金券
Copyright © 2013 - 2025 Tencent Cloud. All Rights Reserved. 腾讯云 版权所有
深圳市腾讯计算机系统有限公司 ICP备案/许可证号:粤B2-20090059 深公网安备号 44030502008569
腾讯云计算(北京)有限责任公司 京ICP证150476号 | 京ICP备11018762号 | 京公网安备号11010802020287
Copyright © 2013 - 2025 Tencent Cloud.
All Rights Reserved. 腾讯云 版权所有