当前位置: 首页 > news >正文

YOLOv6-4.0部分代码阅读笔记-ema.py

ema.py

yolov6\utils\ema.py

目录

ema.py

1.所需的库和模块

2.class ModelEMA: 

3.def copy_attr(a, b, include=(), exclude=()): 

4.def is_parallel(model): 

5.def de_parallel(model): 


1.所需的库和模块

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# The code is based on
# https://github.com/ultralytics/yolov5/blob/master/utils/torch_utils.py
import math
# copy.deepcopy()
# 是深拷贝,会拷贝对象及其子对象,哪怕以后对其有改动,也不会影响其第一次的拷贝。
# 函数可以递归地复制对象及其所有嵌套对象,创建一个全新的独立副本,而不共享任何数据。
from copy import deepcopy
import torch
import torch.nn as nn

2.class ModelEMA: 

class ModelEMA:# 模型的指数移动平均(Exponential Moving Average, EMA)来自 https://github.com/rwightman/pytorch-image-models# 保持模型 state_dict(参数和缓冲区)中所有内容的移动平均值。# 这旨在允许以下功能:# https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage# 一些训练方案需要权重的平滑版本才能表现良好。# 此类在模型初始化、GPU 分配和分布式训练包装器的序列中初始化时很敏感。""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-modelsKeep a moving average of everything in the model state_dict (parameters and buffers).This is intended to allow functionality likehttps://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverageA smoothed version of the weights is necessary for some training schemes to perform well.This class is sensitive where it is initialized in the sequence of model init,GPU assignment and distributed training wrappers."""# 1.self : 类的实例。# 2.model : 要进行 EMA 处理的模型。# 3.decay : EMA 衰减率,默认值为 0.9999。# 4.updates : 更新次数,默认值为 0。def __init__(self, model, decay=0.9999, updates=0):# 这行代码创建了 model 的深拷贝,并赋值给 self.ema 。 deepcopy 用于确保模型参数的值被复制而不是仅仅复制引用。# is_parallel(model) 检查模型是否是并行的(例如,是否使用了 torch.nn.DataParallel 或 torch.nn.parallel.DistributedDataParallel )。# 如果模型是并行的,则 model.module 包含了原始模型,否则直接使用 model 。# eval() 将复制的模型设置为评估模式,这会关闭 Dropout 和 Batch Normalization 等训练特有的层。self.ema = deepcopy(model.module if is_parallel(model) else model).eval()  # FP32 EMA# 将传入的 updates 参数值赋给实例的 self.updates 属性。self.updates = updates# 定义了一个 lambda 函数作为 self.decay 属性。这个函数根据传入的 x 值计算衰减率。这里 decay 参数被用作基础衰减率,而 x 通常是当前的迭代次数或 epoch 数。self.decay = lambda x: decay * (1 - math.exp(-x / 2000))# 遍历 self.ema (EMA 模型)的所有参数。for param in self.ema.parameters():# 对于每个参数,调用 requires_grad_(False) 方法,将参数的 requires_grad 属性设置为 False 。这意味着在反向传播时,这些参数不会计算梯度,通常用于冻结参数,不对其进行训练。param.requires_grad_(False)# 这段代码定义了一个名为 update 的方法,它用于更新指数移动平均(EMA)模型的参数。# 1.model :当前训练的模型。def update(self, model):# 使用 PyTorch 的 torch.no_grad() 上下文管理器,这会暂时禁用梯度计算,使得在这个上下文中的操作不会追踪梯度。这对于更新 EMA 参数是必要的,因为这些操作不需要梯度。with torch.no_grad():# 更新次数 self.updates 加一。self.updates += 1# 计算当前的衰减率。衰减率是根据 self.updates 和在构造函数中定义的衰减函数计算得出的。decay = self.decay(self.updates)# 获取当前训练模型的状态字典( state_dict )。# 如果模型是并行的(即被 torch.nn.DataParallel 或 torch.nn.parallel.DistributedDataParallel 包装过),则通过 model.module 访问原始模型的状态字典。# 否则直接使用 model.state_dict() 。# model.state_dict()# 在PyTorch中, .state_dict() 方法是 torch.nn.Module 类的一个实例方法,用于返回一个包含模型所有参数和缓存的字典(state dictionary)。这个字典通常用于保存和加载模型的权重。# 参数 : 无参数。# 返回值 : 返回一个包含模型中所有参数和缓存的字典。state_dict = model.module.state_dict() if is_parallel(model) else model.state_dict()  # model state_dict    模型状态字典# 遍历 EMA 模型的状态字典中的每个参数。for k, item in self.ema.state_dict().items():# 检查参数 item 是否是浮点类型。EMA 通常只对浮点类型的参数进行更新。if item.dtype.is_floating_point:# 将 EMA 参数 item 乘以衰减率 decay ,这是 EMA 更新的一部分。item *= decay# 将当前训练模型的对应参数 state_dict[k] 乘以 (1 - decay) ,然后加到 EMA 参数 item 上。# .detach() 方法用于从当前计算图中分离出参数,这样在更新 EMA 参数时不会追踪梯度。# .detach()# 在 PyTorch 中, .detach() 函数用于从当前计算图中分离出一个张量(Tensor),使得之后对这个张量的操作不会追踪梯度,也就是说,它将不再参与梯度的计算和反向传播。# 1. 方法定义 : .detach() 是 torch.Tensor 的一个方法,可以被任何张量(Tensor)对象调用。# 2. 返回值 : .detach() 返回一个新的张量,该张量与原始张量共享数据但不会追踪梯度。如果张量已经是不需要梯度的,那么 .detach() 将直接返回原始张量。# 3. 停止梯度追踪 :当进行某些操作时,我们可能不希望建立计算图,或者不希望某些部分的计算图被追踪,这时可以使用 .detach() 。item += (1 - decay) * state_dict[k].detach()def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):copy_attr(self.ema, model, include, exclude)

