在这项工作中,作者介绍了Vision-RWKV(VRWKV),这是为了适应将RWKV架构用于视觉任务而设计的。这种适应性保留了RWKV的核心结构和优势,同时整合了关键的修改,使其适合处理视觉数据。 具体来说,作者引入了一种针对视觉任务的四向移位(Q-Shift)操作,并将原始的因果RWKV注意力机制修改为双向全局注意力机制。Q-Shift操作扩展了单个标记的语义范围,而双向注意力使得在RNN形式的前向和后向计算全局注意力时具有线性计算复杂度。作者主要对RWKV注意力机制中的指数进行修改,释放衰减向量的限制,并将绝对位置偏置转换为相对偏置。 这些变化增强了模型的性能,同时确保了可扩展性和稳定性。这样,Vision-RWKV继承了RWKV在处理全局信息和稀疏输入方面的效率,同时也能够建模视觉任务的局部概念。作者在需要的地方实施了层尺度和层归一化,以稳定模型在不同尺度下的输出。这些调整在模型扩大规模时显著提高了稳定性。
在本节中,作者提出了Vision-RWKV(VRWKV),这是一种具有线性复杂度注意力机制的高效视觉编码器。作者的原则是保留原始RWKV架构的优点,仅进行必要的修改,使其能够灵活地应用于视觉任务中,支持稀疏输入,并在规模扩大后确保训练过程的稳定性。VRWKV概述展示在图2中。
VRWKV采用了类似于ViT的块堆叠图像编码器设计,其中每个块由一个空间混合模块和一个通道混合模块组成。空间混合模块充当注意力机制,执行线性复杂度的全局注意力计算,而通道混合模块则作为一个前馈网络(FFN),在通道维度上进行特征融合。整个VRWKV包括一个图像块嵌入层和一个由
个相同的VRWKV编码器层堆叠而成,其中每个层保持输入分辨率不变。
数据流。 首先,作者将大小为
的图像转换为
个 Patch ,其中
表示 Patch 的大小。经过线性投影后的 Patch 加上位置嵌入,得到形状为
的图像标记,其中
表示标记的总数。这些标记随后被输入到具有
层的 VRWKV 编码器中。
在每一层中,首先将标记(tokens)输入到空间混合模块,该模块起着全局注意力机制的作用。具体来说,如图2(b)所示,输入的标记首先进行移位,并输入到三个并行的线性层中,以获得矩阵
:
在这里,
和
会被传递到一个线性复杂度的双向注意力机制中,以计算全局注意力结果
,并与
相乘,后者控制输出
的概率:
算子
表示sigmoid函数,而
表示逐元素的乘法操作。Q-Shift是一个专为适应视觉任务设计的标记移位函数。在输出线性投影之后,特征通过层归一化来稳定。
随后,这些标记被传递到通道混合模块中进行通道融合。
和
的获取方式与空间混合类似:
在这里,
是经过激活函数后的
的线性投影,而输出
在输出投影之前也受到门机制
的控制:
同时,建立从标记到每个规范化层的残差连接[21],以确保在深层网络中训练梯度不会消失。
与普通的RWKV不同,作者对原有的注意力机制进行了以下修改,以适应视觉任务:
(当前标记)扩展到
(最后一个标记),在求和公式中确保所有标记在计算每个结果时相互可见。因此,原始的因果注意力转变为双向全局注意力。
的绝对值,并将其除以总标记数(表示为
),以表示不同尺寸图像中标记的相对偏置。
在指数项中为正,使得指数衰减注意力可以关注不同通道中离当前标记较远的标记。
这种简单而必要的修改实现了全局注意力的计算,并最大程度地保留了RWKV的低复杂性和对视觉任务的适应性。
类似于RWKV中的注意力机制,双向注意力也可以等价地用求和形式(为了清晰)以及RNN形式(在实际实现中)表达。
求和形式。第
个标记的注意力计算结果由以下公式给出:
在这里,
表示 Token 的总数,等于
,
和
是两个可学习的
维向量,分别表示通道方向的空间衰减和表示当前 Token 的增益。
和
分别表示
和
的第
个特征。
该求和公式表明输出
是沿 Token 维度从
到
对
的加权求和,产生一个
维向量。它表示对第
个 Token 应用注意力操作得到的结果。权重由空间衰减向量
, Token 之间的相对偏置
,以及
共同确定。
RNN形式。 在实际实现中,上述方程(5)可以转化为RNN的递归公式形式,通过固定的FLOPs数量可以得到每个标记的结果。通过将方程(5)中的分子和分母求和项以
为界进行拆分,作者可以得到4个隐藏状态:
可以递归计算的公式如下:隐藏状态的更新仅需要增加或减少一个求和项,并乘以或除以
。那么第
个结果可以表示为:
其中,
的每一个结果可以通过以下方式得到:
每个更新步骤都会为一个标记产生一个注意力结果(即
),因此整个
矩阵需要
个步骤。
当输入
和
是形状为
的矩阵时,计算
矩阵的计算成本由以下公式给出:
在这里,数字13大致来自于对
的更新,指数运算的计算,以及
的计算。
是更新步骤的总数,等于图像标记的数量。上述近似表明前向过程的复杂性为
。算子的反向传播仍然可以表示为更复杂的RNN形式,其计算复杂性为
。反向传播的具体公式在附录中提供。
通过引入指数衰减机制,可以将全局注意力的复杂性从二次降低到一次,从而大大提高模型在高分辨率图像上的计算效率。然而,一维衰减并不符合二维图像中的相邻关系。因此,在每次空间混合和通道混合模块的第一步中,作者引入了四向 Token 移动(Q-Shift)。Q-Shift操作允许所有 Token 与其相邻 Token 进行移动和线性插值,如下所示:
下标
表示通过对可学习向量
的控制,对
和
进行3种插值,分别用于后续的
计算。
和
分别表示标记
的行索引和列索引,":"是一种不包括结束索引的切片操作。Q-Shift使不同通道的注意力机制在内部优先关注邻近标记,而不会引入许多额外的FLOPs。Q-Shift操作还增大了每个标记的感受野,这极大地提升了标记在后层中的覆盖范围。
模型层数的增加以及递归过程中指数项的累积都可能导致模型输出不稳定,影响训练过程的稳定性。为了减轻这种不稳定性,作者采用了两种简单但有效的修改方法来稳定模型规模的扩展:
),使得最大衰减和增长是有界的。
在遵循ViT之后,表1中指定了VRWKV变体的超参数,包括嵌入维度、线性投影中的隐藏维度以及深度。由于VRWKV-L模型的深度增加,作者在适当的位置加入了如第3.4节所讨论的额外的层归一化,以确保输出稳定性。
作者全面评估了VRWKV方法在性能、可扩展性、灵活性和效率方面替代ViT的可能性。作者在广泛使用的图像分类数据集ImageNet上验证了模型的有效性。对于下游的密集预测任务,作者选择了在COCO数据集上的检测任务以及ADE20K数据集上的语义分割任务。
设置。 对于-Tiny/Small/Base模型,作者从零开始在ImageNet-1K 上进行有监督训练。遵循DeiT 的训练策略和数据增强方法,作者使用批量大小为1024,使用AdamW 优化器,基础学习率为5e-4,权重衰减为0.05,并采用余弦退火调度。图像被裁剪为
分辨率用于训练和验证。对于-Large模型,作者首先在ImageNet-22K上以批量大小4096和分辨率
预训练30个周期,然后在高分辨率
的ImageNet-1K上微调20个周期。
结果。 作者在ImageNet-1K数据集上比较了VRWKV与其他分层和非分层 Backbone 网络的结果。如表2所示,在相同的参数数量、计算复杂度以及训练/测试分辨率下,VRWKV取得了与ViT相当的结果。
例如,与ViT-L相比,VRWKV-L在
的分辨率下实现了相似的前1准确率85.3%,计算成本略有降低。当模型尺寸较小时,VRWKV展现了更高的 Baseline 性能。在VRWKV-T和DeiT-T的FLOPs均为1.3G的情况下,VRWKV-T比DeiT-T高出2.9个百分点。在VRWKV中对线性注意力机制的探索和利用证明了其在视觉任务中的潜力,使其成为使用全局注意力机制的传统ViT模型的一个可行替代品。从微小到大尺寸模型的表现也表明,VRWKV模型具有与ViT相似的伸缩性。
设置。 在检测任务中,作者采用Mask R-CNN作为检测Head。对于-Tiny/Small/Base模型,主干网络使用了在ImageNet-1K上预训练300个周期的权重。对于-Large模型,则使用了在ImageNet-22K上预训练的权重。所有模型都采用
训练计划(即12个周期),批量大小为16,使用AdamW优化器,初始学习率为1e-4,权重衰减为0.05。
结果。在表3中,作者报告了使用VRWKV和ViT作为 Backbone 网络在COCO val数据集上的检测结果。正如图1(a)和表3所示的结果,由于在密集预测任务中使用了窗口注意力,具有全局注意力的VRWKV可以比ViT以更低的FLOPs实现更好的性能。
例如,与ViT-T
相比,VRWKV-T的 Backbone FLOPs大约降低了30%,AP
提高了0.6个百分点。同样,VRWKV-L在FLOPs更低的情况下,相比ViT-L
,AP
增加了1.9个百分点。
此外,作者还比较了使用全局注意力的VRWKV和ViT的性能。例如,VRWKV-S在FLOPs降低了45%的情况下,与ViT-S实现了相似的性能。这证明了VRWKV的全局注意力机制在密集预测任务中的有效性,以及与原始注意力机制相比在计算复杂度上的优势。
设置。 在语义分割任务中,作者使用UperNet 作为分割头。具体来说,所有ViT模型在分割任务中使用全局注意力。对于 -Tiny/Small/Base 模型, Backbone 网络使用在ImageNet-1K上预训练的权重。而对于 -Large 模型,使用在ImageNet-22K上预训练的权重。作者采用AdamW优化器,对于 -Small/Base/Large 模型的初始学习率为6e-5,对于 -Tiny 模型为12e-5,批量大小为16,权重衰减为0.01。所有模型都在ADE20K数据集的训练集上训练160k次迭代。
结果。 如表4所示,在用于语义分割时,基于VRWKV的模型一致优于基于全局注意力机制的ViT模型,并且效率更高。例如,VRWKV-S比ViT-S的准确度高1个百分点,同时浮点运算量减少了14%。VRWKV-L取得了与ViT-L相当的53.5 mIoU结果,而其 Backbone 网的计算量则少了25G FLOPs。
这些结果表明,VRWKV Backbone 网与ViT Backbone 网相比,能为语义分割提取更好的特征,并且在效率上也有所提高,这得益于线性复杂度注意力机制。
设置。 作者在ImageNet-1K上对微小尺寸的VRWKV进行消融研究,以验证Q-Shift和双向注意力等不同关键组成部分的有效性。实验设置与第4.1节保持一致。
标记移位。 作者比较了不使用标记移位、使用RWKV中的原始移位方法以及Q-Shift的性能。如表5所示,移位方法的变体显示出性能上的差异。不使用标记移位的变体1性能较差,为71.5,比VRWKV模型低3.6分。即便使用了全局注意力,采用原始标记移位的模型与VRWKV 模型之间仍有0.7分的差距。
双向注意力。 双向注意力机制使模型能够在原始RWKV注意力内部具有因果 Mask 的同时实现全局注意力。第3种变体的结果表明,全局注意力机制使top-1准确率提高了2.3个百分点。
有效感受野(ERF)。作者根据[11]的分析,研究了不同设计对模型ERF的影响,并在图3(a)中进行了可视化。作者可视化了输入尺寸为1024×1024的中心像素的ERF。在图3(a)中,“No Shift”表示没有采用标记移位方法(Q-Shift),“RWKV Attn”表示在没有修改情况下,使用原始RWKV注意力机制进行视觉任务。
从图中的比较来看,除了“RWKV Attn”模型外,所有模型都实现了全局注意力,而VRWKV-T模型的全局容量优于ViT-T模型。尽管有Q-Shift的辅助,由于输入分辨率的较大,“RWKV Attn”中的中心像素仍然无法关注到底部图像的像素。 “No Shift”和Q-Shift的结果显示,Q-Shift方法扩展了感受野的核心范围,增强了全局注意力的归纳偏好。
效率分析。 作者逐步将输入分辨率从
提升到
,并比较了VRWKV-T与ViT-T的推理和内存效率。这些结果是在Nvidia A100 GPU上测试的,如图1所示。从图1(b)中呈现的曲线可以看出,在较低分辨率下,例如大约200个图像 Token 的
时,VRWKV-T与ViT-T的内存使用相当,尽管与ViT-T相比,VRWKV-T的FPS略低。然而,随着分辨率的增加,得益于其线性注意力机制,VRWKV-T的FPS迅速超过了ViT-T。
此外,VRWKV-T的类RNN计算框架确保了内存使用的缓慢增长。当分辨率达到
(相当于16384个 Token )时,VRWKV-T的推理速度是ViT-T的10倍,并且与ViT-T相比,其内存消耗减少了80%。
作者还比较了双向加权键值(Bi-WKV)和闪存注意力的速度,结果如图3(b)所示。闪存注意力在低分辨率下效率很高,但由于其二次复杂度,随着分辨率的增加,其速度会迅速下降。在高分辨率场景中,线性算子Bi-WKV展现了显著的速度优势。例如,当输入为
(即16384个标记)以及根据ViT-B和VRWKV-B设置的通道数和头数时,Bi-WKV算子在推理运行时比闪存注意力快
,在前向和反向传递结合时快
。
MAE预训练。 与ViT类似,VRWKV模型能够处理稀疏输入,并从MAE预训练中受益。仅仅通过修改Q-Shift以执行双向移位操作,VRWKV就可以使用MAE进行预训练。预训练的权重可以通过Q-Shift方法直接用于其他任务的微调。遵循与ViT相同的MAE预训练设置,并类似于第4.1节中的后续分类训练,VRWKV-L在ImageNet-1K验证集上的top-1准确度从85.3%提升到了85.5%,这显示出其能够从 Mask 图像建模中获取视觉先验。
作者提出了Vision-RWKV(VRWKV),一个具有线性计算复杂度注意力机制的视觉编码器。作者展示了其在包括分类、密集预测和 Mask 图像建模预训练在内的综合视觉任务中,作为ViT的替代 Backbone 网的能力。与性能和可扩展性相当的情况下,VRWKV展现出更低计算复杂度和内存消耗。
得益于其低复杂性,VRWKV能够在那些ViT难以承受全局注意力高计算开销的任务中实现更好的性能。作者希望VRWKV能够成为ViT的高效且低成本的替代方案,展示了线性复杂度 Transformer 在视觉领域的强大潜力。
[1].Vision-RWKV: Efficient and Scalable Visual Perception with RWKV-Like Architectures.