当前位置: 首页 > news >正文

onnx底层入门

一、定义

  1. 架构
  2. 报错
  3. onnx 模型调试
  4. pytorch 成功转换为onnx 模型的条件
  5. 案例:缺少 映射关系
  6. 案例: 缺少映射关系
  7. 案例: 自定义torch 算子
  8. 案例: debug 每一层,判定前后精度是否损失

二、实现

  1. 架构
    在这里插入图片描述
    一个 ONNX 模型可以用 ModelProto 类表示。ModelProto 包含了版本、创建者等日志信息,还包含了存储计算图结构的graph。
    GraphProto 类则由输入张量信息、输出张量信息、节点信息组成。
    张量信息 ValueInfoProto 类包括张量名、基本数据类型、形状。
    节点信息 NodeProto 类包含了算子名、算子输入张量名、算子输出张量名。
import onnx
from onnx import helper
from onnx import TensorProto#创建张量信息
a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [10, 10])
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 10])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [10, 10])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [10, 10])#创建节点
mul = helper.make_node('Mul', ['a', 'x'], ['c'])
add = helper.make_node('Add', ['c', 'b'], ['output'])#创建图架构        条件: 前节点的输出 是  后节点的输入,  必须相同
graph = helper.make_graph([mul, add], 'linear_func', [a, x, b], [output])   ##构建模型
model = helper.make_model(graph)#保存模型
onnx.checker.check_model(model)
print(model)
onnx.save(model, 'linear_func.onnx')######################################################测试
import onnxruntime
import numpy as npsess = onnxruntime.InferenceSession('linear_func.onnx')
a = np.random.rand(10, 10).astype(np.float32)
b = np.random.rand(10, 10).astype(np.float32)
x = np.random.rand(10, 10).astype(np.float32)output = sess.run(['output'], {'a': a, 'b': b, 'x': x})[0]
print(output)
assert np.allclose(output, a * x + b)
  1. 报错
    Unsupported model IR version: 10, max supported IR version:9
    解决: 将onnx 转为1.15

  2. onnx 模型调试- 提取子模型

import torch
class Model(torch.nn.Module):def __init__(self):super().__init__()self.convs1 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3),torch.nn.Conv2d(3, 3, 3),torch.nn.Conv2d(3, 3, 3))self.convs2 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3),torch.nn.Conv2d(3, 3, 3))self.convs3 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3),torch.nn.Conv2d(3, 3, 3))self.convs4 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3),torch.nn.Conv2d(3, 3, 3),torch.nn.Conv2d(3, 3, 3))def forward(self, x):x = self.convs1(x)x1 = self.convs2(x)x2 = self.convs3(x)x = x1 + x2x = self.convs4(x)return xmodel = Model()
input = torch.randn(1, 3, 20, 20)
torch.onnx.export(model, input, 'whole_model.onnx')
import onnxmodel = onnx.load("whole_model.onnx")
print(model)
# 提取子模型
onnx.utils.extract_model('whole_model.onnx', 'partial_model.onnx', ['/convs1/convs1.1/Conv_output_0'], ['/convs4/convs4.0/Conv_output_0'])
  1. pytorch 成功转换为onnx 模型的条件

    1. 算子在pytorch 中有实现 https://pytorch.org/docs/stable/torch.html
    2. 有把该 PyTorch 算子映射成一个或多个 ONNX 算子的方法 https://github.com/pytorch/pytorch/tree/master/torch/onnx 或 安装目录 G:\ProgramData\python39\Lib\site-packages\torch\onnx
    3. ONNX 有相应的算子 https://onnx.ai/onnx/operators/
  2. 案例:缺少 映射关系
    问题:缺少pytorch 到 onnx 的映射 —》 ATen 算子补充描述映射规则的符号函数(是 PyTorch 内置的 C++ 张量计算库,PyTorch 算子在底层绝大多数计算都是用 ATen 实现的。)
    Asinh 算子转onnx
    pytorch 查看 https://pytorch.org/docs/stable/generated/torch.asinh.html#torch.asinh
    b. 映射查看 无 G:\ProgramData\python39\Lib\site-packages\torch\onnx
    c. onnx 相应算子: 第9版开始支持 https://onnx.ai/onnx/operators/
    步骤:
    1. 查看接口定义 torch/_C/_VariableFunctions.pyi

def asinh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: 
  1. 注册
    g.op() 函数: torch.onnx._internal.jit_utils.py
