当前位置: 首页 > news >正文

《GBDT 算法的原理推导》 11-15更新决策树的叶子节点值 公式解析

本文是将文章《GBDT 算法的原理推导》中的公式单独拿出来做一个详细的解析,便于初学者更好的理解。


公式(11-15)出现在GBDT算法推导的过程中,用于更新决策树的叶子节点值

公式(11-15)如下:

c m j = arg ⁡ min ⁡ c ∑ x i ∈ R m j L ( y i , f m − 1 ( x i ) + c ) c_{mj} = \arg \min_c \sum_{x_i \in R_{mj}} L(y_i, f_{m-1}(x_i) + c) cmj=argcminxiRmjL(yi,fm1(xi)+c)

其中:

  • c m j c_{mj} cmj 表示第 m m m 棵树在叶子节点 R m j R_{mj} Rmj 上的输出值,也就是叶子节点的预测值。
  • R m j R_{mj} Rmj 表示第 m m m 棵树的第 j j j 个叶子节点区域。
  • L ( y i , f ( x i ) ) L(y_i, f(x_i)) L(yi,f(xi)) 是损失函数,衡量样本 x i x_i xi 的真实值 y i y_i yi 和当前模型预测值 f ( x i ) f(x_i) f(xi) 之间的误差。
  • f m − 1 ( x i ) f_{m-1}(x_i) fm1(xi) 是前 m − 1 m-1 m1 轮构建的模型在 x i x_i xi 处的预测值。

1. 公式(11-15)的背景

在GBDT中,每一棵树的任务是对当前模型的误差进行拟合和修正。我们通过新增一棵树 T ( x ; Θ m ) T(x; \Theta_m) T(x;Θm) 来改善当前模型的预测能力。

这棵新树的结构(分裂方式)已经确定,它会将输入样本分配到不同的叶子节点区域 R m j R_{mj} Rmj。接下来,我们需要为每个叶子节点分配一个值,使得这棵树能够最好地拟合该节点区域内的样本误差。

2. 目标是最小化损失

在叶子节点 R m j R_{mj} Rmj 上,我们希望找到一个最优的输出值 c m j c_{mj} cmj,使得它能够最小化该节点区域内所有样本的损失。具体来说,给定前 m − 1 m-1 m1 棵树的预测值 f m − 1 ( x i ) f_{m-1}(x_i) fm1(xi),新树的叶子节点值 c m j c_{mj} cmj 需要满足以下优化目标:

c m j = arg ⁡ min ⁡ c ∑ x i ∈ R m j L ( y i , f m − 1 ( x i ) + c ) c_{mj} = \arg \min_c \sum_{x_i \in R_{mj}} L(y_i, f_{m-1}(x_i) + c) cmj=argcminxiRmjL(yi,fm1(xi)+c)

这意味着,我们为每个叶子节点选择一个最优的常数 c c c,使得该节点区域内所有样本的损失之和最小。

3. 损失函数 L ( y i , f m − 1 ( x i ) + c ) L(y_i, f_{m-1}(x_i) + c) L(yi,fm1(xi)+c)

在GBDT算法中,不同的损失函数 L ( y , f ( x ) ) L(y, f(x)) L(y,f(x)) 会影响叶子节点的最优输出值计算方式。常见的损失函数包括:

  • 平方损失:用于回归任务。
  • 对数损失:用于二分类任务。

不同的损失函数会导致不同的叶子节点值计算方式。下面以平方损失为例来说明如何求解。

例子:平方损失

假设损失函数是平方损失:

L ( y i , f ( x i ) ) = 1 2 ( y i − f ( x i ) ) 2 L(y_i, f(x_i)) = \frac{1}{2} (y_i - f(x_i))^2 L(yi,f(xi))=21(yif(xi))2

代入公式(11-15)中的 L ( y i , f m − 1 ( x i ) + c ) L(y_i, f_{m-1}(x_i) + c) L(yi,fm1(xi)+c)

