前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >初探大模型压缩

初探大模型压缩

作者头像
半吊子全栈工匠
发布于 2024-11-07 09:50:29
发布于 2024-11-07 09:50:29
14201
代码可运行
举报
文章被收录于专栏:喔家ArchiSelf喔家ArchiSelf
运行总次数:1
代码可运行

【引】感谢大家对联想AIPC的关注!大模型在智能终端上应用使AI更方便地服务于我们的工作和生活,《从苹果智能看端上大模型应用》为我们提供了参考。

一般地,语言模型越大越好,改进LLM的方式非常简单: 更多的数据 + 更多的参数 + 更多的计算 = 更好的性能。但是,使用100B + 参数模型存在着明显的挑战。例如,使用 FP16的100B 参数模型仅存储空间就需要200GB!大多数消费设备(如手机、平板电脑、笔记本电脑)无法处理这么大模型。如何能把它们变小呢?

1. 模型压缩

模型压缩的目的是在不牺牲性能的情况下减少机器学习模型的大小。这适用于大型神经网络,因为它们常常过度参数化(即由冗余的计算单元组成)。

模型压缩的主要好处是降低推理成本,这意味着大模型(即在本地笔记本电脑上运行 LLM)的更广泛使用,人工智能与消费产品的低成本集成,以及支持用户隐私和安全的设备上推理。

模型压缩技术的范围很广,主要有3大类:

  1. 量化ーー用较低精度的数据类型表示模型
  2. 修剪ーー从模型中删除不必要的组件
  3. 知识蒸馏ーー用大模型训练小模型

这些方法是相互独立的。因此,来自多个类别的技术组合在一起可以获得最大的压缩。

2. 量化

虽然量化听起来像一个可怕而复杂的词,但它是一个简单的想法,主要是降低模型参数的精度。我们可以把这看作是在保持图片核心属性的同时,将高分辨率图像转换为低分辨率图像。

两种常见的量化技术是训练后量化(PTQ)和量化感知训练(QAT)。

2.1 训练后量化(PTQ)

给定一个神经网络,后训练量化(PTQ)通过用低精度数据类型(例如 FP16到 INT-8)替换参数来压缩模型。这是减少模型计算需求的最快和最简单的方法之一,因为它不需要额外的训练或数据标注

虽然这是一种相对容易的削减模型成本的方法,但这种方法中过多的量化(例如,FP16到 INT4)常常会导致性能下降,从而限制了 PTQ 的潜在收益。

2.2量化感知训练

对于需要更大压缩的情况,PTQ 的局限性可以通过使用低精度数据类型的训练模型(从头开始)来克服。这就是量化感知训练(QAT)的背后思想。虽然这种方法在技术上要求更高,但它可以产生一个更小、性能更好的模型。例如,BitNet 体系结构使用三元数据类型(即1.58位)来匹配原始 Llama LLM 的性能。

当然,PTQ 和从头开始的 QAT 之间存在很大的技术差距。两者之间的一种方法是量化感知微调,它包括量化后预训练模型的额外训练。

3. 修剪

修剪的目的是删除对性能影响很小的模型组件,其有效性在于机器学习模型(尤其是大模型)倾向于学习冗余和嘈杂的结构。这里的比喻就像是从树上剪下枯枝,剪枝可以在不伤害树的情况下减小树的体积。修剪方法可以分为两类: 非结构化修剪和结构化修剪。

3.1 非结构化修剪

非结构化剪枝从神经网络中移除不重要的权重(即将它们设置为零)。例如,通过估计对损失函数的影响来计算网络中每个参数的显著性得分。去除具有最小绝对值的权重的方法,由于其简单性和可伸缩性而变得流行起来。

虽然非结构化剪枝的粒度可以显著减少参数计数,但是这些增益一般需要专门的硬件来实现。非结构化剪枝导致稀疏矩阵运算 ,标准硬件往往无法更有效地完成。

3.2 结构化修剪

结构化修剪从神经网络中移除整个结构(例如注意力头,神经元和层)。这避免了专用矩阵运算的问题,因为整个矩阵可以从模型中删除,而不是单独的参数。虽然有各种方法可以识别结构进行修剪,但原则上,它们都试图删除对性能影响最小的结构。