@_beartype.beartype
def op(self,opname: str,        onnx 中名字*raw_args: Union[torch.Tensor, _C.Value],    #输入outputs: int = 1,**kwargs,
):
步骤:
1 获取原算子的前向推理接口。
2 获取目标 ONNX 算子的定义。
3 编写符号函数并绑定。import torch
# torch: 1.19
from torch.onnx.symbolic_registry import register_opimport torch
class Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):return torch.asinh(x)from torch.onnx.symbolic_registry import register_op       # 注册算子def asinh_symbolic(g, input, *, out=None):             #定义符号函数,参考接口定义return g.op("Asinh", input)                                       #onnx 中算子定义register_op('asinh', asinh_symbolic, '', 9)     #绑定  第一个参数是目标 ATen 算子名,第二个是要注册的符号函数model = Model()
input = torch.rand(1, 3, 10, 10)
torch.onnx.export(model, input, 'asinh.onnx')#####################################    测试       ###################
import onnxruntime
import torch
import numpy as npclass Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):return torch.asinh(x)model = Model()
input = torch.rand(1, 3, 10, 10)
torch_output = model(input).detach().numpy()
sess = onnxruntime.InferenceSession('asinh.onnx')
ort_output = sess.run(None, {'0': input.numpy()})[0]assert np.allclose(torch_output, ort_output)方式二
import torch
import torch.onnx
import onnxruntime
from torch.onnx import register_custom_op_symbolicdef asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)
register_custom_op_symbolic('aten::asinh', asinh_symbolic, 12)class Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):x = torch.asinh(x)return xdef export_norm_onnx():input   = torch.rand(1, 5)model   = Model()model.eval()file    = "asinh.onnx"torch.onnx.export(model         = model, args          = (input,),f             = file,input_names   = ["input0"],output_names  = ["output0"],opset_version = 12)print("Finished normal onnx export")if __name__ == "__main__":export_norm_onnx()方式三:import torch
import torch.onnx
import onnxruntime
import functools
from torch.onnx._internal import registration_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)@_onnx_symbolic('aten::asinh')
def asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)class Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):x = torch.asinh(x)return xdef export_norm_onnx():input = torch.rand(1, 5)model = Model()model.eval()file = "asinh2.onnx"torch.onnx.export(model=model,args=(input,),f=file,input_names=["input0"],output_names=["output0"],opset_version=12)print("Finished normal onnx export")if __name__ == "__main__":export_norm_onnx()
  1. 案例 缺少映射关系
    报错“ torch.onnx.errors.UnsupportedOperatorError: ONNX export failed on an operator with unrecognized namespace torchvision::deform_conv2d.
import torch
import torchvisionclass Model(torch.nn.Module):def __init__(self):super().__init__()self.conv1 = torch.nn.Conv2d(3, 18, 3)self.conv2 = torchvision.ops.DeformConv2d(3, 3, 3)def forward(self, x):return self.conv2(x, self.conv1(x))
model = Model()
input = torch.rand(1, 3, 10, 10)
torch.onnx.export(model, input, 'dcn.onnx')
#torch.onnx.errors.UnsupportedOperatorError: ONNX export failed on an operator with unrecognized namespace torchvision::deform_conv2d. If you are trying to export a custom operator, make sure you registered it with the right domain and version.
#未识别torchvision::deform_conv2d

查看推理算子
查看映射关系
查看onnx

import torch
import torchvision
class Model(torch.nn.Module):def __init__(self):super().__init__()self.conv1 = torch.nn.Conv2d(3, 18, 3)self.conv2 = torchvision.ops.DeformConv2d(3, 3, 3)def forward(self, x):return self.conv2(x, self.conv1(x))
#
from torch.onnx import register_custom_op_symbolic        # 为 TorchScript 算子补充注册符号函数
from torch.onnx.symbolic_helper import parse_args
#
# '''
# 装饰器 @parse_args 了。简单来说,TorchScript 算子的符号函数要求标注出每一个输入参数的类型。比如"v"表示 Torch 库里的 value 类型,
# 一般用于标注张量,而"i"表示 int 类型,"f"表示 float 类型,"none"表示该参数为空。具体的类型含义可以在 torch.onnx.symbolic_helper.py
# '''
#
@parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i", "i", "i", "i", "none")
def symbolic(g,input,weight,offset,mask,bias,stride_h, stride_w,pad_h, pad_w,dil_h, dil_w,n_weight_grps,n_offset_grps,use_mask):return g.op("custom::deform_conv2d", input, offset)register_custom_op_symbolic("torchvision::deform_conv2d", symbolic, 9)
model = Model()
input = torch.rand(1, 3, 10, 10)
torch.onnx.export(model, input, 'dcn.onnx')
  1. 案例: 自定义torch 算子
    1. 采用c++扩展的方式添加算子
    2. python 方式添加算子
    3. torch.autograd.Function 来封装算子的底层调用,symbolic函数挂载onnx算子
    4. 正常使用
import torch
#封装算子
class MyAddFunction(torch.autograd.Function):@staticmethoddef forward(ctx, a, b):return 2 * a + b                     #自定义算子@staticmethoddef symbolic(g, a, b):                   #注册onnx 算子two = g.op("Constant", value_t=torch.tensor([2]))a = g.op('Mul', a, two)return g.op('Add', a, b)#调用
my_add = MyAddFunction.apply        # 前向推理或者反向传播时的调度
class MyAdd(torch.nn.Module):def __init__(self):super().__init__()def forward(self, a, b):return my_add(a, b)model = MyAdd()
input = torch.rand(1, 3, 10, 10)
torch.onnx.export(model, (input, input), 'my_add.onnx')
torch_output = model(input, input).detach().numpy()
import onnxruntime
import numpy as np
sess = onnxruntime.InferenceSession('my_add.onnx')
ort_output = sess.run(None, {'a.1': input.numpy(), 'b.1': input.numpy()})[0]print(ort_output,torch_output)
assert np.allclose(torch_output, ort_output)
  1. 案例: debug 每一层,判定前后精度是否损失
