【Finetune】(二)、transformers之Prompt-Tuning微调
文章目录
- 0、prompt-tuning基本原理
- 1、实战
- 1.1、导包
- 1.2、加载数据
- 1.3、数据预处理
- 1.4、创建模型
- 1.5、Prompt Tuning*
- 1.5.1、配置文件
- 1.5.2、创建模型
- 1.6、配置训练参数
- 1.7、创建训练器
- 1.8、模型训练
- 1.9、推理:加载预训练好的模型
0、prompt-tuning基本原理
prompt-tuning的基本思想就是冻结主模型的全部参数,在训练数据前加入一小段Prompt,只训练Prompt的表示向量,即一个Embedding模块。其中,prompt又存在两种形式,一种是hard prompt,一种是soft prompt。
1、实战
1.1、导包
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer
1.2、加载数据
ds = Dataset.load_from_disk("../Data/alpaca_data_zh/")
1.3、数据预处理
tokenizer = AutoTokenizer.from_pretrained("../Model/bloom-389m-zh")
tokenizer
def process_func(example):MAX_LENGTH = 256input_ids, attention_mask, labels = [], [], []instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")response = tokenizer(example["output"] + tokenizer.eos_token)input_ids = instruction["input_ids"] + response["input_ids"]attention_mask = instruction["attention_mask"] + response["attention_mask"]labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
tokenized_ds
1.4、创建模型
model = AutoModelForCausalLM.from_pretrained("../Model/bloom-389m-zh",low_cpu_mem_usage=True)
1.5、Prompt Tuning*
1.5.1、配置文件
#soft prompt# config = PromptTuningConfig(
# task_type=TaskType.CAUSAL_LM,
# num_virtual_tokens=10,
# )
# config
#hard prompt
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM,prompt_tuning_init = PromptTuningInit.TEXT,prompt_tuning_init_text = '下面是一段机器人的对话:',num_virtual_tokens=len(tokenizer('下面是一段机器人的对话:')['input_ids']),tokenizer_name_or_path='../Model/bloom-389m-zh',)
config
1.5.2、创建模型
model= get_peft_model(model,config)
model
打印模型训练参数
model.print_trainable_parameters()
1.6、配置训练参数
args = TrainingArguments(output_dir="./chatbot",per_device_train_batch_size=1,gradient_accumulation_steps=4,logging_steps=10,num_train_epochs=1
)
1.7、创建训练器
trainer = Trainer(args=args,model=model,train_dataset=tokenized_ds,data_collator=DataCollatorForSeq2Seq(tokenizer, padding=True, )
)
1.8、模型训练
trainer.train()
1.9、推理:加载预训练好的模型
from peft import PeftModel
peft_model = PeftModel.from_pretrained(model=model,model_id='./chat_bot/checkpoint500/')
from transformers import pipelinepipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
ipt = "Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: "
pipe(ipt, max_length=256, do_sample=True, temperature=0.5)