当前位置: 首页 > news >正文

PyTorch图像预处理--Compose

   torchvision.transforms.Compose 是 PyTorch 中用于图像预处理的核心工具,可将多个图像变换操作组合成一个顺序执行的流水线。

1. 定义与作用

  • 功能‌:将多个图像处理步骤(如缩放、裁剪、归一化等)串联为一个整体,简化代码并确保操作顺序正确‌。
  • 适用场景‌:数据预处理(训练/测试)、数据增强(如随机裁剪、翻转)‌。

2. 基本用法

通过 transforms.Compose() 按顺序传入变换列表:

from torchvision import transformstransform = transforms.Compose([transforms.Resize(256),          # 缩放图像短边至256像素transforms.CenterCrop(224),       # 中心裁剪224x224区域transforms.ToTensor(),            # 转换为张量(范围[0,1])transforms.Normalize(             # 标准化至[-1,1]mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

3. 常用变换操作

操作说明
transforms.Resize()调整图像尺寸(支持固定值或比例缩放)‌
transforms.RandomCrop()随机裁剪(常用于数据增强)‌
transforms.ToTensor()将 PIL 图像或 NumPy 数组转为张量,并归一化至 [0.0, 1.0]
transforms.Normalize()标准化处理(需先执行 ToTensor())‌

4. 标准化处理详解

假设输入为范围 [0,1] 的张量,Normalize 按以下公式处理:
image = (image - mean) / std

  • 示例‌:若 mean=0.5std=0.5,则数据范围被映射到 [-1, 1]‌。

5. 完整示例

from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader# 定义变换
transform = transforms.Compose([transforms.RandomHorizontalFlip(),  # 随机水平翻转(数据增强)transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载数据集并应用变换
train_set = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)# 训练循环
for images, labels in train_loader:# 输入模型训练...
  • 数据流‌:原始图像 → 随机翻转 → 张量转换 → 标准化 → 批处理输入模型‌。

http://www.mrgr.cn/news/96016.html

相关文章:

  • Linux面试题
  • 优选算法系列(4.前缀和 _下) k
  • CAS(Compare And Swap)
  • 23种设计模式-观察者(Observer)设计模式
  • ElasticSearch -- 部署完整步骤
  • 黑盒测试与白盒测试详解
  • dynamic_cast的理解
  • LangChain4j(1):初识LangChain4j
  • Matlab Hessian矩阵计算(LoG算子)
  • kafka零拷贝技术的底层实现
  • 《Operating System Concepts》阅读笔记:p483-p488
  • Vala编成语言教程-构造函数和析构函数
  • 人员进出新视界:视觉分析算法的力量
  • 全书测试:《C++性能优化指南》
  • element与elementplus入门
  • pytorch与其他ai工具
  • 23种设计模式-外观(Facade)设计模式
  • 23种设计模式-抽象工厂(Abstract Factory)设计模式
  • 23种设计模式-中介者(Mediator)设计模式
  • 【视频】m3u8相关操作