首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >ViG:图像分类领域前沿

ViG:图像分类领域前沿

作者头像
Srlua
发布2024-12-23 08:13:56
发布2024-12-23 08:13:56
4840
举报
文章被收录于专栏:CSDN社区搬运CSDN社区搬运

图片分类任务方法概述

卷积神经网络(CNN)

发展背景: CNN的出现标志着深度学习在图像识别领域的重大突破。最早的CNN模型可以追溯到1998年的LeNet,而2012年的AlexNet模型在ImageNet竞赛中取得优异成绩,使得CNN成为图像分类任务的主流方法。

分类方法优点:

局部感知野: 通过卷积操作,CNN能够捕捉图像的局部特征,减少参数数量。 参数共享: 卷积核在整张图像上共享,提高了模型的泛化能力。 平移不变性: CNN具有平移不变性,能够识别图像中的物体,即使它们的位置发生变化。

视觉Transformer(ViT)

发展背景: ViT于2020年被提出,借鉴了自然语言处理领域的Transformer架构,将自注意力机制应用于图像分类任务。

分类方法优点:

自注意力机制: 能够捕捉图像中的长距离依赖关系,提高分类准确性。 可扩展性: Transformer结构易于扩展,适用于大规模数据集。 并行计算: 自注意力机制使得ViT能够更好地利用并行计算资源。

视觉图神经网络(ViG)

发展背景: ViG的提出是为了解决CNN和ViT在处理不规则和复杂目标时的局限性。ViG将图像视为图结构,通过图卷积操作进行特征提取和分类。

分类方法优点:

灵活的图结构: ViG采用图结构表示图像,能够更好地处理不规则形状的物体,提高对复杂场景的识别能力。 图卷积操作: 通过图卷积,ViG能够有效地聚合和更新节点信息,捕捉局部和全局特征。 节点特征变换: FFN模块(多层感知器)用于节点特征变换,增强了模型的表达能力

ViG模型

图片切成patch

(a) Grid Structure 作用: 像素级信息捕获:通过将图像切分成均匀分布的小块(Patch),每个Patch代表图像的一个局部区域。 空间关系保持:保留了图像的空间布局信息,使得模型能够理解对象的位置和相对位置。 重要性: 经典方法的基础:这是许多传统计算机视觉算法的基本假设,包括早期的人工设计特征提取方法和现代的深度学习模型(如卷积神经网络CNN)。 简单直观:易于理解和实施,是初学者入门的好选择。 (b) Sequence Structure 作用: 序列化处理:将图像的Patch按某种顺序排列,形成一维序列。 时间维度模拟:虽然实际处理的是静态图像,但通过序列化的方式,可以引入类似于自然语言处理(NLP)领域的时间维度概念。 重要性: Transformer的应用:这种结构特别适合于基于Transformer架构的方法,如Vision Transformer(ViT)。ViT等模型通过自注意力机制对序列化的Patch进行处理,从而有效地捕捉全局上下文信息。 灵活性提升:相比固定大小的卷积核,序列化处理允许模型关注任意距离的Patch之间的关系,提高了模型的灵活性和泛化能力。 © Graph Structure 作用: 非结构化数据建模:将图像中的Patch视为图中的节点,允许模型处理更加复杂和灵活的数据结构。 适应性强:能够更好地适应各种形状和尺寸的对象,尤其是对于那些不能很好地用网格或序列描述的情况。 重要性: 图神经网络优势:结合图神经网络(GNN)的优点,能够有效处理具有复杂拓扑结构的数据,如社交网络、分子结构等。 创新性突破:在视觉任务中引入图结构是一种创新尝试,有望带来新的突破和进展,特别是在需要精细分析和理解场景的情况下。

模型架构

图像输入

首先,从一张原始图像开始。在这个例子中,图像展示了一条鱼和一个人的部分身体。

图结构生成

接下来,将图像划分为若干个Patch,并将这些Patch作为图中的节点。每个节点代表图像的一部分,而边则表示这些部分之间的关联。红色圆圈内的节点可能表示图像的关键部分,比如鱼的身体或者人的衣服图案。

