知乎:真中合欢 链接:https://www.zhihu.com/question/627258986/answer/3262812950
答案简单,就是匹配显存。
6B模型可以在在12/16/24G显存的消费级显卡部署和训练。如果一个公司的模型不打算在消费级显卡部署,通常不会训6B这个规模。而且通常还会有一个1.4b或者2.8b,这个是比较适合在手机、车载端量化部署的尺寸。
13B模型按照4k长度组织数据,数据并行=2,刚好占满一个8卡机,并且可以量化部署在A10甚至4090。
下一档也不是130B,目前更大模型有16B、34B、52B、56B、65B、70B、100B、130B、170B、220B这几个规模,基本都是刚好占满某种规格的算力,要么是训练要么是推理。如果需要加快训练速度,只需要倍增卡数即可。比如我们训7B模型以8卡为单位8*8卡训,70B模型以80卡为单位80*6卡训。
-------补充回答-------
评论区有朋友问到怎么计算显存占用,这里给个简单的方法:
首先是训练框架,deepspeed和megatron框架的显存占用是不同的。一家公司如果做pretrain,那么pretrain、sft、rlhf三阶段一般都会使用megatron,因为提供的并行选项更丰富,而且megatron用了很多apex的融合算子,计算效率更高一些,更适合大规模训练。如果只做sft,可能会选择deepspeed或者自己在huggingface模型骨架的基础上实现一套框架。我这里用megatron举例。
megatron框架的模型&优化器存储系数是18,也就是模型参数量*18=显存占用。对于13B的模型是13Bx18=234GB显存占用。这个18的来源是2(半精度模型参数) + 4(单精度梯度) + 4(单精度模型参数副本) + 4(单精度一阶动量) + 4(单精度二阶动量)。
注:混合精度训练时,megatron存储的梯度并不是半精度的,而是利用apex算子直接计算单精度的梯度,存储在parameter.main_grad中,所以上面的公式是“单精度梯度"。
在 pipeline并行时,所有这些都会被平分到每张卡上,所以系数18可以整体除以卡数。在zero并行时,半精度模型参数和单精度梯度每张卡都有,后面三个平分到每张卡上,所以系数中只有后面三个4+4+4=12可以除以zero并行数,前面的2+4=6不行。
用13B模型、seqlength=4096作为例子计算一遍,模型&优化器显存占用在megatron框架下是13 x 18 = 234GB。这个至少就用到4卡x80G才能装下。
正向传播中间变量的显存占用可以用40xSHL来近似,也就是40x4096(序列长度)x5120(隐层维度)x40(模型层数) =34G。
如果用zero1数据并行:那么模型&优化器显存占用系数是(6 + 12/zero并行数)。数据并行每张卡都要有完整的一条数据,所以每张卡有34G要拿来存储正向传播的中间变量,可用显存只有80-34=46G, 也就是13 x (6+12/zero并行数) < 46,不等式无解,无论如何也是装不下的。
如果你用了 pipeline并行,那么模型&显存占用系数就是18/pipeline并行数,正向传播的中间变量也平均分配到每张卡上占用就是34G/pipeline并行数,那么可以列出公式(13 x 18/pipeline并行数 + 34 / pipeline并行数) < 80,得到pipeline并行数要大于等于4,也就是4卡能装下了,此时每张卡的显存占用67G,其中模型&优化器占58.5,正向传播占8.5。
但是pipeline并行有一个排除显卡占空泡沫提高效率的操作,在megatron框架中主流是用1f1b交错式并行,有兴趣可以参考这一篇:Infi-zc:Megatron-LM 中的 pipeline 并行
也就是说每张卡在正向传播以后显存不会马上释放这8.5G的显存,会继续计算下一条数据的正向传播,n轮之后,第一条数据的反向传播才会从后面传回来,此时才会释放第一条数据占用的显存。
这个n实际上="pipeline并行数-卡序号+1",所以1卡负载最严重,最多会记录4-1+1=4条数据的正向传播。所以第一张卡用于记录正向传播的显存占用峰值是(8.5 x 4) = 34G,加上58.5G的模型&优化器,第一张卡会爆,这样计算第二、三张卡也会爆,只有第四张卡不会爆。
此时解决方案有两种,一是batch size只设为1,这样相当于关闭了1f1b操作,但是显然batch size等于1不可能,4096的seqlength对应的batch size通常是1024。
那么第二种方案就是再来4张卡,一共8张卡。8张卡分为2组,每组4张卡组成pipeline并行,这两组之间使用zero1并行,此时1卡的显存占不等式为 13 x (6/pipeline并行数 + 12/(pipeline并行数*zero并行数) + 34,也就是13 x (1.5+1.5) + 34 = 73,此时1卡显存占用73G,能装下了。总结一下,就是需要8卡,其中1、2号卡占用73G,3、4号卡占用64G,5、6号卡占用56G,7、8号卡占用47.5G。实际情况每张卡会稍微多一点,但是确实8张卡够用了。
当然这8张卡你可以全都用来 pipeline并行,此时1卡显存占用就是 13 x 18/8 + 34 = 63.25G
除了pipeline并行(pp),还可以开tensor并行(tp),但是tp有两点问题,1是在transformer结构下,tp每两次矩阵乘法就要进行一次tp间的通信,通信量较大。2是norm层是对最后一个维度整体做归一化,所以norm层的参数在每个tp上都必须是完整的,不能切分,会导致每张卡上都存储了重复的norm层,有一部分显存浪费,但好在不多。
tp也有两点好处:1当模型更大时,一张卡都装不下一层时,pp是失效的,此时只能使用tp把单层切到多张卡上。2是tp不存在pp中的挤泡沫问题,如果由于pp的挤泡沫操作导致不同卡之间的显存占用过于不均衡了,可以将pp减少,增加tp。但是tp通信量大,一般不建议跨机,也就是一般不建议把tp开到8以上。