当前位置: 首页 > news >正文

PyTorch使用(7)-张量常见运算函数

1. 基本数学运算

1.1 平方根和幂运算

import torchx = torch.tensor([4.0, 9.0, 16.0])# 平方根
sqrt_x = torch.sqrt(x)  # tensor([2., 3., 4.])# 平方
square_x = torch.square(x)  # tensor([16., 81., 256.])# 任意幂次
pow_x = torch.pow(x, 3)  # tensor([64., 729., 4096.])# 运算符形式
sqrt_x_alt = x ** 0.5
square_x_alt = x ** 2

1.2 指数和对数

# 自然指数
exp_x = torch.exp(x)  # tensor([5.4595e+01, 8.1031e+03, 8.8861e+06])# 自然对数
log_x = torch.log(x)  # tensor([1.3863, 2.1972, 2.7726])# 以10为底的对数
log10_x = torch.log10(x)  # tensor([0.6021, 0.9542, 1.2041])# 带clip的最小值保护(避免log(0))
safe_log = torch.log(x + 1e-8)

2. 统计运算

2.1 求和与均值

x = torch.randn(3, 4)  # 3x4随机张量# 全局求和
total = torch.sum(x)  # 标量# 沿特定维度求和
sum_dim0 = torch.sum(x, dim=0)  # 形状(4,),沿行求和
sum_dim1 = torch.sum(x, dim=1)  # 形状(3,),沿列求和# 均值计算
mean_val = torch.mean(x)  # 全局均值
mean_dim0 = torch.mean(x, dim=0)  # 沿行求均值

2.2 极值与排序

# 最大值/最小值
max_val = torch.max(x)  # 全局最大值
min_val = torch.min(x)  # 全局最小值# 沿维度的极值及索引
max_vals, max_indices = torch.max(x, dim=1)  # 每行最大值及位置
min_vals, min_indices = torch.min(x, dim=0)  # 每列最小值及位置# 排序
sorted_vals, sorted_indices = torch.sort(x, dim=1, descending=True)

2.3 方差与标准差

# 无偏方差(分母n-1)
var_x = torch.var(x, unbiased=True)  # 全局方差
var_dim0 = torch.var(x, dim=0)  # 沿行方差# 标准差
std_x = torch.std(x)  # 全局标准差
std_dim1 = torch.std(x, dim=1)  # 沿列标准差

3. 矩阵运算

3.1 基本矩阵运算

A = torch.randn(3, 4)
B = torch.randn(4, 5)# 矩阵乘法
matmul = torch.matmul(A, B)  # 形状(3,5)
matmul_alt = A @ B  # 等价写法# 点积(向量)
v1 = torch.randn(3)
v2 = torch.randn(3)
dot_product = torch.dot(v1, v2)# 批量矩阵乘法
batch_A = torch.randn(5, 3, 4)  # 5个3x4矩阵
batch_B = torch.randn(5, 4, 5)  # 5个4x5矩阵
batch_matmul = torch.bmm(batch_A, batch_B)  # 形状(5,3,5)

3.2 矩阵分解

# 特征分解(对称矩阵)
sym_matrix = torch.randn(3, 3)
sym_matrix = sym_matrix @ sym_matrix.T  # 构造对称矩阵
eigenvals, eigenvecs = torch.linalg.eigh(sym_matrix)# SVD分解
U, S, V = torch.linalg.svd(A)

4. 比较运算

4.1 元素级比较

a = torch.tensor([1, 2, 3])
b = torch.tensor([3, 2, 1])# 比较运算
eq = torch.eq(a, b)  # tensor([False, True, False])
gt = torch.gt(a, b)  # tensor([False, False, True])
lt = torch.lt(a, b)  # tensor([True, False, False])# 运算符形式
eq_alt = a == b
gt_alt = a > b

4.2 约简比较

# 判断所有元素为True
all_true = torch.all(eq)# 判断任一元素为True
any_true = torch.any(gt)# 判断张量相等(形状和值)
torch.equal(a, b)  # False

5. 规约运算

5.1 常用规约

x = torch.randn(2, 3)# 求和规约
sum_all = x.sum()  # 全局求和
sum_dim = x.sum(dim=1)  # 沿维度规约# 累积和
cumsum = x.cumsum(dim=0)  # 沿维度累积# 乘积规约
prod_all = x.prod()  # 全局乘积

5.2 高级规约

# 加权平均
weights = torch.softmax(torch.randn(3), dim=0)
weighted_mean = torch.sum(x * weights, dim=1)# 沿维度的logsumexp(数值稳定)
logsumexp = torch.logsumexp(x, dim=1)

6. 工程实践建议

6.1. 广播机制理解:确保运算张量的形状兼容

# 广播示例
a = torch.randn(3, 1)
b = torch.randn(1, 3)
c = a + b  # 形状(3,3)

6.2. 原地操作:使用_后缀节省内存

x.sqrt_()  # 原地平方根
x.add_(1)  # 原地加1

6.3. 设备一致性:确保运算张量在同一设备

if torch.cuda.is_available():x = x.cuda()y = y.cuda()z = x + y

6.4. 梯度保留:注意运算对计算图的影响

x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
y.backward()  # dy/dx = 2x = 4.0

6.5. 数值稳定性:使用稳定实现

# 不稳定的softmax实现
unstable = torch.exp(x) / torch.exp(x).sum(dim=1, keepdim=True)# 稳定的softmax实现
stable = torch.softmax(x, dim=1)

7. 性能优化技巧

7.1 向量化操作:避免Python循环

# 不好的做法
result = torch.zeros_like(x)
for i in range(x.size(0)):result[i] = x[i] * 2# 好的做法
result = x * 2

7.2. 融合操作:减少中间结果

# 低效
temp = x + y
result = temp * z# 高效
result = (x + y) * z

7.3. 使用内置函数:利用优化实现

# 自定义实现
custom_norm = torch.sqrt(torch.sum(x ** 2))# 内置优化函数
optimized_norm = torch.norm(x)
原文地址:https://blog.csdn.net/jiaomongjun/article/details/146927433
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mrgr.cn/news/96973.html

相关文章:

  • Design Compiler:库特征分析(ALIB)
  • 第J3-1周:DenseNet算法 实现乳腺癌识别(含真实图片预测)
  • 启动arthas-boot.jar端口占用
  • LeetCode 2140.解决智力问题:记忆化搜索(DFS) / 动态规划(DP)
  • 什么是数据仓库
  • 吾爱置顶软件,吊打电脑自带功能!
  • 关于inode,dentry结合软链接及硬链接的实验
  • AiCube 试用 - 创建流水灯工程
  • 运维之 Centos7 防火墙(CentOS 7 Firewall for Operations and Maintenance)
  • J1 ResNet-50算法实战与解析
  • 搜广推校招面经六十六
  • 运筹帷幄:制胜软件开发
  • 【Pandas】pandas DataFrame select_dtypes
  • 4.1-泛型编程深入指南
  • Ubuntu换Windows磁盘格式化指南
  • 使用Deployment运行无状态应用
  • 部署大模型实战:如何巧妙权衡效果、成本与延迟?
  • Apache httpclient okhttp
  • Git与SVN的区别以及各自的优势
  • Linux基础指令(一)