3.def copy_attr(a, b, include=(), exclude=()): 

# 目的是从一个实例( b )复制属性并将它们设置到另一个实例( a )。
# 1.a :目标实例,将从 b 实例复制的属性设置到这个实例上。
# 2.b :源实例,从这个实例复制属性。
# 3.include :一个元组,指定要复制的属性名称。如果为空,则复制所有属性。
# 4.exclude :一个元组,指定要排除的属性名称。
def copy_attr(a, b, include=(), exclude=()):# 从一个实例复制属性并将其设置到另一个实例。"""Copy attributes from one instance and set them to another instance."""# 遍历源实例 b 的所有属性。 b.__dict__ 返回一个字典,包含 b 的所有属性及其值。for k, item in b.__dict__.items():# 检查每个属性是否满足以下条件之一,如果满足,则跳过该属性:# len(include) and k not in include :如果 include 参数不为空,并且属性名称 k 不在 include 元组中。# k.startswith('_') :如果属性名称 k 以单个下划线开头,通常表示私有属性。# k in exclude :如果属性名称 k 在 exclude 元组中。if (len(include) and k not in include) or k.startswith('_') or k in exclude:continueelse:# 如果属性不满足上述任何条件,则执行以下操作:# 使用 setattr 函数将属性 k 从源实例 b 复制到目标实例 a 。setattr(a, k, item)

4.def is_parallel(model): 

def is_parallel(model):# 如果模型的类型是 DP 或 DDP,则返回 True,否则返回 False。'''Return True if model's type is DP or DDP, else False.'''return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)

5.def de_parallel(model): 

def de_parallel(model):# 解除模型的并行化。如果模型类型为 DP 或 DDP,则返回单 GPU 模型。'''De-parallelize a model. Return single-GPU model if model's type is DP or DDP.'''return model.module if is_parallel(model) else model

http://www.mrgr.cn/news/64820.html

相关文章:

  • 力扣题目解析--最长公共前缀
  • vue系列==Vuex状态管理器
  • Vision - 开源视觉分割算法框架 Grounded SAM2 配置与推理 教程 (1)
  • 珠海盈致mes系统在来料检验管理的优缺点
  • java设计模式之行为型模式(11种)
  • 【MWorks】Ubuntu 系统搭建
  • 2024年一带一路金砖技能大赛之大数据容器云开发
  • Win10 连接到 Ubuntu 黑屏无法连接 使用Rustdesk显示 No Displays 没有显示器
  • GOF的C++软件设计模式的分类和模式名称
  • 数据结构初阶排序全解
  • 力扣周赛:第422场周赛
  • roberta融合模型创新中文新闻文本标题分类
  • 优青博导团队/免费指导/一站式服务/数据分析/实验设计/论文润色/组学技术服务 、表观组分析、互作组分析、遗传转化实验、单细胞检测与生物医学
  • ctfshow——web(总结持续更新)
  • 将分类标签转换为模型可以处理的数值格式
  • 计算机网络串联——打开网站的具体步骤
  • Linux 进程间通信 共享内存_消息队列_信号量
  • 提高交换网络可靠性之端口安全配置
  • windows rdp 将远程技术嵌入到你的软件——未来之窗行业应用跨平台架构
  • 第四次:2024年郑州马拉松赛事记
  • 什么是三大范式, 为什么要有三大范式, 什么场景下不用遵循三大范式
  • 《GBDT 算法的原理推导》 11-15更新决策树的叶子节点值 公式解析
  • Linux内核编程(十八)ADC驱动
  • 深入解析RSA算法:加密与安全性
  • Spring DispatcherServlet详解
  • 在vue中 什么是slot机制,如何使用以及使用场景详细讲解