当前位置: 首页 > news >正文

实战千问2大模型第四天——Qwen2-VL-7B(多模态)lora微调训练和测试

一、简介

Qwen2-VL 是一种先进的多模态人工智能模型,专注于视觉和语言任务,能够理解和生成基于图像的内容。它是通义千问团队开发的 Qwen-VL 模型的升级版本,通过结合最新的机器学习技术和算法,提供了更强的图像理解能力、视频分析能力和多语言支持。

Qwen2-VL 的实现基于深度学习的多模态框架,主要技术包括:

  1. 视觉转换器(Visual Transformers):利用自注意力机制来处理图像数据,使模型能够关注图像中的关键部分并从中提取有意义的特征。

  2. 自然语言处理(NLP)技术:结合先进的 NLP 模型来处理和理解文本信息,使得模型能够更好地与人类语言交互。

  3. 多模态融合技术:通过特定的融合层将视觉和文本信息结合起来,实现更高效的信息处理和决策支持。

微调(Fine-tuning)

微调是一种常见的机器学习技术,用于将预训练好的模型适配到特定的应用场景中。在 Qwen2-VL 的上下文中,微调具有以下作用:

  • 性能优化:通过在特定的数据集上训练模型,可以优化模型的性能,使其在特定任务上表现更好。

  • 自定义能力:微调使得模型可以根据不同用户的需求进行定制,例如在特定领域内理解特殊的视觉内容或者专业术语。

  • 减少资源消耗:相比于从头开始训练一个全新的模型,微调一个已经预训练的模型可以大大减少所需的计算资源和时间。

Qwen2-VL 通过这些先进的技术和微调的策略,能够更好地服务于广泛的行业和用户,提供更精准、更智能的视觉和语言解决方案。

二、训练

在之前部署qwen2vl环境的基础上,下载LLaMA-Factory,并安装环境。

git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
pip install -e ".[torch,metrics]"pip install deepspeed
pip install flash-attn --no-build-isolation

训练之前,修改 train_lora  和  merge_lora 的 qwen2vl_lora_sft.yaml,指定模型的对应位置,不然又得下载一遍。

llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
llamafactory-cli export examples/merge_lora/qwen2vl_lora_sft.yaml

训练后的模型会保存在:LLaMA-Factory/models/qwen2_vl_lora_sft。

问题:

Traceback (most recent call last):File "/home/py/ycc/Qwen/1.py", line 39, in <module>text = processor.apply_chat_template(File "/home/py/anaconda3/envs/qwenycc/lib/python3.9/site-packages/transformers/processing_utils.py", line 988, in apply_chat_templateraise ValueError(
ValueError: No chat template is set for this processor. Please either set the `chat_template` attribute, or provide a chat template as an argument. See https://huggingface.co/docs/transformers/main/en/chat_templating for more information.

解决:

将原模型中的chat_template.json 复制到LLaMA-Factory/models/qwen2_vl_lora_sft 即可解决。

三、测试

使用这张图片进行测试

 1.使用原始模型进行测试

from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info# default: Load the model on the available device(s)
model = Qwen2VLForConditionalGeneration.from_pretrained("qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
)# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
# model = Qwen2VLForConditionalGeneration.from_pretrained(
#     "Qwen/Qwen2-VL-7B-Instruct",
#     torch_dtype=torch.bfloat16,
#     attn_implementation="flash_attention_2",
#     device_map="auto",
# )# default processer
processor = AutoProcessor.from_pretrained("./qwen/Qwen2-VL-7B-Instruct")# The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage.
# min_pixels = 256*28*28
# max_pixels = 1280*28*28
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)messages = [{"role": "user","content": [{"type": "image","image": "1.jpg",},{"type": "text", "text": "描述一下这张图片"},],}
]# Preparation for inference
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(text=[text],images=image_inputs,videos=video_inputs,padding=True,return_tensors="pt",
)
inputs = inputs.to("cuda")# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)

2.使用微调后的模型进行测试

替换成微调后的模型:

from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info# default: Load the model on the available device(s)
model = Qwen2VLForConditionalGeneration.from_pretrained(  #qwen/Qwen2-VL-7B-Instruct"LLaMA-Factory/models/qwen2_vl_lora_sft", torch_dtype="auto", device_map="auto"
)# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
# model = Qwen2VLForConditionalGeneration.from_pretrained(
#     "Qwen/Qwen2-VL-7B-Instruct",
#     torch_dtype=torch.bfloat16,
#     attn_implementation="flash_attention_2",
#     device_map="auto",
# )# default processer
processor = AutoProcessor.from_pretrained("./LLaMA-Factory/models/qwen2_vl_lora_sft")# The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage.
# min_pixels = 256*28*28
# max_pixels = 1280*28*28
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)messages = [{"role": "user","content": [{"type": "image","image": "1.jpg",},{"type": "text", "text": "Who are they?"},],}
]# Preparation for inference
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(text=[text],images=image_inputs,videos=video_inputs,padding=True,return_tensors="pt",
)
inputs = inputs.to("cuda")# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)


http://www.mrgr.cn/news/45350.html

相关文章:

  • python画图|显式和隐式接口The explicit and the implicit interfaces
  • can 总线入门———can简介硬件电路
  • Redis面试篇1
  • 也来猜猜 o1 实现方法
  • OpenCV高级图形用户界面(3)关闭由 OpenCV 创建的指定窗口函数destroyWindow()的使用
  • PCL-点云质心识别
  • 机器学习——强化学习与深度强化学习
  • JioNLP:一款实用的中文NLP预处理工具包
  • gligen安装部署笔记
  • pycharm连接linux服务器需要提前安装ssh服务
  • Collection 框架的结构
  • STM32的时钟复位控制单元(RCU/RCC)技术介绍
  • SpringBoot飘香水果网站:从概念到实现
  • 2024故障测试入门指南!
  • 基于单片机的烧水壶系统设计
  • 如何在VSCode上运行C/C++代码
  • 宠物咖啡馆数字化解决方案:基于SpringBoot的实现
  • 2024下《信息系统运行管理员》案例简答题,刷这些就够了!
  • Android 无Bug版 多语言设计方案!
  • redis 连接池