Loading [MathJax]/jax/output/CommonHTML/config.js
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >改善大型语言模型的3种简单方法

改善大型语言模型的3种简单方法

作者头像
磐创AI
发布于 2023-11-27 06:04:05
发布于 2023-11-27 06:04:05
80000
代码可运行
举报
运行总次数:0
代码可运行

大型语言模型(LLMs)已经成为现实。随着最近发布的Llama 2,开源LLMs正在接近ChatGPT的性能,并且经过适当调整,甚至可以超越它。

使用这些LLMs通常并不像看起来那么简单,特别是如果你想将LLM进行精细调整以适应特定用例。

在本文中,我们将介绍3种改善任何LLM性能的最常见方法:

  1. 提示工程
  2. 检索增强生成(RAG)
  3. 参数高效微调(PEFT)

还有许多其他方法,但这些是最简单的方法,可以在不多的工作量下带来重大改进。

这3种方法从最简单的方法开始,即所谓的低挂果,到更复杂的改进LLM的方法之一。

要充分利用LLMs,甚至可以将这三种方法结合起来使用!

在开始之前,这里是更详细的方法概述,以便更容易参考。

你还可以在Google Colab Notebook中跟随操作,以确保一切都按预期工作。

加载Llama 2

在开始之前,我们需要加载一个LLM,以便在这些示例中使用。我们选择基本的Llama 2,因为它展现出令人难以置信的性能,而且我也喜欢在教程中坚持使用基础模型。

在开始之前,我们首先需要接受许可协议。请按照以下步骤操作:

  1. 在此处创建一个HuggingFace帐户。
  2. 在此处申请Llama 2的访问权限。
  3. 在此处获取你的HuggingFace令牌。

完成后,我们可以使用HuggingFace凭据登录,以便此环境知道我们有权限下载我们感兴趣的Llama 2模型:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
from huggingface_hub import notebook_login
notebook_login()

接下来,我们可以加载Llama 2的13B变体。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
from torch import cuda, bfloat16
import transformers

model_id = 'meta-llama/Llama-2-13b-chat-hf'
pyt
# 4-bit Quanityzation to load Llama 2 with less GPU memory
bnb_config = transformers.BitsAndBytesConfig(
    load_in_4bit=True,  
    bnb_4bit_quant_type='nf4',  
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=bfloat16
)

# Llama 2 Tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)

# Llama 2 Model
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    quantization_config=bnb_config,
    device_map='auto',
)
model.eval()

# Our text generator
generator = transformers.pipeline(
    model=model, tokenizer=tokenizer,
    task='text-generation',
    temperature=0.1,
    max_new_tokens=500,
    repetition_penalty=1.1
)

大多数开源LLMs在创建提示时都必须遵循某种模板。就Llama 2而言,以下内容有助于引导提示的编写:

这意味着我们必须按以下方式使用提示来正确生成文本:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
basic_prompt = """
<s>[INST] <<SYS>>

You are a helpful assistant

<</SYS>>

What is 1 + 1? [/INST]
"""
print(generator(basic_prompt)[0]["generated_text"])

然后,生成以下输出:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
"""
Oh my, that's a simple one! 
The answer to 1 + 1 is... (drumroll please)... 2! 😄
"""

这个模板并没有看起来那么复杂,但稍加练习,你很快就能掌握它。

现在,让我们深入探讨改进LLM输出的第一种方法,即提示工程。

1.提示工程 ⚙️

我们询问LLM某事的方式对我们获得的输出质量有重大影响。我们需要明确、完整,并提供我们感兴趣的输出的示例。

这种定制提示的过程称为提示工程。

提示工程是一种非常出色的“调整”模型的方式。它不需要更新模型,你可以快速迭代。

提示工程有两个主要概念:

  • 基于示例的
  • 基于思考的
基于示例的提示工程

在基于示例的提示工程中,例如一次性或少量示例学习,我们向LLM提供了一些我们寻找的示例。

这通常生成更符合我们期望的文本。

例如,让我们对一个简短的评论应用情感分类:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
prompt = """
<s>[INST] <<SYS>>

You are a helpful assistant.

<</SYS>>

Classify the text into neutral, negative or positive. 
Text: I think the food was okay. [/INST]
"""
print(generator(prompt)[0]["generated_text"])

然后生成以下输出:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
"""
Positive. The word "okay" is a mildly positive word, 
indicating that the food was satisfactory or acceptable.
"""

