当前位置: 首页 > news >正文

【机器学习】K近邻算法

目录

算法引入:

KNN算法的核心思想

KNN算法的步骤

KNN常用的距离度量方法

KNN算法的优缺点

优点:

缺点:

K值的选择

KNN的C++实现

复杂度分析:


K近邻算法(K-Nearest Neighbors, KNN)是一种简单但非常实用的监督学习算法,主要用于分类和回归问题。KNN 基于相似性度量(如欧几里得距离)来进行预测,核心思想是给定一个样本,找到与其最接近的 K 个邻居,根据这些邻居的类别或特征对该样本进行分类或预测。

算法引入:

我们假设平面上有两类点的集合,一类属于A类,一类属于B类,A类有三个点,B类有三个点。如果我们要加入一个橙色的点,那么它是属于A类还是B类? 

我们不去考虑算法的话,如果去归类A类还是B类,那么我们肯定想到的就是这个点它离哪个点最近,就属于哪一类。到这里就与我们的K近邻算法大概相同了,只不过我们要选取一个范围,在这个范围里找点进行判断,比如,我们选择它的邻域三个点,即K=3,那么这个区域里面有一个点属于A类,两个点属于B类,根据少数服从多数,我们就可以把它归于B类点。

K近邻算法有三个要素,如下图所示,第一个就是距离度量,这个距离有很多种距离,比如:欧几里得距离、曼哈顿距离、闵可夫斯基距离等。上面的例子中我们选择的欧几里得距离。第二个是K值,就是这个一个范围的的点个数,上面例子中,K=3。第三个是少数服从多数规则,上面例子中区域里面有一个点属于A类,两个点属于B类,根据少数服从多数,我们就可以把它归于B类点。


KNN算法的核心思想

1. 分类问题:
   - 给定一个未标记的数据点,通过计算该数据点与已标记的训练数据集中的每一个数据点的距离,选择距离最近的 K 个邻居。
   - 根据这 K 个邻居的类别,采用“多数投票”的方式来决定未标记数据点的类别。

2. 回归问题:
   - 计算未标记数据点的 K 个最近邻居的值,然后取这些邻居的平均值或加权平均值作为该点的预测值。

KNN算法的步骤

1. 数据准备:准备好训练数据集,包括特征和标签(分类问题中为类别,回归问题中为数值)。
2. 选择K值:选择邻居的数量K,一般是正整数。
3. 计算距离:对每个未标记数据点,计算它与训练集中每一个数据点的距离(常见的距离度量方法有欧几里得距离、曼哈顿距离等)。
4. 选择K个最近邻居:根据距离从小到大排序,选择距离最近的K个邻居。
5. 投票或平均:分类问题中,根据K个邻居的类别进行投票选择类别;回归问题中,计算邻居的平均值作为预测结果。


KNN常用的距离度量方法

1. 欧几里得距离:

   欧几里得距离是最常用的距离度量方法,适用于连续变量的情况。

2. 曼哈顿距离:
   
   曼哈顿距离适用于某些特定场景,尤其是当特征值的变化范围不均匀时。

3. 闵可夫斯基距离:
  
   闵可夫斯基距离是欧几里得距离和曼哈顿距离的推广形式,其中 p 是一个参数,当 p=2 时,便是欧几里得距离,当 p=1 时便是曼哈顿距离。


KNN算法的优缺点

优点:

1. 简单易理解:KNN算法实现简单,易于理解和解释。
2. 无参数模型:KNN不需要训练过程,可以直接使用数据进行预测。
3. 适用性广泛:KNN可以用于分类和回归问题,并且对非线性数据有较好的适应性。

缺点:

1. 计算复杂度高:KNN算法需要对每一个测试样本都计算与所有训练样本的距离,因此在大数据集下计算开销较大。
2. 内存开销大:KNN需要存储整个训练数据集,占用较大的存储空间。
3. 对噪声敏感:KNN对噪声数据较为敏感,特别是在K值较小的情况下,少量噪声数据可能会对结果产生很大影响。


K值的选择

- 小K值:K值较小时,模型会更加复杂,可能会过拟合。即使有少量噪声数据,也会对分类结果产生较大的影响。
- 大K值:K值较大时,模型会更加平滑,可能会欠拟合。K值过大会忽略数据的局部结构。

