YOLOX预测图片是无法保存
添加红色这一句就行
parser.add_argument("--save_result",action="store_true",default="True",help="whether to save the inference result of image/video",
)
预测文件如下demo.py
#!/usr/bin/env python3 # -*- coding:utf-8 -*- # Copyright (c) Megvii, Inc. and its affiliates. #用来预测的 import argparse import os import time from loguru import loggerimport cv2import torch import sys sys.path.append(r'E:\python_code\YOLOX-0.2.0')from yolox.data.data_augment import ValTransform from yolox.data.datasets import VOC_CLASSES from yolox.exp import get_exp from yolox.utils import fuse_model, get_model_info, postprocess, visIMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]def make_parser():parser = argparse.ArgumentParser("YOLOX Demo!")parser.add_argument("-do","--demo", default="image", help="demo type, eg. image, video and webcam")parser.add_argument("-expn", "--experiment-name", type=str, default=None)parser.add_argument("-n", "--name", type=str, default=None, help="model name")parser.add_argument("--path", default="./assets/mask.jpg", help="path to images or video")parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")parser.add_argument("--save_result",action="store_true", default="True",help="whether to save the inference result of image/video",)# exp fileparser.add_argument("-f","--exp_file", default=r"E:\python_code\YOLOX-0.2.0\exps\example\yolox_voc\yolox_voc_s.py",type=str,help="pls input your experiment description file",)parser.add_argument("-c", "--ckpt", default=r"E:\python_code\YOLOX-0.2.0\YOLOX_outputs\yolox_voc_s\best_ckpt.pth", type=str, help="ckpt for eval")parser.add_argument("--device", default="gpu",type=str,help="device to run our model, can either be cpu or gpu",)parser.add_argument("--conf", default=0.3, type=float, help="test conf")parser.add_argument("--nms", default=0.45, type=float, help="test nms threshold")parser.add_argument("--tsize", default=640, type=int, help="test img size")parser.add_argument("--fp16",dest="fp16",default=False,action="store_true",help="Adopting mix precision evaluating.",)parser.add_argument("--legacy",dest="legacy",default=False,action="store_true",help="To be compatible with older versions",)parser.add_argument("--fuse",dest="fuse",default=False,action="store_true",help="Fuse conv and bn for testing.",)parser.add_argument("--trt",dest="trt",default=False,action="store_true",help="Using TensorRT model for testing.",)return parserdef get_image_list(path):image_names = []for maindir, subdir, file_name_list in os.walk(path):for filename in file_name_list:apath = os.path.join(maindir, filename)ext = os.path.splitext(apath)[1]if ext in IMAGE_EXT:image_names.append(apath)return image_namesclass Predictor(object):def __init__(self,model,exp,cls_names=VOC_CLASSES,trt_file=None,decoder=None,device="cpu",fp16=False,legacy=False,):self.model = modelself.cls_names = cls_namesself.decoder = decoderself.num_classes = exp.num_classesself.confthre = exp.test_confself.nmsthre = exp.nmsthreself.test_size = exp.test_sizeself.device = deviceself.fp16 = fp16self.preproc = ValTransform(legacy=legacy)if trt_file is not None:from torch2trt import TRTModulemodel_trt = TRTModule()model_trt.load_state_dict(torch.load(trt_file))x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()self.model(x)self.model = model_trtdef inference(self, img):img_info = {"id": 0}if isinstance(img, str):img_info["file_name"] = os.path.basename(img)img = cv2.imread(img)else:img_info["file_name"] = Noneheight, width = img.shape[:2]img_info["height"] = heightimg_info["width"] = widthimg_info["raw_img"] = imgratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])img_info["ratio"] = ratioimg, _ = self.preproc(img, None, self.test_size)img = torch.from_numpy(img).unsqueeze(0)img = img.float()if self.device == "gpu":img = img.cuda()if self.fp16:img = img.half() # to FP16with torch.no_grad():t0 = time.time()outputs = self.model(img)if self.decoder is not None:outputs = self.decoder(outputs, dtype=outputs.type())outputs = postprocess(outputs, self.num_classes, self.confthre,self.nmsthre, class_agnostic=True)logger.info("Infer time: {:.4f}s".format(time.time() - t0))return outputs, img_infodef visual(self, output, img_info, cls_conf=0.35):ratio = img_info["ratio"]img = img_info["raw_img"]if output is None:return imgoutput = output.cpu()bboxes = output[:, 0:4]# preprocessing: resizebboxes /= ratiocls = output[:, 6]scores = output[:, 4] * output[:, 5]vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)return vis_resdef image_demo(predictor, vis_folder, path, current_time, save_result):if os.path.isdir(path):files = get_image_list(path)else:files = [path]files.sort()for image_name in files:outputs, img_info = predictor.inference(image_name)result_image = predictor.visual(outputs[0], img_info, predictor.confthre)if save_result:save_folder = os.path.join(vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time))os.makedirs(save_folder, exist_ok=True)save_file_name = os.path.join(save_folder, os.path.basename(image_name))logger.info("Saving detection result in {}".format(save_file_name))cv2.imwrite(save_file_name, result_image)ch = cv2.waitKey(0)if ch == 27 or ch == ord("q") or ch == ord("Q"):breakdef imageflow_demo(predictor, vis_folder, current_time, args):cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # floatheight = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # floatfps = cap.get(cv2.CAP_PROP_FPS)save_folder = os.path.join(vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time))os.makedirs(save_folder, exist_ok=True)if args.demo == "video":save_path = os.path.join(save_folder, args.path.split("/")[-1])else:save_path = os.path.join(save_folder, "camera.mp4")logger.info(f"video save_path is {save_path}")vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height)))while True:ret_val, frame = cap.read()if ret_val:outputs, img_info = predictor.inference(frame)result_frame = predictor.visual(outputs[0], img_info, predictor.confthre)if args.save_result:vid_writer.write(result_frame)ch = cv2.waitKey(1)if ch == 27 or ch == ord("q") or ch == ord("Q"):breakelse:breakdef main(exp, args):if not args.experiment_name:args.experiment_name = exp.exp_namefile_name = os.path.join(exp.output_dir, args.experiment_name)os.makedirs(file_name, exist_ok=True)vis_folder = Noneif args.save_result:vis_folder = os.path.join(file_name, "vis_res")os.makedirs(vis_folder, exist_ok=True)if args.trt:args.device = "gpu"logger.info("Args: {}".format(args))if args.conf is not None:exp.test_conf = args.confif args.nms is not None:exp.nmsthre = args.nmsif args.tsize is not None:exp.test_size = (args.tsize, args.tsize)model = exp.get_model()logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))if args.device == "gpu":model.cuda()if args.fp16:model.half() # to FP16model.eval()if not args.trt:if args.ckpt is None:ckpt_file = os.path.join(file_name, "best_ckpt.pth")else:ckpt_file = args.ckptlogger.info("loading checkpoint")ckpt = torch.load(ckpt_file, map_location="cpu")# load the model state dictmodel.load_state_dict(ckpt["model"])logger.info("loaded checkpoint done.")if args.fuse:logger.info("\tFusing model...")model = fuse_model(model)if args.trt:assert not args.fuse, "TensorRT model is not support model fusing!"trt_file = os.path.join(file_name, "model_trt.pth")assert os.path.exists(trt_file), "TensorRT model is not found!\n Run python3 tools/trt.py first!"model.head.decode_in_inference = Falsedecoder = model.head.decode_outputslogger.info("Using TensorRT to inference")else:trt_file = Nonedecoder = Nonepredictor = Predictor(model, exp, VOC_CLASSES, trt_file, decoder,args.device, args.fp16, args.legacy,)current_time = time.localtime()if args.demo == "image":image_demo(predictor, vis_folder, args.path, current_time, args.save_result)elif args.demo == "video" or args.demo == "webcam":imageflow_demo(predictor, vis_folder, current_time, args)if __name__ == "__main__":args = make_parser().parse_args()exp = get_exp(args.exp_file, args.name)main(exp, args)
还是不行的话试试在终端运行
python demo.py --save_result