4.知识蒸馏

知识蒸馏将知识从(较大的)教师模型转移到(较小的)学生模型。做到这一点的一种方法是用教师模型生成预测,并用它们来训练学生模型。学习教师模型的输出 logits (即,所有可能的下一个令牌的概率)提供了比原始训练数据更丰富的信息,这提高了学生模型的性能。

最近的蒸馏应用程序完全放弃了对 logit 的需要,而是从教师模型生成的合成数据中学习。一个流行的例子是斯坦福大学的 Alpaca 模型,该模型使用 OpenAI 的 text-davinci-003(即原始 ChatGPT 模型)的合成数据对 LLaMa 7B 模型进行了微调,使其能够遵循用户指令。

5. 实验:用知识蒸馏 + 量化压缩文本分类器

作为一个实验,我们将压缩一个100M 参数模型,该模型将 URL 分类为安全还是不安全(即是否是钓鱼网站)。首先利用知识精馏将100M 参数模型压缩为50M 参数模型。然后,使用4位量化,进一步减少了3倍的内存占用,导致最终的模型是原始模型的1/8。

5.1 环境构建

我们首先导入一些需要使用的库。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
from datasets import load_dataset

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import DistilBertForSequenceClassification, DistilBertConfig

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from sklearn.metrics import accuracy_score, precision_recall_fscore_support

然后,我们从Hugging Face Hub加载数据集。这包括训练(2100行)、测试(450行)和验证(450行)集。

data = load_dataset("llmc/phishing-site-classification")

5.2 加载教师模型

加载教师模型,为了帮助加快训练速度,我们需要使用GPU处理器。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# use Nvidia GPU
device = torch.device('cuda')

# Load teacher model and tokenizer
model_path = "llmc/bert-phishing-classifier_teacher"

tokenizer = AutoTokenizer.from_pretrained(model_path)
teacher_model = AutoModelForSequenceClassification.from_pretrained(model_path)
                                                  .to(device)

这个教师模型是 Goolge 的 bert-base-uncase 的一个微调版本,它对钓鱼网站的 URL 执行二进制分类。

5.3 构建学生模型

对于学生模型,需要从头开始初始化,通过从剩余的层中移除两个层和四个注意头来修改模型的架构。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# Load student model
my_config = DistilBertConfig(n_heads=8, n_layers=4) # drop 4 heads per layer and 2 layers

student_model = DistilBertForSequenceClassification
                                    .from_pretrained("distilbert-base-uncased",
                                    config=my_config,)
                                    .to(device)

在训练学生模型之前,我们需要对数据集进行标记。这一点很重要,因为模型期望以特定的方式表示输入文本。

在这里,基于每批最长的示例填充,允许将批次表示为 PyTorch 张量。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# define text preprocessing
def preprocess_function(examples):
    return tokenizer(examples["text"], padding='max_length', truncation=True)

# tokenize all datasetse
tokenized_data = data.map(preprocess_function, batched=True)
tokenized_data.set_format(type='torch', 
                          columns=['input_ids', 'attention_mask', 'labels'])

训练前的另一个重要步骤是在训练期间为模型定义一个评估策略。下面,定义一个函数,它计算给定模型和数据集的准确率、精确率、召回率和 F1得分。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# Function to evaluate model performance
def evaluate_model(model, dataloader, device):
    model.eval()  # Set model to evaluation mode
    all_preds = []
    all_labels = []

    # Disable gradient calculations
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass to get logits
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # Get predictions
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    # Calculate evaluation metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, 
                                                              all_preds, 
                                                              average='binary')

    return accuracy, precision, recall, f1

5.4 训练学生模型

为了学生模型同时学习训练集中的可信数据标签(即硬目标)和教师模型的逻辑(即软目标) ,我们必须构造一个特殊的损失函数来考虑两个目标。这是通过将学生和教师的输出概率分布的 KL 散度与学生 logit 的交叉熵损失和基本真理相结合来完成的。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# Function to compute distillation and hard-label loss
def distillation_loss(student_logits, teacher_logits, 
                      true_labels, temperature, alpha):
    # Compute soft targets from teacher logits
    soft_targets = nn.functional.softmax(teacher_logits / temperature, dim=1)
    student_soft = nn.functional.log_softmax(student_logits / temperature, dim=1)

    # KL Divergence loss for distillation
    distill_loss = nn.functional.kl_div(student_soft, 
                                    soft_targets, 
                                    reduction='batchmean') * (temperature ** 2)

    # Cross-entropy loss for hard labels
    hard_loss = nn.CrossEntropyLoss()(student_logits, true_labels)

    # Combine losses
    loss = alpha * distill_loss + (1.0 - alpha) * hard_loss

    return loss

