PyTorch使用(7)-张量常见运算函数
1. 基本数学运算
1.1 平方根和幂运算
import torchx = torch.tensor([4.0, 9.0, 16.0])# 平方根
sqrt_x = torch.sqrt(x) # tensor([2., 3., 4.])# 平方
square_x = torch.square(x) # tensor([16., 81., 256.])# 任意幂次
pow_x = torch.pow(x, 3) # tensor([64., 729., 4096.])# 运算符形式
sqrt_x_alt = x ** 0.5
square_x_alt = x ** 2
1.2 指数和对数
# 自然指数
exp_x = torch.exp(x) # tensor([5.4595e+01, 8.1031e+03, 8.8861e+06])# 自然对数
log_x = torch.log(x) # tensor([1.3863, 2.1972, 2.7726])# 以10为底的对数
log10_x = torch.log10(x) # tensor([0.6021, 0.9542, 1.2041])# 带clip的最小值保护(避免log(0))
safe_log = torch.log(x + 1e-8)
2. 统计运算
2.1 求和与均值
x = torch.randn(3, 4) # 3x4随机张量# 全局求和
total = torch.sum(x) # 标量# 沿特定维度求和
sum_dim0 = torch.sum(x, dim=0) # 形状(4,),沿行求和
sum_dim1 = torch.sum(x, dim=1) # 形状(3,),沿列求和# 均值计算
mean_val = torch.mean(x) # 全局均值
mean_dim0 = torch.mean(x, dim=0) # 沿行求均值
2.2 极值与排序
# 最大值/最小值
max_val = torch.max(x) # 全局最大值
min_val = torch.min(x) # 全局最小值# 沿维度的极值及索引
max_vals, max_indices = torch.max(x, dim=1) # 每行最大值及位置
min_vals, min_indices = torch.min(x, dim=0) # 每列最小值及位置# 排序
sorted_vals, sorted_indices = torch.sort(x, dim=1, descending=True)
2.3 方差与标准差
# 无偏方差(分母n-1)
var_x = torch.var(x, unbiased=True) # 全局方差
var_dim0 = torch.var(x, dim=0) # 沿行方差# 标准差
std_x = torch.std(x) # 全局标准差
std_dim1 = torch.std(x, dim=1) # 沿列标准差
3. 矩阵运算
3.1 基本矩阵运算
A = torch.randn(3, 4)
B = torch.randn(4, 5)# 矩阵乘法
matmul = torch.matmul(A, B) # 形状(3,5)
matmul_alt = A @ B # 等价写法# 点积(向量)
v1 = torch.randn(3)
v2 = torch.randn(3)
dot_product = torch.dot(v1, v2)# 批量矩阵乘法
batch_A = torch.randn(5, 3, 4) # 5个3x4矩阵
batch_B = torch.randn(5, 4, 5) # 5个4x5矩阵
batch_matmul = torch.bmm(batch_A, batch_B) # 形状(5,3,5)
3.2 矩阵分解
# 特征分解(对称矩阵)
sym_matrix = torch.randn(3, 3)
sym_matrix = sym_matrix @ sym_matrix.T # 构造对称矩阵
eigenvals, eigenvecs = torch.linalg.eigh(sym_matrix)# SVD分解
U, S, V = torch.linalg.svd(A)
4. 比较运算
4.1 元素级比较
a = torch.tensor([1, 2, 3])
b = torch.tensor([3, 2, 1])# 比较运算
eq = torch.eq(a, b) # tensor([False, True, False])
gt = torch.gt(a, b) # tensor([False, False, True])
lt = torch.lt(a, b) # tensor([True, False, False])# 运算符形式
eq_alt = a == b
gt_alt = a > b
4.2 约简比较
# 判断所有元素为True
all_true = torch.all(eq)# 判断任一元素为True
any_true = torch.any(gt)# 判断张量相等(形状和值)
torch.equal(a, b) # False
5. 规约运算
5.1 常用规约
x = torch.randn(2, 3)# 求和规约
sum_all = x.sum() # 全局求和
sum_dim = x.sum(dim=1) # 沿维度规约# 累积和
cumsum = x.cumsum(dim=0) # 沿维度累积# 乘积规约
prod_all = x.prod() # 全局乘积
5.2 高级规约
# 加权平均
weights = torch.softmax(torch.randn(3), dim=0)
weighted_mean = torch.sum(x * weights, dim=1)# 沿维度的logsumexp(数值稳定)
logsumexp = torch.logsumexp(x, dim=1)
6. 工程实践建议
6.1. 广播机制理解:确保运算张量的形状兼容
# 广播示例
a = torch.randn(3, 1)
b = torch.randn(1, 3)
c = a + b # 形状(3,3)
6.2. 原地操作:使用_后缀节省内存
x.sqrt_() # 原地平方根
x.add_(1) # 原地加1
6.3. 设备一致性:确保运算张量在同一设备
if torch.cuda.is_available():x = x.cuda()y = y.cuda()z = x + y
6.4. 梯度保留:注意运算对计算图的影响
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
y.backward() # dy/dx = 2x = 4.0
6.5. 数值稳定性:使用稳定实现
# 不稳定的softmax实现
unstable = torch.exp(x) / torch.exp(x).sum(dim=1, keepdim=True)# 稳定的softmax实现
stable = torch.softmax(x, dim=1)
7. 性能优化技巧
7.1 向量化操作:避免Python循环
# 不好的做法
result = torch.zeros_like(x)
for i in range(x.size(0)):result[i] = x[i] * 2# 好的做法
result = x * 2
7.2. 融合操作:减少中间结果
# 低效
temp = x + y
result = temp * z# 高效
result = (x + y) * z
7.3. 使用内置函数:利用优化实现
# 自定义实现
custom_norm = torch.sqrt(torch.sum(x ** 2))# 内置优化函数
optimized_norm = torch.norm(x)
原文地址:https://blog.csdn.net/jiaomongjun/article/details/146927433
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mrgr.cn/news/96973.html 如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mrgr.cn/news/96973.html 如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!