前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >[干货] 一文介绍如何训练GPT2,让自己的数据会说话

[干货] 一文介绍如何训练GPT2,让自己的数据会说话

原创
作者头像
数智圈
修改2023-10-04 08:03:29
1.2K1
修改2023-10-04 08:03:29
举报
文章被收录于专栏:水滴

作为ChatGPT之前的版本,GPT-2是由OpenAI于2019年发布的人工智能技术,它可以自动生成文本,理解语言并生成连贯的文本回应。它可以用于各种文本生成任务,如文章创作、对话生成和翻译等。它是一个在github的开源项目。

我们可以基于一份文本信息,训练GPT2,然后向它提问,它会根据指定文本的信息,给出回答。

如果你正好对此感兴趣,那就来对地方了。本文英文版首发https://medium.com/@datatec.studio。中文首发公众号德国数据圈。

GPT2根据训练参数的数量,分普通,中等,大型,超大型四个模型,本文以hugging face提供的中等模型gpt2_medium为例,结合Google Colab的GPU来训练。我们需要用到Github, Google Colab, Google driver 以及 hugging face。

如果是本地跑,可以在hugging face上把模型下下来,将Colab项目的源代码少量改动就可以了,比如直接读取硬盘上的模型而不是下载。

文末附带部分训练数据及Colab的项目源代码。

本文主要分7个部分来介绍

1. 展示案例

2. 在Google驱动器中准备数据集

3. 导入Colab项目

4. 更新Colab项目中的Hugging Face 的 Access Token

5. 设置Colab的Runtime并运行项目

6. 结果

7. 源代码

展示案例

我准备一份虚构的公司简介,然后训练模型,之后和它对话,它可以回答这个公司的相关信息。

代码语言:javascript
复制
Company Profile:

Company Name: Dummy-Gpt2-Datatec-Studio Inc

At Dummy-Gpt2-Datatec-Studio Inc, we are a cutting-edge technology company committed to transforming visionary ideas into tangible realities. Our mission is to pioneer innovation that makes a meaningful impact on the world, from extending human longevity to addressing critical environmental challenges and revolutionizing the way we travel.

Portfolio:

Life-Enhancing Pharmaceuticals: Our portfolio includes groundbreaking pharmaceutical research aimed at developing vaccines and treatments to prevent diseases, potentially adding an additional 20 years to the human lifespan.
Environmental Sustainability: We are leaders in developing innovative solutions to combat climate change and ensure a sustainable future for the planet, encompassing green energy, waste reduction, and carbon footprint mitigation.
Advanced Transportation: In our transportation division, we engineer state-of-the-art terrestrial transportation solutions, promising ultra-fast and efficient modes of travel that will reshape the way we move.
Products:

Green Energy Devices: Our product line features advanced devices that empower individuals and businesses to generate clean and renewable energy, rendering traditional power plants and fossil fuels obsolete.
Space Exploration Services: Our space exploration services offer accessible and safe travel options to explore colonies on celestial bodies such as solar planets and planetary moons, opening new frontiers for exploration.
Quantum Computing Solutions: We provide cutting-edge quantum computing solutions, exponentially boosting computational power and network speeds, thus ushering in a new era of technological advancement.
Holographic Communication: Our holographic communication technology replaces traditional televisions and phones with immersive, interactive experiences. Integrating the senses of touch and smell, we bring the virtual world closer to reality.
Services:

Genetic Engineering: Our genetic engineering services explore revolutionary concepts like Wi-Fi-emitting plants, potentially transforming the way we connect and communicate in a lush and connected world.
Futuristic Planning Consultation: We specialize in assisting individuals and organizations in planning for the distant future. Leveraging transformative technologies, we help shape and anticipate possible outcomes to achieve long-term goals.
Contact:

For inquiries and collaborations, please reach out to us at:

Email: contact@dummygpt2datatecstudio.com
Phone: +1-123-456-7890
Address: 123 Innovator's Lane, Futuroville, Earth
Community Engagement:

