前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >Swin-T图像论文复现

Swin-T图像论文复现

作者头像
Srlua
发布2024-11-30 10:55:00
发布2024-11-30 10:55:00
13000
代码可运行
举报
文章被收录于专栏:CSDN社区搬运CSDN社区搬运
运行总次数:0
代码可运行

概述

在计算机视觉领域,卷积神经网络(CNN)一直是构建模型的主流选择。自从AlexNet在ImageNet竞赛中取得了突破性的成绩后,CNN的结构不断演进,变得更庞大、更深入、更多样化。与此同时,自然语言处理领域的网络架构发展则呈现不同的轨迹,目前最流行的是Transformer模型。这种模型专为处理序列数据和转换任务而设计,以其能够捕捉数据中的长距离依赖关系而著称。Transformer在语言处理方面的显著成就激发了研究者探索其在计算机视觉领域的应用潜力,近期的研究表明,它在图像分类、目标检测、图像分割等任务上已经取得了令人鼓舞的成果。 实验得到该模型在图像分类、图像检测、目标检测有很好的效果。

Image Name
Image Name

上表列出了从 224^2 到 384^2 不同输入图像大小的 Swin Transformer 的性能。通常,输入分辨率越大,top-1 精度越高,但推理速度越慢。

Swin Transformer模型原理

1. Swin Transformer模型框架
Image Name
Image Name

首先,我们将图像送入一个称为Patch Partition的模块,该模块负责将图像分割成小块。然后就是通过四个Stage构建不同大小的特征图,除了Stage1中先通过一个Linear Embeding层外,剩下三个stage都是先通过一个Patch Merging层进行下采样。

最后对于分类网络,后面还会接上一个Layer Norm层、全局池化层以及全连接层得到最终输出。

2. W-MSA详解

引入Windows Multi-head Self-Attention(W-MSA)模块是为了减少计算量。如下图所示,对于feature map中的每个像素在Self-Attention计算过程中需要和所有的像素去计算。在使用Windows Multi-head Self-Attention(W-MSA)模块时,首先将feature map按照MxM划分成一个个Windows,然后单独对每个Windows内部进行Self-Attention。

3. SW-MSA详解

采用W-MSA模块时,只会在每个窗口内进行自注意力计算,所以窗口与窗口之间是无法进行信息传递的。为了解决这个问题,作者引入了SW-MSA模块,即进行偏移的W-MSA。根据左右两幅图对比能够发现窗口(Windows)发生了偏移(可以理解成窗口从左上角分别向右侧和下方各偏移了⌊ M/2 ⌋ 个像素)。比如,第二行第二列的4x4的窗口,他能够使第L层的四个窗口信息进行交流,其他的同理。那么这就解决了不同窗口之间无法进行信息交流的问题。

环境配置

复现Swin Transformer需要首先准备pytorch环境。

安装必要的Python依赖:
代码语言:javascript
代码运行次数:0
复制
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/
数据准备

下载好数据集,代码中默认使用的是花分类数据集。当然也可以使用自定义的图像数据集,只要更改分类的数目和参数即可。需要确保数据集目录结构正确,以便Swin Transformer能正确读取数。 以下推荐的数据集文件目录:

代码语言:javascript
代码运行次数:0
复制
├── flower_photos
│   ├── daisy
│   ├── sunflowers
│   └── tulips
├── weights
│   ├── model-0.pth
│   ├── model-1.pth
│   └── model-2.pth
├── pre_weights
│   ├── swin_large_patch4_window7_224_22k.pth
│   └── swin_tiny_patch4_window7_224.pth
├── labels
│   ├── train2017
│   └── val2017
├── class_indices.json
├── record.txt
└── requeirments.txt

部分核心代码

代码语言:javascript
代码运行次数:0
复制
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
            attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

训练过程

此处可调整分类图像任务的种类数目,训练轮数,batch_size,训练图像,预训练模型等参数。

代码语言:javascript
代码运行次数:0
复制
num_classes = 5
epochs = 10
batch_size = 8
lr = 0.0001
data_path = "flower_photos"  # 修改为你的数据集路径
weights = './pre_weights/swin_tiny_patch4_window7_224.pth'
freeze_layers = False

通过8个线程进行模型训练,训练10轮因为数据集较大,耗时比较长有2个小时。查看结果发现只进行了几轮图像分类准确率在90%以上,效果较好:

本人用cpu跑的,最好用cuda跑。 输出的结果在weights中。 挑选准确最高,损失最小的模型model-x.pth进行消融实验即可。

测试和评估

采用model-9.pth模型进行蒲公英的图像分类预测,结果如下所示

这里是用花卉的数据集进行模型训练,可以自定义选择图像数据集进行训练。

混淆矩阵

查看图像分类的混淆矩阵,可以看出效果还是不错的:

参考论文:

​​

希望对你有帮助!加油!

若您认为本文内容有益,请不吝赐予赞同并订阅,以便持续接收有价值的信息。衷心感谢您的关注和支持!

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2024-11-28,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 概述
    • Swin Transformer模型原理
      • 1. Swin Transformer模型框架
      • 2. W-MSA详解
      • 3. SW-MSA详解
    • 环境配置
      • 安装必要的Python依赖:
      • 数据准备
    • 部分核心代码
    • 训练过程
    • 测试和评估
    • 混淆矩阵
  • 参考论文:
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档