当前位置: 首页 > news >正文

【NLP6、损失函数 ① 均方差损失函数】

努力一点,所有的好运都会在我身上降临

                                                        —— 24.12.4

均方差损失函数

1. 定义

均方差损失函数(Mean Squared Error Loss Function),也称为L2损失函数,用于衡量预测值与真实值之间的差异。在有监督学习的回归问题中应用广泛。

对于一个包含 n 个样本的数据集,设预测值为 y_pre,真实值为 y_true

均方差损失函数计算公式为:

首先,对于数据集中的每一个样本,计算预测值与真实值的差,

然后,将这个差值进行平方操作。这样做的目的是为了让差值的正负不相互抵消(因为无论差值是正还是负,平方后都是正的),并且放大较大的差值(因为差值越大,平方后变得更大)

接着,把所有样本的差值平方结果相加起来

最后,将这个总和除以样本的数量,得到的结果就是均方差损失函数的值。

这个值代表了模型预测值与真实值之间平均的差异程度,损失函数的值越小,表示模型的预测效果越好。


2. 作用原理

衡量误差:

它通过计算预测值和真实值差的平方的平均值来量化预测的误差。平方操作确保了损失值始终为非负,并且对较大的误差给予更大的惩罚。例如,当预测值与真实值相差较大时,(y_pre-y_true) ^ 2 的值会很大,损失函数的值也就会变大

优化目标:

在模型训练过程中,模型的目标是最小化这个损失函数。例如,在神经网络的训练中,通过反向传播算法,根据损失函数对模型参数的梯度来调整参数,使得模型的预测结果越来越接近真实值,从而降低损失函数的值。


3. 应用场景

线性回归:

在简单的线性回归模型 y = ax + b + ξ 中(是真实值,是自变量,和 b 是模型参数,ξ 是误差项),使用均方差损失函数来评估模型预测的直线与真实数据点之间的差异。通过最小化均方差损失,找到最佳的 a 和 b,使得直线能够最好地拟合数据。

神经网络回归任务:

在更复杂的深度神经网络用于回归问题时,例如预测股票价格、气温变化等连续数值,均方差损失函数是常用的损失函数。在训练过程中,每次迭代都计算均方差损失,然后根据损失对网络各层的权重进行更新,逐步改进模型的预测能力


4. 代码示例

假设我们有一个简单的线性回归模型  y = mx + b,并且有一组训练数据 (x_i, y_i)

np.mean() 是 NumPy 库中的一个函数,用于计算数组(或沿指定轴的元素)的算术平均值

import numpy as np# 模拟生成训练数据
x = np.array([1, 2, 3, 4, 5])
y = np.array([2, 4, 6, 8, 10])# 初始化模型参数
m = 1.0
b = 0.0# 预测函数
def predict(x):return m * x + b# 计算均方差损失函数
def mse_loss(y_pred, y_true):return np.mean((y_pred - y_true) ** 2)# 进行预测
y_pred = predict(x)
loss = mse_loss(y_pred, y_true=y)
print("均方差损失:", loss)

首先定义了预测函数 predict 均方差损失函数 mse_loss。然后,使用给定的模型参数 m 和 b 对训练数据 x 进行预测,得到预测值 y_pred

最后,通过均方差损失函数计算预测值和真实值之间的损失,并打印出来。然后通过优化算法(如梯度下降)来调整 m 和 b,以降低损失函数的值。


http://www.mrgr.cn/news/79196.html

相关文章:

  • 创建线程、socket通信、recv非阻塞
  • centos8 安装docker换源
  • 泷羽Sec-Burp Suite自动刷漏洞-解放双手
  • 高并发数据采集场景下Nginx代理Netty服务的优化配置
  • 第三部分:进阶概念 9.错误处理 --[JavaScript 新手村:开启编程之旅的第一步]
  • 手机租赁系统开发指南一站式服务流程解析
  • Android 使用TabLayout + ViewPager2 实现标签页的视图切换
  • 【Android】EventBus的使用及源码分析
  • 技术栈6:Docker入门 Linux入门指令
  • 【5G】5G技术组件 5G Technology Components
  • 【C++】入门【六】
  • 数字IC前端学习笔记:脉动阵列的设计方法学(以串行FIR滤波器为例)
  • 优傲协作机器人 Remote TCP Toolpath URCap(操作记录)
  • L17.【LeetCode笔记】另一棵树的子树
  • 【OpenDRIVE_Python】使用python脚本输出OpenDRIVE数据中含有隧道tunnel的道路ID和隧道信息
  • SCP命令实现Linux中的文件传输
  • Qt Quick 开发基础 + 实战(持续更新中…)
  • Vue3 Ts 如何获取组件的类型
  • 【OpenDRIVE_Python】使用python脚本输出OpenDRIVE数据中含有桥梁bridge的道路ID和桥梁信息
  • cgo内存泄漏排查
  • 微信小程序版小米商城的搭建流程详解!
  • Springboot 2.x 升级到Springboot 2.7.x问题汇总
  • mysql集群NDB方式部署
  • 基于python爬虫的智慧人才数据分析系统
  • string类函数的手动实现
  • mysql中的skip_name_resolve详解