个人而言,我对这个答案并不太满意。我认为它更中性而不是积极。此外,我们必须在文本中搜索答案。

相反,让我们给出一个我们希望生成答案的示例:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
prompt = """
<s>[INST] <<SYS>>

You are a helpful assistant.

<</SYS>>

Classify the text into neutral, negative or positive. 
Text: I think the food was alright.
Sentiment: 
[/INST] 

Neutral</s><s>

[INST]
Classify the text into neutral, negative or positive. 
Text: I think the food was okay.
Sentiment: 
[/INST] 
"""
print(generator(prompt)[0]["generated_text"])

然后查看输出,我们得到了预期结果:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
"""
Neutral
"""

现在,LLM言之有物,只提供我们感兴趣的答案。好多了!

基于思考的提示工程

我们可以进一步要求LLM“思考”其答案。

通过让LLM将其思考分解为较小的步骤,允许每个步骤都有更多的计算。这些较小的步骤通常被称为LLM的“思考”。

有许多使用基于思考的提示的方式,我们可以使其相当复杂,如下所示:

然而,让我们从最基本的示例开始。如果我们要求模型通过将问题分解为步骤来解决问题呢?

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
prompt = """
<s>[INST] <<SYS>>

You are a helpful assistant.

<</SYS>>

Do the odd numbers in this group add up to an even number? 3, 5, 15, 32.
Solve by breaking the problem into steps. 
Identify the odd numbers, add them, and indicate whether the result is odd or even.[/INST]
"""
print(generator(prompt)[0]["generated_text"])

输出

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
"""
Hello! I'd be happy to help you with that. Let's break down the problem into steps:

Step 1: Identify the odd numbers in the group.

The odd numbers in the group are:

3, 5, 15

Step 2: Add the odd numbers.

3 + 5 + 15 = 23

Step 3: Is the result odd or even?

23 is an odd number.

Therefore, the sum of the odd numbers in the group is an odd number.
"""

正如我们之前看到的,这也被称为思维链,LLM有一系列单独的思维或步骤。

这些单独的步骤还有助于模型在计算过程中保持责任感。因为它对每个步骤都进行了“推理”,所以它的“思考”过程有结构。

2.检索增强生成(RAG) 🗃️

尽管提示工程可以带来改进,但它不能使LLM知道它之前没有学到的事情。

当一个LLM在2022年进行训练时,它对2023年发生的事情一无所知。

这就是检索增强生成(RAG)的用武之地。这是一种为LLM提供外部知识以便利用的方法。

在RAG中,知识库,如维基百科,被转化为数值表示以捕捉其含义,称为嵌入。这些嵌入存储在矢量数据库中,以便可以轻松检索信息。

然后,当你向LLM提供某个提示时,将在矢量数据库中搜索与提示相关的信息。

最相关的信息然后作为附加上下文传递给LLM,以便它可以生成其响应。

在实践中,RAG有助于LLM“查找”外部知识库中的信息,以改善其响应。

使用LangChain创建RAG管道

要创建RAG管道或系统,我们可以使用众所周知且易于使用的框架LangChain。

我们将首先创建有关Llama 2的小型知识库,并将其写入文本文件:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# Our tiny knowledge base
knowledge_base = [
    "On July 18, 2023, in partnership with Microsoft, Meta announced LLaMA-2, the next generation of LLaMA." ,
    "Llama 2, a collection of pretrained and fine-tuned large language models (LLMs) ",
    "The fine-tuned LLMs, called Llama 2-Chat, are optimized for dialogue use cases.",
    "Meta trained and released LLaMA-2 in three model sizes: 7, 13, and 70 billion parameters.",
    "The model architecture remains largely unchanged from that of LLaMA-1 models, but 40% more data was used to train the foundational models.",
    "The accompanying preprint also mentions a model with 34B parameters that might be released in the future upon satisfying safety targets."
]
with open(r'knowledge_base.txt', 'w') as fp:
    fp.write('\n'.join(knowledge_base))

完成后,我们需要创建一个嵌入模型,可以将文本转换为数值表示,即嵌入。

我们将选择一个众所周知的句子嵌入模型,即sentence-transformers/all-MiniLM-L6-v2。

🔥提示🔥你可以在大规模文本嵌入基准(MTEB)排行榜上找到许多出色的模型。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
from langchain.embeddings.huggingface import HuggingFaceEmbeddings