c m j = arg ⁡ min ⁡ c ∑ x i ∈ R m j 1 2 ( y i − ( f m − 1 ( x i ) + c ) ) 2 c_{mj} = \arg \min_c \sum_{x_i \in R_{mj}} \frac{1}{2} (y_i - (f_{m-1}(x_i) + c))^2 cmj=argcminxiRmj21(yi(fm1(xi)+c))2

我们对 c c c 求导,并让导数等于零,以找到最优的 c c c

∂ ∂ c ∑ x i ∈ R m j 1 2 ( y i − f m − 1 ( x i ) − c ) 2 = 0 \frac{\partial}{\partial c} \sum_{x_i \in R_{mj}} \frac{1}{2} (y_i - f_{m-1}(x_i) - c)^2 = 0 cxiRmj21(yifm1(xi)c)2=0

这等价于:

∑ x i ∈ R m j ( y i − f m − 1 ( x i ) − c ) = 0 \sum_{x_i \in R_{mj}} (y_i - f_{m-1}(x_i) - c) = 0 xiRmj(yifm1(xi)c)=0

解这个方程,可以得到:

c = ∑ x i ∈ R m j ( y i − f m − 1 ( x i ) ) ∣ R m j ∣ c = \frac{\sum_{x_i \in R_{mj}} (y_i - f_{m-1}(x_i))}{|R_{mj}|} c=RmjxiRmj(yifm1(xi))

即:

c m j = ∑ x i ∈ R m j ( y i − f m − 1 ( x i ) ) ∣ R m j ∣ c_{mj} = \frac{\sum_{x_i \in R_{mj}} (y_i - f_{m-1}(x_i))}{|R_{mj}|} cmj=RmjxiRmj(yifm1(xi))

这表明,在平方损失的情况下,叶子节点的输出值 c m j c_{mj} cmj 是该叶子节点区域内所有样本残差( y i − f m − 1 ( x i ) y_i - f_{m-1}(x_i) yifm1(xi))的平均值。

总结

公式(11-15)表示了GBDT算法中如何确定每棵树的叶子节点值。通过最小化叶子节点区域内的损失,可以找到一个最优的输出值 c m j c_{mj} cmj,使得该节点区域的样本预测误差最小化。在不同的损失函数下,最优值 c m j c_{mj} cmj 的计算方式可能有所不同,但原理都是基于最小化损失来确定最佳的叶子节点值。


http://www.mrgr.cn/news/64804.html

相关文章:

  • 微信公众号推送
  • 【算法】Prim最小生成树算法
  • 企业AI助理驱动的决策支持:从数据洞察到战略执行
  • MATLAB绘制水蒸气温度和压力曲线(IAPWS-IF97公式)
  • 【C++类和对象篇】类和对象的六大默认成员函数——构造,析构,拷贝构造,赋值重载,普通对象取地址重载,const对象取地址重载
  • Django视图写法
  • Linux内核编程(十八)ADC驱动
  • 深入解析RSA算法:加密与安全性
  • Spring DispatcherServlet详解
  • 在vue中 什么是slot机制,如何使用以及使用场景详细讲解
  • JWT 是什么?JWT 如何防篡改?JWT 使用【hutools 工具包】
  • python爬虫之JS逆向入门,了解JS逆向的原理及用法(18)
  • 003 配置网络
  • springBoot动态加载jar,将类注册到IOC
  • 【数据分析】怎么提升GMV
  • df_new_last.iloc[:,-1]与df_new_last.iloc[:,:-1]
  • Redis 的使⽤和原理
  • IT运维的365天--018 如何在内网布置一个和外网同域名的网站,并开启SSL(https访问),即外网证书如何在内网使用
  • Kubernetes中常见的volumes数据卷
  • SPI协议——笔记
  • cangjie仓颉程序设计-数据结构(四)
  • [LeetCode] 面试题08.01 三步问题
  • 企业实现数字化转型需要考虑的方面?
  • LeetCode题练习与总结:超级次方--372
  • ‌SSB在时域上的特征
  • RHCE-SElinux+防火墙