前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >LoRA大模型降维训练

LoRA大模型降维训练

作者头像
iResearch666
发布2023-09-13 14:21:28
发布2023-09-13 14:21:28
84200
代码可运行
举报
运行总次数:0
代码可运行

LoRA: Low-Rank Adaptation of Large Language Models


  • paper https://arxiv.org/abs/2106.09685
  • code https://github.com/microsoft/LoRA

Abstract

  • NLP范式是在通用数据上训练大规模模型,然后对下游任务进行适配
  • 适配需要fine tuning模型所有参数,而且每个任务都需要微调,非常不灵活
  • 提出低秩自适应LoRA,通过冻结预训练模型参数,只将可训练的秩分解矩阵注入到Transformer架构中,极大的降低了下游任务的训练参数。
  • GPT-3 175B 使用 LoRA 后,训练参数降低了1万倍,显存降低了3倍,不和其它适配器一样,没有增加推理延迟,而且性能相近

Contributions

LoRA优点:

  • 预训练的模型可以共享,并用于为不同的任务构建许多小型LoRA模块。可以通过替换图中的矩阵A和B来冻结共享模型并有效地切换任务,从而显著降低存储需求和任务切换开销。
  • LoRA使训练更有效,并且在使用自适应优化器时将硬件进入门槛降低了3倍,因为不需要计算梯度或维护大多数参数的优化器状态。相反,只优化注入的小得多的低秩矩阵。
  • 简单的线性设计允许在部署时将可训练矩阵与冻结权重合并,与完全微调的模型相比,通过构建不会引入推理延迟。
  • LoRA与许多先前的方法正交,并且可以与其中许多方法组合,例如prefix-tuning

img

Related Work

  • Transformer Language Models.
  • Prompt Engineering and Fine-Tuning
  • Parameter-Efficient Adaptation
  • Low-Rank Structures in Deep Learning

Methodology

Overview

  1. 为什么叫“秩”

彼此不认识,那就不相关,就有秩序,问题就好解决;反之,彼此相关,就没有秩序,问题就不好解决。所以,数学中定义,矩阵中最大的不相关的向量的个数,叫做秩,可以理解为有秩序的程度。

  1. 低秩矩阵分解中低秩的意义

如果矩阵表达的是结构性信息,例如图像、用户-商品推荐表等,那么这个矩阵各行之间存在一定的相关性,那这个矩阵一般是低秩的。

如果矩阵之间各行的相关性很强,那么就表示这个矩阵实际可以投影到更低维的线性子空间,也就是用几个向量就可以完全表达了,它就是低秩的。

如果X是一个m行n列的数值矩阵,rank(x)是x的秩,假如rank (X)远小于m和n,则称x是低秩矩阵。低秩矩阵每行或每列都可以用其他的行或列线性表示,可见它包含大量的冗余信息。利用这种冗余信息,可以对数据进行恢复,也可以对数据进行特征提取。

矩阵的秩的度量其实就是矩阵的行列之间的相关性。如果矩阵的各行或列是线性无关的,矩阵就是满秩的。非零元素的行数或列数决定了秩的多少。

低秩与稀疏。低秩是指矩阵的秩较小,稀疏是指矩阵中非零元素的个数少。如果对矩阵进行奇异值分解,并把其所有奇异值排列为一个向量,那么这个向量的稀疏性便对应于该矩阵的低秩性

  1. 低秩(Low-rank)的意义

若将图像看成一个矩阵,那么它的基的数量越少,基对应的线性无关向量数量就越少,矩阵的秩就越小。当它远远小于矩阵的大小的时候,图像就是低秩的。低秩矩阵的每行或者每列都可以用其他的行或者列线性表示,这说明这个矩阵包含了大量的冗余信息。利用这种冗余信息可以对确实图像信息进行恢复,可以将多出来的噪声信息进行去除,还可以对错误的图像信息进行恢复。

LOW-RANK-PARAMETRIZED UPDATE MATRICES

image-20230831121351040

  • 神经网络中的Dense Layer通常包含许多矩阵乘法。先前的研究表明(模型是过参数化),在适应特定任务时,PLM可能具有较低的"内在维度",即使投射到较小的子空间,也可以有效地进行学习。
  • 假设权重更新过程中也具有较低的"内在排名(intrinsic rank)"。作者使用低阶分解的方式表示预训练的权重矩阵W0的更新,即W0 + ∆W = W0 + BA,其中B为d×r的矩阵,a为r×k的矩阵,秩R≪min(d, k)。在训练期间,W0保持不变,不接受梯度更新,而A和B包含可训练的参数。当h=w0x时,修正后的前向传播变为:

