当前位置: 首页 > news >正文

使用 pytorch 运行预训练模型的框架

PyTorch 简介:

PyTorch 是一个 Python 程序库,我们可以使用 PyTorch 来构建深度学习项目。

PyTorch 的两个特点:

  1. PyTorch 的核心数据结构是张量,张量是一个多维数组,与 NumPy 数组有许多相似之处。
  2. PyTorch 提供了在专用硬件上执行加速数学操作的特性,这使得神经网络结构设计以及在单机或并行计算资源上训练它们变得很方便。

因此,我们可以将 PyTorch 描述为一个在 Python 中为科学计算提供优化支持的高性能库。

PyTorch 大部分是用 C++ 和 CUDA 编写的,CUDA 是一种来自英伟达的类 C++的语言,可以被编译并在 GPU 上以并行方式运行。


使用 pytorch 运行预训练模型的框架

import torch
  1. 定义模型类 1.1 自定义模型类 1.2 从 torchvision 模块加载模型: from torchvision import models

  1. 实例化模型类
resnet101 = models.resnet101() 

  1. 给实例化的模型类加载预训练好的参数 3.1 实例化模型类和加载预训练好的权重同时进行(这种情况可以省略第 2 步)
resnet101 = models.resnet101(pretrained=True)  # pretrained=True 指示函数下载 resnet101 在 ImageNet数据集上训练好的权重

3.2 使用模型的 load_state_dict() 方法将预训练权重加载到 resnet101 中

model_path = '......'
model_data = torch.load(model_path)
resnet101.load_state_dict(model_data)

3.3 使用 torch.hub 从 github 加载模型(这种情况可以省略第 1、2 步)

from torch import hub
resnet101 = hub.load('pytorch/vision:main''resnet101', pretrained=True)  # 第一项是 GitHub 存储库的名称和分支,第二项是入口点函数的名称

以上代码将 pytorch/vision 主分支的快照及其权重默认下载到本地的 C:\Users\username.cache\torch\hub 目录下,然后运行 resnet101 入口点函数返回实例化的模型,参数 pretrained=true 会从 ImageNet 获得预训练权重,并加载到 resnet101 中。


  1. 使用 Python 图像操作模块 Pillow 从本地文件系统加载一幅图像
from PIL import Image  # PIL 指的是 pillow
img = Image.open(".../xxx.jpg")

  1. 使用 TorchVision 模块提供的 transforms 定义一个对输入图像进行预处理的管道
from torchvision import transforms
preprocess = transforms.Compose([transforms.Resize(256),  # 将输入图像缩放到 256× 256 个像素
                                 transforms.ToTensor(),  # 转换为一个张量
                                ])

  1. 使用预处理管道 preprocess 对图像 img 进行预处理
img_t = preprocess(img)

  1. 给数据添加一个新的维度:批次维度
batch_t = torch.unsqueeze(img_t, 0)

  1. 进行推理时,我们需要将神经网络置于 eval 模式
resnet.eval()

  1. eval 模式设置好之后,进行推理
out = resnet101(batch_t)
out

......

本文由 mdnice 多平台发布


http://www.mrgr.cn/news/65801.html

相关文章:

  • Spring Bean的作用域和生命周期
  • git clone,用https还是ssh
  • 数字信号处理:自动增益控制(AGC)
  • Redis-“自动分片、一定程度的高可用性”(sharding水平拆分、failover故障转移)特性(Sentinel、Cluster)
  • VScode的C/C++点击转到定义,不是跳转定义而是跳转声明怎么办?(内附详细做法)
  • 深入探讨 Jenkins 中 HTML 格式无法正常显示的现象及解决方案
  • FFmpeg 4.3 音视频-多路H265监控录放C++开发十二:在屏幕上显示多路视频播放,可以有不同的分辨率,格式和帧率。
  • HTB:Shocker[WriteUP]
  • 如何在BSV区块链上实现可验证AI
  • 隆盛策略股票杠杆交易市场罕见,26只“牛股”提示风险
  • VSCode 1.82之后的vscode server离线安装
  • Centos使用yum获取离线安装包
  • springboot 单元测试-各个模块举例
  • 爱奇艺大数据多AZ统一调度架构:打破数据孤岛,提升效率
  • windows——病毒的编写
  • Fish Agent:集成 ASR 和 TTS 的端到端语音处理模型,支持多语言转换
  • 单体架构的 IM 系统设计
  • 【教学类-12-10】20241104《连连看竖版6*6 (3套题目空心图案)中2班
  • 泛微开发修炼之旅--53ecology表单转pdf源码修改相关(表单转pdf时可以修改最后生成的pdf的内容)
  • mysql5安装
  • 数字证书的简单记录
  • 基于SpringBoot司机信用评价的货运管理系统【附源码】
  • Windows无法访问\\192.168.1.156,错误代码0x800704cf
  • 11.4OpenCV_图像预处理习题02
  • Python 继承、多态、封装、抽象
  • 字符串算法