前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >MInference:通过动态稀疏Attention加速长文本推理

MInference:通过动态稀疏Attention加速长文本推理

原创
作者头像
aaronwjzhao
发布2024-07-11 11:10:33
5421
发布2024-07-11 11:10:33
举报
文章被收录于专栏:AI工程落地

论文

由于注意力机制的二次复杂度,举例来说(如图1 a 所示),在单台装有 A100 的机器上为 LLaMA-3-8B 提供服务时,如果提示有 30 万个 token,模型需要 6 分钟才能完成预填充( pre-filling)阶段,如果提示增加到 100 万个 token,这个数字将增加到 30 分钟。自注意力计算的开销占到了总预填充延迟的 90% 以上,这使其成为 LLM 处理长上下文时的主要瓶颈。

图 1 Attention特性介绍
图 1 Attention特性介绍

注意力,特别是在长上下文中,是稀疏和动态的,即在不同的输入中,稀疏模式有很大的不同,如图1 b和1 c。这种动态稀疏性呈现出三种适用于所有输入的独特空间聚合模式:A 形(A-shape)、垂直 - 斜线(Vertical-Slash)和块状 - 稀疏(Block-Sparse)。如下图2

图 2 三种稀疏模式
图 2 三种稀疏模式

MInference 首先使用内核感知稀疏模式搜索算法为每个头部离线确定最佳动态稀疏模式,并为每种稀疏模式设计了对应的attention计算过程,如下图 3

图 3 三种稀疏Attention实现算法
图 3 三种稀疏Attention实现算法

对于「垂直 - 斜线」模式,作者首先利用最后一个 Q 和 K 之间的注意力计算来估计垂直线和斜线的最佳指数。然后,他们利用动态稀疏编译器 PIT 和 Triton 构建垂直 - 斜线 FlashAttention 内核,加速注意力计算。

对于块状 - 稀疏模式,作者首先在注意力计算中使用 Q 和 K 的均值池。利用均值池和 MatMul 的交换属性,可以估算出块状 - 稀疏指数。然后,他们使用 Triton 构建块稀疏 FlashAttention 内核,加速注意力计算。

代码

整个model使用MInference推理(仅限MInference支持的模型):

from transformers import pipeline +from minference import MInference pipe = pipeline("text-generation", model=model_name, torch_dtype="auto", device_map="auto") # Patch MInference Module +minference_patch = MInference("minference", model_name) +pipe.model = minference_patch(pipe.model) pipe(prompt, max_length=10)

分别使用单个算法:

from minference import vertical_slash_sparse_attention, block_sparse_attention, streaming_forward attn_output = vertical_slash_sparse_attention(q, k, v, vertical_topk, slash) attn_output = block_sparse_attention(q, k, v, topk) attn_output = streaming_forward(q, k, v, init_num, local_window_num)

稀疏模式搜索+执行相应稀疏算法:

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 论文
  • 代码
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档