前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >【CVPR 2025】高效视觉Mamba模块EfficientViM,即插即用!

【CVPR 2025】高效视觉Mamba模块EfficientViM,即插即用!

作者头像
小白学视觉
发布2025-03-06 23:37:32
发布2025-03-06 23:37:32
8400
代码可运行
举报
运行总次数:0
代码可运行

点击上方“小白学视觉”,选择加"星标"或“置顶

代码语言:javascript
代码运行次数:0
复制
重磅干货,第一时间送达本文转载自:AI缝合术

一、论文信息

1

论文题目:EfficientViM: Efficient Vision Mamba with Hidden State Mixer based State Space Duality

中文题目:高效ViM:基于隐藏状态混合器的状态空间对偶性的高效视觉Mamba

论文链接:https://arxiv.org/pdf/2411.15241

官方github:https://arxiv.org/pdf/2411.15241

所属机构:韩国大学——计算机科学与工程系

核心速览:本文介绍了一种名为EfficientViM的新型视觉架构,该架构基于隐藏状态混合器(Hidden State Mixer)和状态空间对偶性(State Space Duality, SSD),旨在资源受限环境下高效捕获全局依赖性,同时降低计算成本。

二、论文概要

Highlight

图1. ImageNet-1K 分类上高效网络的比较。我们的EfficientViM系列,标为红色和蓝色星号,展示了最佳的速度-准确率权衡。 ✝ 表示进行蒸馏训练的模型。

1. 研究背景:

  • 研究问题:在资源受限的环境中部署神经网络,如移动和边缘设备,需要构建轻量级的视觉架构。这些架构需要在捕获局部和全局依赖性的同时,保持高效的计算性能。
  • 研究难点:尽管之前的研究通过构建轻量级卷积神经网络(CNN)和使用深度可分离卷积(DWConv)等技术,成功地减少了模型的计算复杂度,但Vision Transformers(ViTs)的自注意力机制的高计算复杂度仍然是设计高效架构的主要瓶颈。此外,尽管状态空间模型(SSM)提供了一种具有线性计算复杂度的替代方案,但其在视觉任务中的应用仍然面临速度较慢的问题。
  • 文献综述:先前的研究尝试通过近似自注意力或限制token数量来降低计算成本,或者开发结合CNN的混合ViTs。然而,这些方法仍然受到自注意力的二次复杂度和SSM的因果约束的限制。最近,一些工作如VSSD和LinFusion进一步改进了SSM,引入了非因果状态空间对偶性(NC-SSD),但这些视觉Mambas的处理速度仍然较慢。本文提出的EfficientViM旨在解决这些挑战,提供一种新的轻量级视觉骨干网络,以实现更快的速度和更高的准确性。

2. 本文贡献:

  • HSM-SSD层设计:EfficientViM的核心是基于隐藏状态混合器的SSD层(HSM-SSD),该层通过在隐藏状态空间内进行通道混合操作,有效降低了计算成本。HSM-SSD层将标准SSD层的线性投影和门控函数从图像特征空间转移到隐藏状态空间,这些隐藏状态被视为样本的压缩潜在表示。
  • 总体结论:EfficientViM提出了一种新颖的基于Mamba的轻量级视觉架构,通过HSM-SSD层有效捕获全局依赖关系,同时显著降低了计算成本。该架构在保持模型泛化能力的同时,通过多阶段隐藏状态融合进一步增强了模型的表示能力。

三、创新方法

1

图 4.(左)EfficientViM的整体架构和(右)块设计。虚线表示用于多阶段隐藏状态融合(MSF)的跳过连接。EfficientViM块中HSM-SSD层的示意图在图2中呈现。

图2. (左) NC-SSD和(右) HSM-SSD层的示意图。在HSM-SSD层中,计算密集型的投影通过HSM中的减少隐藏状态处理,如突出显示所示。红色、蓝色和橙色分别表示需要复杂度为O(LD^2)、O(LND)和O(ND^2)的操作。

