当前位置: 首页 > news >正文

pytorch retain_grad vs requires_grad

requires_grad大家都挺熟悉的,因此穿插在retain_grad的例子里进行捎带讲解就行。下面看一个代码片段:

import torch# 创建一个标量 tensor,并开启梯度计算
x = torch.tensor(2.0, requires_grad=True)# 中间计算:y 依赖于 x,是非叶子节点
y = x * 3# 继续计算,得到 z
z = y * 4# 反向传播
z.backward()# 查看梯度
print("x.grad:", x.grad)  
print("y.grad:", y.grad)  

输出结果为:

x.grad: tensor(12.)
y.grad: None
/tmp/ipykernel_219007/1060175670.py:17: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:489.)print("y.grad:", y.grad)

警告的大致意思是:访问了非叶子节点的.grad属性,但非叶子节点的.grad属性并不会在反向传播的过程中被自动保存下来(这是为了节省内存,毕竟我们只需要计算那些手动设置.requires_gradTrue的张量的梯度,并进行梯度更新,对吧?)

因此,我们只需要添加一行代码y.retain_grad(),修改后的代码如下:

import torch# 创建一个标量 tensor,并开启梯度计算
x = torch.tensor(2.0, requires_grad=True)# 中间计算:y 依赖于 x,是非叶子节点
y = x * 3
y.retain_grad()# 继续计算,得到 z
z = y * 4# 反向传播
z.backward()# 查看梯度
print("x.grad:", x.grad)  
print("y.grad:", y.grad)  

输出结果为:

x.grad: tensor(12.)
y.grad: tensor(4.)

可以看到,现在非叶子节点y的梯度也在反向传播以后被正确保存了!


http://www.mrgr.cn/news/94130.html

相关文章:

  • 项目实操分享:一个基于 Flask 的音乐生成系统,能够根据用户指定的参数自动生成 MIDI 音乐并转换为音频文件
  • git本地仓库链接远程仓库
  • go 标准库包学习笔记
  • Rust 之一 基本环境搭建、各组件工具的文档、源码、配置
  • Burpsuite使用笔记
  • 【大模型统一集成项目】让 AI 聊天更丝滑:SSE 实现流式对话!
  • 【大模型统一集成项目】让 AI 聊天更丝滑:WebSocket 实现流式对话!
  • Android实现Socket通信
  • 利用selenium调用豆包进行自动化问答以及信息提取
  • tcc编译器教程6 进一步学习编译gmake源代码
  • go函数详解
  • 【Linux】线程池、单例模式、死锁
  • JVM内存结构笔记01-运行时数据区域
  • golang 高性能的 MySQL 数据导出
  • 下载以后各个软件或者服务器的启动与关闭
  • Docker安装RabbitMQ
  • Qt入门笔记
  • macOS 安装配置 iTerm2 记录
  • 蓝桥杯省赛真题C++B组2024-握手问题
  • MicroPython 智能硬件开发完整指南