# Embedding Model for converting text to numerical representations
embedding_model = HuggingFaceEmbeddings(
    model_name='sentence-transformers/all-MiniLM-L6-v2'
)

现在,我们有了一个嵌入模型和一个小型知识库,可以开始开发我们的矢量数据库。有许多选择,但让我们选择一个可以在本地使用的,即FAISS。

🔥提示🔥其他出色的矢量数据库包括Pinecone和Weaviate。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.document_loaders import TextLoader

# Load documents and split them
documents = TextLoader("knowledge_base.txt").load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)

# Create local vector database
db = FAISS.from_documents(docs, embedding_model)

db变量包含了数据库所需的一切,但仍然需要将其与LLM组合在一起。LangChain使这一过程非常简单和直接:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFacePipeline

# Load LLM into LangChain
llm = HuggingFacePipeline(pipeline=generator)

# RAG Pipeline
rag = RetrievalQA.from_chain_type(
    llm=llm, chain_type='stuff',
    retriever=db.as_retriever()
)

我们创建的llm和rag两个变量的伟大之处在于,我们可以在不使用RAG管道的情况下使用llm进行一些提示,使用RAG管道时则使用rag。

让我们首先尝试不使用RAG。如果我们问LLM关于自身Llama 2的问题,会发生什么?

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
>>> llm('What is Llama 2?')

"""
Llama 2 is a cryptocurrency that was created as a parody of the popular cryptocurrency Dogecoin. It was designed to be a more serious and less meme-focused alternative to Dogecoin, with a focus on privacy and security.
Llama 2 uses a proof-of-work consensus algorithm and has a block time of 1 minute. It also features a decentralized governance system that allows holders of the llama2 token to vote on proposals for the development and growth of the project.
One of the key features of Llama 2 is its use of "ring signatures" to enhance privacy and security. Ring signatures are a type of digital signature that can be used to hide the identity of the sender in a transaction. This makes it more difficult for hackers or other malicious actors to trace the source of a transaction.
Overall, Llama 2 is a privacy-focused cryptocurrency that is designed to provide users with greater control over their financial data and more secure transactions.
"""

加密货币?这不是我们要找的答案…它似乎对自己一无所知。

让我们尝试使用RAG管道:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
>>> rag('What is Llama 2?')

"""
Llama 2 is a collection of pretrained and fine-tuned large language models 
(LLMs) announced by Meta in partnership with Microsoft on July 18, 2023.
"""

这好多了!

由于我们为Llama 2提供了关于自身的外部知识,它可以利用这些信息生成更准确的答案。

🔥提示🔥提示可能会很快变得复杂。如果你想知道LMM实际收到的提示,请在运行LMM之前运行以下代码:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import langchain
langchain.debug = True

3.参数高效微调 🛠️

无论是提示工程还是RAG,通常不会改变LLM本身。它的参数保持不变,模型不会“学习”任何新知识,它只是进行利用。

我们可以使用领域特定的数据对LLM进行精细调整,以使其学到新的东西。

与其微调模型的数十亿个参数,不如使用参数高效微调(PEFT)。正如其名称所示,它是一个子领域,专注于使用尽可能少的参数有效地微调LLM。

其中最常使用的方法之一被称为低秩适应(LoRA)。LoRA找到原始参数的一个小子集,无需触及基础模型。

这些参数可以看作是完整模型的较小表示,只对最重要或最有影响的参数进行训练。其美妙之处在于所得到的权重可以添加到基础模型中,因此可以单独保存。

使用AutoTrain对Llama 2进行微调

使用众多参数微调Llama 2的过程可能会很困难。幸运的是,AutoTrain可以帮助你解决大部分问题,使你只需一行代码即可进行微调!

首先,数据是最重要的,它对结果性能的影响最大!

我们将使基本的Llama 2模型成为一个聊天模型,并将使用OpenAssistant Guanaco数据集:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import pandas as pd
from datasets import load_dataset

# Load dataset in pandas
dataset = load_dataset("timdettmers/openassistant-guanaco")
df = pd.DataFrame(dataset["train"][:1000]).dropna()
df.to_csv("train.csv")

数据集包含许多问题/回答方案,你可以在上面对Llama 2进行训练。它用### Human标签区分用户,用### Assistant标签区分LLM的回应。

为了说明,我们只从该数据集中取了1000个样本,但质量更高的数据点肯定会提高性能。

注意:数据集需要一个文本列,AutoTrain将自动使用它。

