当前位置: 首页 > news >正文

PyTorch实战-手写数字识别-CNN模型

1 需求


2 接口


3 示例

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 定义超参数
batch_size = 128
learning_rate = 0.001
num_epochs = 10# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)# 定义 CNN 模型
class CNNModel(nn.Module):def __init__(self):super(CNNModel, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=5)self.conv2 = nn.Conv2d(32, 64, kernel_size=5)self.fc1 = nn.Linear(64 * 4 * 4, 1024)self.fc2 = nn.Linear(1024, 10)def forward(self, x):x = nn.functional.relu(self.conv1(x))x = nn.functional.max_pool2d(x, 2)x = nn.functional.relu(self.conv2(x))x = nn.functional.max_pool2d(x, 2)x = x.view(-1, 64 * 4 * 4)x = nn.functional.relu(self.fc1(x))x = self.fc2(x)return nn.functional.log_softmax(x, dim=1)# 实例化模型
model = CNNModel()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
for epoch in range(num_epochs):for batch_idx, (data, targets) in enumerate(train_loader):# 前向传播outputs = model(data)loss = criterion(outputs, targets)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{batch_idx + 1}/{len(train_loader)}], Loss: {loss.item()}')# 在测试集上评估模型
model.eval()
with torch.no_grad():correct = 0total = 0for data, targets in test_loader:outputs = model(data)_, predicted = torch.max(outputs.data, 1)total += targets.size(0)correct += (predicted == targets).sum().item()accuracy = correct / totalprint(f'Test Accuracy: {accuracy * 100:.2f}%')

4 参考资料


http://www.mrgr.cn/news/65018.html

相关文章:

  • 如何生成 PEM 格式的 SSH 密钥 ?
  • oracle-函数-NULLIF (expr1, expr2)的妙用
  • 【WebRTC】WebRTC的简单使用
  • 智算中心建设热潮涌动 AI服务器赋能加速
  • Redis常见面试题概览——针对实习面试
  • 算法妙妙屋-------1.递归的深邃回响:二叉树的奇妙剪枝
  • MDK 平台下弱声明函数实现后不能执行原因排查
  • 第04章 MySQL图形化管理工具的介绍
  • 别人卷技术,我们卷变现。。。
  • 深入理解 ZooKeeper:分布式协调服务的核心与应用
  • 研究了100个小绿书十万加之后,我们发现2024小绿书独家秘籍就是:在于“先抄后超,持续出摊,量大管饱”!
  • 「Mac畅玩鸿蒙与硬件25」UI互动应用篇2 - 计时器应用实现
  • ERP项目(进销存仓储管理系统)-1
  • 11.1 网络编程-套接字
  • C语言-详细讲解-洛谷P1909 [NOIP2016 普及组] 买铅笔
  • 【数据结构】二叉树——层序遍历
  • Python Matplotlib 如何处理大数据集的绘制,提高绘图效率
  • 上尚优选项目
  • interrupt、interrupted、isInterrupted方法详解
  • WPF+MVVM案例实战(二十一)- 制作一个侧边弹窗栏(CD类)
  • LeetCode 0685.冗余连接 II:并查集(和I有何不同分析)——详细题解(附图)
  • Docker容器消耗资源过多导致宿主机死机解决方案
  • 发现不为人知的AI宝藏:深藏功与名! —— 《第十期》
  • js逆向-模拟加密
  • Linux的IP网路命令: 用于显示和操作网络接口(网络设备)的命令ip link详解
  • masm汇编字符串输出演示