当前位置: 首页 > news >正文

Pytorch分布式训练

现在深度学习模型占用显存大,数据量也大,单张显卡上训练已经满足不了要求了,只有多GPU并行训练才能加快训练速度;并行训练又分为模型并行和数据并行两种。模型并行比较少用到,这里主要介绍数据并行,pytorch中数据并行有两种DataParallel和DistributedDataParallel,前者是pytorch训练早期采用的,由于其单线程和显存利用率低等缺点,现在大多使用后者。


文章目录

  • 1、并行训练
  • 2、DataParallel与DistributedDataParallel
    • 2.1 DataParallel
      • 2.1.1 DataParallel架构
      • 2.1.2 模型参数的更新过程
      • 2.1.3 DataParallel的优缺点
    • 2.2 DistributedDataParallel
      • 2.2.1 DistributedDataParallel架构
      • 2.2.2 训练方式及模型参数更新
  • 3、DDP训练代码框架


1、并行训练

并行训练分两种,模型并行和数据并行。
1)模型并行。模型并行通常是指要训练的模型非常大,大到一块卡根本放不下,因而需要把模型进行拆分放到不同的卡上。例如早期的AlexNet就是拆分模型利用两块GPU训练的。
2)数据并行。数据并行通常用于训练数据非常庞大的时候,比如有几百万张图像用于训练模型。此时,如果只用一张卡来进行训练,那么训练时间就会非常的长。或者模型比较大,由于单卡显存的限制,训练时的batch size不能设置过大。这时就需要多个GPU训练来提升batchsize大小。
在这里插入图片描述
如上图所示,数据并行又分为单进程的DataParallel和分布式数据并行DistributedDataParallel。


2、DataParallel与DistributedDataParallel

2.1 DataParallel

在PyTorch中,DataParallel是一种用于多GPU训练的数据并行方式。其训练过程是单进程的,但会利用多个GPU来加速模型的训练。关于DataParallel训练方式如何更新模型参数,可以详细解释如下:

2.1.1 DataParallel架构

DataParallel采用参数服务器架构。在这种架构中,通常会将一块GPU作为server(服务器),其余的GPU作为worker(工作节点)。每个worker上都会保留一个模型的副本用于计算。

2.1.2 模型参数的更新过程

  1. 数据拆分与分发:
    在训练开始时,输入数据会被拆分成多个部分,并分发到不同的GPU上。
  2. 前向传播与计算梯度:
    每个GPU独立地对其分配到的数据进行前向传播,计算预测输出。
    接着进行反向传播,计算损失函数关于模型参数的梯度。
  3. 梯度汇总与参数更新:
    所有的worker计算得到的梯度都会被汇总到server GPU上。
    在server GPU上,使用这些汇总的梯度来更新模型参数。
    更新后的模型参数会被同步到其他所有的GPU上,以确保每个GPU上的模型副本都是最新的。
    DataParallel梯度汇总与参数更新如下图所示:
    在这里插入图片描述

2.1.3 DataParallel的优缺点

  • 优点:
    代码简洁易于实现:只需在原有单卡训练的代码中加上一行即可实现多GPU训练(例如:model = nn.DataParallel(model))。
  • 缺点:
    通信开销大:作为server的GPU需要和其他所有的GPU进行通信,梯度汇总、参数更新等步骤都由它完成,导致效率较低。
    可扩展性差:仅支持单机多卡,随着GPU数量的增加,通信开销也会线性增长,因此不适用于GPU数量非常多的情况。
    训练效率低:占用显存不均衡,且不支持混合精度训练,训练效率较低。

2.2 DistributedDataParallel

PyTorch的DistributedDataParallel(DDP)训练方式是一种高效的分布式训练方式,特别适用于多GPU甚至跨多节点(机器)的训练场景。DDP采用Ring-All-Reduce架构,其训练过程是多进程的。在DDP中,模型参数的更新过程涉及多个步骤和组件的协同工作。以下是DDP训练方式更新模型参数的详细解释:

2.2.1 DistributedDataParallel架构

DistributedDataParallel采用Ring-All-Reduce架构同步各个GPU之间的梯度,所有的GPU设备安排在一个逻辑环中,每个GPU只与它两个相邻的GPU之间通信,每个GPU对应一个进程,通过N轮迭代,每个GPU设备都会用于全局平均梯度。

