当前位置: 首页 > news >正文

昇思MindSpore进阶教程-参数

大家好,我是刘明,明志科技创始人,华为昇思MindSpore布道师。
技术上主攻前端开发、鸿蒙开发和AI算法研究。
努力为大家带来持续的技术分享,如果你也喜欢我的文章,就点个关注吧

参数

参数(Parameter)是神经网络训练的核心,通常作为神经网络层的内部成员变量。本节我们将系统介绍参数以及其相关使用方法。

Parameter

参数(Parameter)是一类特殊的Tensor,是指在模型训练过程中可以对其值进行更新的变量。MindSpore提供mindspore.Parameter类进行Parameter的构造。为了对不同用途的Parameter进行区分,下面对两种不同类别的Parameter进行定义:

  • 可训练参数。在模型训练过程中根据反向传播算法求得梯度后进行更新的Tensor,此时需要将required_grad设置为True。

  • 不可训练参数。不参与反向传播,但需要更新值的Tensor(如BatchNorm中的mean和var变量),此时需要将requires_grad设置为False。

Parameter默认设置required_grad=True。

下面我们构造一个简单的全连接层:

import numpy as np
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameterclass Network(nn.Cell):def __init__(self):super().__init__()self.w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weightself.b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # biasdef construct(self, x):z = ops.matmul(x, self.w) + self.breturn znet = Network()

在Cell的__init__方法中,我们定义了w和b两个Parameter,并配置name进行命名空间管理。在construct方法中使用self.attr直接调用参与Tensor运算。

获取Parameter

在使用Cell+Parameter构造神经网络层后,我们可以使用多种方法来获取Cell管理的Parameter。

获取单个参数

单独获取某个特定参数,直接调用Python类的成员变量即可。

获取可训练参数

可使用Cell.trainable_params方法获取可训练参数,通常在配置优化器时需调用此接口。

获取所有参数

使用Cell.get_parameters()方法可获取所有参数,此时会返回一个Python迭代器。
或者可以调用Cell.parameters_and_names返回参数名称及参数。

for name, param in net.parameters_and_names():print(f"{name}:\n{param.asnumpy()}")

修改Parameter

直接修改参数值

Parameter是一种特殊的Tensor,因此可以使用Tensor索引修改的方式对其值进行修改。

覆盖修改参数值

可调用Parameter.set_data方法,使用相同Shape的Tensor对Parameter进行覆盖。该方法常用于使用Initializer进行Cell遍历初始化。

运行时修改参数值

参数的主要作用为模型训练时对其值进行更新,在反向传播获得梯度后,或不可训练参数需要进行更新,都涉及到运行时参数修改。由于MindSpore的使用静态图加速编译设计,此时需要使用mindspore.ops.assign接口对参数进行赋值。该方法常用于自定义优化器场景。下面是一个简单的运行时修改参数值样例:

import mindspore as ms@ms.jit
def modify_parameter():b_hat = ms.Tensor([7, 8, 9])ops.assign(net.b, b_hat)return Truemodify_parameter()
print(net.b.asnumpy())

Parameter Tuple

变量元组ParameterTuple,用于保存多个Parameter,继承于元组tuple,提供克隆功能。

如下示例提供ParameterTuple创建方法:

from mindspore.common.initializer import initializer
from mindspore import ParameterTuple
# 创建
x = Parameter(default_input=ms.Tensor(np.arange(2 * 3).reshape((2, 3))), name="x")
y = Parameter(default_input=initializer('ones', [1, 2, 3], ms.float32), name='y')
z = Parameter(default_input=2.0, name='z')
params = ParameterTuple((x, y, z))# 从params克隆并修改名称为"params_copy"
params_copy = params.clone("params_copy")print(params)
print(params_copy)

http://www.mrgr.cn/news/34187.html

相关文章:

  • C语言函数指针,重命名使用
  • 产品经理晋级-Axure中继器+动态面板制作美观表格
  • 推荐一款好用的postman替代工具2024
  • 机器情绪及抑郁症算法
  • CCI3.0-HQ:用于预训练大型语言模型的高质量大规模中文数据集
  • Fortinet Security Fabric安全平台
  • MATLAB在无线传感器网络设计中的应用与实践
  • 从零开始之AI面试小程序
  • LeetCode 20.有效的括号
  • Leetcode 543. 124. 二叉树的直径 树形dp C++实现
  • 输出Hate-C语言
  • 【Ambari自定义组件集成】Bigtop320集成Ranger实战
  • GPT-4o能玩《黑神话》!精英怪胜率超人类,无强化学习纯大模型方案
  • ChatGPT与R语言融合技术在生态环境数据统计分析、绘图、模型中的实践与进阶应用
  • Debian安装mysql遇到的问题解决及yum源配置
  • 苹果和香蕉联合食用,益处最大,能控制血压水平,高血压死亡风险降低 40%!
  • C#知识|继承与多态
  • 【2024.09】关于 UMLS 在支持大型语言模型提出的诊断生成中的作用
  • spring 注解 - @NotNull - 确保字段或参数值不为 null
  • C++学习,命令空间
  • redis常用五种数据类型的常用指令
  • 核心复现—计及需求响应的区域综合能源系统双层优化调度策略
  • 网安新声 | 黎巴嫩BP机爆炸事件带来的安全新挑战与反思
  • ubuntu安装gitlab-runner
  • 力扣647-回文子串(Java详细题解)
  • 光控资本:沪指涨0.72%,煤炭、银行板块拉升,车路云概念活跃