首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >使用SwinTransformer进行图片分类

使用SwinTransformer进行图片分类

作者头像
lyhue1991
发布2023-09-05 13:32:58
发布2023-09-05 13:32:58
7830
举报

SwinTransformer 是微软亚洲研究院在2021年提出的适用于CV领域的一种基于Tranformer的backbone结构。

它是 Shift Window Transformer 的缩写,主要创新点如下。

  • 1,分Window进行Transformer计算,将自注意力计算量从输入尺寸的平方量级降低为线性量级。
  • 2,使用Shift Window 即窗格偏移技术 来 融合不同窗格之间的信息。(SW-MSA)
  • 3,使用类似七巧板拼图技巧 和Mask 技巧 来对 Window偏移后不同大小的窗格进行注意力计算以提升计算效率。
  • 4,在经典的QKV注意力公式中引入 Relative Position Bias 项来非常自然地表达位置信息的影响。
  • 5,使用Patch Merging技巧来 实现特征图的下采样,作用类似池化操作但不易丢失信息。
  • 6,使用不同大小的Window提取不同层次的特征并进行融合。

SwinTransformer虽然采用了Transformer的实现方法,但在整体设计上借鉴了非常多卷积的设计特点。

如:局域性,平移不变性,特征图逐渐减小,通道数逐渐增加,多尺度特征融合等。

同时它还应用了非常多的trick来弥补Transformer的不足,如效率问题,位置信息表达不充分等。

B站上有UP主说SwinTransformer是披着Transformer皮的CNN。但毕竟它的主要内在计算是Transformer,所以我感觉它更像是叠加了卷积Buff的Transformer。

SwinTransformer这个backbone结构表达能力非常强,同时适用性广泛,可适用于图片分类,分割,检测等多种任务,而且结构设计和实验工作都做得比较touch,所以被评为了2021年的ICCV best paper.

下面的范例我们微调 timm库中的 SwinTransformer模型来 做一个猫狗图片分类任务。

公众号算法美食屋后台回复关键词:torchkeras,获取本文notebook源码和数据集下载链接。

代码语言:javascript
复制
#!pip install -U  timm, torchkeras 

〇,预训练模型

代码语言:javascript
复制
import timm 
from urllib.request import urlopen
from PIL import Image
import timm
import torch 

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))
img 
代码语言:javascript
复制
model = timm.create_model("swin_base_patch4_window7_224.ms_in22k_ft_in1k", pretrained=True)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))  # unsqueeze single image into batch of 1

top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1), k=5)

代码语言:javascript
复制


代码语言:javascript
复制
info = timm.data.ImageNetInfo()
class_codes = info.__dict__['_synsets']
class_names = [info.__dict__['_lemmas'][x] for x in class_codes]

代码语言:javascript
复制
{class_names[i]:v for i,v in zip(top5_class_indices.tolist()[0],
                                top5_probabilities.tolist()[0])}
代码语言:javascript
复制
{'espresso': 0.1655443161725998,
 'cup': 0.12100766599178314,
 'chocolate sauce, chocolate syrup': 0.11809349805116653,
 'eggnog': 0.06144588068127632,
 'tray': 0.03965265676379204}
识别出来的主要是 espresso(蒸馏咖啡),cup 啥的,跟图片差不多,么得问题。

一,准备数据

代码语言:javascript
复制
import torch
import os 

代码语言:javascript
复制
data_path = './datasets/cats_vs_dogs'

train_cats = os.listdir(os.path.join(data_path,"train","cats"))
img = Image.open(os.path.join(os.path.join(data_path,"train","cats",train_cats[0])))
img 

代码语言:javascript
复制
train_dogs = os.listdir(os.path.join(data_path,"train","dogs"))
img = Image.open(os.path.join(os.path.join(data_path,"train","dogs",train_dogs[0])))
img 

代码语言:javascript
复制
from torchvision.datasets import ImageFolder


ds_train = ImageFolder(os.path.join(data_path,"train"),transforms)

ds_val = ImageFolder(os.path.join(data_path,"val"),transforms)


dl_train = torch.utils.data.DataLoader(ds_train, batch_size=4 ,
                                             shuffle=True)
dl_val = torch.utils.data.DataLoader(ds_val, batch_size=2,
                                             shuffle=True)

class_names = ds_train.classes

print(len(ds_train))
print(len(ds_val))

代码语言:javascript
复制
2000
995
代码语言:javascript
复制
for batch in dl_val:
    break 
    
代码语言:javascript
复制
batch[1]
代码语言:javascript
复制
tensor([0, 1])

二,定义模型

代码语言:javascript
复制
model.reset_classifier(num_classes=2)

代码语言:javascript
复制
model(batch[0])
代码语言:javascript
复制
tensor([[ 0.1698, -0.3366],
        [ 0.4805,  0.1415]], grad_fn=<AddmmBackward0>)
代码语言:javascript
复制
model.cuda();

三,训练模型

代码语言:javascript
复制
from torchkeras import KerasModel 
from torchmetrics import Accuracy

loss_fn = torch.nn.CrossEntropyLoss()
metrics_dict = {"acc":Accuracy(task='multiclass',num_classes=2)}

optimizer = torch.optim.Adam(model.parameters(),
                            lr=1e-5)

keras_model = KerasModel(model,
                   loss_fn = loss_fn,
                   metrics_dict= metrics_dict,
                   optimizer = optimizer
                  )     

代码语言:javascript
复制
features,labels = batch 
代码语言:javascript
复制
loss_fn(model(features.cuda()),labels.cuda())
代码语言:javascript
复制
tensor(0.6743, device='cuda:0', grad_fn=<NllLossBackward0>)
代码语言:javascript
复制
dfhistory= keras_model.fit(train_data=dl_train, 
                    val_data=dl_val, 
                    epochs=100, 
                    ckpt_path='checkpoint.pt',
                    patience=10, 
                    monitor="val_acc",
                    mode="max",
                    mixed_precision='no',
                    plot = True,
                    quiet=True
                   )

可以看到SwinTransformer的拟合能力非常逆天,在这个简单的数据集上,finetune两个Epoch就直接把训练集上的Acc打到了100%,最后的验证集结果也是高达99.8%,非常强大~

四,评估模型

代码语言:javascript
复制
keras_model.evaluate(dl_val)

五,使用模型

代码语言:javascript
复制
from PIL import Image 
img = Image.open('./datasets/cats_vs_dogs/val/dogs/dog.2005.jpg')
代码语言:javascript
复制
model.eval();
代码语言:javascript
复制
model(transforms(img)[None,...].cuda()).softmax(axis=1)
代码语言:javascript
复制
tensor([[1.1537e-04, 9.9988e-01]], device='cuda:0', grad_fn=<SoftmaxBackward0>)

六,保存模型

代码语言:javascript
复制
torch.save(model.state_dict(),'swin_transformer.pt') 
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-06-13,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 算法美食屋 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 〇,预训练模型
  • 一,准备数据
  • 二,定义模型
  • 三,训练模型
  • 四,评估模型
  • 五,使用模型
  • 六,保存模型
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档