import torch
import onnx
import onnxruntime
import numpy as np#定义debug算子
class DebugOp(torch.autograd.Function):@staticmethoddef forward(ctx, x, name):return x@staticmethoddef symbolic(g, x, name):return g.op("my::Debug", x, name_s=name)debug_apply = DebugOp.apply
class Debugger():def __init__(self):super().__init__()self.torch_value = dict()      #存储torch valueself.onnx_value = dict()       #存储onnx  valueself.output_debug_name = []def debug(self, x, name):self.torch_value[name] = x.detach().cpu().numpy()       #收集 输出return debug_apply(x, name)#获取debug 子模块               修改debug 节点def extract_debug_model(self, input_path, output_path):model = onnx.load(input_path)inputs = [input.name for input in model.graph.input]outputs = []for node in model.graph.node:if node.op_type == 'Debug':                        #修改Debugdebug_name = node.attribute[0].s.decode('ASCII')self.output_debug_name.append(debug_name)output_name = node.output[0]outputs.append(output_name)                   #记录输出名字node.op_type = 'Identity'node.domain = ''del node.attribute[:]e = onnx.utils.Extractor(model)extracted = e.extract_model(inputs, outputs)       #提取模型onnx.save(extracted, output_path)def run_debug_model(self, input, debug_model):sess = onnxruntime.InferenceSession(debug_model,providers=['CPUExecutionProvider'])onnx_outputs = sess.run(None, input)for name, value in zip(self.output_debug_name, onnx_outputs):self.onnx_value[name] = value           #收集onnx 对应的值#打印debugdef print_debug_result(self):          #评估for name in self.torch_value.keys():if name in self.onnx_value:mse = np.mean(self.torch_value[name] - self.onnx_value[name])**2print(f"{name} MSE: {mse}")#转换模型
class Model(torch.nn.Module):def __init__(self):super().__init__()self.convs1 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3, 1, 1),torch.nn.Conv2d(3, 3, 3, 1, 1),torch.nn.Conv2d(3, 3, 3, 1, 1))self.convs2 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3, 1, 1),torch.nn.Conv2d(3, 3, 3, 1, 1))self.convs3 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3, 1, 1),torch.nn.Conv2d(3, 3, 3, 1, 1))self.convs4 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3, 1, 1),torch.nn.Conv2d(3, 3, 3, 1, 1),torch.nn.Conv2d(3, 3, 3, 1, 1))def forward(self, x):x = self.convs1(x)x = self.convs2(x)x = self.convs3(x)x = self.convs4(x)return x#
torch_model = Model()
debugger = Debugger()from types import MethodType
def new_forward(self, x):x = self.convs1(x)x = debugger.debug(x, 'x_0')x = self.convs2(x)x = debugger.debug(x, 'x_1')x = self.convs3(x)x = debugger.debug(x, 'x_2')x = self.convs4(x)x = debugger.debug(x, 'x_3')return x
torch_model.forward = MethodType(new_forward, torch_model)dummy_input = torch.randn(1, 3, 10, 10)
torch.onnx.export(torch_model, dummy_input, 'before_debug.onnx', input_names=['input'])  #转换debugger.extract_debug_model('before_debug.onnx', 'after_debug.onnx')debugger.run_debug_model({'input':dummy_input.numpy()}, 'after_debug.onnx')debugger.print_debug_result()

http://www.mrgr.cn/news/48021.html

相关文章:

  • “游戏信息化”:游戏后台系统的未来发展
  • 电脑提示报错NetLoad.dll文件丢失或损坏?是什么原因?
  • 【自留】Unity VR入门
  • 磁盘结构、访问时间、调度算法
  • 显示 Windows 任务栏
  • 前端使用 Element Plus架构vue3.0实现图片拖拉拽,后等比压缩,上传到Spring Boot后端
  • 你知道C++多少——模版进阶
  • 金九银十软件测试面试题(800道)
  • yarn install 报错 Expected version “>=18“,Got “16.20.0“
  • 数据库设计与查询分析(练习--对小白友好)
  • 【Java 22 | 2】 深入解析Java 22 :原生支持的记录类型
  • C++11 简单手撕多线程编程
  • 一个比较复杂的makefile工程例子
  • this,this指向
  • 在Stable Diffusion WebUI中安装SadTalker插件时几种错误提示的处理方法
  • 直流有刷电机驱动芯片:【TOSHIBA:TB6612】
  • Linux基础命令groupmod详解
  • 使用LlamaFactory进行模型微调
  • 低功耗
  • 多人播报配音怎么弄?简单4招分享
  • 【C++学习】核心编程之内存分区模型、引用和函数提高(黑马学习笔记)
  • 简单解析由于找不到xinput1_3.dll,无法继续执行代码的详细解决方法
  • 图的深度优先遍历的非递归算法
  • 服务端测试开发必备的技能:Mock测试!
  • 半周期检查-下降沿发上升沿采
  • AI语音助手在线版本