点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达本文转载自:AI缝合术
一、论文信息
1
论文题目:EfficientViM: Efficient Vision Mamba with Hidden State Mixer based State Space Duality
中文题目:高效ViM:基于隐藏状态混合器的状态空间对偶性的高效视觉Mamba
论文链接:https://arxiv.org/pdf/2411.15241
官方github:https://arxiv.org/pdf/2411.15241
所属机构:韩国大学——计算机科学与工程系
核心速览:本文介绍了一种名为EfficientViM的新型视觉架构,该架构基于隐藏状态混合器(Hidden State Mixer)和状态空间对偶性(State Space Duality, SSD),旨在资源受限环境下高效捕获全局依赖性,同时降低计算成本。
二、论文概要
Highlight
图1. ImageNet-1K 分类上高效网络的比较。我们的EfficientViM系列,标为红色和蓝色星号,展示了最佳的速度-准确率权衡。 ✝ 表示进行蒸馏训练的模型。
1. 研究背景:
2. 本文贡献:
三、创新方法
1
图 4.(左)EfficientViM的整体架构和(右)块设计。虚线表示用于多阶段隐藏状态融合(MSF)的跳过连接。EfficientViM块中HSM-SSD层的示意图在图2中呈现。
图2. (左) NC-SSD和(右) HSM-SSD层的示意图。在HSM-SSD层中,计算密集型的投影通过HSM中的减少隐藏状态处理,如突出显示所示。红色、蓝色和橙色分别表示需要复杂度为O(LD^2)、O(LND)和O(ND^2)的操作。
Hidden State Mixer (HSM) 是一种用于优化计算效率的机制,它通过减少计算复杂度来提高模型性能。其核心思想是利用一个共享的全局隐藏状态h来执行通道混合,包括门控和输出投影,直接在减少的潜在数组h上进行操作。具体实现原理如下:
1. 线性投影优化:HSM通过首先计算隐藏状态hin,然后将其作为线性投影到隐藏状态,从而将计算复杂度从O(LD^2)降低到O(ND^2),其中N是状态的数量,L是序列长度,D是通道数。这种优化依赖于状态数量N远小于通道数D(即N ≪ D)的情况。
2. 减少隐藏状态的使用:在HSM-SSD层中,通过减少隐藏状态来处理计算密集型的投影,如图2b所示。这使得计算成本主要依赖于状态的数量,而不是通道数。
3. 门控和输出投影的直接应用:HSM直接在隐藏状态上应用门控和投影,而不是先计算Ch再进行门控和投影。这减少了计算复杂度,并且当N较小时,捕捉全局上下文的总复杂度变得可以忽略。
4. 多阶段隐藏状态融合(MSF):为了进一步提升性能,HSM引入了MSF机制,该机制融合了来自网络多个阶段的隐藏状态来生成预测logits。通过这种方式,模型能够整合低层次和高层次的特征,增强模型在推理时的泛化能力。
5. 单头设计:HSM-SSD采用单头设计,通过设置∆ ∈ RL×N和ˆa ∈ RN来估计每个状态的token重要性。这种设计避免了多头配置中的内存绑定操作,提高了吞吐量,并且在保持竞争性能的同时,实现了更高的效率。
四、实验分析
1. 图像分类性能:EfficientViM在ImageNet-1K分类任务中表现出色,与先前的高效网络相比,在速度和准确性方面均有所超越。特别是EfficientViM-M2在保持与MobileViTV2 0.75和FastViT-T8相似性能的同时,实现了约4倍的加速。
2. 扩展性分析:EfficientViM在高分辨率图像上的性能同样优异,与SHViT相比,在5122分辨率下实现了超过15%的速度提升。此外,通过与教师模型RegNetY-160进行蒸馏训练,EfficientViM在速度-准确性权衡上进一步确立了优势。
3. 内存效率:尽管EfficientViM的参数数量相对较多,但其在设备上的内存使用量由推理时的内存I/O决定,而非仅由参数数量决定。EfficientViM在保持最佳吞吐量的同时,显示出与具有较低参数数量的模型相当的内存效率。
表3. 在ImageNet-1K 分类任务上高效网络的比较。结果按准确率排序。我们还指出了每个方法与EfficientViM相比的相对吞吐量Thrrel。
表4. EfficientViM与vision Mambas的比较。Thrrel是与EfficientViM-M4相比的相对吞吐量。
表6. 使用蒸馏目标训练后高效网络的比较。Thrrel是与EfficientViM在每个分割中的相对吞吐量比较。
表B. 在COCO-2017上的实例分割、物体检测及结果。
五、代码
1
温馨提示:对于所有推文中出现的代码,如果您在微信中复制的代码排版错乱,请复制该篇推文的链接,在任意浏览器中打开,再复制相应代码,即可成功在开发环境中运行!或者进入官方github仓库找到对应代码进行复制!
import torch
import torch.nn as nn
import math
# 论文题目:EfficientViM: Efficient Vision Mamba with Hidden State Mixer based State Space Duality
# 中文题目:高效ViM:基于隐藏状态混合器的状态空间对偶性的高效视觉Mamba
# 论文链接:https://arxiv.org/pdf/2411.15241
# 官方github:https://arxiv.org/pdf/2411.15241
# 所属机构:韩国大学——计算机科学与工程系
# 代码整理:微信公众号:AI缝合术
class LayerNorm2D(nn.Module):
"""LayerNorm for channels of 2D tensor(B C H W)"""
def __init__(self, num_channels, eps=1e-5, affine=True):
super(LayerNorm2D, self).__init__()
self.num_channels = num_channels
self.eps = eps
self.affine = affine
if self.affine:
self.weight = nn.Parameter(torch.ones(1, num_channels, 1, 1))
self.bias = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
def forward(self, x):
mean = x.mean(dim=1, keepdim=True) # (B, 1, H, W)
var = x.var(dim=1, keepdim=True, unbiased=False) # (B, 1, H, W)
x_normalized = (x - mean) / torch.sqrt(var + self.eps) # (B, C, H, W)
if self.affine:
x_normalized = x_normalized * self.weight + self.bias
return x_normalized
class LayerNorm1D(nn.Module):
"""LayerNorm for channels of 1D tensor(B C L)"""
def __init__(self, num_channels, eps=1e-5, affine=True):
super(LayerNorm1D, self).__init__()
self.num_channels = num_channels
self.eps = eps
self.affine = affine
if self.affine:
self.weight = nn.Parameter(torch.ones(1, num_channels, 1))
self.bias = nn.Parameter(torch.zeros(1, num_channels, 1))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
def forward(self, x):
mean = x.mean(dim=1, keepdim=True) # (B, 1, H, W)
var = x.var(dim=1, keepdim=True, unbiased=False) # (B, 1, H, W)
x_normalized = (x - mean) / torch.sqrt(var + self.eps) # (B, C, H, W)
if self.affine:
x_normalized = x_normalized * self.weight + self.bias
return x_normalized
class ConvLayer2D(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, norm=nn.BatchNorm2d, act_layer=nn.ReLU, bn_weight_init=1):
super(ConvLayer2D, self).__init__()
self.conv = nn.Conv2d(
in_dim,
out_dim,
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride),
padding=(padding, padding),
dilation=(dilation, dilation),
groups=groups,
bias=False
)
self.norm = norm(num_features=out_dim) if norm else None
self.act = act_layer() if act_layer else None
if self.norm:
torch.nn.init.constant_(self.norm.weight, bn_weight_init)
torch.nn.init.constant_(self.norm.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
if self.norm:
x = self.norm(x)
if self.act:
x = self.act(x)
return x
class ConvLayer1D(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, norm=nn.BatchNorm1d, act_layer=nn.ReLU, bn_weight_init=1):
super(ConvLayer1D, self).__init__()
self.conv = nn.Conv1d(
in_dim,
out_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=False
)
self.norm = norm(num_features=out_dim) if norm else None
self.act = act_layer() if act_layer else None
if self.norm:
torch.nn.init.constant_(self.norm.weight, bn_weight_init)
torch.nn.init.constant_(self.norm.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
if self.norm:
x = self.norm(x)
if self.act:
x = self.act(x)
return x
class FFN(nn.Module):
def __init__(self, in_dim, dim):
super().__init__()
self.fc1 = ConvLayer2D(in_dim, dim, 1)
self.fc2 = ConvLayer2D(dim, in_dim, 1, act_layer=None, bn_weight_init=0)
def forward(self, x):
x = self.fc2(self.fc1(x))
return x
class HSMSSD(nn.Module):
def __init__(self, d_model, ssd_expand=1, A_init_range=(1, 16), state_dim = 64):
super().__init__()
self.ssd_expand = ssd_expand
self.d_inner = int(self.ssd_expand * d_model)
self.state_dim = state_dim
self.BCdt_proj = ConvLayer1D(d_model, 3*state_dim, 1, norm=None, act_layer=None)
conv_dim = self.state_dim*3
self.dw = ConvLayer2D(conv_dim, conv_dim, 3,1,1, groups=conv_dim, norm=None, act_layer=None, bn_weight_init=0)
self.hz_proj = ConvLayer1D(d_model, 2*self.d_inner, 1, norm=None, act_layer=None)
self.out_proj = ConvLayer1D(self.d_inner, d_model, 1, norm=None, act_layer=None, bn_weight_init=0)
A = torch.empty(self.state_dim, dtype=torch.float32).uniform_(*A_init_range)
self.A = torch.nn.Parameter(A)
self.act = nn.SiLU()
self.D = nn.Parameter(torch.ones(1))
self.D._no_weight_decay = True
def forward(self, x):
batch, _, L= x.shape
H = int(math.sqrt(L))
BCdt = self.dw(self.BCdt_proj(x).view(batch,-1, H, H)).flatten(2)
B,C,dt = torch.split(BCdt, [self.state_dim, self.state_dim, self.state_dim], dim=1)
A = (dt + self.A.view(1,-1,1)).softmax(-1)
AB = (A * B)
h = x @ AB.transpose(-2,-1)
h, z = torch.split(self.hz_proj(h), [self.d_inner, self.d_inner], dim=1)
h = self.out_proj(h * self.act(z)+ h * self.D)
y = h @ C # B C N, B C L -> B C L
y = y.view(batch,-1,H,H).contiguous()# + x * self.D # B C H W
return y, h
class EfficientViMBlock(nn.Module):
def __init__(self, dim, mlp_ratio=4., ssd_expand=1, state_dim=64):
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.mixer = HSMSSD(d_model=dim, ssd_expand=ssd_expand,state_dim=state_dim)
self.norm = LayerNorm1D(dim)
self.dwconv1 = ConvLayer2D(dim, dim, 3, padding=1, groups=dim, bn_weight_init=0, act_layer = None)
self.dwconv2 = ConvLayer2D(dim, dim, 3, padding=1, groups=dim, bn_weight_init=0, act_layer = None)
self.ffn = FFN(in_dim=dim, dim=int(dim * mlp_ratio))
#LayerScale
self.alpha = nn.Parameter(1e-4 * torch.ones(4,dim), requires_grad=True)
def forward(self, x):
alpha = torch.sigmoid(self.alpha).view(4,-1,1,1)
# DWconv1
x = (1-alpha[0]) * x + alpha[0] * self.dwconv1(x)
# HSM-SSD
x_prev = x
x, h = self.mixer(self.norm(x.flatten(2)))
x = (1-alpha[1]) * x_prev + alpha[1] * x
# DWConv2
x = (1-alpha[2]) * x + alpha[2] * self.dwconv2(x)
# FFN
x = (1-alpha[3]) * x + alpha[3] * self.ffn(x)
# return x, h
return x
if __name__ == "__main__":
# 将模块移动到 GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 创建测试输入张量
x = torch.randn(1, 32, 256, 256).to(device)
# 初始化 evim 模块
evim = EfficientViMBlock(dim=32).to(device)
print(evim)
# 前向传播
print("\n微信公众号: AI缝合术!\n")
output = evim(x)
# 打印输入和输出张量的形状
print("输入张量形状:", x.shape)
print("输出张量形状:", output.shape)
运行结果