Hidden State Mixer (HSM) 是一种用于优化计算效率的机制,它通过减少计算复杂度来提高模型性能。其核心思想是利用一个共享的全局隐藏状态h来执行通道混合,包括门控和输出投影,直接在减少的潜在数组h上进行操作。具体实现原理如下:

1. 线性投影优化:HSM通过首先计算隐藏状态hin,然后将其作为线性投影到隐藏状态,从而将计算复杂度从O(LD^2)降低到O(ND^2),其中N是状态的数量,L是序列长度,D是通道数。这种优化依赖于状态数量N远小于通道数D(即N ≪ D)的情况。

2. 减少隐藏状态的使用:在HSM-SSD层中,通过减少隐藏状态来处理计算密集型的投影,如图2b所示。这使得计算成本主要依赖于状态的数量,而不是通道数。

3. 门控和输出投影的直接应用:HSM直接在隐藏状态上应用门控和投影,而不是先计算Ch再进行门控和投影。这减少了计算复杂度,并且当N较小时,捕捉全局上下文的总复杂度变得可以忽略。

4. 多阶段隐藏状态融合(MSF):为了进一步提升性能,HSM引入了MSF机制,该机制融合了来自网络多个阶段的隐藏状态来生成预测logits。通过这种方式,模型能够整合低层次和高层次的特征,增强模型在推理时的泛化能力。

5. 单头设计:HSM-SSD采用单头设计,通过设置∆ ∈ RL×N和ˆa ∈ RN来估计每个状态的token重要性。这种设计避免了多头配置中的内存绑定操作,提高了吞吐量,并且在保持竞争性能的同时,实现了更高的效率。

四、实验分析

1. 图像分类性能:EfficientViM在ImageNet-1K分类任务中表现出色,与先前的高效网络相比,在速度和准确性方面均有所超越。特别是EfficientViM-M2在保持与MobileViTV2 0.75和FastViT-T8相似性能的同时,实现了约4倍的加速。

2. 扩展性分析:EfficientViM在高分辨率图像上的性能同样优异,与SHViT相比,在5122分辨率下实现了超过15%的速度提升。此外,通过与教师模型RegNetY-160进行蒸馏训练,EfficientViM在速度-准确性权衡上进一步确立了优势。

3. 内存效率:尽管EfficientViM的参数数量相对较多,但其在设备上的内存使用量由推理时的内存I/O决定,而非仅由参数数量决定。EfficientViM在保持最佳吞吐量的同时,显示出与具有较低参数数量的模型相当的内存效率。

表3. 在ImageNet-1K 分类任务上高效网络的比较。结果按准确率排序。我们还指出了每个方法与EfficientViM相比的相对吞吐量Thrrel。

表4. EfficientViM与vision Mambas的比较。Thrrel是与EfficientViM-M4相比的相对吞吐量。

表6. 使用蒸馏目标训练后高效网络的比较。Thrrel是与EfficientViM在每个分割中的相对吞吐量比较。

表B. 在COCO-2017上的实例分割、物体检测及结果。

五、代码

1

温馨提示:对于所有推文中出现的代码,如果您在微信中复制的代码排版错乱,请复制该篇推文的链接,在任意浏览器中打开,再复制相应代码,即可成功在开发环境中运行!或者进入官方github仓库找到对应代码进行复制!