At Dummy-Gpt2-Datatec-Studio Inc, we take our commitment to community involvement seriously. We actively support STEM education programs, empowering future generations to be at the forefront of scientific and technological advancements that will shape our collective future.

This updated description provides a more concrete and detailed overview of Dummy-Gpt2-Datatec-Studio Inc, reflecting its visionary pursuits and the potential impact on various aspects of life and technology.

最后训练的结果如下:

在数据准备方面,基于公司简介的文本文件,我创建了相关的json文件my_company_info.json,其中包含数百个对话,每个对话包含几次交流,并且每次交流都带有质量标签。我让ChatGPT为我创建了样本,然后使用Python脚本将它们合并成一个文件。下一步会介绍如何下载它。

2. 在Google Driver中准备数据集

2.1 从我的github存储库下载此项目

https://github.com/datatetecyl/gpt2_lab

2.2 在Google Driver中创建一个带有名称GPT2_LAB_DTS的文件夹。https://drive.google.com/

2.3 将github文件夹gpt2_lab/google_driver的内容,上传到你的Google驱动程序文件夹gpt2_lab_dts。

3. 导入Colab项目

在Colab中打开一个新项目。 http://colab.research.google.com/

从上一步下载的gpt2_lab文件夹中,找到文件gpt2_lab/colab/GPT2_FT_Company_Profile_102023.ipynb。 将该文件导入到Colab项目中。

4. 更新Colab项目中的Hugging Face Access Token

前往 Hugging Face,从用户名 -> 设置 -> 访问令牌创建一个新的访问令牌(Access Token)。https://huggingface.co/

复制此访问令牌。

在从上一步创建的 Colab 项目中,将该行中的虚拟令牌替换为新的令牌。

代码语言:javascript
复制
os.environ["HF_HOME_TOKEN"] = "Please_replace_it_with_your_hf_access_token"

5. 设置Colab的Runtime 并运行项目

将 Colab 的运行时更改为 GPU。(Colab -> Menu "Runtime" -> Change Runtime Type->GPU)

现在,一切都准备就绪。 只需运行 Colab 项目,

您可以按“Shift” + “Enter”来运行每个代码块,或者只需点击每个代码块前面的“运行”图标。

6. 结果

运行到最后一段代码时,会显示一个输入文本框,您可以在其中提出提示。然后按下“Enter”,应该显示与您自己的数据和问题相关的答案。

7. 源代码

JSON训练数据样本

代码语言:javascript
复制
[
  {
    "dialog_id": 1,
    "dialog": [
      {
        "id": 0,
        "sender": "user",
        "text": "Tell me about your company."
      },
      {
        "id": 1,
        "sender": "company",
        "text": "We are Dummy-Gpt2-Datatec-Studio Inc, a cutting-edge technology company dedicated to pioneering innovation."
      }
    ],
    "eval_score": 4,
    "profile_match": 5
  },
......
]

Colab 项目文件

代码语言:javascript
复制
# Fine tuning gpt2_medium model and use own data like company profile
#
# See also medium.com blog
# "GPT-2 Fine-Tuning Guide: Building a Chatbot for Your Company Profile"
# https://medium.com/@datatec.studio
#

# Mount google driver
from google.colab import drive
drive.mount('/content/drive')

# Change to google driver folder which contains datasets
# This folder will also be used to save model
print("Please upload the github dataset to your google driver folder GPT2_Lab_DTS")
print("github repository: https://github.com/datatecyl/gpt2_lab/tree/master/google_driver")
%cd /content/drive/MyDrive/GPT2_Lab_DTS

# Install related python package. The requirements.txt is from google driver.
!pip install -r requirements.txt

# Import packages
import os
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
import tqdm

# Define path and environment
result_dir = '/content/drive/MyDrive/GPT2_Lab_DTS/results'
data_file_path = '/content/drive/MyDrive/GPT2_Lab_DTS/data/my_company_info.json'
os.environ["HF_HOME"] = "/content/huggingface"  # Replace with your desired directory
print("Please replace it with your hf access token:")
os.environ["HF_HOME_TOKEN"] = "Please_replace_it_with_your_hf_access_token"

