当前位置: 首页 > news >正文

pytorch dataloader学习

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np torch.manual_seed(1)
# 自定义数据集
class CustomDataset(Dataset):def __init__(self):# 创建一些示例数据(100个样本,每个样本包含10个特征)self.data = torch.randn(100, 10)self.labels =torch.from_numpy(np.arange(100))  # 二分类标签def __len__(self):# 返回数据集的大小return len(self.data)def __getitem__(self, idx):# 根据索引 idx 返回对应的样本和标签sample = self.data[idx]label = self.labels[idx]return sample, label# 创建数据集的实例
dataset = CustomDataset()# 使用DataLoader加载数据
# 设置batch_size=16,shuffle=True表示打乱数据顺序
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)# 迭代DataLoader
for i in range(2):for batch_idx, (inputs, labels) in enumerate(dataloader):print(f"Batch {batch_idx+1}")print(f"Inputs: {inputs.size()}")  # 显示当前batch中输入数据的维度print(f"Labels: {labels.size()}")  # 显示当前batch中标签的维度print(labels)# 在这里你可以对数据进行训练# 例如:outputs = model(inputs)

只要是shuffle=True,每次epoch结果的顺序是不一样的,如果想每一次的结果是一样的
在这里插入图片描述

如果shuffle=False

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np torch.manual_seed(1)
# 自定义数据集
class CustomDataset(Dataset):def __init__(self):# 创建一些示例数据(100个样本,每个样本包含10个特征)self.data = torch.randn(100, 10)self.labels =torch.from_numpy(np.arange(100))  # 二分类标签def __len__(self):# 返回数据集的大小return len(self.data)def __getitem__(self, idx):# 根据索引 idx 返回对应的样本和标签sample = self.data[idx]label = self.labels[idx]return sample, label# 创建数据集的实例
dataset = CustomDataset()# 使用DataLoader加载数据
# 设置batch_size=16,shuffle=True表示打乱数据顺序
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)# 迭代DataLoader
for i in range(2):for batch_idx, (inputs, labels) in enumerate(dataloader):print(f"Batch {batch_idx+1}")print(f"Inputs: {inputs.size()}")  # 显示当前batch中输入数据的维度print(f"Labels: {labels.size()}")  # 显示当前batch中标签的维度print(labels)# 在这里你可以对数据进行训练# 例如:outputs = model(inputs)

结果如下
在这里插入图片描述


http://www.mrgr.cn/news/57396.html

相关文章:

  • 基于 Datawhale 开源的量化投资学习指南(7):量化择时策略
  • 算法的学习笔记—数字在排序数组中出现的次数(牛客JZ53)
  • tomcat部署war包部署运行,IDEA一键运行启动tomacat服务,maven打包为war包并部署到tomecat
  • 与ai一起作诗(《校园清廉韵》)
  • 五、列表——————相关概念详解
  • No.18 笔记 | XXE(XML 外部实体注入)漏洞原理、分类、利用及防御整理
  • 动态规划算法专题(八):01 背包问题
  • 1024是什么日子
  • 头条微头条文章洗稿发布软件注意事项(四)
  • 中国最有钱的起名大师颜廷利名字的含义和历史背景是什么?
  • CF978
  • C++ 判断语句的深入解析
  • 使用亚马逊SQS实现一个队列任务,包括:向队列发送消息和从队列中读取消息
  • IBM Granite 3.0:一款开源,SOTA 企业模型
  • python画图|坐标轴显隐设置
  • 【开源鸿蒙】OpenHarmony 5.0轻量系统最小开发环境搭建
  • AI自主学习:未来的智能系统
  • 近似推断 - 最大后验推断和稀疏编码篇
  • AI学习指南深度学习篇-对比学习的变种
  • Python | Leetcode Python题解之第503题下一个更大元素II
  • SELinux详解
  • Golang | Leetcode Golang题解之第504题七进制数
  • 一文彻底搞透Redis的数据类型及具体的应用场景
  • 重温Java基础语法随笔录
  • 【QT】常用控件(四)
  • 12_Linux进程管理命令详解