前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >SAM-Adapter

SAM-Adapter

作者头像
Srlua
发布2025-01-02 08:56:15
发布2025-01-02 08:56:15
10500
代码可运行
举报
文章被收录于专栏:CSDN社区搬运CSDN社区搬运
运行总次数:0
代码可运行

概述

SAM-Adapter 的核心思想是通过引入轻量级适配器,将任务特定知识注入到冻结的 SAM 模型中,以增强其在下游任务中的适应能力。适配器的设计简洁高效,通过灵活的任务知识输入,提升了模型的性能与泛化能力,特别是在数据稀缺场景下表现突出。

  1. 该文章分析了SAM作为基础模型的局限性,并提出如何利用SAM服务于下游任务的问题;
  2. 其提出的SAM-Adapter,创新性地整合任务特定知识与大模型地通用知识,灵活适应多种任务。

模型详述

1. 使用SAM作为骨干网络

  • 目标:SAM-Adapter的目标是灵活利用SAM预训练模型中的知识;
  • 骨干架构:使用SAM的图像编码器(基于ViT-H/16)作为骨干网络,同时保持其预训练权重不变;使用SAM的掩码解码器(包括修改后的Transformer解码器和动态掩码预测头),在预训练权重的基础上微调;
  • 任务特定知识引入:通过适配器将任务特定知识 FiFi​ 注入网络,利用Prompts技术提升莫i下在下游任务中的泛化能力。

2. 输入任务特定知识 任务特定知识FiFi​可以根据具体任务灵活设计,形式多样。其可以是从数据集中提取的特征(如纹理或频率信息),也可以是手工设计的规则,以及多种信息的组合形式: Fi=∑j=1NwjFjFi​=j=1∑N​wj​Fj​ 其中FjFj​为某种类型的知识或特征,wjwj​为可调节的权重(用于控制组合强度)。

3. Adapters 结构:由两个多层感知器(MLP)和一个激活函数(GELU)组成:

Pi=MLPup(GELU(MLPtuneiFi))Pi​=MLPup​(GELU(MLPtunei​Fi​))

其中,MLPtuneiMLPtunei​是线性层,用于为每个适配器生成任务特定的提示(prompts); MLPupMLPup​ 是一个共享的上投影层,用于调整Transformer特征的维度;GELUGELU是激活函数。 PiPi​是输出的提示,附加到SAM模型的每一层Transformer中。 在该项目的代码实现中,是这样实现adpater的功能的:

代码语言:javascript
代码运行次数:0
复制
class PromptGenerator(nn.Module):
  def __init__(self, ...):
    ...
    self.shared_mlp = nn.Linear(self.embed_dim//self.scale_factor, self.embed_dim)
    self.embedding_generator = nn.Linear(self.embed_dim, self.embed_dim//self.scale_factor)
    for i in range(self.depth):
        lightweight_mlp = nn.Sequential(
             nn.Linear(self.embed_dim//self.scale_factor, self.embed_dim//self.scale_factor),
             nn.GELU()
            )
        setattr(self, 'lightweight_mlp_{}'.format(str(i)), lightweight_mlp)

    self.prompt_generator = PatchEmbed2(img_size=img_size,
                                        patch_size=patch_size, in_chans=3,
                                        embed_dim=self.embed_dim//self.scale_factor)
    ...

  def init_embeddings(self, x):
        N, C, H, W = x.permute(0, 3, 1, 2).shape
        x = x.reshape(N, C, H*W).permute(0, 2, 1)
        return self.embedding_generator(x)

  def init_handcrafted(self, x):
        x = self.fft(x, self.freq_nums)
        return self.prompt_generator(x)

  def get_prompt(self, handcrafted_feature, embedding_feature):
      N, C, H, W = handcrafted_feature.shape
      handcrafted_feature = handcrafted_feature.view(N, C, H*W).permute(0, 2, 1)
      prompts = []
      for i in range(self.depth):
          lightweight_mlp = getattr(self, 'lightweight_mlp_{}'.format(str(i)))
          # prompt = proj_prompt(prompt)
          prompt = lightweight_mlp(handcrafted_feature + embedding_feature)
          prompts.append(self.shared_mlp(prompt))
      return prompts
     ...

(1)这里只摘取了可以显示其大概思路的部分进行展示,至于细节则请参考该项目的具体实现; (2)self.prompt_generator.init_embeddings和self.prompt_generator.init_handcrafted的实现均很简单,分别是线性层和卷积层; (3)在获取handcrafted_feature时,运用了傅里叶变换,然后提取高频信息,对应的是原图像中的边缘、纹理等信息; (4) embedding_feature更偏向图像的全局语义,适合提供通用背景信息, 而handcrafted_feature偏向图像的局部高频特征,适合突出任务关键细节。 两者互补,使得生成的prompts同时具有全局视角和局部任务适应性。

实验

数据集介绍

本文复现使用的是COD10K数据集,其在伪装目标检测(COD)领域具有重要地位,包含10,000张图像,涵盖78个类别(69个伪装类别,9个非伪装类别)。这些图像来自多种自然场景,包括5,066张伪装目标图像、3,000张背景图像和1,934张非伪装目标图像。数据集提供高分辨率图像和精细标注信息,可支持目标检测、分割和边缘检测等任务。其丰富的多样性和高质量标注使其成为伪装目标检测领域的重要研究资源。

复现流程

如图展示了该模型在COD10K的测试集中随机选取的两个样本上的预测效果。从左到右依次为原图、真实标签和模型预测结果。

  1. 下载附件中的项目代码、数据集和权重文件并放置在相应路径下。 数据集(cod10k)和相应的权重文件我已经准备好,网盘链接也放在了附件当中。
  2. 训练模型
代码语言:javascript
代码运行次数:0
复制
python train_single.py --config [config_file_path]

train_single.py是我增加的适用于非分布式环境(单GPU)的训练脚本。

  1. 推理
代码语言:javascript
代码运行次数:0
复制
python test.py --config [config_file_path] --model [model_path]

环境配置

python 3.9 torch2.2.1+cu121 A800

​​

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2025-01-02,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 概述
  • 模型详述
  • 实验
    • 数据集介绍
    • 复现流程
  • 环境配置
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档