当前位置: 首页 > news >正文

batc和mini-batch

一、概念介绍

batch

批处理,在机器学习中,batch 是指一次处理整个训练数据集的方式。例如,如果有 1000 个训练样本,使用 batch 训练时,模型会同时使用这 1000 个样本进行一次参数更新。也就是说,计算损失函数(如均方误差、交叉熵等)是基于整个数据集的所有样本。

mini-batch

小批次,将整个训练数据集分成多个较小的子集(批次)来进行训练。比如还是 1000 个训练样本,我们可以将其分成 10 个 mini - batch,每个 mini - batch 包含 100 个样本。模型在训练时,每次使用一个 mini - batch 来计算损失和更新参数。

二、区别

batch 训练参数更新方向更稳定但可能陷入局部最优;mini - batch 在训练中有一定随机性,有助于寻找全局最优,但批次过小时可能使训练不稳定。

三、使用场景

数据量较小:使用batch;
数据量较大:使用mini-batch;在神经网络基本使用这个

四、mini-batch代码

以下是在深度学习模型中使用batch和mini - batch的方法:

1. 数据准备阶段

  • 数据加载:首先,需要将原始数据加载到程序中。对于图像数据,可以使用ImageDataLoader(PyTorch中)等工具;对于文本数据,可以使用DataLoader结合自定义的文本处理函数。例如,在PyTorch中:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader# 加载MNIST数据集
train_data = MNIST(root='data/', train=True, transform=ToTensor(), download=True)
test_data = MNIST(root='data/', train=False, transform=ToTensor(), download=True)
  • 划分批次(针对mini - batch):使用数据加载器将数据集划分为指定大小的批次。例如,继续上面的代码,设置batch_size为64来创建训练集和测试集的数据加载器:
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

这里batch_size参数决定了每个mini - batch的样本数量,shuffle参数用于在每个训练轮次(epoch)开始时是否打乱数据顺序,对于mini - batch训练通常设置为True以增加随机性;对于测试集,一般不需要打乱数据。

2. 模型训练阶段

  • 使用mini - batch进行训练:在训练循环中,每次从数据加载器中获取一个mini - batch的数据进行训练。以下是一个典型的使用PyTorch训练神经网络的示例:
# 假设model是已经定义好的模型,criterion是损失函数,optimizer是优化器
for epoch in range(num_epochs):for i, (x_batch, y_batch) in enumerate(train_loader):# 前向传播outputs = model(x_batch)loss = criterion(outputs, y_batch)# 反向传播和更新参数optimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item()}')

在这个示例中,train_loader每次迭代返回一个mini - batch的输入数据x_batch和对应的标签y_batch。模型使用这些数据进行前向传播计算预测值,然后计算损失,接着进行反向传播更新模型参数。

enumerate将一个可遍历的数据对象(如列表、元组、字符串或迭代器)组合为一个索引序列。

3. 模型评估阶段

  • 使用mini - batch评估:在测试循环中,使用测试数据加载器以mini - batch的方式获取数据进行评估。例如:
correct = 0
total = 0
with torch.no_grad():for x_batch, y_batch in test_loader:outputs = model(x_batch)_, predicted = torch.max(outputs.data, 1)total += y_batch.size(0)correct += (predicted == y_batch).sum().item()accuracy = correct / total
print(f'Test Accuracy: {accuracy}')

这里使用test_loader以mini - batch方式获取测试数据,对每个mini - batch进行预测,并统计正确预测的样本数量,最后计算模型在整个测试集上的准确率。


http://www.mrgr.cn/news/63422.html

相关文章:

  • FastAPI 核心概念:构建高性能的 Python Web 服务
  • SpringBoot抗疫物资管理:系统开发与部署
  • mongodb 按条件进行备份和恢复
  • Python基于TensorFlow实现双向循环神经网络GRU加注意力机制分类模型(BiGRU-Attention分类算法)项目实战
  • Java面试经典 150 题.P80. 删除有序数组中的重复项 II(004)
  • 假设检验简介
  • Java面试题十五
  • 基于springboot的Java学习论坛平台
  • prometheus 快速入门
  • python enum用法
  • opencv - py_imgproc - py_grabcut GrabCut 算法提取前景
  • JavaScript 实战技巧:让你成为前端高手的必备知识3(进阶版)
  • 【环境问题】pycharm远程服务器文件路径问题
  • 【前端】项目中遇到的问题汇总(长期更新)
  • 热点扫描:人工智能专利布局背后的商业博弈
  • Java思想
  • 拒绝无效发稿!软文推广这样精选媒体,一不小心省下百万宣发费用!媒介盒子分享
  • 视频一键转换3D:Autodesk 发布 Video to 3D Scene
  • 【django】django RESTFramework前后端分离框架快速入门
  • dim的方向 傻傻分不清
  • BGP实验--BGP路由汇总
  • 苹果手机备忘录怎么看字数统计
  • WAF+AI结合,雷池社区版的强大防守能力
  • RUM最佳实践:内网IP地址映射地图地理位置场景
  • 周转时间、带权周转时间、平均周转时间、平均带权周转时间
  • uniapp MD5加密