SAM-Adapter 的核心思想是通过引入轻量级适配器,将任务特定知识注入到冻结的 SAM 模型中,以增强其在下游任务中的适应能力。适配器的设计简洁高效,通过灵活的任务知识输入,提升了模型的性能与泛化能力,特别是在数据稀缺场景下表现突出。
1. 使用SAM作为骨干网络
2. 输入任务特定知识 任务特定知识FiFi可以根据具体任务灵活设计,形式多样。其可以是从数据集中提取的特征(如纹理或频率信息),也可以是手工设计的规则,以及多种信息的组合形式: Fi=∑j=1NwjFjFi=j=1∑NwjFj 其中FjFj为某种类型的知识或特征,wjwj为可调节的权重(用于控制组合强度)。
3. Adapters 结构:由两个多层感知器(MLP)和一个激活函数(GELU)组成:
Pi=MLPup(GELU(MLPtuneiFi))Pi=MLPup(GELU(MLPtuneiFi))
其中,MLPtuneiMLPtunei是线性层,用于为每个适配器生成任务特定的提示(prompts); MLPupMLPup 是一个共享的上投影层,用于调整Transformer特征的维度;GELUGELU是激活函数。 PiPi是输出的提示,附加到SAM模型的每一层Transformer中。 在该项目的代码实现中,是这样实现adpater的功能的:
class PromptGenerator(nn.Module):
def __init__(self, ...):
...
self.shared_mlp = nn.Linear(self.embed_dim//self.scale_factor, self.embed_dim)
self.embedding_generator = nn.Linear(self.embed_dim, self.embed_dim//self.scale_factor)
for i in range(self.depth):
lightweight_mlp = nn.Sequential(
nn.Linear(self.embed_dim//self.scale_factor, self.embed_dim//self.scale_factor),
nn.GELU()
)
setattr(self, 'lightweight_mlp_{}'.format(str(i)), lightweight_mlp)
self.prompt_generator = PatchEmbed2(img_size=img_size,
patch_size=patch_size, in_chans=3,
embed_dim=self.embed_dim//self.scale_factor)
...
def init_embeddings(self, x):
N, C, H, W = x.permute(0, 3, 1, 2).shape
x = x.reshape(N, C, H*W).permute(0, 2, 1)
return self.embedding_generator(x)
def init_handcrafted(self, x):
x = self.fft(x, self.freq_nums)
return self.prompt_generator(x)
def get_prompt(self, handcrafted_feature, embedding_feature):
N, C, H, W = handcrafted_feature.shape
handcrafted_feature = handcrafted_feature.view(N, C, H*W).permute(0, 2, 1)
prompts = []
for i in range(self.depth):
lightweight_mlp = getattr(self, 'lightweight_mlp_{}'.format(str(i)))
# prompt = proj_prompt(prompt)
prompt = lightweight_mlp(handcrafted_feature + embedding_feature)
prompts.append(self.shared_mlp(prompt))
return prompts
...
(1)这里只摘取了可以显示其大概思路的部分进行展示,至于细节则请参考该项目的具体实现; (2)self.prompt_generator.init_embeddings和self.prompt_generator.init_handcrafted的实现均很简单,分别是线性层和卷积层; (3)在获取handcrafted_feature时,运用了傅里叶变换,然后提取高频信息,对应的是原图像中的边缘、纹理等信息; (4) embedding_feature更偏向图像的全局语义,适合提供通用背景信息, 而handcrafted_feature偏向图像的局部高频特征,适合突出任务关键细节。 两者互补,使得生成的prompts同时具有全局视角和局部任务适应性。
本文复现使用的是COD10K数据集,其在伪装目标检测(COD)领域具有重要地位,包含10,000张图像,涵盖78个类别(69个伪装类别,9个非伪装类别)。这些图像来自多种自然场景,包括5,066张伪装目标图像、3,000张背景图像和1,934张非伪装目标图像。数据集提供高分辨率图像和精细标注信息,可支持目标检测、分割和边缘检测等任务。其丰富的多样性和高质量标注使其成为伪装目标检测领域的重要研究资源。
如图展示了该模型在COD10K的测试集中随机选取的两个样本上的预测效果。从左到右依次为原图、真实标签和模型预测结果。
python train_single.py --config [config_file_path]
train_single.py是我增加的适用于非分布式环境(单GPU)的训练脚本。
python test.py --config [config_file_path] --model [model_path]
python 3.9 torch2.2.1+cu121 A800