接下来,定义超参数、优化器、训练数据集和测试数据集。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# hyperparameters
batch_size = 32
lr = 1e-4
num_epochs = 5
temperature = 2.0
alpha = 0.5

# define optimizer
optimizer = optim.Adam(student_model.parameters(), lr=lr)

# create training data loader
dataloader = DataLoader(tokenized_data['train'], batch_size=batch_size)
# create testing data loader
test_dataloader = DataLoader(tokenized_data['test'], batch_size=batch_size)

最后,我们使用 PyTorch 训练学生模型。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# put student model in train mode
student_model.train()

# train model
for epoch in range(num_epochs):
    for batch in dataloader:
        # Prepare inputs
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Disable gradient calculation for teacher model
        with torch.no_grad():
            teacher_outputs = teacher_model(input_ids, 
                                            attention_mask=attention_mask)
            teacher_logits = teacher_outputs.logits

        # Forward pass through the student model
        student_outputs = student_model(input_ids, 
                                        attention_mask=attention_mask)
        student_logits = student_outputs.logits

        # Compute the distillation loss
        loss = distillation_loss(student_logits, teacher_logits, labels, 
                                  temperature, alpha)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} completed with loss: {loss.item()}")

    # Evaluate the teacher model
    teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = 
                         evaluate_model(teacher_model, test_dataloader, device)

    print(f"Teacher (test) - Accuracy: {teacher_accuracy:.4f}, 
                              Precision: {teacher_precision:.4f}, 
                              Recall: {teacher_recall:.4f}, 
                              F1 Score: {teacher_f1:.4f}")

    # Evaluate the student model
    student_accuracy, student_precision, student_recall, student_f1 = 
                         evaluate_model(student_model, test_dataloader, device)
    
    print(f"Student (test) - Accuracy: {student_accuracy:.4f}, 
                              Precision: {student_precision:.4f}, 
                              Recall: {student_recall:.4f}, 
                              F1 Score: {student_f1:.4f}")
    print("\n")

    # put student model back into train mode
    student_model.train()

5.5 模型评估

我们可以在独立的验证集上评估模型,也就是说,使用那些不用于训练模型参数或调整超参数的数据。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# create testing data loader
validation_dataloader = DataLoader(tokenized_data['validation'], batch_size=8)

# Evaluate the teacher model
teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = 
                   evaluate_model(teacher_model, validation_dataloader, device)
print(f"Teacher (validation) - Accuracy: {teacher_accuracy:.4f}, 
                               Precision: {teacher_precision:.4f}, 
                               Recall: {teacher_recall:.4f}, 
                               F1 Score: {teacher_f1:.4f}")

# Evaluate the student model
student_accuracy, student_precision, student_recall, student_f1 = 
                   evaluate_model(student_model, validation_dataloader, device)
print(f"Student (validation) - Accuracy: {student_accuracy:.4f}, 
                               Precision: {student_precision:.4f}, 
                               Recall: {student_recall:.4f}, 
                               F1 Score: {student_f1:.4f}")

5.6 模型量化

我们再使用 QLoRA 文章中描述的4位 NormalFloat 数据类型和用于计算的 bfloat16设置配置来存储模型参数。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
from transformers import BitsAndBytesConfig

# load model in model as 4-bit
nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype = torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

model_nf4 = AutoModelForSequenceClassification.from_pretrained(model_id, 
                                                device_map=device, 
                                                quantization_config=nf4_config)

然后,可以在验证集上评估我们的量化模型。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# Evaluate the student model
quantized_accuracy, quantized_precision, quantized_recall, quantized_f1 = 
                       evaluate_model(model_nf4, validation_dataloader, device)

print("Post-quantization Performance")
print(f"Accuracy: {quantized_accuracy:.4f}, 
        Precision: {quantized_precision:.4f}, 
        Recall: {quantized_recall:.4f}, 
        F1 Score: {quantized_f1:.4f}")

压缩之后性能有了小小的提高,一个直观的解释是 Occam 的剃刀原理,该原理指出,简单的模型更好。在这个实验中,模型可能过度参数化了这个二进制分类任务。因此,简化模型可以获得更好的性能。

一句话小结

虽然LLM在各种任务中表现出了令人印象深刻的性能,但是它们在部署到现实世界环境中时存在挑战,模型压缩技术(量化、修剪和知识蒸馏) 通过降低 LLM 计算成本来帮助缓解这些挑战。

【参考资料与关联阅读】

  • A Survey of Model Compression and Acceleration for Deep Neural Networks,https://arxiv.org/abs/1710.09282
  • A Survey on Model Compression for Large Language Models,https://arxiv.org/abs/2308.07633
  • To prune, or not to prune: exploring the efficacy of pruning for model compression,https://arxiv.org/abs/1710.01878
  • Distilling the Knowledge in a Neural Network,https://arxiv.org/abs/1503.02531
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2024-10-26,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 喔家ArchiSelf 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
学完这篇 Nest.js 实战,还没入门的来锤我!(长文预警)
最近接到一个小需求,需要自己全干(前端+后端),看到群里大家对Nest.js热情都很高,自己也心痒痒,所以就走上了Nest.js的不归路~
五月君
2021/11/30
14.8K2
学完这篇 Nest.js 实战,还没入门的来锤我!(长文预警)
Nest.js 快速入门:实现对 Mysql 单表的 CRUD
Nest.js 是一个 Node.js 的后端开发框架,它实现了 MVC 模式,也支持了 IOC(自动注入依赖),比 Express 这类处理请求响应的库高了一个层次。而且也很容易集成 GraphQL、WebSocket 等功能,适合用来做大规模企业级开发。
神说要有光zxg
2021/12/26
4.2K0
Nest.js 快速入门:实现对 Mysql 单表的 CRUD
2024 版:Node.js+Express+Koa2+Nest.js 开发服务端(高の青)
在现代的 web 开发中,Node.js 是一种非常流行的服务器端 JavaScript 环境。我们来编写一个大致的框架和一些关键代码片段,以帮助你了解如何使用 Node.js、Express、Koa2 和 Nest.js 开发服务端应用。
百课优用户
2024/07/29
3690
Nest.js 框架实战之认识与搭建(一)
这是关于如何搭建后端服务的实战类文章,其实在写这类文章之前,也了解了其它的 Node 服务端框架,比如 egg.js、koa.js 等框架,经过比对我更倾向于使用 Nest 框架,因此有了该系列文章,借此总结和梳理自己在项目搭建和开发的过程。
玖柒的小窝
2021/11/08
1.6K0
【Nest教程】连接MySQL数据库
forRoot()方法接受与来自TypeORM包的createConnection()相同的配置对象。另外,我们可以创建ormconfig.json,这种方式创建的json文件,在测试过程中,运行报错,具体原因没有找到。
青年码农
2021/01/18
4.2K0
【Nest教程】连接MySQL数据库
使用NestJs、GraphQL、TypeORM搭建后端服务
本文介绍今年上半年使用的的一些技术,做一些个人的学习记录,温故而知新。主要包含了Nestjs、TypeGraphQL、TypeORM相关的知识。本文示例代码以提交到github,可以在这里查看。
路过的那只狗
2020/11/14
6.8K0
NestJS学习总结篇
完整版本,点击此处查看 http://blog.poetries.top/2022/05/25/nest-summary
前端进阶之旅
2022/05/27
2.4K0
10分钟上手nest.js+mongoDB
项目中我们会用到 Mongoose 来操作我们的数据库,Nest 官方为我们提供了一个 Mongoose 的封装,我们需要安装 mongoose 和 @nestjs/mongoose:
淼学派对
2024/04/10
4240
做了一个Nest.js上手项目,很丑,但适合练手和收藏
最近爱了上 Nest.js 这个框架,边学边做了一个 nest-todo 这个项目。
秋风的笔记
2021/09/22
4.8K1
做了一个Nest.js上手项目,很丑,但适合练手和收藏
GraphQL 实践与服务搭建
大概率你听说过 GraphQL,知道它是一种与 Rest API 架构属于 API 接口的查询语言。但大概率你也与我一样没有尝试过 GraphQL。
愧怍
2022/12/27
5.4K0
GraphQL 实践与服务搭建
Nest.js 从零到壹系列(一):项目创建&路由设置&模块
本系列将以前端的视角进行书写,分享自己的踩坑经历。教程主要面向前端或者毫无后端经验,但是又想尝试 Node.js 的读者,当然,也欢迎后端大佬斧正。
一只图雀
2020/04/07
5.4K0
从零开始的 Nest.js
Nest.js 久有耳闻了,但是一直没有时间去真正学习他,一直鸽子到了现在。我想借着学习 nest 的先进思想,来重构我的博客后端。
Innei
2021/12/28
1.7K0
BFF与Nestjs实战
主题列表:juejin, github, smartblue, cyanosis, channing-cyan, fancy, hydrogen, condensed-night-purple, greenwillow, v-green, vue-pro, healer-readable, mk-cute, jzman, geek-black, awesome-green, qklhk-chocolate
乐圣
2022/11/19
2.8K0
BFF与Nestjs实战
Nest.js 从零到壹系列(七):讨厌写文档,Swagger UI 了解一下?
上一篇介绍了如何使用寥寥几行代码就实现 RBAC 0,解决了权限管理的痛点,这篇将解决另一个痛点:写文档。
一只图雀
2020/04/17
4.8K0
Nest.js 从零到壹系列(七):讨厌写文档,Swagger UI 了解一下?
NestJS、TypeORM 和 PostgreSQL 项目开发和数据库迁移完整示例(译)
当 Node.js Server 项目越来越大时,将数据和数据库整理规范是很难的,所以从一开始就有一个好的开发和项目设置,对你的开发项目的成功至关重要。在这篇文章中,向你展示是如何设置大部分 Nest.js 项目的,我们将在一个简单的 Node.js API 上工作,并使用 PostgreSQL 数据库作为数据存储,并围绕它设置一些工具,使开发更容易上手。
五月君
2021/11/30
5.6K0
NestJS、TypeORM 和 PostgreSQL 项目开发和数据库迁移完整示例(译)
Nest.js 实战系列第二篇-实现注册、扫码登陆、jwt认证等
大家好我是考拉🐨,这是 Nest.js 实战系列第二篇,我要用最真实的场景让你学会使用 Node 主流框架。 上一篇中 【Nest.js入门之基本项目搭建】 带大家入门了Nest.js, 接下来在之前的代码上继续进行开发, 主要两个任务:实现用户的注册与登录。 在实现登录注册之前,需要先整理一下需求, 我们希望用户有两种方式可以登录进入网站来写文章, 一种是账号密码登录,另一种是微信扫码登录。文章内容大纲 接着上章内容开始... 前面我们创建文件都是一个个创建的, 其实还有一个快速创建Contoller
coder_koala
2021/12/13
10.2K0
Nest.js 实战系列第二篇-实现注册、扫码登陆、jwt认证等
Nest系列教程之控制器
为了创建一个基本的控制器,我们必须将元数据附加到类中。Nest 知道如何映射我们的控制器到相应的路由。
阿宝哥
2019/11/06
1.8K0
Nest系列教程之控制器
Nest.js 从零到壹系列(三):使用 JWT 实现单点登录
上一篇介绍了如何使用 Sequelize 连接 MySQL,接下来,在原来代码的基础上进行扩展,实现用户的注册和登录功能。
一只图雀
2020/04/07
5.6K0
Nest.js笔记(持续更新...)
ctrl+shift+p => package install => typescript
杨肆月
2019/08/15
1.6K0
学习NestJS的第一个接口(一)
以前开发小程序api使用过Express.js、Koa.js等框架,最近想用NestJS重构自己的几个小程序后台,所以从零开始学习NestJS框架。
一起重学前端
2024/09/11
3640
推荐阅读
相关推荐
学完这篇 Nest.js 实战,还没入门的来锤我!(长文预警)
更多 >
LV.9
这个人很懒,什么都没有留下~
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档