torchvision 教程
PyTorch torchvision
教程
torchvision
是 PyTorch 的一个子库,专为计算机视觉任务设计,提供了常用的数据集、预训练模型、以及图像转换和处理的工具。本文将介绍如何使用 torchvision
中的功能来加载数据集、预处理数据、使用预训练模型以及进行图像增强。
1. 安装 torchvision
首先,你需要安装 torchvision
库。可以使用 pip
安装:
pip install torchvision
2. torchvision
的主要组件
torchvision
的主要组件有:
torchvision.datasets
:提供常用的数据集,例如 MNIST、CIFAR-10、ImageNet 等。torchvision.transforms
:用于图像的预处理和数据增强。torchvision.models
:提供预训练的深度学习模型。torchvision.io
:用于读取和写入图像、视频等数据。
3. 使用 torchvision.datasets
加载数据集
torchvision
提供了许多流行的数据集,可以直接从 torchvision.datasets
中加载。你可以加载数据集,并使用 DataLoader
迭代数据。
3.1 加载 MNIST 数据集
MNIST 是一个包含手写数字的经典数据集,常用于图像分类任务。
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader# 定义数据转换 (如归一化)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])# 下载并加载 MNIST 数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 使用 DataLoader 加载数据集
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 打印样本
for images, labels in train_loader:print(f"Image batch shape: {images.size()}")print(f"Labels batch shape: {labels.size()}")break
3.2 加载 CIFAR-10 数据集
CIFAR-10 是另一个常用的数据集,包含 10 类自然图片。
# 加载 CIFAR-10 数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)# 使用 DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
4. torchvision.transforms
图像预处理与增强
torchvision.transforms
提供了许多常用的图像预处理和增强方法,例如缩放、裁剪、旋转、翻转等。
4.1 基本预处理操作
transform = transforms.Compose([transforms.Resize((32, 32)), # 调整图像大小transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ToTensor(), # 转换为 PyTorch 张量transforms.Normalize((0.5,), (0.5,)) # 标准化
])train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
4.2 常见的 transforms
操作
transforms.Resize(size)
:调整图像大小为size
。transforms.CenterCrop(size)
:从图像中心裁剪大小为size
的部分。transforms.RandomCrop(size)
:随机裁剪图像。transforms.RandomHorizontalFlip()
:随机水平翻转图像。transforms.ColorJitter()
:随机更改图像的亮度、对比度和饱和度。transforms.ToTensor()
:将 PIL 图像或 NumPy 数组转换为 PyTorch 张量。transforms.Normalize(mean, std)
:标准化图像数据。
5. 使用 torchvision.models
的预训练模型
torchvision.models
提供了多种预训练模型,例如 ResNet、VGG、AlexNet 等,这些模型在 ImageNet 数据集上进行了预训练。
5.1 加载预训练模型
你可以加载一个预训练的 ResNet 模型并在新任务上进行微调。
import torchvision.models as models# 加载预训练的 ResNet18 模型
model = models.resnet18(pretrained=True)# 查看模型架构
print(model)
5.2 微调预训练模型
如果你想要微调预训练模型(例如用于 CIFAR-10 数据集),你可以冻结预训练模型的部分参数,并修改最后一层以适应新的任务。
# 冻结所有层的参数
for param in model.parameters():param.requires_grad = False# 修改最后一层以适应 CIFAR-10 (10 类分类任务)
model.fc = torch.nn.Linear(512, 10)# 将模型移动到 GPU(如果有)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
5.3 训练模型
# 训练模型
for epoch in range(2):running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)# 清零梯度optimizer.zero_grad()# 前向传播outputs = model(inputs)loss = criterion(outputs, labels)# 反向传播与优化loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')
5.4 测试模型
# 测试模型性能
correct = 0
total = 0
with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy: {100 * correct / total}%')
6. torchvision.io
读取和保存图像
torchvision.io
提供了方便的图像读取和保存功能。
6.1 读取图像
import torchvision.io as io# 读取图像
img = io.read_image('image.jpg') # 读取为张量# 显示张量信息
print(img.size())
6.2 保存图像
# 保存张量为图像文件
io.write_jpeg(img, 'output_image.jpg')
7. 完整示例
以下是一个使用 torchvision
加载数据集、进行数据增强、使用预训练模型微调并进行训练的完整示例:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim# 数据增强与预处理
transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载 CIFAR-10 数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 加载预训练的 ResNet18 模型并修改最后一层
model = models.resnet18(pretrained=True)
for param in model.parameters():param.requires_grad = False
model.fc = nn.Linear(512, 10)# 设备设置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)# 训练模型
for epoch in range(2):running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs= model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')# 测试模型
correct = 0
total = 0
with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy: {100 * correct / total}%')
8. 总结
torchvision
是 PyTorch 中处理计算机视觉任务的重要工具,它为常用的数据集、模型、数据处理和增强提供了便利的接口。通过本教程,你可以学习如何使用 torchvision
加载数据集、应用图像预处理、使用预训练模型进行微调,并训练模型来解决实际的计算机视觉任务。