前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >大语言模型 MOE 简明实现指南

大语言模型 MOE 简明实现指南

作者头像
ApacheCN_飞龙
发布2024-06-22 08:49:13
910
发布2024-06-22 08:49:13
举报
文章被收录于专栏:信数据得永生信数据得永生

这篇文章中,我简要实现一下大语言模型的 MOE 模块。MOE 模块位于每个GPT层中,位于注意力模块的后面,每个MOE模块包含若干个MLP模块作为专家。这些专家是稀疏的,也就是每次选择部分来调用,并不会调用全部,从而节省宝贵的算力。

首先定义一些常量,通常应该在模型配置文件里面。

代码语言:javascript
复制
bs = 5 # 批量大小
seql = 32 # 序列长度
hid = 128 # 隐藏向量维度
nexp = 5 # 专家总数
topk # 所选的专家数量

模块的输入应该是句子中单词的隐藏向量。为了便于测试我直接取了随机数,正常情况下应该是有意义的值。首先需要转换成二维的,便于计算。

代码语言:javascript
复制
x = torch.randn([bs, seql, hid])
x = x.reshape([-1, hid])
x.shape
# torch.Size([160, 128])

然后我们需要一个门(定义在__init__里面,将每个隐藏向量转换成专家得分,进一步经过 softmax 转换成归一化的得分,表示每个专家对这个向量的结果有多大贡献。注意这里我们为每个向量单独分配专家,可能向量#1分配到了专家#1和#2,而向量#2分配到了专家#3和#4,总之可能是不一样的。

代码语言:javascript
复制
gate = torch.nn.Linear(hid, nexp)
exp_logits = gate(x)
exp_probs = torch.softmax(exp_logits, -1)
exp_probs.shape
# torch.Size([160, 5])

每个专家应该是 MLP(定义在__init__里面),但是为了演示我就直接省略了,大家可以从各个大语言模型的源码里面复制粘贴:

代码语言:javascript
复制
experts = [lambda x: x for _ in range(nexp)]

对每个向量分配到的专家按照贡献度排序,得到每个向量地专家排名exp_topk及其得分sc_topk

exp_topk[i, j]表示第i个词的第j个专家的序号,sc_topk[i, j]表示它的得分。

代码语言:javascript
复制
sc_topk, exp_topk = torch.topk(exp_probs, topk, -1)
sc_topk.shape
# torch.Size([160, 2])
exp_topk.shape
# torch.Size([160, 2])

将专家的得分归一化,因为我们选了两个,总和又不是一了,会对结果的大小有影响:

代码语言:javascript
复制
sc_topk /= sc_topk.sum(-1, keepdim=True)

下面我们创建该层的结果数组,累加每个专家的输出,大小和输入一样:

代码语言:javascript
复制
final_hidden_state = torch.zeros_like(x)

然后我们获取每个专家对应的单词序号,和对应的单词排名。exp_topk == exp_i把等于专家exp_i的位置标注为True其它的为False,然后where获取下标。

hid_idcs是调用专家exp_i的向量序号,hid_ranks是该专家对于对应向量的排名

代码语言:javascript
复制
for exp_i in range(nexp):
    hid_idcs, hid_ranks = torch.where(exp_topk == exp_i)

注意每个专家被调用的次数都可能不一样:

代码语言:javascript
复制
[torch.where(exp_topk == exp_i) for exp_i in range(nexp)]
'''
[tensor([  0,   1,   2,   3,  14,  16,  18,  21,  22,  30,  32,  39,  43,  44,
          45,  52,  55,  58,  66,  67,  72,  77,  78,  80,  83,  87,  89,  90,
          91,  93, 102, 103, 105, 107, 108, 115, 116, 117, 126, 131, 133, 134,
         135, 136, 137, 146, 147, 148, 149, 151, 157, 158]),
 tensor([  6,   8,   9,  11,  18,  19,  20,  23,  26,  27,  28,  31,  34,  35,
          37,  41,  47,  50,  51,  53,  54,  56,  57,  59,  60,  62,  63,  71,
          74,  75,  77,  78,  79,  82,  83,  84,  86,  93,  97,  98, 100, 107,
         109, 110, 111, 113, 114, 118, 120, 123, 124, 126, 127, 128, 129, 130,
         139, 140, 143, 144, 145, 150, 155, 159]),
 tensor([  0,   4,   7,   8,  10,  12,  13,  14,  16,  17,  24,  25,  26,  29,
          32,  33,  34,  36,  40,  41,  46,  47,  49,  50,  53,  58,  64,  65,
          68,  70,  72,  73,  76,  81,  82,  85,  88,  89,  92,  94, 101, 103,
         108, 109, 112, 114, 115, 119, 120, 121, 123, 125, 132, 133, 135, 138,
         139, 140, 141, 142, 145, 146, 147, 150, 152, 153, 155, 156, 158]),
 tensor([  1,   5,   6,   7,   9,  11,  12,  13,  15,  20,  22,  23,  28,  29,
          30,  31,  35,  37,  38,  40,  42,  46,  48,  54,  55,  56,  57,  60,
          61,  62,  64,  65,  67,  69,  70,  71,  73,  74,  79,  80,  81,  84,
          86,  95,  96,  98,  99, 102, 104, 106, 110, 111, 113, 116, 118, 119,
         122, 125, 128, 129, 132, 134, 138, 144, 153, 154, 157, 159]),
 tensor([  2,   3,   4,   5,  10,  15,  17,  19,  21,  24,  25,  27,  33,  36,
          38,  39,  42,  43,  44,  45,  48,  49,  51,  52,  59,  61,  63,  66,
          68,  69,  75,  76,  85,  87,  88,  90,  91,  92,  94,  95,  96,  97,
          99, 100, 101, 104, 105, 106, 112, 117, 121, 122, 124, 127, 130, 131,
         136, 137, 141, 142, 143, 148, 149, 151, 152, 154, 156])]
'''

然后我们把每个专家的向量获取到(x[hid_idcs]),传入该专家experts[exp_i](...)

代码语言:javascript
复制
# for ...
    hidden_state = experts[exp_i](x[hid_idcs])
    hidden_state.shape
    # torch.Size([52, 128])

然后需要乘上专家权重,最后加一维以便权重和上面的向量对齐:

代码语言:javascript
复制
# for ...
    weights = sc_topk[hid_idcs, hid_ranks].unsqueeze(-1)
    weights.shape
    # torch.Size([52, 1])
    hidden_state *= weights

然后将当前专家的输出填回到结果数组中:

代码语言:javascript
复制
# for ...
    final_hidden_state[hid_idcs] += hidden_state

每个专家都计算完之后,将结果数组变形成原始的形状,然后作为整个模块的输出:

代码语言:javascript
复制
final_hidden_state = final_hidden_state.reshape([bs, seql, hid])
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2024-06-21,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档