前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >麻省理工(MIT) | 提出跨层Attention,减少Transformer大模型键值(KV)缓存,加快LLM推理!

麻省理工(MIT) | 提出跨层Attention,减少Transformer大模型键值(KV)缓存,加快LLM推理!

作者头像
ShuYini
发布2024-05-28 16:32:19
2180
发布2024-05-28 16:32:19
举报

引言

键值 (KV) 缓存能够显著提升Transformer大模型的解码速度。但是当面对长序列的时候,键值 (KV) 缓存需要大量的内存资源。当前减少键值 (KV) 缓存的两个主要方法分别为:Multi-Query Attention(MQA)和Grouped-Query Attention (GQA)。这两种方法主要是修改了Attention块,使得多头请求头共享单个KV头,从而大大减少了不同KV的数量。

而本文作者受前人启发,提出了一种新的Attention设计方法:跨层注意力(Cross-Layer Attention, CLA),即通过在不同层之间共享KV头,减少了KV缓存的大小。对比其它方法,在相同准确性的情况下,可以将KV缓存的大小缩小2倍!

https://arxiv.org/pdf/2405.12981

背景介绍

大模型实际部署应用的时候,键值 (KV) 缓存的内存可能会成为其应用的瓶颈。由于 KV 缓存的大小与序列长度以及批处理大小存在比例关系,因此在长序列长度上操作时,KV 缓存内存的大小可能会限制批处理大小,并且当面对设备内存不足时,也有会采用成本较高的方法,例如:offloading策略。但为了提升推理速度减少冗余计算,人们也非常希望能够长期保存键值 (KV) 。然而,KV 缓存的大小直接决定了存储和检索此类持久缓存的成本

随着LLM 新应用的出现,需要更长的序列长度,KV 缓存的内存占用的挑战越来越受到研究人员的关注。并且当前研究人员也提出了多种减少KV缓存内存占用的方法,例如:采用低精度来缓存KV、驱逐不重要的KV缓存条目以及跨请求头共享KV等。

与之前方法不同,本文提出了一种新的方法:跨层Attention(Cross-Layer Attention,CLA),简单来说,该方法主要是通过减少KV缓存中唯一层的数量来减小KV缓存的大小

MQA and GQA

在介绍跨层Attention之前,先带大家简单的了解一下多头Attention(MHA)多请求Attention(MQA)分组请求Attention(GQA)

最初的Transformer架构主要是用多头Attention(MHA),其中每个请求头主要关注不同KV头生成的KV。在MHA中,每个KV头的KV激活必须单独存储在KV缓存中,这对于每个token来说,它的存储开销为:

2\ast n_{query}\ast d_{head}

,其中

n_{query}

表示每个请求头数量,

d_{head}

表示每个头的嵌入维度。

为了减少减少 Transformer 解码期间存储和访问 KV 缓存相关的开销,有研究人员提出了多请求Attention,并逐渐的将其推广至分组请求Attention。分组查询Attention通过将每个Attention层的请求头编制成组来修改Transformer架构,其中每组请求头共享单个KV头。由于 KV 缓存的大小仅随着不同KV头的数量而变化,而不是请求头的数量,因此 GQA 将 KV 缓存的存储开销降低到

2\ast n_{group}\ast d_{head}

,其中

n_{group}

表示 GQA的组数,且很明显:

n_{group} < n_{query}

。另外,MQA 可以看作是 GQA 的特例,其中

n_{group} = 1

研究发现,与具有相同头尺寸的 MHA 架构相比,MQA 和 GQA 能够显着减少 KV 缓存大小和Transformer解码延迟,但精度会有略微下降。所以在模型设计过程中,需要平衡Attention架构的准确性和KV缓存大小之间的关系

跨层Attention

受MQA 和 GQA的启发,本文作者提出了跨层共享KV头,并将这种Attention架构称为:跨层Attention(CLA),如下图所示:

可以看到在CLA中,只有模型中的一部分层会计算KV投影,而没有计算KV投影的层的Attention块会重新使用之前层的KV激活值。这意味着只有计算了KV投影的那些层会使用KV缓存,从而与传统架构相比,后者在每一层都应用了独立的KV投影,对比之下,CLA能够减少对内存的使用。

除此之外,CLA可以与MQA、GQA、MHA 进行组合。此外,与 GQA 允许不同的

n_{group}

访问一系列不同的Attention配置一样,CLA 可以改变共享每个 KV 投影输出的层数,作者将其称为共享因子。通过共享因子来引用 CLA 的不同配置,从而产生了 CLA2,它在一对相邻层之间共享每个 KV 投影,CLA3,它在一组 3 层之间共享每个 KV 投影,依此类推。如下图所示:

另外,作者还在系统工程的角度总结了 CLA 对相关关键指标的影响:

  • KV 缓存内存:CLA 显着减少了 KV 缓存内存占用量,减少的倍数等于共享因子
  • 训练内存占用:CLA 减少了训练期间具体化的中间 KV 激活张量的内存占用,尽管对于 GQA 和 MQA 模型,此类 KV 张量与模型的隐藏状态和 MLP 激活相比通常很小。
  • 模型并行性:CLA 与标准完全兼容 张量并行技术,可用于跨多个加速器分片模型权重。
  • 参数和FLOP:由于CLA 减少了模型中KV投影块的总数,因此CLA 略微减少了模型中参数的数量以及前向或后向传递期间所需的FLOP 数量。
  • 解码延迟:在完整的LLM 服务堆栈的背景下,CLA 可以实现比其他方式更大的批量大小和更长的KV 缓存持久时间,可以减少推理延迟。
  • 核心Attention延迟:与MQA和GQA不同,CLA对每个解码步骤中Attention机制消耗的内存带宽没有直接影响。

实验结果

下图展示了CLA在准确性/内存权衡上的影响。可以看到在1B和3B参数规模的模型上,CLA结合MQA相比于单纯的MQA基线,KV缓存所需内存缩小了2倍,同时仅造成微小的困惑度(perplexity)增加。同时作者也展示了不同CLA共享模式的实验结果,可以发现CLA2在性能上一致优于其他配置。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2024-05-22,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AINLPer 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 引言
  • 背景介绍
  • MQA and GQA
  • 跨层Attention
  • 实验结果
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档