当前位置: 首页 > news >正文

线性可分支持向量机代码实现

### 实现线性可分支持向量机
### 硬间隔最大化策略
class Hard_Margin_SVM:### 线性可分支持向量机拟合方法def fit(self, X, y):# 训练样本数和特征数m, n = X.shape# 初始化二次规划相关变量:P/q/G/hself.P = matrix(np.identity(n + 1, dtype=np.float))self.q = matrix(np.zeros((n + 1,), dtype=np.float))self.G = matrix(np.zeros((m, n + 1), dtype=np.float))self.h = -matrix(np.ones((m,), dtype=np.float))# 将数据转为变量self.P[0, 0] = 0for i in range(m):self.G[i, 0] = -y[i]self.G[i, 1:] = -X[i, :] * y[i]# 构建二次规划求解sol = solvers.qp(self.P, self.q, self.G, self.h)# 对权重和偏置寻优self.w = np.zeros(n,) self.b = sol['x'][0] for i in range(1, n + 1):self.w[i - 1] = sol['x'][i]return self.w, self.b### 定义模型预测函数def predict(self, X):return np.sign(np.dot(self.w, X.T) + self.b)

线性可分支持向量机的硬间隔最大化策略

该代码实现了线性可分支持向量机(SVM) 的硬间隔最大化策略。支持向量机是用于二分类问题的监督学习算法,而硬间隔策略意味着数据集是线性可分的,并且我们尝试通过最大化分类间隔来找到最优的决策边界。该实现依赖 cvxopt 库来求解一个二次规划问题。

以下是对代码的详细解释:

1. 类的定义

class Hard_Margin_SVM:

定义了一个名为 Hard_Margin_SVM 的类,用于实现硬间隔支持向量机。这个类有两个主要的方法:

  • fit():训练模型的方法。
  • predict():根据训练好的模型进行预测。

2. fit() 方法

def fit(self, X, y):

fit() 方法用于训练支持向量机模型,即根据给定的训练数据 X X X 和标签 y y y,通过二次规划求解最优的权重 w w w 和偏置 b b b,构建出最大化间隔的分类超平面。

(a) 训练样本数和特征数
m, n = X.shape

m 是训练样本的数量,n 是特征的维数。

(b) 初始化二次规划的参数矩阵
self.P = matrix(np.identity(n + 1, dtype=np.float))
self.q = matrix(np.zeros((n + 1,), dtype=np.float))
self.G = matrix(np.zeros((m, n + 1), dtype=np.float))
self.h = -matrix(np.ones((m,), dtype=np.float))
  • P P P:定义目标函数中的二次项。为了计算 1 2 w T w \frac{1}{2} w^T w 21wTwP 被初始化为一个单位矩阵,其中额外的维度是为偏置项 b b b 保留的。
  • q q q:定义目标函数中的线性项。在硬间隔 SVM 中,线性项为 0,所以初始化为零向量。
  • G G G h h h:定义约束条件 G x ≤ h Gx \leq h GxhG 用于约束支持向量的位置,h 是用来实现 y i ( w ⋅ x i + b ) ≥ 1 y_i(w \cdot x_i + b) \geq 1 yi(wxi+b)1 的不等式条件,确保所有点都被正确分类且满足硬间隔条件。
(c) 设置 P 矩阵和 G 矩阵
self.P[0, 0] = 0
for i in range(m):self.G[i, 0] = -y[i]self.G[i, 1:] = -X[i, :] * y[i]
  • self.P[0, 0] = 0:确保 P 的第一项为 0,因为我们不需要对偏置项 b b b 做二次惩罚。
  • self.G:构建了不等式约束矩阵 G G G,用于确保 y i ( w ⋅ x i + b ) ≥ 1 y_i (w \cdot x_i + b) \geq 1 yi(wxi+b)1self.G[i, 0] 对应偏置项 b b bself.G[i, 1:] 对应权重 w w w
