Loading [MathJax]/jax/output/CommonHTML/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >已开源!通过高度驱动的注意力网络改善城市场景语义分割 | CVPR2020

已开源!通过高度驱动的注意力网络改善城市场景语义分割 | CVPR2020

作者头像
AI算法修炼营
发布于 2020-05-27 15:12:29
发布于 2020-05-27 15:12:29
83200
代码可运行
举报
文章被收录于专栏:AI算法修炼营AI算法修炼营
运行总次数:0
代码可运行

论文地址:https://arxiv.org/abs/2003.05128

代码地址:https://github.com/shachoi/HANet(已开源)

该论文利用了城市场景图像的内在特征,并提出了一个通用的附加模块,称为高度驱动的注意力网络(HANet),用于改善城市场景图像的语义分割。

将城市场景图像进行垂直方向的分割后(分为上部、中部、下部),像素级类别分布彼此之间存在显著差异。同样,城市场景图像具有其自身独特的特征,但是大多数语义分割网络并未反映出体系结构中的此类独特属性。HANet网络架构结合了利用垂直属性来有效处理城市场景数据集的能力。HANet根据像素的垂直位置来选择相关特征并进行像素类

经过广泛的定量分析表明,HANet模块能既简单又经济高效地添加到现有模型中。在基于ResNet-101的分割模型中,该方法在Cityscapes基准上实现了新的SOTA性能。此外,文中通过可视化和解释注意力图来表明所提出的模型与在城市场景中观察到的事实是一致的。

1. 简介

由于城市现场图像是由安装在汽车前部的摄像头捕获的,因此城市现场数据集仅由道路行驶图片组成这导致有可能根据空间位置,特别是在垂直位置,引入共同的结构先验。

下图显示了垂直位置上的城市场景数据集的类别分布。尽管少数类别的像素在整个图像区域中都是主要的(图1(a)),但类别分布对垂直位置有很大的依赖性。也就是说,图像的下部主要由道路组成,而中间部分则包含各种相对较小的对象。在上部,建筑物,植被和天空是主要对象,如图1(b)所示。

可以看出,类别分布极为不平衡,主要的前五类的概率:道路,建筑物,植被,汽车和人行道。占主导地位的类占据了整个数据集的88%。如上所述,如果将图像分为三个区域:上部,中部和下部,则类别分布完全不同。

如果能够识别出图像中任意像素所属的部分,将有助于语义分割中的像素级分类。提出了一种新型的高度驱动的注意力网络(HANet),作为城市场景图像语义分割的通用附加模块。给定一个输入特征图,HANet提取代表每个水平划分部分的“高度上下文信息”,然后从高度上下文信息中预测每个水平部分中特征或类别。

论文主要贡献:

  • 提出了一种新颖的轻量级附加模块HANet,可以轻松地将其添加到现有模型中,并通过根据像素的垂直位置通道的注意力驱动来提高性能。通过广泛的实验,我们证明了该方法的有效性和广泛适用性.
  • 通过将HANet添加到DeeplabV3+的baseline中,在Cityscapes数据集上获得了最新的性能,而计算和内存开销可忽略不计。
  • 可视化并解释各个渠道上的注意力权重,并以实验方式证实了高度位置对于改善城市场景中的片段化性能至关重要。

2. 背景

语义分割的模型中,在捕获高级语义特征的同时保持特征图的分辨率对于实现语义分割的高性能至关重要。主要方法有:

  • 跳级连接(利用编码器层中较早存在的高分辨率特征来恢复解码器层中的对象边界)
  • 空洞卷积(在不增加计算量的情况下,增加感受野的大小)
  • 自注意力机制(捕获远程依赖)
  • 关注类别的边界信息

3. 方法

根据空间位置的不同,城市场景图像通常包含共同的结构先验。就类别分布而言,图像的每一行都有明显不同的统计信息。从这个意义上说,在城市场景分割的像素级分类过程中,可以分别捕获表示每一行的全局上下文信息即高度上下文信息来估计信道的权重。

因此,提出了HANet,其目的是:i)提取高度方向的上下文信息,ii)使用上下文计算高度驱动的注意权重以表示每行的特征(中间层)或类(最后一层)的重要性。。

