
Transformer架构的优化主要集中在注意力机制的计算效率上,目前主流分为线性注意力(Linear Attention)和稀疏注意力(Sparse Attention)两类方法。
线性注意力通过核函数近似替代Softmax计算,将时间复杂度从O(N²)降低到O(N)。代表性工作如Kimi的Delta Attention,使用指数核函数近似标准注意力机制。
稀疏注意力保留Softmax计算但通过动态选择重要Token减少计算量。例如DeepSeek的DSA(Dynamic Sparse Attention)通过评分函数筛选Top-k个Token进行注意力计算。
采用多项式核函数的线性注意力实现:
import torch
import torch.nn as nn
class LinearAttention(nn.Module):
def __init__(self, dim, heads=8):
super().__init__()
self.scale = (dim // heads) ** -0.5
self.heads = heads
self.to_qkv = nn.Linear(dim, dim*3)
self.to_out = nn.Linear(dim, dim)
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: t.view(t.shape[0], -1, self.heads, t.shape[-1] // self.heads), qkv)
# 使用多项式核近似
q = (q + 1).pow(2) # (b, n, h, d)
k = (k + 1).pow(2) # (b, n, h, d)
# 线性注意力计算
k = k.transpose(-2, -1) # (b, h, d, n)
context = torch.matmul(q, k) * self.scale # (b, h, n, n)
out = torch.matmul(context, v) # (b, h, n, d)
return self.to_out(out.view(x.shape[0], -1, x.shape[-1]))动态稀疏注意力(DSA)的简化实现:
class DynamicSparseAttention(nn.Module):
def __init__(self, dim, heads=8, topk=32):
super().__init__()
self.scale = (dim // heads) ** -0.5
self.heads = heads
self.topk = topk
self.to_qkv = nn.Linear(dim, dim*3)
self.to_out = nn.Linear(dim, dim)
def get_score(self, q, k):
return torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: t.view(t.shape[0], -1, self.heads, t.shape[-1] // self.heads), qkv)
scores = self.get_score(q, k)
# 动态选择Top-k
topk_scores, topk_indices = scores.topk(self.topk, dim=-1)
sparse_attn = torch.zeros_like(scores)
sparse_attn.scatter_(-1, topk_indices, topk_scores.softmax(dim=-1))
out = torch.matmul(sparse_attn, v)
return self.to_out(out.view(x.shape[0], -1, x.shape[-1]))状态空间模型(SSM)的代表性工作Mamba采用选择性状态空间实现线性复杂度:
class MambaBlock(nn.Module):
def __init__(self, dim, expand=2):
super().__init__()
inner_dim = dim * expand
self.in_proj = nn.Linear(dim, inner_dim*2)
self.conv = nn.Conv1d(inner_dim, inner_dim, 3, padding=1)
self.ssm = SSM(inner_dim)
self.out_proj = nn.Linear(inner_dim, dim)
def forward(self, x):
x = self.in_proj(x)
x, gate = x.chunk(2, dim=-1)
x = self.conv(x.transpose(1,2)).transpose(1,2)
x = self.ssm(x) * torch.sigmoid(gate)
return self.out_proj(x)扩散模型在文本生成中的应用示例:
class DiffusionTransformer(nn.Module):
def __init__(self, dim, num_layers):
super().__init__()
self.time_embed = nn.Sequential(
nn.Linear(dim, dim*4),
nn.SiLU(),
nn.Linear(dim*4, dim)
)
self.layers = nn.ModuleList([
TransformerLayer(dim) for _ in range(num_layers)
])
def forward(self, x, timestep):
t = self.time_embed(timestep)
for layer in self.layers:
x = layer(x + t.unsqueeze(1))
return x技术类型 | 时间复杂度 | 计算特点 | 适用序列长度 |
|---|---|---|---|
标准Attention | O(N²) | 全连接注意力矩阵 | 短序列(<1k) |
线性Attention | O(N) | 核函数近似+矩阵分解 | 长序列(>10k) |
稀疏Attention | O(kN) | 基于top-k或滑动窗口选择 | 中长序列 |
SSM架构 | O(N) | 状态空间模型+硬件优化 | 超长序列 |
典型示例:
长序列建模场景:
精确注意力需求场景:
生成式任务场景:
资源受限场景:
对于复杂场景可考虑: