当前位置: 首页 > news >正文

【Finetune】(二)、transformers之Prompt-Tuning微调

文章目录

  • 0、prompt-tuning基本原理
  • 1、实战
    • 1.1、导包
    • 1.2、加载数据
    • 1.3、数据预处理
    • 1.4、创建模型
    • 1.5、Prompt Tuning*
      • 1.5.1、配置文件
      • 1.5.2、创建模型
    • 1.6、配置训练参数
    • 1.7、创建训练器
    • 1.8、模型训练
    • 1.9、推理:加载预训练好的模型

0、prompt-tuning基本原理

 prompt-tuning的基本思想就是冻结主模型的全部参数,在训练数据前加入一小段Prompt,只训练Prompt的表示向量,即一个Embedding模块。其中,prompt又存在两种形式,一种是hard prompt,一种是soft prompt。

在这里插入图片描述

1、实战

1.1、导包

from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer

1.2、加载数据

ds = Dataset.load_from_disk("../Data/alpaca_data_zh/")

1.3、数据预处理

tokenizer = AutoTokenizer.from_pretrained("../Model/bloom-389m-zh")
tokenizer

def process_func(example):MAX_LENGTH = 256input_ids, attention_mask, labels = [], [], []instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")response = tokenizer(example["output"] + tokenizer.eos_token)input_ids = instruction["input_ids"] + response["input_ids"]attention_mask = instruction["attention_mask"] + response["attention_mask"]labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
tokenized_ds

1.4、创建模型

model = AutoModelForCausalLM.from_pretrained("../Model/bloom-389m-zh",low_cpu_mem_usage=True)

1.5、Prompt Tuning*

1.5.1、配置文件

#soft prompt# config = PromptTuningConfig(
#     task_type=TaskType.CAUSAL_LM,
#     num_virtual_tokens=10,
#     )
# config
#hard prompt
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM,prompt_tuning_init = PromptTuningInit.TEXT,prompt_tuning_init_text = '下面是一段机器人的对话:',num_virtual_tokens=len(tokenizer('下面是一段机器人的对话:')['input_ids']),tokenizer_name_or_path='../Model/bloom-389m-zh',)
config

1.5.2、创建模型

model= get_peft_model(model,config)
model

打印模型训练参数

model.print_trainable_parameters()

1.6、配置训练参数

args = TrainingArguments(output_dir="./chatbot",per_device_train_batch_size=1,gradient_accumulation_steps=4,logging_steps=10,num_train_epochs=1
)

1.7、创建训练器

trainer = Trainer(args=args,model=model,train_dataset=tokenized_ds,data_collator=DataCollatorForSeq2Seq(tokenizer, padding=True, )
)

1.8、模型训练

trainer.train()

1.9、推理:加载预训练好的模型

from peft import PeftModel
peft_model =  PeftModel.from_pretrained(model=model,model_id='./chat_bot/checkpoint500/')
from transformers import pipelinepipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
ipt = "Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: "
pipe(ipt, max_length=256, do_sample=True, temperature=0.5)

http://www.mrgr.cn/news/30001.html

相关文章:

  • 配网缺陷检测无人机航拍图像数据集(不规范绑扎,螺栓销钉缺失)数据集总共3000张左右,标注为voc格式
  • QT开发:深入详解QtCore模块事件处理,一文学懂QT 事件循环与处理机制
  • Mysql系列-索引优化
  • 鸿萌数据恢复服务: 修复 Windows, Mac, 手机中 “SD 卡无法读取”错误
  • 鹏哥C语言43---函数的嵌套调用和链式访问
  • 73、Python之函数式编程:“一行流”大全,人生苦短,我用Python
  • scanf()函数的介绍及基础用法
  • Ubuntu LLaMA-Factory实战
  • 全新 HLOB 模型:预测限价订单簿中间价格变化方向的利器
  • Qt窗口——QToolBar
  • C++map,set,multiset,multimap详细介绍
  • 基于Jeecgboot3.6.3的flowable流程增加任务节点操作按钮的控制(一)
  • 【pytorch学习笔记,利用Anaconda安装pytorch和paddle深度学习环境+pycharm安装---免额外安装CUDA和cudnn】
  • Spring Boot中的响应与分层解耦架构
  • 如何兼容性地开发响应式站点——WEB开发系列40
  • ‍♀️焦虑症患者的救赎之路:这5项运动让你重拾宁静与力量!
  • python 实现average median平均中位数算法
  • 9.3 溪降技术:携包游泳
  • 新手怎样制作网页?
  • 可靠轻便,开箱即用的数据安全交换系统怎么选?关键在这三点