(a)width-wise pooling

压缩空间维。在宽度合并操作的最大合并和平均合并之间进行选择是一个超参数,并根据经验设置为平均合并

(b,d)interpolation for coarse attention

合并操作后,模型生成矩阵Z∈RC`×H`。但是,并不是矩阵Z的所有行对于计算有效的关注图都是必要的。因此,先经过插值进行下采样(b),同时,由于由下采样表示构造的注意图也是粗糙的,因此还需要通过上采样将注意图转换为与给定的高层次特征图Xh具有等效的高维(图2(d))

(c)computing height-driven attention map

高度驱动的通道式注意力图A是由卷积层获得的,这些卷积层将宽度合并和插值后的特征图ˆZ作为输入。在生成注意力图过程中,采用卷积层而不是全连接层(与SENet有区别),以便在估计注意力图时考虑相邻行之间的关系,因为每一行都与其相邻行相关。同时,为了允许这些多个功能和标签,在计算注意力图时使用了S形函数,而不是softmax函数。这些由N个卷积层组成的运算可以写成:

(c)positon encoding (这部分需要结合重点理解)

当人类识别出驾驶场景时,他们对特定物体的垂直位置具有先验知识(例如,道路和天空分别出现在下部和上部)。受此观察的启发,将NLP领域的正弦位置编码添加到HANet中。具体位置编码定义为:

在计算注意力图之后,可以将给定的较高级特征图Xh转换为通过A和Xh的元素乘积获取的新表示。每个通道的单个缩放向量是由每个单独的行或多个连续行每组派生的,因此该向量与水平方向一起进行计算,公式为:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
    ''' Sinusoid position encoding table '''
    def cal_angle(position, hid_idx):
        if d_hid > 50:
            cycle = 10
        elif d_hid > 5:
            cycle = 100
        else:
            cycle = 10000
        cycle = 10 if d_hid > 50 else 100
        return position / np.power(cycle, 2 * (hid_idx // 2) / d_hid)
    def get_posi_angle_vec(position):
        return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
    sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1
    if padding_idx is not None:
        # zero vector for padding dimension
        sinusoid_table[padding_idx] = 0.
    return torch.FloatTensor(sinusoid_table)

class PosEmbedding2D(nn.Module):
    
    def __init__(self, pos_rfactor, dim):
        super(PosEmbedding2D, self).__init__()

        self.pos_layer_h = nn.Embedding((128//pos_rfactor)+1, dim)
        self.pos_layer_w = nn.Embedding((128//pos_rfactor)+1, dim)
        initialize_embedding(self.pos_layer_h)
        initialize_embedding(self.pos_layer_w)

    def forward(self, x, pos):
        pos_h, pos_w = pos
        pos_h = pos_h.unsqueeze(1)
        pos_w = pos_w.unsqueeze(1)
        pos_h = nn.functional.interpolate(pos_h.float(), size=x.shape[2:], mode='nearest').long()  # B X 1 X H X W
        pos_w = nn.functional.interpolate(pos_w.float(), size=x.shape[2:], mode='nearest').long()  # B X 1 X H X W
        pos_h = self.pos_layer_h(pos_h).transpose(1,4).squeeze(4)   # B X 1 X H X W X C
        pos_w = self.pos_layer_w(pos_w).transpose(1,4).squeeze(4)   # B X 1 X H X W X C
        x = x + pos_h + pos_w
        return x

class PosEncoding1D(nn.Module):
    
    def __init__(self, pos_rfactor, dim, pos_noise=0.0):
        super(PosEncoding1D, self).__init__()
        print("use PosEncoding1D")
        self.sel_index = torch.tensor([0]).cuda()
        pos_enc = (get_sinusoid_encoding_table((128//pos_rfactor)+1, dim) + 1)
        self.pos_layer = nn.Embedding.from_pretrained(embeddings=pos_enc, freeze=True)
        self.pos_noise = pos_noise
        self.noise_clamp = 16 // pos_rfactor # 4: 4, 8: 2, 16: 1

        self.pos_rfactor = pos_rfactor
        if pos_noise > 0.0:
            self.min = 0.0 #torch.tensor([0]).cuda()
            self.max = 128//pos_rfactor #torch.tensor([128//pos_rfactor]).cuda()
            self.noise = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([pos_noise]))

    def forward(self, x, pos, return_posmap=False):
        pos_h, _ = pos # B X H X W
        pos_h = pos_h//self.pos_rfactor
        pos_h = pos_h.index_select(2, self.sel_index).unsqueeze(1).squeeze(3) # B X 1 X H
        pos_h = nn.functional.interpolate(pos_h.float(), size=x.shape[2], mode='nearest').long() # B X 1 X 48

        if self.training is True and self.pos_noise > 0.0:
            #pos_h = pos_h + (self.noise.sample(pos_h.shape).squeeze(3).cuda()//1).long()
            pos_h = pos_h + torch.clamp((self.noise.sample(pos_h.shape).squeeze(3).cuda()//1).long(), 
                            min=-self.noise_clamp, max=self.noise_clamp)
            pos_h = torch.clamp(pos_h, min=self.min, max=self.max)
            #pos_h = torch.where(pos_h < self.min_tensor, self.min_tensor, pos_h)
            #pos_h = torch.where(pos_h > self.max_tensor, self.max_tensor, pos_h)

        pos_h = self.pos_layer(pos_h).transpose(1,3).squeeze(3)   # B X 1 X 48 X 80 > B X 80 X 48 X 1 
        x = x + pos_h
        if return_posmap:
            return x, self.pos_layer.weight # 33 X 80
        return x

class PosEmbedding1D(nn.Module):
    
    def __init__(self, pos_rfactor, dim, pos_noise=0.0):
        super(PosEmbedding1D, self).__init__()
        print("use PosEmbedding1D")
        self.sel_index = torch.tensor([0]).cuda()
        self.pos_layer = nn.Embedding((128//pos_rfactor)+1, dim)
        initialize_embedding(self.pos_layer)
        self.pos_noise = pos_noise
        self.pos_rfactor = pos_rfactor
        self.noise_clamp = 16 // pos_rfactor # 4: 4, 8: 2, 16: 1

        if pos_noise > 0.0:
            self.min = 0.0 #torch.tensor([0]).cuda()
            self.max = 128//pos_rfactor #torch.tensor([128//pos_rfactor]).cuda()
            self.noise = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([pos_noise]))

    def forward(self, x, pos, return_posmap=False):
        pos_h, _ = pos # B X H X W
        pos_h = pos_h//self.pos_rfactor
        pos_h = pos_h.index_select(2, self.sel_index).unsqueeze(1).squeeze(3) # B X 1 X H
        pos_h = nn.functional.interpolate(pos_h.float(), size=x.shape[2], mode='nearest').long() # B X 1 X 48

        if self.training is True and self.pos_noise > 0.0:
            #pos_h = pos_h + (self.noise.sample(pos_h.shape).squeeze(3).cuda()//1).long()
            pos_h = pos_h + torch.clamp((self.noise.sample(pos_h.shape).squeeze(3).cuda()//1).long(),
                            min=-self.noise_clamp, max=self.noise_clamp)
            pos_h = torch.clamp(pos_h, min=self.min, max=self.max)

        pos_h = self.pos_layer(pos_h).transpose(1,3).squeeze(3)   # B X 1 X 48 X 80 > B X 80 X 48 X 1 
        x = x + pos_h
        if return_posmap:
            return x, self.pos_layer.weight # 33 X 80
        return x

采用DeepLabv3 + 作为语义分割任务的基准。 DeepLabv3 +具有带ASPP的编解码器架构,该架构采用各种扩张速率对应于不同rate的空洞率。在从骨干网络对高级表示进行编码之后,将HANet添加到五个不同层的分段网络。这是因为高级特征与垂直位置的相关性更强

HANet具体的结构图如上图所示。在Pytorch中使用二维自适应平均池化操作2来实现针对粗略注意的宽度方向池化和插值。此后,应用了dropout层和三个一维卷积层。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class HANet_Conv(nn.Module):
    
    def __init__(self, in_channel, out_channel, kernel_size=3, r_factor=64, layer=3, pos_injection=2, is_encoding=1,
                pos_rfactor=8, pooling='mean', dropout_prob=0.0, pos_noise=0.0):
        super(HANet_Conv, self).__init__()

        self.pooling = pooling
        self.pos_injection = pos_injection
        self.layer = layer
        self.dropout_prob = dropout_prob
        self.sigmoid = nn.Sigmoid()

        if r_factor > 0:
            mid_1_channel = math.ceil(in_channel / r_factor)
        elif r_factor < 0:
            r_factor = r_factor * -1
            mid_1_channel = in_channel * r_factor

        if self.dropout_prob > 0:
            self.dropout = nn.Dropout2d(self.dropout_prob)

        self.attention_first = nn.Sequential(
                nn.Conv1d(in_channels=in_channel, out_channels=mid_1_channel,
                    kernel_size=1, stride=1, padding=0, bias=False),
                Norm2d(mid_1_channel),
                nn.ReLU(inplace=True))    

        if layer == 2:
            self.attention_second = nn.Sequential(
                    nn.Conv1d(in_channels=mid_1_channel, out_channels=out_channel,
                        kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=True))
        elif layer == 3:
            mid_2_channel = (mid_1_channel * 2)
            self.attention_second = nn.Sequential(
                    nn.Conv1d(in_channels=mid_1_channel, out_channels=mid_2_channel,
                        kernel_size=3, stride=1, padding=1, bias=True),
                    Norm2d(mid_2_channel),
                    nn.ReLU(inplace=True))    
            self.attention_third = nn.Sequential(
                    nn.Conv1d(in_channels=mid_2_channel, out_channels=out_channel,
                        kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=True))

        if self.pooling == 'mean':
            #print("##### average pooling")
            self.rowpool = nn.AdaptiveAvgPool2d((128//pos_rfactor,1))
        else:
            #print("##### max pooling")
            self.rowpool = nn.AdaptiveMaxPool2d((128//pos_rfactor,1))

        if pos_rfactor > 0:
            if is_encoding == 0:
                if self.pos_injection == 1:
                    self.pos_emb1d_1st = PosEmbedding1D(pos_rfactor, dim=in_channel, pos_noise=pos_noise)
                elif self.pos_injection == 2:
                    self.pos_emb1d_2nd = PosEmbedding1D(pos_rfactor, dim=mid_1_channel, pos_noise=pos_noise)
            elif is_encoding == 1:
                if self.pos_injection == 1:
                    self.pos_emb1d_1st = PosEncoding1D(pos_rfactor, dim=in_channel, pos_noise=pos_noise)
                elif self.pos_injection == 2:
                    self.pos_emb1d_2nd = PosEncoding1D(pos_rfactor, dim=mid_1_channel, pos_noise=pos_noise)
            else:
                print("Not supported position encoding")
                exit()


    def forward(self, x, out, pos=None, return_attention=False, return_posmap=False, attention_loss=False):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        H = out.size(2)
        x1d = self.rowpool(x).squeeze(3)

        if pos is not None and self.pos_injection == 1:
            if return_posmap:
                x1d, pos_map1 = self.pos_emb1d_1st(x1d, pos, True)
            else:
                x1d = self.pos_emb1d_1st(x1d, pos)

        if self.dropout_prob > 0:
            x1d = self.dropout(x1d)
        x1d = self.attention_first(x1d)

        if pos is not None and self.pos_injection == 2:
            if return_posmap:
                x1d, pos_map2 = self.pos_emb1d_2nd(x1d, pos, True)
            else:
                x1d = self.pos_emb1d_2nd(x1d, pos)

        x1d = self.attention_second(x1d)

        if self.layer == 3:
            x1d = self.attention_third(x1d)
            if attention_loss:
                last_attention = x1d            
            x1d = self.sigmoid(x1d)
        else:
            if attention_loss:
                last_attention = x1d            
            x1d = self.sigmoid(x1d)

        x1d = F.interpolate(x1d, size=H, mode='linear')
        out = torch.mul(out, x1d.unsqueeze(3))

        if return_attention:
            if return_posmap:
                if self.pos_injection == 1:
                    pos_map = (pos_map1)
                elif self.pos_injection == 2:
                    pos_map = (pos_map2)
                return out, x1d, pos_map
            else:
                return out, x1d
        else:
            if attention_loss:
                return out, last_attention
            else:
                return out

如上图所示,宽度方向上的列与高度方向上的相应的列类别分布相似。因此,相对于图像的水平位置提取不同的信息将相对困难。同样,从经验上讲,使用注意力网络与预测宽度类别分布时,没有观察到有意义的性能提升。这证实了HANet概念的基本原理,该思想提取并合并了高度方向的上下文信息,而不是宽度方向的上下文信息。

5. 效果

为了与其他最新模型进行比较,使用经过精细注释的训练和验证集,9000次迭代进行训练。在采用ResNext-101 作为主干网络的情况下,额外使用了粗注释的图像,并且该模型在Mapillary 上进行了预训练。crop和批量大小分别更改为864×864和12。将基于ResNet-101和ResNext-101的最佳模型与Cityscapes测试集上的其他最新模型进行了比较,模型实现了最新的性能。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-05-26,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI算法修炼营 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
支招 | 使用Pytorch进行文本分类
dropout的值要在 0.1 以下(经验之谈,笔者在实践中发现,dropout取0.1时比dropout取0.3时在测试集准确率能提高0.5%)。
AI研习社
2019/09/17
2.2K0
支招 | 使用Pytorch进行文本分类
音视频开发之旅(89) - Transformer论文解读和源码解析
Transformer是续MLP RNN CNN后又一个影响深远的模型,  之前CNN RNN基本上都在各自的领域发光发热,  但是Transformer 在很多领域都有着很广泛的应用. eg: chatgpt  llama等语言大模型, sd文生图模型, 以及多模态 llava等. 我们最近探索的视频&图像画质评测以及画质增强很多算法也都是基于Transformer. 所以加强对Transformer学习理解和应用迫在眉睫.
音视频开发之旅
2024/09/07
1390
音视频开发之旅(89) - Transformer论文解读和源码解析
pytorch实现的transformer代码分析
代码来源:https://github.com/graykode/nlp-tutorial/blob/master/5-1.Transformer/Transformer-Torch.py
西西嘛呦
2020/08/26
9660
搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了
本文对Vision Transformer的原理和代码进行了非常全面详细的解读,一切从Self-attention开始、Transformer的实现和代码以及Transformer+Detection:引入视觉领域的首创DETR。
AI算法与图像处理
2021/01/20
8.2K0
搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了
使用PyTorch实现去噪扩散模型
在深入研究去噪扩散概率模型(DDPM)如何工作的细节之前,让我们先看看生成式人工智能的一些发展,也就是DDPM的一些基础研究。
deephub
2024/01/07
6270
使用PyTorch实现去噪扩散模型
Transformer的PyTorch实现
文本主要介绍一下如何使用 PyTorch 复现 Transformer,实现简单的机器翻译任务。
mathor
2020/07/15
8620
Transformer的PyTorch实现
指针生成网络(PGN)详细指南(引入)
我们首先要了解的是seq2seq(Sequence-to-Sequence)模型。它最早由Google在2014年的一篇论文中提出,是第一个真正意义上的端到端的编码器-解码器(Encoder-Decoder)框架。
@小森
2025/01/24
1100
Python实现替换照片人物背景,精细到头发丝(附代码)
其中,model文件夹放的是模型文件,模型文件的下载地址为:https://drive.google.com/drive/folders/1NmyTItr2jRac0nLoZMeixlcU1myMiYTs
用户8544541
2022/03/24
1.1K0
Python实现替换照片人物背景,精细到头发丝(附代码)
使用pytorch进行文本分类——ADGCNN
在文本分类任务中常用的网络是RNN系列或Transformer的Encoder,很久没有看到CNN网络的身影(很久之前有TextCNN网络)。本文尝试使用CNN网络搭建一个文本分类器,命名为:ADGCNN。
Dendi
2019/12/12
2K0
120分钟吃掉DIEN深度兴趣演化网络
2018年的深度兴趣演化网络, DIEN(DeepInterestEvolutionNetWork)。
lyhue1991
2023/02/23
4830
120分钟吃掉DIEN深度兴趣演化网络
计算机视觉中的注意力:PyTorch中实现MultiHead和CBAM
自从Transformer在“注意力就是你所需要的”的工作中被引入以来,在自然语言处理领域已经发生了一个转变,即用基于注意力的网络取代循环神经网络(RNN)。在当前的文献中,已经有很多很棒的文章描述了这种方法。下面是我在评论中发现的两个最好的:带注释的Transformer和Transformer的可视化解释。
deephub
2021/07/01
6020
3W字长文带你轻松入门视觉Transformer
Transformer整个网络结构完全由Attention机制组成,其出色的性能在多个任务上都取得了非常好的效果。本文从Transformer的结构出发,结合视觉中的成果进行了分析,能够帮助初学者们快速入门。
石晓文
2020/12/08
1.2K0
3W字长文带你轻松入门视觉Transformer
preprint版本 | 何凯明大神新作MAE | CVPR2022最佳论文候选
本文证明了蒙面自动编码器(MAE)是一种可扩展的计算机视觉自监督学习器。我们的MAE方法很简单:我们屏蔽输入图像的随机补丁并重建丢失的像素。
机器学习炼丹术
2021/12/06
1.3K0
preprint版本 | 何凯明大神新作MAE | CVPR2022最佳论文候选
如何将tensorflow1.x代码改写为pytorch代码(以图注意力网络(GAT)为例)
之前讲解了图注意力网络的官方tensorflow版的实现,由于自己更了解pytorch,所以打算将其改写为pytorch版本的。
西西嘛呦
2020/09/16
2.2K0
如何将tensorflow1.x代码改写为pytorch代码(以图注意力网络(GAT)为例)
解析 Token to Token Vision Transformer
Vision Transformer!的提出让我们看到了Transformer模型在图像方向的潜力,但其有一些缺点,如需要超大型数据集(JFT)预训练,才能达到现在CNN的精度。本文分析了ViT的一些不合理之处:
BBuf
2021/03/11
7.7K0
搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了(二)
本文为详细解读Vision Transformer的第二篇,主要包括三个方向的分类:可变形的Transformer ,用于分类任务的Transformer ,用于底层视觉任务的Transformer,分别对应了三篇相关论文。附有超详细的代码解读。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
godweiyang
2021/04/08
3.5K0
搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了(二)
PyTorch技术点整理
这里我们可以看到PyTorch更简洁,不需要那么多的接口API,更接近于Python编程本身。
算法之名
2022/03/24
7560
PyTorch技术点整理
深度学习 | 《深度学习入门之PyTorch》阅读笔记
KDD(knowledge discovery in database),从数据中获取有意义的信息
Justlovesmile
2021/12/14
1.5K0
深度学习 | 《深度学习入门之PyTorch》阅读笔记
LLM入门5 | SAM代码从入门到出门 | MetaAI
非常好加载,基本上pytorch和torchvision版本不太落后就可以加载。里面的model_type需要和模型参数对应上,"vit_h"或者"vit_l"或者"vit_b",即便加载最大的2.4G的vit_h模型,也只需要占用8G的显卡。算是非常小的模型了。这里SAM测试的效果,很多情况下效果并不太好,是一个foundation model,我觉得主要原因是模型参数比较少。导致他不能很好的解决所有的问题。正确用法是对小领域最微调。
机器学习炼丹术
2023/09/02
1.3K0
LLM入门5 | SAM代码从入门到出门 | MetaAI
可视化VIT中的注意力
来源:DeepHub IMBA 本文约4000字,建议阅读8分钟 本文为你介绍ViT模型。 2022年, Vision Transformer (ViT)成为卷积神经网络(cnn)的有力竞争对手,
数据派THU
2023/04/18
1.1K0
可视化VIT中的注意力
推荐阅读
相关推荐
支招 | 使用Pytorch进行文本分类
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验