(d) 使用 cvxopt.solvers.qp() 进行二次规划求解
sol = solvers.qp(self.P, self.q, self.G, self.h)

cvxopt.solvers.qp()cvxopt 中用于求解二次规划的函数。它使用矩阵 P P P q q q G G G h h h 来构建二次规划问题,并返回最优解 sol。该最优解包含了权重 w w w 和偏置 b b b

(e) 提取权重 w w w 和偏置 b b b
self.w = np.zeros(n,) 
self.b = sol['x'][0] 
for i in range(1, n + 1):self.w[i - 1] = sol['x'][i]
  • self.b:是从求解器中提取的偏置项 b b b,它是解向量 sol['x'] 的第一个元素。
  • self.w:是从解向量中提取的权重项 w w w,并且赋值给类的属性 self.w

3. predict() 方法

def predict(self, X):return np.sign(np.dot(self.w, X.T) + self.b)

predict() 方法用于对新的数据 X X X 进行预测:

  • np.dot(self.w, X.T) \text{np.dot(self.w, X.T)} np.dot(self.w, X.T):计算数据点与权重向量 w w w 的点积。
  • np.sign() \text{np.sign()} np.sign():通过决策函数的符号来决定分类结果。如果结果为正,则归为正类;否则为负类。

4. 示例使用

以下是如何使用该类来训练和预测的示例:

import numpy as np
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt# 生成数据集
X, y = make_blobs(n_samples=100, centers=2, random_state=42)
y = 2 * (y - 0.5)  # 转换为 -1 和 1# 创建 SVM 实例
svm = Hard_Margin_SVM()# 训练模型
svm.fit(X, y)# 对数据集进行预测
y_pred = svm.predict(X)# 可视化分类结果
plt.scatter(X[:, 0], X[:, 1], c=y_pred, cmap='coolwarm')
plt.show()

5. 结论

  • Hard_Margin_SVM 类实现了线性可分支持向量机的硬间隔最大化策略。通过 cvxopt 库中的二次规划求解器,我们能够找到最优的超平面,使得正类和负类样本被分隔开,并且分类间隔最大。

另一篇文章是对这段代码的举例说明:线性可分支持向量机代码 举例说明 具体的变量数值变化


http://www.mrgr.cn/news/58003.html

相关文章:

  • 论文《Text2SQL is Not Enough: Unifying AI and Databases with TAG》
  • Matlab中HybridFcn参数的用法
  • 基于大型语言模型的智能网页抓取
  • 2-解决联想拯救者Y7000p在ubuntu20.04未找到wifi适配器,安装rtl8852ce网卡驱动问题
  • 面试知识梳理
  • COSCon'24 志愿者招募令:共创开源新生活!
  • Python 代码的主要功能是从给定的日志文件和设备列表中提取特定设备(华为和华三)的用户账号信息
  • Java 开发——(下篇)从零开始搭建后端基础项目 Spring Boot 3 + MybatisPlus
  • AI基础:传教士与野人
  • Python如何处理zip压缩文件(Python处理zip压缩文件接口源码)
  • SLAM:未来智能科技的核心——探索多传感器融合的无限可
  • [蓝桥杯 2024 省 C] 回文数组
  • LeetCode199. 二叉树的右视图(2024秋季每日一题 47)
  • Linux 权限的理解
  • 前端发送请求格式
  • 1024——视触觉传感器GelSight的前世今生
  • 系统移植相关概念总结
  • 力扣周赛第420场 中等 3325.字符至少出现k次的子字符串 I
  • C语言程序设计:现代设计方法习题笔记《chapter4》
  • java的maven打包插件来了,package一键打包exe、dmg、rpm等
  • JAVA应用测试,线上故障排查分析全套路!
  • C++,STL 045(24.10.24)
  • 【Linux】进程状态及其转换
  • Github_以太网开源项目verilog-ethernet代码阅读与移植(八)——移植工程分享
  • 头歌——人工智能(遗传算法)
  • 获取图像的风格矩阵