使用Qwen2-VL模型批量标注图像内容(图像理解)
Qwen2-VL模型可以以问答的形式得到图像的标注内容,以下记录流程以及数据的后处理过程。
一、下载权重
在huggingface上下载Qwen2-VL的权重:https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/tree/main
权重保存地址:/home/user/models/Qwen2-VL-7B-Instruct
环境安装:需要transformers和qwen_vl_utils库;
二、批量处理数据代码:
图像存放文件夹:/home/user/data/images_need_processing,图像命名,从00000-05000.jpg,共5000张图像;
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import os
from tqdm import tqdm# default: Load the model on the available device(s)
model = Qwen2VLForConditionalGeneration.from_pretrained("/home/user/models/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
)# default processer
processor = AutoProcessor.from_pretrained("/home/user/models/Qwen2-VL-7B-Instruct")image_data_root = "/home/user/data/images_need_processing"
images = os.listdir(image_data_root)
images.sort()for image in tqdm(new_images):image_path = os.path.join(image_data_root, image)Question_description = "Describe the image." #FIXME 替换为需要的问题messages = [{"role": "user","content": [{"type": "image","image": image_path,},{"type": "text", "text": Question_description},],}]# Preparation for inferencetext = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)image_inputs, video_inputs = process_vision_info(messages)inputs = processor(text=[text],images=image_inputs,videos=video_inputs,padding=True,return_tensors="pt",)inputs = inputs.to("cuda")# Inference: Generation of the outputgenerated_ids = model.generate(**inputs, max_new_tokens=128)generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]with open("/home/user/data/descriptions.txt", "a") as f:f.writelines(image+"#####"+output_text+"\n")
该代码中,将所有图像的caption存入txt文件中,为了方便索引,使用图像名称+“#####”+描述的形式进行存储;
三、后处理
在实际得到的txt文件中,由于多gpu同时处理文件,导致图像的描述出现断行,例如:
00001.jpg######The image xxxx.The background of the image xxx.
00002.jpg######The image xxxx. The background of the image xxx.
00003.jpg######The image xxxx. The background of the image xxx.
00004.jpg######The image xxxx.The background of the image xxx.
00005.jpg######The image xxxx. The background of the image xxx.
00006.jpg######The image xxxx. The background of the image xxx.
00007.jpg######The image xxxx. The background of the image xxx.
00008.jpg######The image xxxx. The background of the image xxx.
为了对该txt中内容进行后处理,执行下述代码:(利用#####标识位)
def raw_qwen_caption_handle():# Function: filter the \n and merge some broken sentences.raw_qwen_ann_path = "/home/user/data/descriptions.txt"with open(raw_qwen_ann_path, "r") as f:anns = f.readlines()new_anns = []number = 0for ann in anns: # 将只包含\n符的空行滤除if "\n" in ann:index = ann.index("\n")if index == 0:number += 1else:new_anns.append(ann)final_anns = []caption_index = 0for ann in new_anns: # 利用标志位合并前一行和断开的描述内容caption = ann.replace("\n", "")if "#####" not in caption:final_anns[caption_index-1] = final_anns[caption_index-1] + " " + captionelse:final_anns.append(caption)caption_index += 1prompt_save_file = "/home/user/data/descriptions_new.txt"with open(prompt_save_file, "a") as f:for ann in final_anns:f.writelines(ann+"\n")