当前位置: 首页 > news >正文

distances = np.linalg.norm(data[:, None] - centers, axis=2)

问题:

如果data的形状是 n×d(样本数 × 特征维度),centers的形状是 k×d(聚类中心数×特征维度)

data[:,None]扩充1维 形状变成   n×1×d

那为什么做减法 形状变成了 n×k×d?

distances = np.linalg.norm(data[:, None] - centers, axis=2),虽然 centers 没有显式地扩展维度,但通过 data[:, None] 的操作,数据的维度已经被调整,使得广播机制可以正常工作。

广播

NumPy 的广播机制允许不同形状的数组在一起进行算术运算。广播的规则如下:

  1. (维度数不同)如果数组的维度数不相同,则将维度较小的数组的形状前面补 1,直到两个数组的维度数相同。
  2. (固定维度 比长度 长度有1)如果两个数组在某个维度上的长度不相同,但其中一个数组在该维度上的长度为 1,则可以进行广播。
  3. (维度相同,长度不同 没有1)如果两个数组在任何维度上的长度都不相同,并且其中一个数组在该维度上的长度不为 1,则无法进行广播。

实例学习法:

假设我们有一个二维数组 data,其形状为 (n_samples, feature_dim),以及一个二维数组 centers,其形状为 (n_clusters, feature_dim)

import numpy as npdata = np.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])centers = np.array([[2, 3, 4],[8, 9, 10]])print("Data shape:", data.shape)
print("Centers shape:", centers.shape)
Data shape: (3, 3)
Centers shape: (2, 3)

通过 data[:, None] 操作,我们在第二个维度位置插入一个新的维度,使得 data 的形状变为 (n_samples, 1, feature_dim)

data_expanded = data[:, None]
print("Data expanded shape:", data_expanded.shape)
Data expanded shape: (3, 1, 3)

此时,data_expanded 的形状为 (3, 1, 3),而 centers 的形状为 (2, 3)。根据广播机制的规则:

  1. data_expanded 的形状为 (3, 1, 3)
  2. centers 的形状为 (2, 3)

为了使两个数组的形状相同,NumPy 会将 centers 的形状扩展为 (1, 2, 3),这样两个数组的形状就变为:

  • data_expanded 的形状为 (3, 1, 3)
  • centers 的形状为 (1, 2, 3)

然后,NumPy 会将这两个数组广播为相同的形状 (3, 2, 3),从而可以进行减法操作。

distances = np.linalg.norm(data[:, None] - centers, axis=2)
print("Distances shape:", distances.shape)
print("Distances:\n", distances)
Distances shape: (3, 2)
Distances:[[ 1.73205081 10.39230485][ 5.19615242  5.19615242][10.39230485  1.73205081]]

解释

  • data[:, None] 的形状为 (3, 1, 3)
  • centers 的形状为 (2, 3),通过广播机制扩展为 (1, 2, 3)
  • 通过广播机制,data[:, None] - centers 的结果形状为 (3, 2, 3)
  • np.linalg.norm 函数计算沿着 axis=2(即特征维度)的范数,得到的 distances 的形状是 (3, 2),表示每个样本到每个中心点的距离。

通过这种方式,我们可以方便地计算每个样本与每个聚类中心之间的距离,而不需要显式地使用循环。


http://www.mrgr.cn/news/69919.html

相关文章:

  • 缓存冲突(Cache Conflict)
  • 「QT」文件类 之 QDir 目录类
  • 【Excel】数据透视表分析方法大全
  • 2分钟在阿里云ECS控制台部署个人应用(图文示例)
  • Docker 基础命令介绍和常见报错解决
  • VTK知识学习(7)-纹理贴图
  • spring-security(记住密码,CSRF)
  • C#-抽象类、抽象函数
  • 腾讯云双11狂欢:拼团优惠、会员冲榜、限时秒杀,多重好礼等你来拿!
  • 论文解读之SDXL: Improving Latent Diffusion Models forHigh-Resolution Image Synthesis
  • 「iOS」——知乎日报第三周总结
  • 销售管理SCRM助力企业高效提升业绩与客户关系管理
  • 【C++练习】二进制到十进制的转换器
  • The Rank-then-Encipher Approach
  • 「Mac玩转仓颉内测版1」入门篇1 - Cangjie环境的搭建
  • goframe开发一个企业网站 开发环境DOCKER 搭建16
  • MATLAB实现最大最小蚁群算法(Max-Min Ant Colony Optimization, MMAS)
  • leetcode hot100【LeetCode 131.分割回文串】java实现
  • Jquery添加或删除Class属性实例代分享
  • Linux应用项目之量产工具(一)——显示系统
  • SwiftUI开发教程系列 - 第7章:数据流和状态管理
  • 信息安全数学基础(46)域和Galois理论
  • Python实现Delaunay三角剖分之Bowyer-Watson算法
  • 区块链技术在版权保护中的应用
  • Java项目实战II基于Spring Boot的农商对接系统的设计与实现(开发文档+数据库+源码)
  • Iceberg 写入和更新模式,COW,MOR(Copy-on-Write,Merge-on-Read)