model_name = "gpt2-medium"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(device)

# Write a python file to google driver
# Sample of json datasets
# You can also directly upload this code to your google driver
# The code write here in this way is for better understanding of whole project
%%writefile chat_data.py

from torch.utils.data import Dataset
import json

class ChatData(Dataset):
    def __init__(self, path: str, tokenizer):
        self.data = json.load(open(path, "r"))

        self.X = []
        for i in self.data:
            for j in i['dialog']:
                self.X.append(j['text'])

        for idx, i in enumerate(self.X):
            try:
                self.X[idx] = "<startofstring> " + i + " <bot>: " + self.X[idx + 1] + " <endofstring>"
            except:
                break

        for i in self.data:
            for j in i['dialog']:
                self.X.append(j['text'])

        total_samples = len(self.X)  # Calculate the total number of samples
        print("total_samples", total_samples)
        # define samples amount
        self.X = self.X[:500]
        print("Here is the self.X[0] i wanna check:")
        print(self.X[0])

        self.X_encoded = tokenizer(self.X, return_tensors="pt", max_length=30, padding="max_length", truncation=True)
        self.input_ids = self.X_encoded['input_ids']
        self.attention_mask = self.X_encoded['attention_mask']

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attention_mask[idx]


# Download model, save model and tokernize to harddisk
## prepare tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({"pad_token": "<pad>",
                            "bos_token": "<startofstring>",
                            "eos_token": "<endofstring>"})
tokenizer.add_tokens(["<bot>:"])

## prepare model
### Specify the desired embedding size (must be a multiple of 8)
desired_embedding_size = 50264  # Change this to the desired size
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
### Resize the embedding layer to the desired size
model.resize_token_embeddings(len(tokenizer), desired_embedding_size)
model = model.to(device)

## save tokenizer and model to harddisk
tokenizer.save_pretrained(result_dir)
model.save_pretrained(result_dir)

## load model and tokenizer from harddisk
### Load the GPT-2 tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(result_dir)


### Load the GPT-2 model from the local folder
model = GPT2LMHeadModel.from_pretrained(result_dir)
model.to(device)

# Define infer and train function
def infer(inp):
  inp = "<startofstring> " + inp + " <bot>: "
  inp = tokenizer(inp, return_tensors="pt")
  X = inp["input_ids"].to(device)  # Use .to(device) method to move the tensor to the specified device
  a = inp["attention_mask"].to(device)  # Use .to(device) method here as well

  output = model.generate(X, attention_mask=a, max_length=100, num_return_sequences=1)

  output = tokenizer.decode(output[0])

  return output

def train(chatData, model, optim):

  epochs = 12

  for _ in tqdm.tqdm(range(epochs)):  # Use range() to iterate through epochs
      for X, a in chatData:
          print(X)
          X = X.to(device)
          a = a.to(device)
          optim.zero_grad()
          loss = model(input_ids=X, attention_mask=a, labels=X).loss
          loss.backward()
          optim.step()

  # Save the model's state dictionary after training is complete
  torch.save(model.state_dict(), "model_state.pt")
  print(infer("How do you see the integration of holographic technology in education?"))

# Load ChatData, train model and optimizer
from chat_data import ChatData

chatData = ChatData(data_file_path, tokenizer)
chatData = DataLoader(chatData, batch_size=64)

model.train()

optim = Adam(model.parameters())

# train 10 times

epochs = 10  # You can adjust the number of epochs as needed
for epoch in range(epochs):
    print("Round: ", epoch)
    train(chatData, model, optim)

# Show input textfield to interaction with model

inp = ""
while True:
  inp = input("Enter your input (press Enter when done): " + " " * 20)
  print(infer(inp))

相关链接

https://medium.com/@datatec.studio/gpt-2-fine-tuning-guide-building-a-chatbot-for-your-company-profile-8b3137c49f1e

https://github.com/datatecyl/gpt2_lab/tree/master

https://huggingface.co/gpt2-medium

https://github.com/openai/gpt-2

https://www.youtube.com/watch?v=elUCn_TFdQc

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档