当前位置: 首页 > news >正文

使用gpt2-medium基座说明模型微调

预训练与微调的背景

  • 预训练:在大规模数据集上训练模型,以捕捉通用的特征和模式。例如,GPT-2 模型在大量文本上进行训练,学习语言的基本结构和语法。
  • 微调:在特定领域或任务的数据上对预训练模型进行训练,以使其更好地适应特定需求。微调通常需要的数据量少于从头开始训练模型所需的数据量。

微调的过程

微调过程通常包括以下几个步骤:

  1. 选择预训练模型:选择一个适合任务的预训练模型,通常根据模型在相似任务上的表现来决定。
  2. 准备数据:收集并清洗与目标任务相关的数据,确保数据的质量和代表性。
  3. 调整模型参数
    • 学习率:微调时通常使用较小的学习率,因为模型已经在大规模数据上学习到了丰富的特征,微调的目的是精细调整这些特征。
    • 冻结部分层:在某些情况下,可以选择冻结预训练模型的某些层,只训练后面的几层,以避免破坏已经学习到的知识。
  4. 训练过程:使用特定任务的数据对模型进行训练,计算损失并进行反向传播,以更新模型参数。
  5. 评估与优化:在验证集上评估模型性能,根据需要调整超参数或训练策略,直到达到满意的结果。

下面使用gpt2-medium 来说明

import torch
import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch.utils.data import Dataset, DataLoadernum_epochs=1000
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 加载预训练模型和tokenizer
model_name = "gpt2-medium"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.to(device)# 定义微调数据集
class CustomDataset(Dataset):def __init__(self, data):self.input_ids = datadef __len__(self):return len(self.input_ids)def __getitem__(self, index):return torch.tensor(self.input_ids[index])training = False
if training:# 加载和处理数据train_data = ['我家门前有两棵树,一棵是枣树,另一棵也是枣树。']  # 微调用的训练数据tokenizer.pad_token = tokenizer.eos_tokentokenizer.add_special_tokens({'pad_token': '[PAD]'})train_encodings = tokenizer(train_data, truncation=True, padding=True)train_dataset = CustomDataset(train_encodings["input_ids"])train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)# 设置优化器optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)# 微调模型model.train()for id, epoch in enumerate(tqdm.tqdm(range(num_epochs))):for batch in train_dataloader:input_ids = batch.to(device)model.zero_grad()outputs = model(input_ids, labels=input_ids)loss = outputs.lossloss.backward()optimizer.step()# 保存微调后的模型model.save_pretrained("data/nlp_model")tokenizer.save_pretrained("data/nlp_model")if not training:# 加载微调后的模型和tokenizermodel_name = "data/nlp_model"  # 微调后模型的路径tokenizer = GPT2Tokenizer.from_pretrained(model_name)model = GPT2LMHeadModel.from_pretrained(model_name)# 设置设备device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)# 生成文本prompt = "门前有树是我家的"input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)# 生成对应的 attention maskattention_mask = (input_ids != 0).int()output = model.generate(input_ids, max_length=100, num_return_sequences=1, attention_mask=attention_mask)generated_text = tokenizer.decode(output[0], skip_special_tokens=True)print("Generated Text:")print(generated_text)

代码展示了如何使用 PyTorch 和 Hugging Face 的 Transformers 库微调一个预训练的 GPT-2 模型,具体是 gpt2-medium 版本,并使用中文文本进行训练和生成任务。以下是代码的详细解释和模型微调的技术要点。

