当前位置: 首页 > news >正文

【机器学习】多项式回归

多项式回归是回归分析的一种扩展形式,通过增加多项式特征,可以模拟输入特征与输出之间的非线性关系。与线性回归不同,线性回归仅适用于直线拟合,而多项式回归则可以用曲线拟合复杂数据。本教程将系统讲解多项式回归模型,包括模型的推导、数据转换、模型训练、损失函数的定义和梯度下降,并使用numpy实现。最后,我们会通过sklearn实现多项式回归模型。

多项式回归模型简介

多项式回归模型的核心在于扩展输入特征。对于一个特征 (x),我们可以将它转化为多项式特征,例如将输入特征 (x) 转换为二次特征(或更高次),模型表达式如下:

y = w 0 + w 1 x + w 2 x 2 + ⋯ + w n x n + b y = w_0 + w_1 x + w_2 x^2 + \cdots + w_n x^n + b y=w0+w1x+w2x2++wnxn+b
其中:

  • ( y ) 是预测值,
  • ( x ) 是输入特征,
  • ( w_0, w_1, \ldots, w_n ) 是模型的权重,
  • ( b ) 是偏置项。

这个模型是输入特征 (x) 的 n次多项式。多项式回归通过增加特征的次幂,允许模型更好地拟合非线性数据。

数据转换:构建多项式特征

在构建多项式回归模型前,需要将输入特征转换为多项式特征。例如,给定一个特征 ( x ),我们将它扩展为多项式特征 ( [1, x, x^2, \ldots, x^n] )。

我们定义一个函数 poly_features 来生成多项式特征。

import numpy as npdef poly_features(X, degree):"""将输入特征 X 扩展为多项式特征矩阵。X : 原始特征 (n_samples, 1)degree : 多项式的最高次数"""X_poly = np.ones((X.shape[0], degree + 1))  # 初始化为 1for i in range(1, degree + 1):X_poly[:, i] = X[:, 0] ** ireturn X_poly

损失函数:均方误差 (Mean Squared Error, MSE)

为了评估模型预测值与真实值的误差,我们使用均方误差作为损失函数:
M S E = 1 n ∑ i = 1 n ( y i − y i ^ ) 2 MSE = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y_i})^2 MSE=n1i=1n(yiyi^)2
其中:

  • ( y_i ) 是第 ( i ) 个样本的真实值,
  • ( \hat{y_i} ) 是模型的预测值,
  • ( n ) 是样本总数。

模型训练:梯度下降优化

多项式回归模型的优化可以通过梯度下降实现。以下是计算梯度的实现:

def compute_gradients(X, y, w, b):n = len(y)y_pred = X.dot(w) + bdw = (2/n) * X.T.dot(y_pred - y)db = (2/n) * np.sum(y_pred - y)return dw, db

使用梯度下降训练模型

接下来,我们使用梯度下降来优化模型参数:

def gradient_descent(X, y, w, b, learning_rate, iterations):for i in range(iterations):dw, db = compute_gradients(X, y, w, b)w -= learning_rate * dwb -= learning_rate * dbif i % 100 == 0:y_pred = X.dot(w) + bloss = mse_loss(y, y_pred)print(f"Iteration {i}: Loss = {loss}")return w, b

代码实现:多项式回归模型

我们使用numpy从头实现一个多项式回归模型。

数据准备

我们生成一个非线性数据集,用来训练多项式回归模型。

import matplotlib.pyplot as plt# 生成非线性数据
np.random.seed(42)
X = 6 * np.random.rand(100, 1) - 3  # X 范围在 [-3, 3]
y = 0.5 * X**2 + X + 2 + np.random.randn(100, 1)  # y = 0.5x^2 + x + 2 + 噪声# 可视化数据
plt.scatter(X, y)
plt.xlabel("X")
plt.ylabel("y")
plt.title("Generated Non-linear Data")
plt.show()

转换多项式特征

假设我们想训练一个二次多项式回归模型。

# 转换成二次多项式特征
degree = 2
X_poly = poly_features(X, degree)

