大模型训练中,当一个张量计算无法放到单张GPU上进行时,需要使用张量并行策略,将其拆分到不同的GPU上进行计算。张量如何拆分,才能达到计算和通信效率的平衡,使得训练效率最优。围绕以上问题,本文主要介绍:
1)对张量按行和按列拆分,并以不同聚合方式得到相同的结果。 2)模型架构中前馈网络MLP和注意力模块的拆分方式和通信量计算。 3)编码层和损失层的拆分方式和优化方案。
假设:有输入矩阵X和参数矩阵Y,维度分别为[2,4]和[4,2],两个矩阵相乘得到一个矩阵Z维度为[2,2],如下图。如果使用矩阵分块运算,对于参数矩阵Y,有两种的切分方式:

图1,矩阵分块计算。
将矩阵Y按列分块得到两个矩阵Y1&Y2,矩阵X分别和 Y1&Y2相乘,然后将得到的结果Z1&Z2,按列拼接得到结果矩阵Z。
将矩阵Y按行分块得到两个矩阵Y1&Y2,为了符合矩阵乘法规则,将X按列分块得到X1&X2,分块结果分别对应相乘,得到中间结果Y1&Y2再相加后得到最终结果矩阵Z。
小结:按列拆分特点是,结果按列拼接。按行拆分特点是,对应左乘矩阵需按列拆分,每部分结果相加后即为最终结果。简记为“行拼接,列相加”。
以上两种方式得到的最终结果Z和不分块是一样的。了解了矩阵分块计算原理,对应transformer架构中向量是怎样分块的?可以分为三大部分,前馈网络MLP,自注意力self-attention,编码层损失层。
前馈网络中有参数矩阵A,维度为[h,4h]和参数矩阵B,维度为[4h,h] ,先对输入进行升维后,然后再降低维度。公式如下:
参数矩阵分解方式:对A按列切分,对B按行切分。
原因:尽量保证各GPU上的计算相互独立,减少通讯量。GELU非线性性质,限制了参数矩阵A必须按照列切分,GELU后的结果不用进行同步。GELU(Y) 不等于 GELU(Y1)+GELU(Y2)

图2,MLP分块计算。
通信量计算
算子 f 操作功能,将X矩阵复制到每块GPU上,进行前向计算,无需通信。
算子 g 操作功能,将得到的前向结果 Zi 进行AIl-Reduce通信求和操作,结果相加产生Z,通信量为2M。
算子 g 操作功能,将损失对Z的梯度信息,分发到每个GPU上,每个GPU即可进行反向传播,无需通信。
算子 f 操作功能,当前层的梯度计算完毕,需要传递到下一层继续做梯度计算,需要求得 dL/dX,此时GPU之间做一次AIl-Reduce,把各自的梯度dL/dX 相加即可。通信量为2M。
综上:前向和反向共有2次All-Reduce操作,总的通信量为4M(m为参数量个数 bsh)。

图3,自注意力层拆分方式。
参数矩阵分解方式:对三个参数矩阵Q,K,V,按照”列切割”,每个头放到一块GPU上,做并行计算。对线性层B,按照”行切割”
注意:在实际应用中,并不一定按照一个head占用一块GPU来切割权重,我们也可以一个多个head占用一块GPU,这依然不会改变单块GPU上独立计算的目的。所以实际设计时,我们尽量保证head总数能被GPU个数整除。
通信量计算:类比于MLP层,self attention层在forward中做一次AlI-Reduce,在backward中做一次AlI-Reduce,总通讯是也是4M(M为bsh个数)。

图4,注意力分块计算。
输入编码需要两部分:
位置编码一般为[seq, h],seq为输入的最大长度,远小于字典的个数V,每个GPU保存一分位置编码是可以接受的。下面主要对字典编码进行分解。

图5,字典编码拆分方式。
字典编码维度[v,h],按照字典个数行进行切分,每个GPU保存一部分,输入的X和每个GPU上的字典编码相乘,如果需要的索引不在,就将其置0,最后所有GPU上得到的位置编码,做reduce 通信相加后,即可得到完整的编码向量。
一般流程:
最后输出层,得到分块后的结果Zi,与对应的wording embedding 相乘后,再进行 All-reduce 操作,拼接起来合成最后的输出,再与label计算交叉熵,即可得到最后的损失。
注意:为了确保word embedding的更新是正确的,我们需要将这两个梯度相加,然后用总和来更新word embedding。
原因:由于输入层和输出层共用同一套word embedding,因此在反向传播时,word embedding的梯度会被计算两次(一次在输入层,一次在输出层)。
一般流程:
正常来说输出层,需要对Y1和Y2做一次AIl-Gather,把它们concat起来形成Y,然后对Y的每一行做softmax,就可得到对于当前位置每个词出现的概率,再用此概率和真值组做cross-entropy 即可。
问题:对Y1和Y2做一次AIl-Gather,会产生额外的通讯量 bsV。当词表V很大时,这个通讯开销也不容忽视。
优化方案:
优化结果:通讯量从 b∗s∗v 降至 b∗s+N 。

图6,交叉熵通信优化。

图7,张量并行通信量示意图。
模型每层前向需要 2 次 AlI-Reduce,反向需要 2次 All-Reduce。还有在每个batch开始和结束时,embedding层和损失层的通信。
1D 张量井行局限:
以上介绍的方式按照行或是列进行切分,为1D方式,其存在几点的局限性:
为了克服以上的局限,发展出了 2D,2.5D,3D 的切分方式,总体思想是:利用矩阵性质,对输入和参数在多个维度进行切分,进而达到精简显存,降低通信需求的目的。具体可以参考以下文献。
参考:
[1] arXiv:1909.08053v4 [2] arXiv:2105.14500 [3] arXiv:2104.05343