训练本身非常简单,只需安装AutoTrain,然后运行以下代码:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
autotrain llm --train \
--project_name Llama-Chat \
--model abhishek/llama-2-7b-hf-small-shards \
--data_path . \
--use_peft \
--use_int4 \
--learning_rate 2e-4 \
--num_train_epochs 1 \
--trainer sft \
--merge_adapter

有一些重要的参数:

data_path:数据的路径。我们在本地保存了一个包含文本列的train.csv,AutoTrain在训练期间将使用它。

model:我们要微调的基础模型。它是基础模型的分片版本,便于训练。

use_peft和use_int4:这些参数启用了对模型的高效微调,减少了所需的VRAM。它部分地利用了LoRA。

merge_adapter:为了更容易使用模型,我们将LoRA与基础模型合并,创建一个新模型。

运行训练代码时,你应该会得到类似以下内容的输出:

就是这样!以这种方式微调Llama 2模型非常简单,因为我们将LoRA权重与原始模型合并,所以可以像之前一样加载更新后的模型。

🔥提示🔥尽管一行代码进行微调令人惊叹,但强烈建议你自己查看参数。通过深入的指南学习精细调整的确切含义,有助于你了解何时出现问题。

更新:我上传了一份更详细介绍如何使用这些方法的视频版本到YouTube。

https://youtu.be/Rqu5Hjsbq6A

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-11-26,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 磐创AI 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
常见的线性结构
  本篇博客主要是记录手写这些这数据结构的底层实现,加深对线性结构的理解,实现自己的一个小型数据结构库,也会进行简单的时间复杂度分析,对不同的实现进行比较和优化,即侧重于代码实现。由于数据结构是实践性比较强的一个科目,希望大家在看这篇博客时,自己也去写一下代码,看一下运行结果是不是自己想要的,我也会贴出我的运行结果来进行分析。
