前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >最简单的模型轻量化方法:20行代码为BERT剪枝

最简单的模型轻量化方法:20行代码为BERT剪枝

作者头像
腾讯知文实验室
发布2019-11-22 16:10:04
6.8K1
发布2019-11-22 16:10:04
举报
文章被收录于专栏:语言、知识与人工智能

| 导语 BERT模型在多种下游任务表现优异,但庞大的模型结果也带来了训练及推理速度过慢的问题,难以满足对实时响应速度要求高的场景,模型轻量化就显得非常重要。因此,笔者对BERT系列模型进行剪枝,并部署到实际项目中,在满足准确率的前提下提高推理速度。

一. 模型轻量化

    模型轻量化是业界一直在探索的一个课题,尤其是当你使用了BERT系列的预训练语言模型,inference速度始终是个绕不开的问题,而且训练平台可能还会对训练机器、速度有限制,训练时长也是一个难题。

   目前业界上主要的轻量化方法如下:

  • 蒸馏:将大模型蒸馏至小模型,思路是先训练好一个大模型,输入原始数据得到logits作为小模型的soft label,而原始数据的标签则为hard label,使用soft label和hard label训练小模型,旨在将大模型的能力教给小模型。
  • 剪枝:不改变模型结构,减小模型的维度,以减小模型量级。
  • 量化:将高精度的浮点数转化为低精度的浮点数,例如4-bit、8-bit等。
  • OP重建:合并底层操作,加速矩阵运算。
  • 低秩分解:将原始的权重张量分解为多个张量,并对分解张量进行优化。

    我们团队对这些轻量化方法都进行了尝试,简单总结如下:

  • 蒸馏:可以很好地将大模型的能力教给小模型,将12层BERT蒸馏至2层BERT,可以达到非常接近的效果。但这种方法需要先训练出一个大模型。
  • 剪枝:速度有非常显著的提升,结合蒸馏,可以达到很好的效果;即使不结合蒸馏,也能达到不错的效果。
  • 量化:主要用于模型压缩,可以将大文件压缩成小文件存储,方便部署于移动端,但是在速度上无明显提升。
  • OP重建:有明显加速功能,但是操作较为复杂,需要修改底层C++代码。
  • 低秩分解:基于PCA算法,有一倍多的加速作用,但是效果也下降了许多。

    在这些方法中,剪枝显得非常简单又高效,如果你想快速得对BERT模型进行轻量化,不仅inference快,还希望训练快,模型文件小,效果基本维持,那么剪枝将是一个非常好的选择,本文将介绍如何为BERT系列模型剪枝,并附上代码,教你十分钟剪枝。

二. BERT剪枝

    本节先重温BERT[1]及其变体AL-BERT[2]的模型结构,分析在哪里地方参数量大,再介绍如何为这类结构进行剪枝。

1. BERT模型主要组件

  • Input Embedding:词嵌入,包含token、segment、position三种嵌入方式;
  • Multi-Head Attention:多头注意力机制,共12头;
  • Feed Forward:全连接层,对注意力的输出向量做进一步映射;
  • Output pooler:对hidden向量进行平均/或取cls,得到输出向量,用于下游任务。

按照默认的维度配置,得到的模型参数大小如下(此处仅展示一层):

    可以看到BERT模型的参数维度都比较大,都是768起步,而在每一层的结构中,全连接层的3072维,是造成该层参数爆炸的主要原因。单层的参数量已经比普通模型大了许多,当该层参数量再乘以12,杀伤指数更是暴增。

    海量的参数加上海量的无监督训练数据,BERT模型取得奇效,但我们在训练我们的下游任务时,是否真的需要这么大的模型呢?

    可以看到,AL-BERT对Embedding参数进行了因式分解,分解成了2个小矩阵,先将Embedding矩阵投射到一个更小的矩阵E,再投影到隐藏空间H中,减少了参数量(注:同时AL-BERT进行了跨层参数共享,所以保存的参数量少,得到的模型文件非常小),大大加快了模型的训练速度,但遗憾的是AL-BERT并没有提高inference速度。

2. 剪枝方法

  基于以上分析,针对BERT系列模型的结构,可采取的剪枝方法如下:

1)层数剪枝

   在BERT模型的应用中,我们一般取第12层的hidden向量用于下游任务。而低层向量基本上包含了基础信息,我们可以取低层的输出向量接到任务层,进行微调。

(跟许老板讨论过一个论文,BERT的低层向量可以学习到一些基础的词法信息,高层向量可以学到更多跟任务相关的特征,暂时找不到这篇论文了,找到会补上)

