近年来,Transformer在计算机视觉任务中取得了显著进展。然而,它们的全局建模往往伴随着相当大的计算开销,与人类眼睛高效的信息处理形成鲜明对比。受人类眼睛稀疏扫描机制的启发,作者提出了一个稀疏扫描自注意力机制(SA)。 该机制为每个标记预定义了一系列感兴趣 Anchor 点,并使用局部注意力来高效建模这些 Anchor 点周围的空间信息,避免了多余的全局建模和对局部信息的过度关注。这种方法模仿了人类眼睛的功能,并显著降低了视觉模型的计算负担。基于SA,作者引入了稀疏扫描视觉Transformer(SSViT)。 广泛的实验证明了SSViT在多种任务上的卓越性能。 特别是在ImageNet分类任务中,在没有额外监督或训练数据的情况下,SSViT取得了84.4%/85.7%的top-1准确率,且计算量为4.4G/18.2G FLOPs。 SSViT在下游任务,如目标检测、实例分割和语义分割方面也表现出色。 其鲁棒性进一步在多种数据集上得到了验证。 代码将可在https://github.com/qhfan/SSViT获取。
自从Vision Transformer(ViT)[12]问世以来,它由于其在建模长距离依赖方面的卓越能力而吸引了研究界的广泛关注。然而,作为ViT核心的自注意力机制[61]带来了巨大的计算开销,从而限制了其更广泛的应用。已经提出了几种策略来减轻自注意力的这一局限性。
例如,像Swin-Transformer [40; 11]这样的方法通过为注意力机制分组 Token 来减少计算成本,并使模型更专注于局部信息。像PVT [63; 64; 18; 16; 29]这样的技术通过下采样 Token 来缩小QK矩阵的大小,从而降低计算需求同时保留全局信息。同时,像UniFormer [35; 47]这样的方法在视觉建模的早期阶段放弃了注意力操作,转而采用轻量级的卷积。此外,一些模型[50]通过剪枝冗余 Token 来提高计算效率。
尽管取得了这些进展,但大多数方法主要集中在通过减少自注意力操作中的 Token 数量来提高ViT的效率,常常忽略了人眼处理视觉信息的方式。
与ViT模型相比,人视觉系统的操作明显不那么复杂,但效率却非常高。与在Swin [40],NAT [20],LVT [69]中进行的细粒度局部空间信息建模,或像PVT [63],PVTv2 [64],CMT [18]中看到的不明确的全局信息建模不同,人眼采用了一种稀疏扫描机制,这一点得到了众多生物研究的证实。如图1所示,作者的眼睛在感兴趣的点之间迅速移动,仅在 Anchor 点[43; 14; 48]处深入处理详细信息。这种选择性的注意力机制使大脑能够有效地处理关键视觉信息,而不仅仅是关注局部细节或模糊的全局信息。鉴于人视网膜的黄斑区具有固定的大小,眼睛焦点的每次移动都会感知到一个固定大小的感受野[31; 45]。
如图1所示,作者受人类眼睛的稀疏扫描机制的启发,引入了一种新颖的自注意力机制,称为稀疏扫描自注意力(Sparse Scan Self-Attention,)。对于每个目标 Token ,作者设计了一组均匀分布的集合内 Anchor 点(AoI)。作者对这些AoIs应用局部注意力,处理周围的视觉信息,并利用这些局部数据来更新AoI Token 。每个局部窗口的大小保持不变,反映了人视觉中固定的黄斑区大小。然后,作者将所有AoIs的信息汇总起来更新目标 Token 。建模方法调和了细粒度局部建模和对兴趣 Anchor 点的稀疏建模,与人类眼睛的运作方式非常相似。的方法超越了以往的自注意力机制,提供了一个更像人、更高效、更有效的模型。
在S的基础上,作者开发了稀疏扫描视觉 Transformer (SSViT)。SSViT有效地模拟了人眼的视觉信息处理,并在一系列视觉任务中显示出显著的有效性。如图2所示,SSViT在图像分类准确度上超过了先前的最先进模型,仅用1500万参数和2.4G FLOPs就实现了83.0%的Top-1准确度,无需额外的训练数据或监督。即使模型放大,这种性能优势仍然保持,作者的SSViT-L仅用1亿参数就达到了85.7%的Top-1准确度。除了分类任务之外,SSViT在下游任务(如目标检测、实例分割和语义分割)中也表现出色。SSViT的鲁棒性还通过其在各种数据集上的卓越性能进一步得到证实。
视觉 Transformer 。自从Vision Transformer(ViT)[12]问世以来,由于它的卓越性能而引起了广泛关注。众多研究[40; 11; 16; 15]探索了优化ViT的方法,通过改进其核心运算符——自注意力,以减少其二次计算复杂度并提高性能。一系列方法[11; 40; 76]被提出以减轻自注意力的计算负担。这些技术通过分组 Token 限制每个 Token 可以关注的区域。《Swin-Transformer》[40]例如,将所有 Token 划分为单独的窗口,并在这些窗口内执行自注意力操作。《BiFormer》[76]则动态确定每个 Token 可以关注的窗口。此外,一些方法[63; 64; 18; 29; 22; 47; 4]通过 Token 下采样减少参与自注意力操作 Token 的数量。《PVT》[63; 64]使用平均池化进行直接下采样,从而减少 Token 数量。《CMT》[18]和《PVTv2》[64]将 Token 下采样与卷积结合,以提高模型学习局部特征的能力。《STViT》[29]通过采样超 Token 有效捕获全局依赖性,对它们应用自注意力,然后将它们映射回原始 Token 空间。某些方法[39; 47; 35]选择在模型的早期层放弃计算密集型的自注意力,改为使用更有效的卷积来学习局部特征。然后在模型的更深层部署自注意力来学习全局特征。尽管这些先前的方法显示出有希望的结果并减少了计算复杂性,但值得注意的是,它们的工作机制与人类眼球的运作有着显著差异。
稀疏扫描在人类视觉中。稀疏扫描是人类视觉中的一个关键机制,它能在感官限制的情况下有效处理视觉刺激。神经影像学调查已经确定了控制稀疏扫描行为的关键神经结构,如上丘脑[58]。此外,微扫视,被认为是稀疏扫描的一种变体,对于保持视觉稳定性和引导注意力至关重要[48]。稀疏扫描的损伤与注意缺陷多动障碍(ADHD)和精神分裂症等疾病中观察到的认知缺陷有关[44]。稀疏扫描的动态性,与任务需求和注意力偏好之间的关系,也成为了研究的焦点。证据表明,稀疏扫描表现出对特定任务需求的适应性,灵活地分配视觉资源以增强处理效率[14]。这种适应性特征强调了视觉处理中自下而上的感官输入与自上而下的认知影响之间的复杂相互作用。本质上,稀疏扫描是一个基本机制,它不仅影响基本的感知功能,也影响更复杂的认知过程[55; 54; 43]。
图3:SSViT的说明。SSViT由多个SSViT块组成。一个单独的SSViT块由CPE、和FFN组成。
图3展示了稀疏扫描视觉 Transformer (SSViT)的整体架构。为了处理输入图像,作者将其送入由卷积组成的块嵌入中,得到形状为的标记。遵循之前的分层设计[40; 15; 16],作者将SSViT划分为四个阶段。由分层结构带来的多分辨率表示可以用于下游任务,如目标检测和语义分割。
SSViT块由三个关键组成部分组成:条件位置编码(CPE)[6],稀疏扫描自注意力()和前馈网络(FFN)[61, 12]。完整的SSViT块可以定义为方程1:
对于每个块,输入张量被送入CPE,为每个标记引入位置信息。经过CPE后,被用于扫描每个标记的稀疏兴趣区域。最后的FFN用于整合标记的通道信息。
稀疏扫描自注意力()受到人眼在处理视觉信息时的稀疏扫描机制的启发。它可以分解为三个子过程。首先,为每个标记选择感兴趣 Anchor 点(AoI)。其次,提取由AoIs确定的参考窗口(RWins)内的背景局部信息。最后,AoIs之间的交互。整个过程如图4所示。
感兴趣 Anchor 点。理想情况下,每个标记都应该根据其特性选择自己的合适AoIs。然而,这种方法将赋予模型过度的自由度,导致实现过程繁琐且效率低下。因此,作者放弃了这一做法,转而人工为每个标记定义AoIs。具体来说,假设AoIs的步长为,为每个标记选择的AoIs数量为。对于位于位置的标记,其选择的AoIs可以表示为等式2:
其中是输入特征图,是的AoIs集合。为了简化,在等式2中作者没有考虑边界点的情况。在实际操作中,处理边界点时,和的范围有一定的偏移,以确保所有AoIs保持在特征图边界内。
背景局部信息。如图1所示,当人眼观察某些 Anchor 点时,它也会处理周围的背景信息。基于这一考虑,对于前一步中的每个AoI中的 Anchor 点,作者选择其对应的参考窗口(Rwin),即 Anchor 点的背景。与前面的定义一致,作者假设单个Rwin的大小为。由于人眼视网膜上的中央凹大小保持不变,人眼感知的接受域对于每个AoI是相同的。转化为模型设计,这意味着每个AoI具有相同大小的Rwin。对于位于位置的AoI,其RWin定义为等式3:
其中 是 的 RWin 中的标记集合。与 AoI 的定义类似,在方程式 3 中,作者省略了边界点的情况。对于每个 AoI,在其确定的 RWin 内部,作者使用自注意力来更新 AoI。对于位于 位置的 AoI,这个过程可以表示为方程式 4:
其中 表示标准的自注意力操作。 是可学习的矩阵。 是更新后的 。 是针对 的更新后的 AoI 集合。
Anchor 点之间的交互。在实践中,人眼并不是独立处理每个 Anchor 点。相反,它对在每个 Anchor 点观察到的信息进行交互式建模,从而从图像中推理出深层语义信息。作者也通过自注意力来建模这个过程。对于目标标记 ,在其所有 AoI 都被更新之后,作者使用 和 来更新 ,如方程式 5 所示:
以上三步构成了完整的 。在完成 之后,为了进一步增强模型捕捉局部信息的能力,作者采用局部上下文增强模块来建模局部信息:
其中 是一个简单的深度卷积。值得注意的是,尽管 执行了两轮自注意力计算,实际上,、 和 的投影是在单个操作中完成的。在两轮自注意力计算(方程4和方程5)过程中, 和 被复用,因此没有引入额外的计算或参数开销。
作者对包括ImageNet-1k [9]上的图像分类、COCO [38]上的目标检测和实例分割以及ADE20K [74]上的语义分割在内的广泛视觉任务进行了实验。作者还评估了SSViT在ImageNet-v2 [52]、ImageNet-A [26]和ImageNet-R [25]上的鲁棒性。所有模型都可以使用8个A100 80G GPU进行训练。
作者从头开始在ImageNet-1k [9]上训练作者的模型。为了公平比较,作者采用了与[40]中相同的训练策略,仅以分类损失作为唯一的监督。对于增加随机深度[28]的最大速率,分别设置为0.1、0.15、0.4和0.5,对应于SSViT-T、SSViT-S、SSViT-B和SSViT-L。
与SOTA的比较。作者在Tab.1中展示了将SSViT与众多最先进模型进行比较的结果。SSViT在所有规模上均一致地超越了之前的模型。特别地,SSViT-T仅用15M参数和2.4G FLOPs就达到了83.0%的Top1准确率,超过了之前的最先进水平(SMT[39])0.8%。对于更大的模型,SSViT-L使用100M参数和18.2G FLOPs达到了85.7%的Top1准确率。
与通用/高效模型的严格比较。为了保证公平比较,作者选择了两个 Baseline :通用目的的 Backbone 网络Swin-Transformer [40]和以效率为导向的 Backbone 网络FasterViT [21]。
作者将它们与SSViT进行比较。在比较模型(SS-Swin和SS-FasterViT)中,作者仅将原始Swin-Transformer和FasterViT中的注意力机制替换为,没有引入其他任何修改(例如CPE、Conv Stem等)。
如表2所示,仅将注意力机制简单地替换为,在性能和效率上都带来了显著的优势。具体来说,SS-Swin在所有模型大小上实现了甚至超过Swin 2.0%的改进。同时,SS-FasterViT在参数更少的情况下获得了比FasterViT更高的准确率。
作者使用MMDetection [5]来实现Mask-RCNN [24],Cascade Mask R-CNN [2]和RetinaNet [37],以评估作者的SSViT。对于Mask R-CNN和Cascade Mask R-CNN,作者遵循通常使用的“”设置,对于Mask R-CNN和RetinaNet,作者应用“”设置。按照[40]的做法,在训练期间,作者将图像调整大小,使得较短的边为800像素,同时使较长的边保持在1333像素以内。作者使用AdamW优化器进行模型优化。
表3和表4展示了SSViT在不同检测框架下的性能表现。结果显示,在所有比较中,SSViT始终优于其对应的方法。在“MS”计划下,SSViT超过了最近的SMT,使用Mask R-CNN框架实现了**+2.2的边界框AP和+2.0**的面具AP改进。对于Cascade Mask R-CNN,SSViT仍然保持了相对于SMT的显著性能优势。关于“”计划,SSViT展示了卓越的性能。
特别是,在Mask-RCNN框架下,SSViT-S相较于InternImage-T实现了**+2.2的边界框AP和+1.5**的面具AP提升。
设置。作者使用Semantic FPN [33] 和 UperNet [67] 来评估SSViT的性能,通过MMSegmentation [7]实现这些框架。对于Semantic FPN,作者模仿了PVT [63]的训练设置,训练模型80k次迭代。所有模型都使用的输入分辨率,在测试期间,将图像的短边调整到512像素。UperNet按照Swin [40]的设置训练160K次迭代。作者采用AdamW优化器,权重衰减为0.01,包括1500次迭代的预热。
结果。语义分割的结果详细列于表5中。所有FLOPs都是使用的输入分辨率进行评估的,SSViT-T组除外,该组采用\begintable} \begin{tabular}{cc c c c\hline \hline \multirow2}{*}{ Backbone 网络} & \multicolumn{2}{c{参数 FLOPs} & \multicolumn{4}{c}{Mask R-CNN +MS} \ & (M) & (G) & & & & & & & \ \hline Focal-T [70] & 49 & 291 & 47.2 & 69.4 & 51.9 & 42.7 & 66.5 & 45.9 \ NAT-T [20] & 48 & 28 & 47.8 & 47.9 & 69.0 & 52.2 & 46.6 & 46.0 \ GC-ViT-T [22] & 48 & 291 & 47.9 & 70.1 & 52.8 & 43.2 & 67.0 & 分辨率。在所有设置中,SSViT都提供了卓越的性能。特别地,在Semantic FPN框架内,作者的SSViT-S比FAT-B3显著高出+0.7** mIoU。SSViT-L进一步超过了CSWin-L,高达**+1.6** mIoU。在UperNet框架内,SSViT-S超过了最近的SMT-S,达到**+0.9** mIoU。SSViT-B和SSViT-L也分别超过了它们的对应版本。
表6展示了鲁棒性评估结果。在ImageNet-V2(IN-V2)上,SSViT优于所有竞争对手。例如,SSViT-B在参数和FLOPs相似的情况下,超过了BiFormer-B,提高了1.7。SSViT的优势在ImageNet-A(IN-A)和ImageNet-R(IN-R)上进一步扩大。具体来说,仅基于ImageNet-1k预训练的SSViT-L在IN-A上达到55.0的准确率,在IN-R上达到59.2,明显超过了FAN-Hybrid-L(IN-A: +13.2,IN-R: +6.0)。这强调了SSViT的鲁棒性。
SSViT与DeiT&Swin的比较。如表7所示,作者将SSViT与DeiT[59]和Swin[40]进行比较。通过将DeiT/Swin中的Self-Attention/Window Self-Attention模块唯一替换为,作者构建了SS-DeiT/SS-Swin-T。值得注意的是,尽管SS-DeiT-S所需的FLOPs更少,但其性能却超过了DeiT-S,性能显著提升了**+1.5**,从而突显了的强大作用。与Swin-T相比,SS-Swin-T在广泛的下游任务中展现了实质性的增强。
SA与WSA的比较。Window Self-Attention(WSA),是Swin Transformer[40]中的一个重要组成部分,被广泛采用。在表7中,作者将与WSA进行了对比,显示出在分类和下游任务中都具有显著的优势。具体来说,以SSViT-T为基准,在分类准确率上比WSA提高了**+1.7,在框AP上提高了+3.3**。
SA与CSWSA的比较。CSWSA是CSwin-Transformer[11]中提出的一种对WSA的改进,其性能优于其前身。然而,作者的仍然在关键性能指标上超过了CSWSA。特别是,在分类准确率上比CSWSA高出**+1.4个百分点。在Semantic FPN分割框架内,比CSWSA高出+3.1** mIoU。
Lce。LCE是一种简单的深度卷积组件,用于增强模型捕捉局部特征的能力。作者对LCE进行了消融研究,揭示了其对模型性能提升的贡献。如表7所示的结果表明,LCE使模型的分类准确率提高了**+0.2,框AP提高了+0.3, Mask AP提高了+0.2**。
Cpe。CPE[6]是一种多功能、即插即用的位置编码策略,常用于向模型传授位置信息。CPE仅由一个残差块中的3x3深度卷积组成,如表7所示,CPE提供了适度的性能改进,大约使分类准确率提高了**+0.1**。
Conv Stem。Conv Stem在模型的初始阶段使用,有助于提取精细的局部特征。表7表明,Conv Stem在一定程度上加强了模型在分类和下游任务中的性能,特别是将分类准确率提高了**+0.2,平均交并比(mIoU)提高了+0.4**。
受人类眼睛对视觉信息处理的高效稀疏扫描机制的启发,作者提出了稀疏扫描自注意力机制()。这种机制模拟人类眼睛的程序化操作:首先选择感兴趣的信标,然后在这些信标周围提取局部信息,最后聚合这些信息。利用的力量,作者开发了稀疏扫描视觉 Transformer (SSViT),这是一种为各种视觉任务设计的健壮视觉主干网络。作者在一系列常见的视觉任务上评估SSViT,如图像分类、目标检测、实例分割和语义分割,它在这些任务上均展示了卓越的性能。值得注意的是,SSViT对分布外(OOD)数据也显示出显著的鲁棒性。
作者提供了作者稀疏扫描自注意力机制的代码。
1 import torch.nn as nn
3 import torch
4 from einops import rearrange
5 from atten.functional import natten2dqkrpb, natten2dav
6
7 class S3A(nn.Module):
8
9 def __init__(self, embed_dim, num_heads, window_size, anchor_size, stride):
10 super().__init__()
11 self.embed_dim = embed_dim
12 self.num_heads = num_heads
13 self.window_size = window_size
14 self.anchor_size = anchor_size
15 self.stride = stride
16 self.head_dim = embed_dim // num_heads
17 self.scaling = self.head_dim ** -0.5
18 self.qkv = nn.Conv2d(embed_dim, embed_dim * 3, 1, bias=True)
19 self.out_proj = nn.Conv2d(embed_dim, embed_dim, 1, bias=True)
20
21 def forward(self, x: torch.Tensor):
22 '''
23 z:(bchw)
24 '''
25 bsz, _, h, w = x.size()
26 qkv = self.qkv(x) # (b3*chw)
27
28 q, k, v = rearrange(qkv, 'b(mnd)hw->mbnhwd', m=3, n=self.num_heads)
29
30 k = k * self.scaling
31
32 window_size = self.window_size
33 anchor_size = self.anchor_size
34
35 attn = natten2dqkrp(q, k, None, window_size, 1)
36 attn = attn.softmax(dim=-1)
37 v = natten2dav(attn, v, window_size, 1)
38 stride = self.stride
39
40 attn = natten2dqkrp(q, k, None, anchor_size, stride)
41 attn = attn.softmax(dim=-1)
42 v = natten2dav(attn, v, anchor_size, stride)
43
44 res = rearrange(v, 'bnhwd->b(nd)hw')
45 return self.out_proj(res)
[1].Vision Transformer with Sparse Scan Prior.