代码语言:javascript
代码运行次数:0
复制
import torch
import torch.nn as nn
import math
# 论文题目:EfficientViM: Efficient Vision Mamba with Hidden State Mixer based State Space Duality
# 中文题目:高效ViM:基于隐藏状态混合器的状态空间对偶性的高效视觉Mamba
# 论文链接:https://arxiv.org/pdf/2411.15241
# 官方github:https://arxiv.org/pdf/2411.15241
# 所属机构:韩国大学——计算机科学与工程系
# 代码整理:微信公众号:AI缝合术
class LayerNorm2D(nn.Module):
    """LayerNorm for channels of 2D tensor(B C H W)"""
    def __init__(self, num_channels, eps=1e-5, affine=True):
        super(LayerNorm2D, self).__init__()
        self.num_channels = num_channels
        self.eps = eps
        self.affine = affine
        if self.affine:
            self.weight = nn.Parameter(torch.ones(1, num_channels, 1, 1))
            self.bias = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
    def forward(self, x):
        mean = x.mean(dim=1, keepdim=True)  # (B, 1, H, W)
        var = x.var(dim=1, keepdim=True, unbiased=False)  # (B, 1, H, W)
        x_normalized = (x - mean) / torch.sqrt(var + self.eps)  # (B, C, H, W)
        if self.affine:
            x_normalized = x_normalized * self.weight + self.bias
        return x_normalized
class LayerNorm1D(nn.Module):
    """LayerNorm for channels of 1D tensor(B C L)"""
    def __init__(self, num_channels, eps=1e-5, affine=True):
        super(LayerNorm1D, self).__init__()
        self.num_channels = num_channels
        self.eps = eps
        self.affine = affine
        if self.affine:
            self.weight = nn.Parameter(torch.ones(1, num_channels, 1))
            self.bias = nn.Parameter(torch.zeros(1, num_channels, 1))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
    def forward(self, x):
        mean = x.mean(dim=1, keepdim=True)  # (B, 1, H, W)
        var = x.var(dim=1, keepdim=True, unbiased=False)  # (B, 1, H, W)
        x_normalized = (x - mean) / torch.sqrt(var + self.eps)  # (B, C, H, W)
        if self.affine:
            x_normalized = x_normalized * self.weight + self.bias
        return x_normalized
class ConvLayer2D(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, norm=nn.BatchNorm2d, act_layer=nn.ReLU, bn_weight_init=1):
        super(ConvLayer2D, self).__init__()
        self.conv = nn.Conv2d(
            in_dim,
            out_dim,
            kernel_size=(kernel_size, kernel_size),
            stride=(stride, stride),
            padding=(padding, padding),
            dilation=(dilation, dilation),
            groups=groups,
            bias=False
        )
        self.norm = norm(num_features=out_dim) if norm else None
        self.act = act_layer() if act_layer else None
        
        if self.norm:
            torch.nn.init.constant_(self.norm.weight, bn_weight_init)
            torch.nn.init.constant_(self.norm.bias, 0)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        if self.norm:
            x = self.norm(x)
        if self.act:
            x = self.act(x)
        return x
    
    
class ConvLayer1D(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, norm=nn.BatchNorm1d, act_layer=nn.ReLU, bn_weight_init=1):
        super(ConvLayer1D, self).__init__()
        self.conv = nn.Conv1d(
            in_dim,
            out_dim,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=False
        )
        self.norm = norm(num_features=out_dim) if norm else None
        self.act = act_layer() if act_layer else None
        
        if self.norm:
            torch.nn.init.constant_(self.norm.weight, bn_weight_init)
            torch.nn.init.constant_(self.norm.bias, 0)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        if self.norm:
            x = self.norm(x)
        if self.act:
            x = self.act(x)
        return x
class FFN(nn.Module):
    def __init__(self, in_dim, dim):
        super().__init__()
        self.fc1 = ConvLayer2D(in_dim, dim, 1)
        self.fc2 = ConvLayer2D(dim, in_dim, 1, act_layer=None, bn_weight_init=0)
        
    def forward(self, x):
        x = self.fc2(self.fc1(x))
        return x