代码解释

  1. 导入所需的库

    import torch
    import tqdm
    from transformers import GPT2LMHeadModel, GPT2Tokenizer
    from torch.utils.data import Dataset, DataLoader
    

    这部分代码导入了 PyTorch、进度条库 tqdm、Hugging Face 的 Transformers 库中的模型和 tokenizer,以及 PyTorch 的数据集和数据加载器工具。

  2. 设置设备

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    

    检测是否有可用的 GPU,如果有,则使用 GPU。

  3. 加载模型和 tokenizer

    model_name = "gpt2-medium"
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    model = GPT2LMHeadModel.from_pretrained(model_name)
    model.to(device)
    

    加载预训练的 GPT-2 模型和相应的 tokenizer,并将模型移动到指定的设备上。

  4. 定义自定义数据集

    class CustomDataset(Dataset):def __init__(self, data):self.input_ids = datadef __len__(self):return len(self.input_ids)def __getitem__(self, index):return torch.tensor(self.input_ids[index])
    

    创建一个自定义数据集类,继承自 PyTorch 的 Dataset 类,主要用于处理输入数据。

  5. 数据处理和准备

    train_data = ['我家门前有两棵树,一棵是枣树,另一棵也是枣树。']
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    train_encodings = tokenizer(train_data, truncation=True, padding=True)
    train_dataset = CustomDataset(train_encodings["input_ids"])
    train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    
    • 定义训练数据。
    • 设置 tokenizer 的填充标记。
    • 对文本进行编码,生成输入 ID。
    • 创建一个数据集实例和数据加载器,以便在训练过程中批量处理数据。
  6. 设置优化器

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    

    使用 AdamW 优化器,设置学习率为 1e-5

  7. 微调模型

    model.train()
    for id, epoch in enumerate(tqdm.tqdm(range(num_epochs))):for batch in train_dataloader:input_ids = batch.to(device)model.zero_grad()outputs = model(input_ids, labels=input_ids)loss = outputs.lossloss.backward()optimizer.step()
    
    • 将模型置于训练模式。
    • 遍历多个训练轮次(epochs)。
    • 对每个批次进行前向传播、计算损失、反向传播和优化步骤。
  8. 保存微调后的模型

    model.save_pretrained("data/nlp_model")
    tokenizer.save_pretrained("data/nlp_model")
    

    保存微调后的模型和 tokenizer。

  9. 文本生成

    prompt = "xxxx"
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)attention_mask = (input_ids != 0).int()
    output = model.generate(input_ids, max_length=100, num_return_sequences=1, attention_mask=attention_mask)
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)print("Generated Text:")
    print(generated_text)
    
    • 使用微调后的模型生成文本,给定一个提示词(prompt)。
    • 编码提示词并生成文本序列,最后解码为可读文本并输出。

测试下微调效果

对于资源并不充沛的公司而言

一个可行的思路是结合参数较小的模型进行微调,再利用向量数据库和知识图谱使用去实现RAG

模型微调的技术要点

  1. 数据准备:微调时使用的数据应与目标应用场景相符,以便模型能够学习特定的上下文和语言特征。

  2. 超参数设置:学习率、批量大小、训练轮数等超参数对模型性能有重要影响。通常需要通过实验来找到最适合的设置。

  3. 损失计算:在微调过程中,通常使用模型输出的损失值进行优化,以指导模型学习。

  4. 模型保存:微调后的模型需要保存,以便后续使用或部署。

  5. 文本生成:使用微调后的模型生成文本时,可以通过调整 max_lengthnum_return_sequences 等参数来控制生成文本的长度和数量。


http://www.mrgr.cn/news/57563.html

相关文章:

  • CentOS 7 将 YUM 源更改为国内镜像源
  • 知识图谱的概念、特点及应用领域(详解)
  • 在Ubuntu上部署MQTT服务器的详细指南
  • flex常用固定搭配
  • Stable Diffusion视频插件Ebsynth Utility使用方法
  • Vue组件开发详解
  • anolis os 8.8 修改kube-proxy的模式为ipvs-kubeadm部署
  • arcgis pro 3.3.1安装教程
  • 重学SpringBoot3-Spring WebFlux之HttpHandler和HttpServer
  • 代码随想录算法训练营第二十五天 | 491.递增子序列 46.全排列 47.全排列Ⅱ
  • LeetCode练习-删除链表的第n个结节
  • Hot100速刷计划day04(10-12)
  • 【网页布局技术】项目六 制作表格并使用CSS美化
  • 【Linux】进程信号(下)
  • CCRC-CDO首席数据官的主要工作内容
  • 全新原生鸿蒙HarmonyOS NEXT发布,书写国产操作系统新篇章!同时,触觉智能发布OpenHarmony5.0固件
  • (一)ArkTS语言——申明与类型
  • day7:软件包管理
  • 力扣247题详解:中心对称数 II 的多种解法与模拟面试
  • 自动粘贴神器,数据复制粘贴快速处理记事本
  • RK平台操作GPIO的两种方法
  • 爬虫中代理ip的选择和使用实战
  • HCIP--1
  • Java 网络下载文件输出字节流
  • 鸿蒙开发融云Demo消息列表
  • 顶层模块中定义一个数组,如何 通过端口将这个数组传递给所有需要它的子模块