当前位置: 首页 > news >正文

分类模型onnx推理,并生成混淆矩阵

废话不多说直接上代码

import onnxruntime
import numpy as np
import os
import cv2
import argparse
import time
import shutil
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplaylabels = ["0", "1", "2", "3", "4", "5", "6", "7"]def sigmoid(x):"""Sigmoid function for a scalar or NumPy array."""return 1 / (1 + np.exp(-x))def getFileList(dir, Filelist, ext=None):"""获取文件夹及其子文件夹中文件列表输入 dir:文件夹根目录输入 ext: 扩展名返回: 文件路径列表"""newDir = dirif os.path.isfile(dir):if ext is None:Filelist.append(dir)else:if ext in dir[-3:]:Filelist.append(dir)elif os.path.isdir(dir):for s in os.listdir(dir):newDir = os.path.join(dir, s)getFileList(newDir, Filelist, ext)return Filelistdef read_image(image_path, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):src = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR)image = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)image = cv2.resize(image, (224, 224))image = image.astype(np.float32)image = image / 255.0image = image.transpose(2, 0, 1)mean = np.array(mean, dtype=np.float32).reshape((3,1,1))std = np.array(std, dtype=np.float32).reshape((3,1,1))# 对图像进行归一化normalized_image = (image - mean) / stdnormalized_image = np.expand_dims(normalized_image, axis=0)return normalized_image, srcdef load_onnx_model(model_path):providers = ['CUDAExecutionProvider']  # 使用 GPU# providers = ['CPUExecutionProvider']options = onnxruntime.SessionOptions()options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALLoptions.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIALsession = onnxruntime.InferenceSession(model_path, options, providers=providers)print("ONNX模型已成功加载。")return sessiondef main(image_path, session):image, src = read_image(image_path)input_name = session.get_inputs()[0].nameoutput_name = session.get_outputs()[0].namepred = session.run([output_name], {input_name: image})[0]pred = np.squeeze(pred)pred = [sigmoid(x) for x in pred]return pred.index(max(pred)), max(pred), labels[pred.index(max(pred))]def plot_confusion_matrix(y_true, y_pred, labels):"""绘制混淆矩阵输入 y_true: 真实标签输入 y_pred: 预测标签输入 labels: 标签名称"""cm = confusion_matrix(y_true, y_pred)disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)disp.plot(cmap=plt.cm.Blues)plt.title('Confusion Matrix')plt.show()if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--images_path', type=str, default="/home/workspace/temp/test_result/kaideng", help='images_path')parser.add_argument('--model_path', type=str, default="/home/workspace/temp/47-val-Loss-0.0203-Acc-0.9942.onnx", help='model_path')args = parser.parse_args()img_list = []img_list = getFileList(args.images_path, img_list)count = 0session = load_onnx_model(args.model_path)start = time.time()y_true = []y_pred = []count_time = 0for img in img_list:#true_label = int(img.split('/')[-2].split('-')[0])true_label = img.split('/')[-3] #这一句代码是获取图像类别文件夹的名称,具体索引需要修改start_1 = time.time()predicted_index, score, label = main(img, session)count_time += time.time() - start_1y_true.append(true_label)#y_pred.append(predicted_index)y_pred.append(label)if label == true_label:count += 1# else:#     dst_path = img.replace('test', 'test_out')#     dst_dir = os.path.dirname(dst_path)#     if not os.path.exists(dst_dir):#         os.makedirs(dst_dir)#     shutil.copy(img, dst_path.replace('.jpg', "-" + label + '.jpg'))accuracy = count / len(img_list) * 100print(f"Accuracy: {accuracy:.2f}%")print(f"Correct predictions: {count}, Total images: {len(img_list)}")print(f"Time taken: {time.time() - start:.2f} seconds")print("推理", len(img_list), "张图像用时", count_time)# 绘制混淆矩阵plot_confusion_matrix(y_true, y_pred, labels)

1.要确保图像类别文件夹的名称和labels列表相对应,不然无法生成混淆矩阵

2.read_image函数中的预处理方式要和训练时的一致,请根据个人需要修改代码


http://www.mrgr.cn/news/66097.html

相关文章:

  • Android CCodec Codec2 (十九)C2LinearBlock
  • 【Centos】在 CentOS 9 上使用 Apache 搭建 PHP 8 教程
  • solidity selfdestruct合约销毁
  • Python3 No module named ‘pymysql‘
  • Django中分组查询(annotate 和 aggregate 使用)
  • Mac下载 安装MIMIC-IV 3.0数据集
  • Mysql数据库的UDF提权
  • 文件描述符fd和0 1 2的含义(stdin..)
  • 如何配置 GreptimeDB 作为 Prometheus 的长期存储
  • YOLO11改进 | 融合改进 | C3k2引入多尺度分支来增强特征表征【全网独家 附结构图】
  • OBOO鸥柏丨甘肃火车站/高铁多媒体网络广告刷屏机数字转型
  • 2024年最新10款顶级项目管理软件排行
  • 类与对象—中
  • mutable用法
  • vue 使用openlayers加载超图图层
  • 富格林:揭露欺诈陷阱用心追损
  • Spring Boot 内置工具类
  • OpenCV视觉分析之目标跟踪(10)估计两个点集之间的刚性变换函数estimateRigidTransform的使用
  • KVM虚拟机的冷热迁移
  • 量化交易 股市技术指标
  • 【ARM Linux 系统稳定性分析入门及渐进 1.4 -- Crash 工具调用】
  • Vue 3 性能提升与 Vue 2 的比较 - 2024最新版前端秋招面试短期突击面试题【100道】
  • 51单片机--- 蜂鸣器电子琴仿真
  • 【Linux】网络编程:实现一个简易的基于HTTP协议格式、TCP传输的服务器,处理HTTP请求并返回HTTP响应;GET方法再理解
  • Odoo的结构
  • 数据分析-39-时间序列分解之经验小波分解EWT