当前位置: 首页 > news >正文

torchvision 教程

PyTorch torchvision 教程

torchvision 是 PyTorch 的一个子库,专为计算机视觉任务设计,提供了常用的数据集、预训练模型、以及图像转换和处理的工具。本文将介绍如何使用 torchvision 中的功能来加载数据集、预处理数据、使用预训练模型以及进行图像增强。

1. 安装 torchvision

首先,你需要安装 torchvision 库。可以使用 pip 安装:

pip install torchvision

2. torchvision 的主要组件

torchvision 的主要组件有:

  • torchvision.datasets:提供常用的数据集,例如 MNIST、CIFAR-10、ImageNet 等。
  • torchvision.transforms:用于图像的预处理和数据增强。
  • torchvision.models:提供预训练的深度学习模型。
  • torchvision.io:用于读取和写入图像、视频等数据。

3. 使用 torchvision.datasets 加载数据集

torchvision 提供了许多流行的数据集,可以直接从 torchvision.datasets 中加载。你可以加载数据集,并使用 DataLoader 迭代数据。

3.1 加载 MNIST 数据集

MNIST 是一个包含手写数字的经典数据集,常用于图像分类任务。

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader# 定义数据转换 (如归一化)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])# 下载并加载 MNIST 数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 使用 DataLoader 加载数据集
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 打印样本
for images, labels in train_loader:print(f"Image batch shape: {images.size()}")print(f"Labels batch shape: {labels.size()}")break
3.2 加载 CIFAR-10 数据集

CIFAR-10 是另一个常用的数据集,包含 10 类自然图片。

# 加载 CIFAR-10 数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)# 使用 DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

4. torchvision.transforms 图像预处理与增强

torchvision.transforms 提供了许多常用的图像预处理和增强方法,例如缩放、裁剪、旋转、翻转等。

4.1 基本预处理操作
transform = transforms.Compose([transforms.Resize((32, 32)),            # 调整图像大小transforms.RandomHorizontalFlip(),      # 随机水平翻转transforms.ToTensor(),                  # 转换为 PyTorch 张量transforms.Normalize((0.5,), (0.5,))    # 标准化
])train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
4.2 常见的 transforms 操作
  • transforms.Resize(size):调整图像大小为 size
  • transforms.CenterCrop(size):从图像中心裁剪大小为 size 的部分。
  • transforms.RandomCrop(size):随机裁剪图像。
  • transforms.RandomHorizontalFlip():随机水平翻转图像。
  • transforms.ColorJitter():随机更改图像的亮度、对比度和饱和度。
  • transforms.ToTensor():将 PIL 图像或 NumPy 数组转换为 PyTorch 张量。
  • transforms.Normalize(mean, std):标准化图像数据。

5. 使用 torchvision.models 的预训练模型

torchvision.models 提供了多种预训练模型,例如 ResNet、VGG、AlexNet 等,这些模型在 ImageNet 数据集上进行了预训练。

5.1 加载预训练模型

你可以加载一个预训练的 ResNet 模型并在新任务上进行微调。

import torchvision.models as models# 加载预训练的 ResNet18 模型
model = models.resnet18(pretrained=True)# 查看模型架构
print(model)
5.2 微调预训练模型

如果你想要微调预训练模型(例如用于 CIFAR-10 数据集),你可以冻结预训练模型的部分参数,并修改最后一层以适应新的任务。

# 冻结所有层的参数
for param in model.parameters():param.requires_grad = False# 修改最后一层以适应 CIFAR-10 (10 类分类任务)
model.fc = torch.nn.Linear(512, 10)# 将模型移动到 GPU(如果有)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
5.3 训练模型
# 训练模型
for epoch in range(2):running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)# 清零梯度optimizer.zero_grad()# 前向传播outputs = model(inputs)loss = criterion(outputs, labels)# 反向传播与优化loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')
5.4 测试模型
# 测试模型性能
correct = 0
total = 0
with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy: {100 * correct / total}%')

6. torchvision.io 读取和保存图像

torchvision.io 提供了方便的图像读取和保存功能。

6.1 读取图像
import torchvision.io as io# 读取图像
img = io.read_image('image.jpg')  # 读取为张量# 显示张量信息
print(img.size())
6.2 保存图像
# 保存张量为图像文件
io.write_jpeg(img, 'output_image.jpg')

7. 完整示例

以下是一个使用 torchvision 加载数据集、进行数据增强、使用预训练模型微调并进行训练的完整示例:

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim# 数据增强与预处理
transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载 CIFAR-10 数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 加载预训练的 ResNet18 模型并修改最后一层
model = models.resnet18(pretrained=True)
for param in model.parameters():param.requires_grad = False
model.fc = nn.Linear(512, 10)# 设备设置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)# 训练模型
for epoch in range(2):running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs= model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')# 测试模型
correct = 0
total = 0
with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy: {100 * correct / total}%')

8. 总结

torchvision 是 PyTorch 中处理计算机视觉任务的重要工具,它为常用的数据集、模型、数据处理和增强提供了便利的接口。通过本教程,你可以学习如何使用 torchvision 加载数据集、应用图像预处理、使用预训练模型进行微调,并训练模型来解决实际的计算机视觉任务。


http://www.mrgr.cn/news/29641.html

相关文章:

  • Matplotlib库中show()函数的用法
  • 38配置管理工具(如Ansible、Puppet、Chef)
  • blenderFds代码解读
  • rocketmq——docker-compose安装
  • WordPress 2024主题实例镜像
  • Go语言开发基于SQLite数据库实现用户表查询详情接口(三)
  • (待会删)分享8款AI写论文可以用到的网站神器,请低调使用!
  • ant-design表格自动合并相同内容的单元格
  • 基于windows下docker安装HDDM并运行
  • Linux权限理解【Shell的理解】【linux权限的概念、管理、切换】【粘滞位理解】
  • MODIS/Landsat/Sentinel下载教程详解【常用网站及方法枚举】
  • 【Manim】用manim描述二次曲面——上
  • 构建自己的文生图工具:Python + Stable Diffusion + CUDA
  • 为什么制造业要上MES,有哪些不得不上的理由吗?
  • AntFlow系列教程二之流程同意
  • 系统架构设计师 数据库篇
  • ARM架构中的重要知识点的详细解释
  • python中Web API 框架
  • 力扣之181.超过经理收入的员工
  • 多线程2(gamere)
  • Python 课程16-Pygame
  • spring security 手机号 短信验证码认证、验证码认证 替换默认的用户名密码认证132
  • itk c++ 3D医学图像刚性配准
  • 【与C++的邂逅】--- C++的IO流
  • wifi中的相干带宽
  • Windows系统下使用VS排查内存泄露的两种办法