通常,K值通过交叉验证等方法来选择合适的值。


KNN的C++实现

下面是一个简单的KNN算法的C++实现,用于分类问题,采用欧几里得距离来计算邻居之间的距离。

#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
using namespace std;// 定义一个点,包含特征和标签
struct Point {vector<double> features;int label;
};// 计算欧几里得距离
double euclideanDistance(const vector<double>& a, const vector<double>& b) {double distance = 0.0;for (int i = 0; i < a.size(); i++) {distance += pow(a[i] - b[i], 2);}return sqrt(distance);
}// KNN算法实现
int knn(const vector<Point>& train_data, const vector<double>& test_point, int k) {vector<pair<double, int>> distances; // 距离和标签的对// 计算每个训练数据点到测试点的距离for (const auto& point : train_data) {double distance = euclideanDistance(point.features, test_point);distances.push_back({distance, point.label});}// 按距离排序sort(distances.begin(), distances.end());// 统计前k个最近邻的类别vector<int> label_count(100, 0); // 假设标签在0-99之间for (int i = 0; i < k; i++) {label_count[distances[i].second]++;}// 返回出现次数最多的类别int max_count = 0;int predicted_label = -1;for (int i = 0; i < label_count.size(); i++) {if (label_count[i] > max_count) {max_count = label_count[i];predicted_label = i;}}return predicted_label;
}int main() {int n, m, k;cin >> n;//训练数据的个数cin >> m;//测试数据的维度cin >> k;// 输入训练数据vector<Point> train_data(n);for (int i = 0; i < n; i++) {train_data[i].features.resize(m);for (int j = 0; j < m; j++) {cin >> train_data[i].features[j];}cin >> train_data[i].label;}// 输入测试点vector<double> test_point(m);for (int j = 0; j < m; j++) {cin >> test_point[j];}// 使用KNN进行分类int predicted_label = knn(train_data, test_point, k);cout << predicted_label << endl;return 0;
}

复杂度分析:

- 时间复杂度:对于每个测试点,KNN需要计算与所有训练点的距离,因此时间复杂度为 O(n * m),其中 n 是训练集大小,m 是特征维度。
- 空间复杂度:主要用于存储训练数据和距离结果,空间复杂度为 O(n)。


K近邻算法是一个简单直观的非参数分类算法,适用于低维、小数据集的情况。然而,由于它的计算复杂性较高,KNN在大数据集或高维数据上的表现不佳。因此,KNN算法通常被用作基准模型或在小规模数据集上使用。


http://www.mrgr.cn/news/71224.html

相关文章:

  • 【云计算】OpenStack云计算平台
  • LabVIEW与WPS文件格式的兼容性
  • 实力认证 | 海云安入选《信创安全产品及服务购买决策参考》
  • qt QPainter setViewport setWindow viewport window
  • 【STM32-学习笔记-7-】USART串口通信
  • 如何将 sqlserver 数据迁移到 mysql
  • 7天用Go从零实现分布式缓存GeeCache(学习)(3)
  • CTF-RE 从0到N: windows反调试-获取Process Environment Block(PEB)信息来检测调试
  • Go开发指南-Gin与Web开发
  • android studio 配置过程
  • 开启鸿蒙开发之旅:核心组件及其各项属性介绍——布局容器组件2
  • Mysql高可用架构方案
  • 如何在60分钟内进行ASO竞争对手分析(App Store 和 Google Play Store)
  • seatunnel常用集群操作命令
  • 鸿蒙系统(HarmonyOS)与OpenHarmony
  • notepad++下载安装教程
  • 全球碳循环数据集(2000-2023)包括总初级生产力、生态系统净碳交换和生态系统呼吸变量
  • upload-labs通关练习---更新到15关
  • 【Conda】Windows下conda的安装并在终端运行
  • 第三百二十节 Java线程教程 - Java线程中断、Java Volatile变量
  • 3349、检测相邻递增子数组 Ⅰ
  • golang如何实现sse
  • 一文熟悉redis安装和字符串基本操作
  • 37 string类关键函数的模拟实现
  • 【网络安全渗透测试零基础入门】之Vulnhub靶场PWNOS: 2.0 多种渗透方法,收藏这一篇就够了!
  • FAS在数据库环境中应用详解