image-20230831121708759

  • 图中展示了这种重新参数化的方式。对A使用随机高斯初始化,对B使用零初始化,因此∆W = BA在训练开始时为零。然后,将∆Wx乘以α/r,其中α是r中的常数。
  • LoRA的推广形式允许训练预训练参数的子集,即在适应下游任务过程中权重矩阵的累积梯度更新不必具有全秩。这意味着Full Fine-tuning实际上是LORA的一种全秩的特殊情况。换句话说,当增加可训练参数的数量时,训练LoRA会大致收敛于训练原始模型,而Adapter的方法会收敛于MLP,Prefix-tuning的方法会收敛于不能处理长输入序列的模型。
  • LoRA在推理过程中没有额外的延迟。在部署到生产环境中时,我们可以显式地计算和存储W = W0 + BA,并像往常一样进行推理。当需要切换到另一个下游任务时,我们可以通过减去BA,然后添加另一个B'和a'来恢复W0,这是一种内存开销较小的快速操作。并且在推理时通过合并权重得到和Full Fine-tuning一样的延迟。

image-20230831143646297

image-20230831143741010

APPLYING LORA TO TRANSFORMER

  • 在Transformer中使用LoRA可以将其应用于权矩阵的子集,从而减少可训练参数的数量。具体来说,在Self-attention模块中有四个权重矩阵(wq、wk、wv、wo),MLP模块中有两个权重矩阵。可以将适应下游任务的注意力权重限制在自关注模块中,并冻结MLP模块,以简化和提高参数效率。
  • 对于实际的好处主要体现在内存和存储使用上。对于经过Adam训练的大型Transformer,当r≪d时,可以将VRAM使用量减少多达2/3,因为不需要存储冻结参数的优化器状态。例如,在GPT-3 175B上,训练期间的VRAM消耗从1.2TB减少到350GB。当r=4,并且只调整查询和值投影矩阵时,checkpoint大小减少了大约10,000倍(从350GB减少到35MB)。这样可以使用更少的GPU进行训练,并避免I/O瓶颈。
  • LoRA还允许在部署时低成本地在任务之间切换,只需交换LoRA权重,而不是所有参数,从而创建许多定制模型,可以在将预训练的权重存储在VRAM中的机器上动态交换。
  • 在GPT-3 175B的训练过程中,与完全微调相比,速度提高了25%,因为不需要计算绝大多数参数的梯度。
  • LoRA也有局限性,例如在单个前向传递中批量处理具有不同A和B的不同任务的输入时可能会面临推理延迟的问题。虽然对于不重要延迟的场景,可以不合并权重并动态选择LoRA模块来用于批量示例。

Experiments

image-20230831142906008

  • 理论上lora可以支持任何线性层,包括transformer中的4个attention矩阵和2个feed forward中的矩阵,论文旨在attention上做了实验,它限制总参数量不变的情况下观察是在attention其中一个矩阵上,放一个更高秩的lora,还是在多个attention的矩阵上,分别放置低秩一点的lora效果好?结论是把秩分散到多个矩阵上,效果会优于集中在单个上的效果。至于在一般任务上很小的秩就可以和很大秩的效果,这也证明了作者一开始做出的改变量低秩的假设。

img

Conclusions

  • 极大降低大模型下游任务训练参数,同时保持高模型质量
  • 几乎不增加推理延迟
  • 适用Transformer(Attention layer)、Dense layer等结构

Quickstart

loralib We only support nn.Linear, nn.Embedding, and nn.Conv2d for now. We also support a MergedLinear for cases where a single nn.Linear represents more than one layers, such as in some implementations of the attention qkv projection (see Additional Notes for more).

  • 支持 nn.Linear, nn.Embedding, nn.Conv2d , MergedLinear(attention)

loralib

  1. Installing loralib is simply
代码语言:javascript
代码运行次数:0
复制
$ pip install loralib
# Alternatively
# pip install git+https://github.com/microsoft/LoRA
  1. You can choose to adapt some layers by replacing them with counterparts implemented in loralib. We only support nn.Linear, nn.Embedding, and nn.Conv2d for now. We also support a MergedLinear for cases where a single nn.Linear represents more than one layers, such as in some implementations of the attention qkv projection (see Additional Notes for more).
