Transformer架构因其强大的通用性而备受瞩目,它能够处理文本、图像或任何类型的数据及其组合。其核心的“Attention”机制通过计算序列中每个token之间的自相似性,从而实现对各种类型数据的总结和生成。在Vision Transformer中,图像首先被分解为正方形图像块,然后将这些图像块展平为单个向量嵌入。这些嵌入可以被视为与文本嵌入(或任何其他嵌入)完全相同,甚至可以与其他数据类型进行连接。通常图像块的创建步骤会与使用2D卷积的第一个可学习的非线性变换相结合,这对于初学者来说可能比较难以理解,所以本文将深入探讨这一过程。
为了简单起见,本文使用MNIST数据集,这是一个手写数字的集合,常用于训练基本的图像分类器。MNIST图像在PyTorch中可以直接获取,并且可以使用 DataLoader 类方便地加载:
from torchvision.datasets.mnist import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch
torch.manual_seed(42)
img_size = (32,32) # We will resize MNIST images to this size
batch_size = 4
transform = T.Compose([
T.ToTensor(),
T.Resize(img_size)
])
train_set = MNIST(
root="./../datasets", train=True, download=True, transform=transform
)
train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
batch = next(iter(train_loader)) # loads the first batch
上述代码首先下载MNIST数据集,然后定义一个PyTorch变换,该变换将图像转换为PyTorch张量并将其大小调整为32x32。接着,使用 DataLoader 类加载一个大小为batch_size=4的图像批次。 torch.manual_seed 函数用于将随机数生成器初始化为相同的值,以确保读者在自己的 notebook 中能够看到与本文中相同的图像。有关PyTorch的 DataSet 和 DataLoader 类的更多信息,请参考以下链接:
可以使用matplotlib可视化该批次,其中包含四个图像和对应的四个标签:
import matplotlib.pyplot as plt
# batch[0] contains the images and batch[1] the labels
images = batch[0]
labels = batch[1]
# Create a figure and axes for the subplots
fig, axes = plt.subplots(1, batch_size, figsize=(12, 4))
# Iterate through the batch of images and labels
for i in range(batch_size):
# Convert the image tensor to a NumPy array and remove the channel dimension if it's a grayscale image
image_np = images[i].numpy().squeeze()
# Display the image in the corresponding subplot
axes[i].imshow(image_np, cmap='gray') # Use 'gray' cmap for grayscale images
axes[i].set_title(f"Class: {labels[i].item()}") # Assuming labels are tensors, use .item() to get the value
axes[i].axis('off')
# Adjust the spacing between subplots
plt.tight_layout()
# Display the plot
plt.show()
使用Transformer神经网络处理图像的第一步是将其分解为图像块。例如,可以将32x32的图像分解为64个4x4的图像块(每个块包含16个像素)、16个8x8的图像块(每个块包含64个像素)或4个16x16的图像块(每个块包含256个像素):
虽然我们以二维形式展示这些图像块,但也可以将它们存储在维度分别为16、64或256的列向量中。这些向量嵌入与文本嵌入已经没有本质区别,它们的序列可以被视为与字符串或单词的序列相同。有关Transformer架构的更多信息,可以参考以下链接,其中使用文本嵌入作为示例进行了详细讲解:
以下是使用PyTorch的 unfold 算子分解图像的代码:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
# Image and patch sizes
img_size = (32, 32)
patch_size = (8, 8)
n_channels = 1
image = batch[0][1].unsqueeze(0)
# Patch Class
class Patch(nn.Module):
def __init__(self, img_size, patch_size, n_channels):
super().__init__()
self.patch_size = patch_size
self.n_channels = n_channels
def forward(self, x): # B x C x H X W
x = x.unfold(
2, self.patch_size[0], self.patch_size[0]
).unfold(
3,self.patch_size[1],self.patch_size[1]
) # (B, C, P_row, P_col, P_height, P_width)
x = x.flatten(2) #(B, C, P_row*P_col*P_height*P_width)
x = x.transpose(1, 2) # (B, P_row*P_col*P_height*P_width, C)
return x
# Instantiate model
patch = Patch(img_size, patch_size, n_channels)
# Extract patches
with torch.no_grad():
patches = patch(image)
# Visualize
patches = patches.squeeze(0) # Remove batch dimension -> (P, d_model)
patches = patches.view(-1, patch_size[0], patch_size[1]) # reshape back into 8x8
npatches = img_size[0] // patch_size[0]
# Plot patches
fig, axs = plt.subplots(npatches, npatches, figsize=(6, 6)) # 4x4 grid for (32x32) -> 16 patches
for i in range(npatches):
for j in range(npatches):
patch_idx = i * npatches + j # Patch index
axs[i, j].imshow(patches[patch_idx], cmap="gray", vmin=0, vmax=1)
axs[i, j].axis("off")
plt.show()
如代码所示,核心操作发生在 Patch 类的 forward 方法中。该类继承自 nn.Module ,其 forward 方法首先沿高度维度进行展开,然后再沿宽度维度进行展开。代码注释中展示了每一步操作后张量的维度,其中B代表批次大小,C代表通道数(在本例中为1),H代表高度,W代表宽度。展开操作之后,从存储图像数据的第二个维度开始展平张量,最后转置张量,以便颜色通道位于最后一个维度。
代码的剩余部分用于实例化 Patch 类,转换图像并将其可视化。需要注意的是,在可视化之前,需要先删除批次维度,然后将一维的图像数据转换回二维张量,才能正确显示图像块。
上述方法在某种程度上将嵌入维度限制为原始图像尺寸的倍数。为了打破这个限制,可以在展开操作之后添加一个线性投影层,从而创建一个可学习的嵌入。
为了便于可视化,这些嵌入被转换回二维张量,从而展示了线性投影层如何对图像块进行操作。使用单位矩阵作为 nn.Linear 类的权重初始化,表明原始数据得以保留。使用随机权重,可以看到图像中具有零值的部分保持不变。最后,添加一个偏差项表明该变换确实平等地影响了每个图像块——所有空白图像块都显示出完全相同的偏差。
以下是新的 PatchEmbedding 类及其实例化代码。注意,这里引入了一个新的变量 d_model ,它代表期望的输出嵌入的维度。 d_model 可以是任意数值。这里选择 d_model=64 是为了与上面图像的设置保持一致,但实际上不再有任何限制。
class PatchEmbedding(nn.Module):
def __init__(self, img_size, patch_size, n_channels, d_model):
super().__init__()
self.patch_size = patch_size
self.n_channels = n_channels
self.d_model = d_model
# Linear projection layer to map each patch to d_model
self.linear_proj = nn.Linear(patch_size[0] * patch_size[1] * n_channels, d_model,bias=False)
# The next two lines are unnecessary, but help to visualize that the linear
# projection operates along the correct dimensions
#with torch.no_grad():
# self.linear_proj.weight.copy_(torch.eye(self.linear_proj.weight.shape[0]))
def forward(self, x): # B x C x H X W
x = x.unfold(
2, self.patch_size[0], self.patch_size[0]
).unfold(
3,self.patch_size[1],self.patch_size[1]
) # (B, C, P_row, P_col, P_height, P_width)
B, C, P_row, P_col, P_height, P_width = x.shape
x = x.reshape(B,C,P_row*P_col,P_height*P_width)
x = self.linear_proj(x) # (B*N, d_model)
x = x.flatten(2) #(B, C, P_row*P_col*P_height*P_width)
x = x.transpose(1, 2) # (B, P_row*P_col*P_height*P_width, C)
x = x.view(B, -1, self.d_model)
return x
d_model = 64
# Instantiate model
patch = PatchEmbedding(img_size, patch_size, n_channels, d_model)
只要维度是二次的,我们仍然可以可视化结果,下图展示了 d_model=4 和 d_model=2500 时的输出:
可以看到,非线性变换(一个全连接的神经网络,它接受从8x8 (64)到 d_model 的输入)可以包含相当多的可学习参数,从左侧的64x4(256)到右侧的64x2500(160k)。可以使用以下代码自行测试:
def count_parameters(model):count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_parameters(patch)
unfold 算子使用起来比较繁琐。实际上有一种更简单的方法可以将展开和线性变换结合起来,那就是使用2D卷积,并设置卷积核大小和步长长度与期望的图像块大小相对应。这样卷积操作将不再逐像素进行,而是逐图像块进行,从而产生与组合使用 unfold 和 nn.Linear 相同的结果:
以下是修改后的 PatchEmbedding 类:
class PatchEmbedding(nn.Module):
def __init__(self, img_size, patch_size, n_channels, d_model):
super().__init__()
self.patch_size = patch_size
self.n_channels = n_channels
self.d_model = d_model # Flattened patch size
# Conv2d to extract patches
self.linear_project = nn.Conv2d(
in_channels=n_channels,
out_channels=self.d_model, # Each patch is flattened to d_model
kernel_size=patch_size,
stride=patch_size,
bias=False
)
def forward(self, x):
x = self.linear_project(x) # (B, d_model, P_row, P_col)
x = x.flatten(2) # (B, d_model, P_row * P_col) -> (B, d_model, P)
x = x.transpose(1, 2) # (B, P, d_model)
return x
可以将上述任何一种图像块嵌入方法提供给Vision Transformer。使用2D卷积进行操作是最通用又是最紧凑的表示形式。需要注意的是,卷积操作为每个维度使用一个专用的卷积核,而到目前为止,我们一直在为每个图像块使用相同的卷积核。
可以通过初始化卷积核权重来演示这一点,并测试卷积操作是否执行任何有趣的操作,例如让每个卷积核仅提取每个图像块的单个像素。以下代码适用于图像块大小为(8,8)且生成的 d_model=64 的情况。将其添加到 PatchEmbedding 类的 __init__ 方法的末尾:
"""Initialize Conv2d to extract patches without transformation."""
with torch.no_grad():
identity_kernel = torch.zeros(
self.d_model, self.n_channels, *self.patch_size
) # Shape: (64, 1, 8, 8)
for i in range(self.d_model):
row = i // self.patch_size[1] # Row index in the patch
col = i % self.patch_size[1] # Column index in the patch
identity_kernel[i, 0, row, col] = 1 # Place a 1 at the correct pixel position
self.linear_project.weight.copy_(identity_kernel)
如代码所示, identity_kernel 张量维护 d_model 个条目,每个维度一个,并且每个图像块只有一个像素设置为1,从而仅提取该像素。一种更简单的方法是将d_model x d_model的单位矩阵简单地转换为 patch_size 的 d_model 矩阵:
identity_matrix = torch.eye(self.d_model)
identity_kernel = identity_matrix.view(d_model, 1, *patch_size) # Shape: (64, 1, 8, 8)
with torch.no_grad():
self.linear_project.weight.copy_(identity_kernel)
两种方法都具有相同的结果,但第一种方法更清楚地说明了实际发生的情况:每个卷积核都是一个零矩阵,只有一个条目为1。
无论使用线性变换还是小卷积核的集合,两者都具有相同数量的参数。可以通过检查两个图像块嵌入的数据结构来看到这一点:
PatchEmbedding(
(linear_proj): Linear(in_features=64, out_features=64, bias=False)
)
和
PatchEmbedding(
(linear_project): Conv2d(1, 64, kernel_size=(8, 8), stride=(8, 8), bias=False)
)
其中一个只是一个64x64矩阵(4096个参数)。另一个由64个8x8矩阵组成,也由4096个参数组成。
Coovally AI模型训练与应用平台,它整合了30+国内外开源社区1000+模型算法。无论是最新的YOLOv12模型还是Transformer系列视觉模型算法,平台全部包含,均可一键下载助力实验研究与产业应用。
在Coovally平台上,无需配置环境、修改配置文件等繁琐操作,可一键另存为我的模型,上传数据集,即可使用ViT等热门模型进行训练与结果预测,全程高速零代码!而且模型还可分享与下载,满足你的实验研究与产业应用。
本文深入探讨了如何在Vision Transformer (ViT)架构中处理图像,包括图像的创建与嵌入过程。通过MNIST数据集的实例,介绍了如何使用PyTorch进行图像分割、图像块分层、以及通过线性投影和2D波形层理解。通过示例代码和嵌入详细讲解,读者能够更清晰地显示视觉块Transformer在任务中的应用,特别是在图像处理中的创新技术。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。