当前位置: 首页 > news >正文

parameters()函数 --- 获取模型参数量

  parameters() 函数是 PyTorch 中 torch.nn.Module 类的一个方法,用于返回模型中所有可训练的参数。下面是对这个函数的详细解释:

1. parameters() 方法工作机制

parameters() 方法工作机制:定义一个模型,通常会将多个层(如卷积层、线性层等)组合在一起,这些层就是主模块的子模块。而这些子模块中有些也可能包含自己的子模块,形成一个递归的层次结构。parameters() 方法会自动遍历整个层次结构,获取每个模块和子模块中的可训练参数。

2. 返回值

  • 返回一个迭代器,其中包含所有可训练的参数。

 3. 示例

import torch
import torch.nn as nn# 定义一个包含多个子模块的模型
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)  # 子模块1self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)  # 子模块2self.fc = nn.Linear(32 * 16 * 16, 10)  # 子模块3def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)  # Flattening for the fully connected layerx = self.fc(x)return x# 实例化模型
model = MyModel()# 获取模型的参数
for param in model.parameters():print(param.size())

      上述代码中Conv2d函数bias没有指定值,则取默认值bias=True,也就是说,在上述代码中,bias是存在的。

分析:

  • MyModel 包含了 3 个层:两个卷积层(conv1conv2)以及一个全连接层(fc)。
  • 这些层都是 MyModel子模块
  • 当你调用 model.parameters() 时,它不仅返回 MyModel 自己的参数,还会递归地返回所有子模块(即 conv1conv2fc)的参数。

3. 递归参数获取

parameters() 函数递归地遍历模块中的所有子模块,获取每个模块的参数。每个 nn.Module 对象的 parameters() 方法会遍历它自己和它的所有子模块,将所有可训练参数打包在一起。

你可以通过以下代码验证这一点:

# 打印每个模块的参数名和大小 
for name, param in model.named_parameters(): print(f"参数名: {name}, 大小: {param.size()}")

输出:

参数名: conv1.weight, 大小: torch.Size([16, 3, 3, 3]) 
参数名: conv1.bias, 大小: torch.Size([16]) 
参数名: conv2.weight, 大小: torch.Size([32, 16, 3, 3]) 
参数名: conv2.bias, 大小: torch.Size([32]) 
参数名: fc.weight, 大小: torch.Size([10, 8192]) 
参数名: fc.bias, 大小: torch.Size([10])

这里你会看到,不仅 MyModel 中的卷积层 conv1conv2 的权重和偏置参数被列出,连全连接层 fc 的参数也被列出了。

4. 参数过滤(过滤出可训练的参数)

可以通过条件过滤来获取特定类型的参数,例如仅获取可训练的参数:

trainable_params = [p for p in model.parameters() if p.requires_grad]

5. 计算参数总量

可以结合 numel() 方法来计算模型的参数总量:

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 
print(f"总参数量: {total_params}")

在这个示例中,p.numel() 返回参数的元素数量,if p.requires_grad 确保只计算需要梯度的参数(即可训练的参数)。运行这个代码将输出模型的参数总量。 

 


http://www.mrgr.cn/news/34707.html

相关文章:

  • C语言 | Leetcode C语言题解之第432题全O(1)的数据结构
  • 二、电脑入门2之常用dos命令
  • [vulnhub] Jarbas-Jenkins
  • 生物反馈治疗仪——精神患者治疗方案
  • 2024从传统到智能,AI做PPT软件的崛起之路
  • mysql高级
  • 12. Scenario Analysis for greedy algorithm
  • MyBatis - 动态SQL
  • VirtualBox+Vagrant快速搭建Centos7系统【最新详细教程】
  • 爬虫的流程
  • 毕业设计选题:基于ssm+vue+uniapp的英语学习激励系统小程序
  • 免费的高质量、美观的甘特图模板
  • 【前端】读取 xlsx 文件并转化成 json 数据
  • Springboot Mybatis条件查询
  • 基于 Amazon Bedrock +lambda函数调用大模型构建你的智能网页助手
  • 【已解决】用JAVA代码实现递归算法-从自然数中取3个数进行组合之递归算法-用递归算法找出 n(n>=3) 个自然数中取 3 个数的组合。
  • 匈牙利算法详解与实现
  • 如何使用GLib的单向链表GSList
  • 【leetcode】环形链表、最长公共前缀
  • 注册建造师执业工程规模标准(市政公用工程)