初始化模型参数并定义损失函数

我们初始化权重和偏置,并定义损失函数:

# 初始化参数
w = np.random.randn(degree + 1, 1)
b = np.random.randn(1)# 定义均方误差损失函数
def mse_loss(y_true, y_pred):return np.mean((y_true - y_pred) ** 2)

训练模型

设置学习率和迭代次数,进行模型训练。

learning_rate = 0.01
iterations = 1000
w_trained, b_trained = gradient_descent(X_poly, y, w, b, learning_rate, iterations)
print(f"Trained weights: {w_trained}, Trained bias: {b_trained}")

可视化拟合曲线

我们将模型拟合的曲线与原始数据进行对比。

# 生成预测值
X_fit = np.linspace(-3, 3, 100).reshape(100, 1)  # 用于绘制拟合曲线
X_fit_poly = poly_features(X_fit, degree)
y_fit = X_fit_poly.dot(w_trained) + b_trained# 绘制结果
plt.scatter(X, y, label="Original Data")
plt.plot(X_fit, y_fit, color='red', label="Polynomial Fit")
plt.xlabel("X")
plt.ylabel("y")
plt.title("Polynomial Regression Fit")
plt.legend()
plt.show()

使用sklearn实现多项式回归

最后,我们使用sklearn快速实现多项式回归模型。

from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression# 生成多项式特征
poly = PolynomialFeatures(degree=2)
X_poly_sklearn = poly.fit_transform(X)# 训练模型
lin_reg = LinearRegression()
lin_reg.fit(X_poly_sklearn, y)# 可视化拟合曲线
y_sklearn_fit = lin_reg.predict(poly.fit_transform(X_fit))
plt.scatter(X, y, label="Original Data")
plt.plot(X_fit, y_sklearn_fit, color='red', label="Sklearn Polynomial Fit")
plt.xlabel("X")
plt.ylabel("y")
plt.title("Polynomial Regression with Sklearn")
plt.legend()
plt.show()

总结

本文通过逐步推演实现了多项式回归模型,深入理解了多项式特征转换、损失函数和梯度下降优化过程。最后,我们通过sklearn验证了模型的实现并进行了可视化展示,希望这篇教程帮助你掌握多项式回归的基本原理与实现。


http://www.mrgr.cn/news/58311.html

相关文章:

  • Do not use built-in or reserved HTML elements as component id: map
  • 无人机之低空管控技术
  • React框架详解
  • opencv学习:基于计算机视觉的表情识别系统
  • 5550 取数(max)
  • [论文阅读]Constrained Decision Transformer for Offline Safe Reinforcement Learning
  • GC.2022.六年级.05.数三角形
  • odoo17的分包重新供应路线如何设置?可从销售订单中实时直接触发采购订单或相关单据
  • apache poi导出excel
  • 单片机入门教程
  • 15分钟学 Go 第 20 天:Go的错误处理
  • 【数据结构和算法】二、python中的常用数据结构
  • AI大模型应用(3)开源框架Vanna: 利用RAG方法做Text2SQL任务
  • 写出Windows操作系统内核的程序员,70多岁,还去办公室敲代码
  • openpnp - 解决“底部相机高级校正成功后, 开机归零时,吸嘴自动校验失败的问题“
  • NVR录像机汇聚管理EasyNVR多品牌NVR管理工具/设备视频报警功能详解
  • Chromium127调试指南 Windows篇 - 安装VS Code扩展(四)
  • 数据结构:堆的应用
  • Javascript数据结构——哈希表
  • 揭秘:登录注册表单背后的动画奥秘
  • 一个vue3的待办列表组件
  • Windows AD 域的深度解析 第一篇:AD 域原理与多系统联动
  • GPU 服务器厂家:谁将引领科技未来的强大动力?
  • LLM - CV 图像实例分割开源算法 SAM2(Segment Anything 2) 配置与推理 教程 (1)
  • 力扣之612.平面上的最近距离
  • softmax回归从零实现