2.2.2 训练方式及模型参数更新

  1. 环境准备与初始化
    环境准备:
    确保所有节点(机器)上的PyTorch版本一致,并且安装了对应的CUDA版本,并确保节点间可以相互通信。
    初始化:
    在每个节点上启动多个进程,每个进程绑定到特定的GPU上。
    使用torch.distributed.init_process_group初始化分布式进程组,指定后端(如nccl)和初始化方法(如env://或tcp://)。
  2. 数据分发与模型复制
    数据分发:
    使用torch.utils.data.distributed.DistributedSampler来分发数据集,确保每个进程处理的数据集部分是不重叠的。
    创建DataLoader时,将DistributedSampler作为sampler参数传入。
    模型复制:
    在每个GPU上复制一份模型副本。使用torch.nn.parallel.DistributedDataParallel包装模型,指定device_ids和output_device。
  3. 前向传播与反向传播
    前向传播:每个进程在其绑定的GPU上独立地对分配到的数据进行前向传播。
    反向传播:每个进程独立地计算损失函数关于模型参数的梯度,并进行反向传播。
  4. 梯度同步与参数更新
    梯度同步:
    使用all-reduce操作将所有进程计算得到的梯度进行同步,确保每个进程都能获得全局平均梯度。这一步是并行的,避免了单节点成为通信瓶颈。
    参数更新:
    每个进程使用同步后的梯度独立地更新其模型参数。由于每个进程都使用了相同的初始参数和同步后的梯度进行更新,因此更新后的模型参数在所有进程中都是一致的。
  5. 训练循环与模型保存
    训练循环:
    在每个epoch开始时,使用DistributedSampler的set_epoch方法来更新采样器状态,以确保每个epoch的数据打乱方式是一致的。
    遍历数据加载器,进行前向传播、反向传播和参数更新。
    模型保存:
    通常在rank为0的进程上保存模型参数,以避免多次保存导致的重复和浪费。
  6. 模型评估
    在DDP训练中,每个GPU进程独立计算其数据的评估结果(如准确率、损失等),在评估时,可能需要收集和整合这些结果。
    通过torch.distributed.all_gather函数,可以将所有进程的评估结果聚集到每个进程中。这样每个进程都可以获取到完整的评估数据,进而计算全局的指标。如果只需要全局的汇总数据(如总损失或平均准确率),可以使用torch.distributed.reduceall_reduce操作直接计算汇总结果,这样更加高效。

使用Ring-All-Reduce架构同步梯度,这一步是并行的,避免了单节点成为通信瓶颈,提升了训练效率,如下图所示。
在这里插入图片描述

3、DDP训练代码框架

参考:https://www.cnblogs.com/liyier/p/18135209


http://www.mrgr.cn/news/81334.html

相关文章:

  • 用Unity做没有热更需求的单机游戏是否有必要使用AssetBundle?
  • Lecture 6 Isolation System Call Entry
  • Linux(二)_清理空间
  • 2002 - Can‘t connect to server on ‘192.168.1.XX‘ (36)
  • 【test】git clone lfs问题记录
  • Confluent Cloud Kafka 可观测性最佳实践
  • Unity模型观察脚本
  • Android开发环境搭建和编译系统
  • 知识图谱嵌入大总结:难点、方法、工具、和图嵌入的区别
  • 【innodb 阅读笔记】之 数据页结构介绍
  • springboot容器无法获取@Autowired对象,报null对象空指针问题的解决方式
  • Element-plus表格使用总结
  • 5、mysql的读写分离
  • Docker数据库的主从复制
  • 基于springboot的海洋知识服务平台的设计与实现
  • HuaWei、NVIDIA 数据中心 AI 算力对比
  • ThinkPHP接入PayPal支付
  • Kibana:LINUX_X86_64 和 DEB_X86_64两种可选下载方式的区别
  • RT-DETR学习笔记(2)
  • CTFHub disable_functions通关
  • 华为路由器AR101W-S
  • go语言并发文件备份,自动比对自动重命名(逐行注释)
  • Require:离线部署 Sourcegraph
  • Linux驱动开发--字符设备驱动开发
  • STM32 高级 谈一下IPV4/默认网关/子网掩码/DNS服务器/MAC
  • c++类型判断和获取原始类型