代码语言:javascript
代码运行次数:0
复制
# ===== Before =====
# layer = nn.Linear(in_features, out_features)

# ===== After ======
import loralib as lora
# Add a pair of low-rank adaptation matrices with rank r=16
layer = lora.Linear(in_features, out_features, r=16)
  1. Before the training loop begins, mark only LoRA parameters as trainable.
代码语言:javascript
代码运行次数:0
复制
import loralib as lora
model = BigModel()
# This sets requires_grad to False for all parameters without the string "lora_" in their names
lora.mark_only_lora_as_trainable(model)
# Training loop
for batch in dataloader:
   ...
  1. When saving a checkpoint, generate a state_dict that only contains LoRA parameters.
代码语言:javascript
代码运行次数:0
复制
# ===== Before =====
# torch.save(model.state_dict(), checkpoint_path)
# ===== After =====
torch.save(lora.lora_state_dict(model), checkpoint_path)
  1. When loading a checkpoint using load_state_dict, be sure to set strict=False.
代码语言:javascript
代码运行次数:0
复制
# Load the pretrained checkpoint first
model.load_state_dict(torch.load('ckpt_pretrained.pt'), strict=False)
# Then load the LoRA checkpoint
model.load_state_dict(torch.load('ckpt_lora.pt'), strict=False)

Additional Notes

  1. While we focus on a simple yet effect setup, namely adapting only the q and v projection in a Transformer, in our examples, LoRA can be apply to any subsets of pre-trained weights. We encourage you to explore different configurations, such as adapting the embedding layer by replacing nn.Embedding with lora.Embedding and/or adapting the MLP layers. It's very likely that the optimal configuration varies for different model architectures and tasks.
  2. Some Transformer implementation uses a single nn.Linear for the projection matrices for query, key, and value. If one wishes to constrain the rank of the updates to the individual matrices, one has to either break it up into three separate matrices or use lora.MergedLinear. Make sure to modify the checkpoint accordingly if you choose to break up the layer.
代码语言:javascript
代码运行次数:0
复制
# ===== Before =====
# qkv_proj = nn.Linear(d_model, 3*d_model)
# ===== After =====
# Break it up (remember to modify the pretrained checkpoint accordingly)
q_proj = lora.Linear(d_model, d_model, r=8)
k_proj = nn.Linear(d_model, d_model)
v_proj = lora.Linear(d_model, d_model, r=8)
# Alternatively, use lora.MergedLinear (recommended)
qkv_proj = lora.MergedLinear(d_model, 3*d_model, r=8, enable_lora=[True, False, True])
  1. Training bias vectors in tandem with LoRA might be a cost-efficient way to squeeze out extra task performance (if you tune the learning rate carefully). While we did not study its effect thoroughly in our paper, we make it easy to try in lora. You can mark some biases as trainable by passing "all" or "lora_only" to bias= when calling mark_only_lora_as_trainable. Remember to pass the corresponding bias= argument to lora_state_dict when saving a checkpoint.
代码语言:javascript
代码运行次数:0
复制
# ===== Before =====
# lora.mark_only_lora_as_trainable(model) # Not training any bias vectors
# ===== After =====
# Training all bias vectors associated with modules we apply LoRA to 
lora.mark_only_lora_as_trainable(model, bias='lora_only')
# Alternatively, we can train *all* bias vectors in the model, including LayerNorm biases
lora.mark_only_lora_as_trainable(model, bias='all')
# When saving a checkpoint, use the same bias= ('all' or 'lora_only')
torch.save(lora.lora_state_dict(model, bias='all'), checkpoint_path)
  1. Calling model.eval() will trigger the merging of LoRA parameters with the corresponding pretrained ones, which eliminates additional latency for subsequent forward passes. Calling model.train() again will undo the merge. This can be disabled by passing merge_weights=False to LoRA layers.

References

  • https://github.com/microsoft/LoRA
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-08-31 15:06,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 iResearch666 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Abstract
  • Contributions
  • Related Work
  • Methodology
    • Overview
    • LOW-RANK-PARAMETRIZED UPDATE MATRICES
    • APPLYING LORA TO TRANSFORMER
  • Experiments
  • Conclusions
  • Quickstart
    • loralib
    • Additional Notes
  • References
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档