网络模块

然后,进入网络模块,该模块由两部分组成:图处理和特征变换。

图处理

在这一步骤中,模型会对图结构进行处理,以提取出各个Patch之间的关系和相互影响。这可以通过图卷积操作或其他类型的图神经网络技术完成。

特征变换

经过图处理之后,得到的特征会被送入特征变换模块。这里可能会涉及到一些标准的神经网络组件,如全连接层、激活函数等,目的是进一步提炼和转化所获得的信息。

多尺度处理

整个过程会重复多次(L次),每次都会产生一个新的特征图。这样做的好处是可以从不同的层次和角度来观察和理解图像内容,增强模型的表现力。

输出头

最后,所有经过多轮处理后的特征被整合起来,传递给输出头(Head for recognition)。这个输出头负责最终的识别任务,可能是分类、回归或者其他类型的问题。

ViG代码

PatchEmbedding

代码语言:javascript
复制
class Stem(nn.Module):
    def __init__(self, img_size=224, in_dim=3, out_dim=768, act='relu'):
        super().__init__()
        self.convs = nn.Sequential(
            nn.Conv2d(in_dim, out_dim//8, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_dim//8),
            act_layer(act),
            nn.Conv2d(out_dim//8, out_dim//4, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_dim//4),
            act_layer(act),
            nn.Conv2d(out_dim//4, out_dim//2, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_dim//2),
            act_layer(act),
            nn.Conv2d(out_dim//2, out_dim, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_dim),
            act_layer(act),
            nn.Conv2d(out_dim, out_dim, 3, stride=1, padding=1),
            nn.BatchNorm2d(out_dim),
        )

    def forward(self, x):
        x = self.convs(x)
        return x

模型主体架构设计

代码语言:javascript
复制
self.backbone = Seq(*[Seq(Grapher(channels, num_knn[i], 1, conv, act, norm,
                 bias, stochastic, epsilon, 1, drop_path=dpr[i]),
                 FFN(channels, channels * 4, act=act, drop_path=dpr[i])
) for i in range(self.n_blocks)])

核心代码

聚合特征

代码语言:javascript
复制
class MRConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
        super(MRConv2d, self).__init__()
        self.nn = BasicConv([in_channels*2, out_channels], act, norm, bias)

    def forward(self, x, edge_index, y=None):
        print(x.shape, edge_index.shape)
        x_i = batched_index_select(x, edge_index[1])
        print(x_i.shape)
        if y is not None:
            x_j = batched_index_select(y, edge_index[0])
        else:
            x_j = batched_index_select(x, edge_index[0])
            print(x_j.shape)
        x_j, _ = torch.max(x_j - x_i, -1, keepdim=True)
        b, c, n, _ = x.shape
        x = torch.cat([x.unsqueeze(2), x_j.unsqueeze(2)], dim=2).reshape(b, 2 * c, n, _)
        print(x.shape)
        return self.nn(x)

演示效果

附件使用

安装相应依赖包

代码语言:javascript
复制
pip install -r requirements.txt

获取cifa10数据集

代码语言:javascript
复制
import torchvision
import torchvision.transforms as transforms

# transforms用于数据预处理
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 下载并加载训练数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

# 下载并加载测试数据集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

# CIFAR-10数据集中的类别
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

运行代码

代码语言:javascript
复制
python train.py
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2024-12-23,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 图片分类任务方法概述
    • 卷积神经网络(CNN)
    • 视觉Transformer(ViT)
    • 视觉图神经网络(ViG)
  • ViG模型
    • 图片切成patch
    • 模型架构
      • 图像输入
      • 图结构生成
      • 网络模块
      • 图处理
      • 特征变换
      • 多尺度处理
      • 输出头
  • ViG代码
    • PatchEmbedding
    • 模型主体架构设计
    • 核心代码
  • 演示效果
  • 附件使用
    • 安装相应依赖包
    • 获取cifa10数据集
    • 运行代码
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档