class HSMSSD(nn.Module):
    def __init__(self, d_model, ssd_expand=1, A_init_range=(1, 16), state_dim = 64):
        super().__init__()
        self.ssd_expand = ssd_expand
        self.d_inner = int(self.ssd_expand * d_model)
        self.state_dim = state_dim
        self.BCdt_proj = ConvLayer1D(d_model, 3*state_dim, 1, norm=None, act_layer=None)
        conv_dim = self.state_dim*3
        self.dw = ConvLayer2D(conv_dim, conv_dim, 3,1,1, groups=conv_dim, norm=None, act_layer=None, bn_weight_init=0) 
        self.hz_proj = ConvLayer1D(d_model, 2*self.d_inner, 1, norm=None, act_layer=None)
        self.out_proj = ConvLayer1D(self.d_inner, d_model, 1, norm=None, act_layer=None, bn_weight_init=0)
        A = torch.empty(self.state_dim, dtype=torch.float32).uniform_(*A_init_range)
        self.A = torch.nn.Parameter(A)
        self.act = nn.SiLU()
        self.D = nn.Parameter(torch.ones(1))
        self.D._no_weight_decay = True
    def forward(self, x):
        batch, _, L= x.shape
        H = int(math.sqrt(L))
        
        BCdt = self.dw(self.BCdt_proj(x).view(batch,-1, H, H)).flatten(2)
        B,C,dt = torch.split(BCdt, [self.state_dim, self.state_dim,  self.state_dim], dim=1) 
        A = (dt + self.A.view(1,-1,1)).softmax(-1) 
        
        AB = (A * B) 
        h = x @ AB.transpose(-2,-1) 
        
        h, z = torch.split(self.hz_proj(h), [self.d_inner, self.d_inner], dim=1) 
        h = self.out_proj(h * self.act(z)+ h * self.D)
        y = h @ C # B C N, B C L -> B C L
        
        y = y.view(batch,-1,H,H).contiguous()# + x * self.D  # B C H W
        return y, h
    
class EfficientViMBlock(nn.Module):
    def __init__(self, dim, mlp_ratio=4., ssd_expand=1, state_dim=64):
        super().__init__()
        self.dim = dim
        self.mlp_ratio = mlp_ratio
        
        self.mixer = HSMSSD(d_model=dim, ssd_expand=ssd_expand,state_dim=state_dim)  
        self.norm = LayerNorm1D(dim)
        
        self.dwconv1 = ConvLayer2D(dim, dim, 3, padding=1, groups=dim, bn_weight_init=0, act_layer = None)
        self.dwconv2 = ConvLayer2D(dim, dim, 3, padding=1, groups=dim, bn_weight_init=0, act_layer = None)
        
        self.ffn = FFN(in_dim=dim, dim=int(dim * mlp_ratio))
        
        #LayerScale
        self.alpha = nn.Parameter(1e-4 * torch.ones(4,dim), requires_grad=True)
        
    def forward(self, x):
        alpha = torch.sigmoid(self.alpha).view(4,-1,1,1)
        
        # DWconv1
        x = (1-alpha[0]) * x + alpha[0] * self.dwconv1(x)
        
        # HSM-SSD
        x_prev = x
        x, h = self.mixer(self.norm(x.flatten(2))) 
        x = (1-alpha[1]) * x_prev + alpha[1] * x
        
        # DWConv2
        x = (1-alpha[2]) * x + alpha[2] * self.dwconv2(x)
        
        # FFN
        x = (1-alpha[3]) * x + alpha[3] * self.ffn(x)
        # return x, h
        return x
if __name__ == "__main__":
    # 将模块移动到 GPU(如果可用)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 创建测试输入张量
    x = torch.randn(1, 32, 256, 256).to(device)
    # 初始化 evim 模块
    evim = EfficientViMBlock(dim=32).to(device)
    print(evim)
    # 前向传播    
    print("\n微信公众号: AI缝合术!\n")
    output = evim(x)
    
    # 打印输入和输出张量的形状
    print("输入张量形状:", x.shape)
    print("输出张量形状:", output.shape)

运行结果

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-03-05,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 小白学视觉 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档