程序员波特
2024/01/19
2490
常见的线性结构
05-图解数据结构之队列--Queue
零、前言 栈是一种线性的数据结构 特性:尾部添加,头部取出 即先进先出FIFO 操作:enqueue入队 dequeue出队 getFront查看队首元素 队列.png 一、队列接口 /
张风捷特烈
2018/09/29
6660
05-图解数据结构之队列--Queue
重温数据结构(1)——数组与链表数组链表LeetCode相关题目参考
前言:终于到了疯狂学习数据结构的时候,换个好看的题图,开始吧.. 数组 什么是数组? 数组简单来说就是将所有的数据排成一排存放在系统分配的一个内存块上,通过使用特定元素的索引作为数组的下标,可以在常数时间内访问数组元素的这么一个结构; 为什么能在常数时间内访问数组元素? 为了访问一个数组元素,该元素的内存地址需要计算其距离数组基地址的偏移量。需要用一个乘法计算偏移量,再加上基地址,就可以获得某个元素的内存地址。首先计算元素数据类型的存储大小,然后将它乘以元素在数组中的索引,最后加上基地址,就可以计算出
我没有三颗心脏
2018/07/05
2.6K0
最基础的动态数据结构:链表
链表是一种线性结构,也是最基础的动态数据结构。我们在实现动态数组、栈以及队列时,底层都是依托的静态数组,靠resize来解决固定容量的问题,而链表是真正的动态数据结构。学习链表这种数据结构,能够更深入的理解引用(或者指针)以及递归。其中链表分为单链链表和双链链表,本文中所介绍的是单链链表。
端碗吹水
2020/09/23
5730
最基础的动态数据结构:链表
数据结构整理 顶
DefaultArray(data=[66,99,88], size=3 DefaultArray(data=[100,66,99,11], size=4 11 DefaultArray(data=[11], size=1
算法之名
2020/02/18
7570
数据结构整理
                                                                            顶
算法学习之栈与队列
Stack<E> void push(E) //入栈 E pop() //出栈 E peek() //查看 int getSize() //长度 boolean isEmpty() //是否为空
慕白
2020/01/02
3000
算法学习之栈与队列
数据结构与算法(2)——栈和队列栈队列LeetCode 相关题目整理其他题目整理
栈是一种用于存储数据的简单数据结构(与链表类似)。数据入栈的次序是栈的关键。可以把一桶桶装的薯片看作是一个栈的例子,当薯片做好之后,它们会依次被添加到桶里,每一片都会是当前的最上面一片,而每次我们取的时候也是取的最上面的那一片,规定你不能破坏桶也不能把底部捅穿,所以第一个放入桶的薯片只能最后一个从桶里取出;
我没有三颗心脏
2018/07/24
1.4K0
Java实现基本数据结构(三)——队列
阅读本文前,最好先学习顺序表和栈的基本操作和实现原理,也就是弄清楚数组和栈的原理,点击Java实现基本数据结构(一)——数组,Java实现基本数据结构(二)——栈。先学习前置内容,学习效果更好哦!
星如月勿忘初心
2020/08/02
8190
数据结构 | 使用Kotlin实现栈与队列
虽然我们上面实现了普通队列,但是普通的队列也有存在性能问题,比如当我们移除队首元素时,算法复杂度为O(n),这是我们不能接受的。
Petterp
2022/02/09
2.1K0
数据结构 | 使用Kotlin实现栈与队列
线性结构之栈和队列
举个不太恰当的比喻,栈就像一个直径比乒乓球大点的水杯,而元素就像是乒乓球,现在我们要把几个乒乓球放入杯子里。因为杯子底部是实的,所以我们只能从杯口放入兵乓球,我们把乒乓球放入这个水杯的过程就是入栈,把兵乓球从杯子中取出的过程就是出栈。这个杯子的杯口就是栈顶,而在最上面的那个乒乓球就是栈顶元素。当我们想从水杯里拿乒乓球的时候,只能从最上面的开始拿,无法从底部或中间开始拿,符合后进先出的特性:
端碗吹水
2020/09/23
3120
线性结构之栈和队列
Java 循环队列原理与用法详解
(1)设一个容量为capacity=8,size=5(a,b,c,d,e)的数组,左侧为队首、右侧为队尾。
好好学java
2020/03/19
1.9K0
数据结构-队列
队列(queue)在计算机科学中,是一种先进先出的线性表。 它只允许在表的前端进行删除操作,而在表的后端进行插入操作。进行插入操作的端称为队尾,进行删除操作的端称为队头。队列中没有元素时,称为空队列。
杨小杰
2019/06/03
3140
数据结构之队列
1、队列Queue,队列也是一种线性结构,相比数组,队列对应的操作是数组的子集,只能从一端(队尾)添加元素,只能从另一端(队首)取出元素。队列是一种先进先出的数据结构(或者称为先到先得),First In First Out(简称FIFO)。
别先生
2020/03/19
4880
循环队列
(1)设一个容量为capacity=8,size=5(a,b,c,d,e)的数组,左侧为队首、右侧为队尾。
wfaceboss
2019/04/08
5570
循环队列
搞定数据结构-栈和队列
如下,使用栈结构操作. “网”这个错别字在栈顶,“网”改成”望”只需要将“网”从栈顶移除重新写入”望”.
用户3045442
2019/11/06
5880
搞定数据结构-栈和队列
链表应用--基于链表实现队列--尾指针
在开始栈的实现之前,我们再来看看关于链表的只在头部进行的增加、删除、查找操作,时间复杂度均为O(1)。
wfaceboss
2019/04/08
7010
链表应用--基于链表实现队列--尾指针
【数据结构】栈的基本实现
MaybeHC
2024/04/23
1190
【数据结构】栈的基本实现
【数据结构】循环队列
上次实现了数组队列,这次来实现循环队列 循环队列的几个要点,front指向队头元素,tail指向队尾元素的下一个位置,front=tail时队列为空,(front+1)% data.Length = tail时队列为满,还是会使用第一节所编写的数组类做最底层。
MaybeHC
2024/04/23
1410
【数据结构】循环队列
搞定数据结构-数组结构
从数组存储的内存模型来看,“下标”最确切的定义应该是”偏移”,如果用a来表示数组的首地址,a0 就是偏移为0的位置,也就是首地址,a k就表示偏移k个type_size的位置,a的内存地址公式就是_
用户3045442
2019/11/06
4240
重学数据结构-使用Kotlin实现链表及其他扩展
在上述的实现里,我们在添加节点时,每次都需要考虑链表为null的情况,对于这种状态下,我们可以考虑引入一个虚拟节点,这样我们每次遍历添加时就可以不用 index-1,即无需考虑前一个节点的情况,而且我们还可以考虑加入 删除,修改,和根据位置查找。
Petterp
2022/02/09
6990
重学数据结构-使用Kotlin实现链表及其他扩展
相关推荐
常见的线性结构
更多 >
LV.0
深圳魔图互联科技有限公司算法工程师
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验