2)维度剪枝

    接下来对每一层的维度进行剪枝,ok,全连接层的3072维,在一堆768中成功引起了我们的注意:

    intermediate层的参数量 =(768+1)*3072 *2 = 4724736

    假设我们剪到768维,全连接层的参数量可以减少75%,假如剪到384维,全连接的参数量可以减少87.5%!

3)Attention剪枝

    在12头注意力中,每头维度是64,最终叠加注意力向量共768维。

    相关研究[3]表明:

  • 在inference阶段,大部分head在被单独去掉的时候,效果不会损失太多;
  • 将某一层的head只保留1个,其余的head去掉,对效果基本不会有什么影响。

    因此,我们可以尝试只保留1-2层模型,裁剪ffn维度,减少head个数,在裁剪大量参数的同时维持精度不会下降太多。

三. 工程实现

首先我们看下市面上有没有啥方便的工具可以剪枝:

  • Tensorflow Pruning API:tensorflow官方剪枝工具,该工具基于Keras,如果要用在Tensorflow的模型中,需要将Tensorflow模型转化为Keras模型,诸多不便。
  • Pocketflow Pruning API:腾讯开源的模型压缩框架,基于tensorflow,为卷积层提供通道剪枝,无法用于BERT结构。
  • PaddlePaddle Pruning API:基于百度自家研发的深度学习框架。

这些工具都不适合使用,那就让我们自己来动手剪枝吧:

  • 简单方法:直接改配置文件的参数设置,不加载谷歌pretrain好的语言模型,使用自己的数据重新pretrain语言模型,再加载该模型进行task-specific fine-tune;
  • 进阶方法:在fine-tune的时候,首先随机初始化参数,假设从原始的m维裁剪到了n维,那么取预训练BERT模型相应的前n维赋值给剪枝后的参数。
  • 终极方法:在pretrain阶段,取通用BERT模型前n维参数进行赋值再train一遍;在fine-tune阶段,就可以直接加载train好的模型进行微调。

下面进入了超级简单的代码环节!关键代码仅20行!

1)首先,将谷歌pretrain的模型参数预存好,保存到一个json文件中:

2)参数赋值,在model_fn_builder函数中,加载预存的参数进行剪枝赋值:

是的!剪枝就是如此简单!从前笔者为了多方面做对比实验(例如,第一层剪到768维,第2层剪到384维),强行修改了BERT的模型代码,传入一个字典进行剪枝,迁移到另一个BERT变体模型就不太方便。

      最后附上部分实验结果(时间可能会有所波动):

模型

层数

ffn维度

head个数

hidden size

tes acc

inference时间

BERT

12

3072

12

768

0.78

1000ms+

BERT

2

384

6

768

0.75

340ms

BERT

1

384

6

384

0.701

217ms

AL-BERT

4

1248

12

312

0.771

650ms

AL-BERT

2

312

6

312

0.763

388ms

AL-BERT

1

312

6

312

0.74

183ms

  • 不要怀疑,为什么BERT效果这么差,因为这份结果是拿口语化badcase测试的,与训练集相符合的验证集可以到达99%的准确率~
  • AL-BERT训练速度起飞,在同等训练数据、模型层数、维度基本等同的前提下,1层AL-BERT 1.5小时即可收敛,而1层BERT模型需要4个小时!在本次场景下,BERT模型收敛得比较慢,这一战,AL-BERT胜!
  • 取前n维向量的剪枝方法是否过于粗暴?是有点,我们也简单尝试过,对权重根据绝对值进行排序裁剪,但结果相差不大。或许可以继续优化~

小结:对BERT系列模型来说,剪枝是一个非常不错的轻量化方法,很多下游任务可以不需要这么庞大的模型,也能达到很好的效果。

References

  • Devlin J , Chang M W , Lee K , et al. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding[J]. 2018.
  • Lan Z , Chen M , Goodman S , et al. ALBERT: A Lite BERT for Self-supervised Learning of Language Representations[J]. 2019.
  • Michel P, Levy O, Neubig G. Are Sixteen Heads Really Better than One?[J]. arXiv preprint arXiv:1905.10650, 2019.
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-11-19,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 腾讯知文 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
NLP 服务
NLP 服务(Natural Language Process,NLP)深度整合了腾讯内部的 NLP 技术,提供多项智能文本处理和文本生成能力,包括词法分析、相似词召回、词相似度、句子相似度、文本润色、句子纠错、文本补全、句子生成等。满足各行业的文本智能需求。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档