arcface
GitHub - bubbliiiing/arcface-pytorch: 这是一个arcface-pytorch的源码,可以用于训练自己的模型。
https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch
torch模型转换onnx
import torch
import arcface
from nets.arcface import Arcface as arcface
from torch.onnx import export
import onnxruntime as ort
import numpy as np
def convert2onnx_demo():# model_path = './model_data/arcface_mobilefacenet.pth'# model_path = './model_data/arcface_mobilenet_v1.pth'model_path = './model_data/arcface_iresnet50.pth'device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print('Loading weights into state dict...')# net = arcface(backbone='mobilefacenet', mode="predict").eval()# net = arcface(backbone='mobilenetv1', mode="predict").eval()net = arcface(backbone='iresnet50', mode="predict").eval()net.load_state_dict(torch.load(model_path, map_location=device), strict=True)net = net.to(device)batch_size = 4print('{} model loaded.'.format(model_path))dummy_input = torch.randn(batch_size, 3, 112, 112).to(device)# onnx_path = './model_data/arcface_mobilefacenet.onnx'# onnx_path = './model_data/arcface_mobilenet_v1.onnx'onnx_path = './model_data/arcface_iresnet50.onnx'opset = 10# export_onnx(net, dummy_input, onnx_path, opset, dynamic=True, simplify=True)# export(net, dummy_input, onnx_path, opset, dynamic=True, simplify=True)# 使用 torch.onnx.export 来导出模型# dynamic_axes = {'images': {0: 'batch_size'}} # 支持动态批处理大小dynamic_axes = {'input.1': {0: 'batch_size'}} # 使用正确的输入名export(net, dummy_input, onnx_path, opset_version=opset, dynamic_axes=dynamic_axes, do_constant_folding=True)ort_session = ort.InferenceSession(onnx_path)# outputs = ort_session.run(None, {'images': np.random.randn(batch_size, 3, 112, 112).astype(np.float32)})outputs = ort_session.run(None, {'input.1': np.random.randn(batch_size, 3, 112, 112).astype(np.float32)}) # 使用正确的输入名print(outputs[0], outputs[0].shape)convert2onnx_demo()
onnx模型推理
import onnxruntime as ort
import numpy as np
import cv2# 加载ONNX模型
# session = ort.InferenceSession("./model_data/arcface_iresnet50.onnx")
session = ort.InferenceSession("./model_data/arcface_mobilenet_v1.onnx")# 读取并预处理图像
image_path = "./img/1_001.jpg"
image = cv2.imread(image_path)
image = cv2.resize(image, (112, 112)) # 假设模型需要的输入尺寸是112x112
image = image.transpose(2, 0, 1) # 转换为 CxHxW
image = image.astype(np.float32)
image = (image - 127.5) / 128.0 # 归一化# 添加batch维度
image = np.expand_dims(image, axis=0)# 运行模型
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: image})# 'outputs' 是模型的输出,这里假设输出是特征向量
features = outputs[0]
print(features)
print(features.shape)
参考博客
Arcface部署应用实战-CSDN博客
https://zhuanlan.zhihu.com/p/165294876