《GBDT 算法的原理推导》 11-15更新决策树的叶子节点值 公式解析
本文是将文章《GBDT 算法的原理推导》中的公式单独拿出来做一个详细的解析,便于初学者更好的理解。
公式(11-15)出现在GBDT算法推导的过程中,用于更新决策树的叶子节点值。
公式(11-15)如下:
c m j = arg min c ∑ x i ∈ R m j L ( y i , f m − 1 ( x i ) + c ) c_{mj} = \arg \min_c \sum_{x_i \in R_{mj}} L(y_i, f_{m-1}(x_i) + c) cmj=argcminxi∈Rmj∑L(yi,fm−1(xi)+c)
其中:
- c m j c_{mj} cmj 表示第 m m m 棵树在叶子节点 R m j R_{mj} Rmj 上的输出值,也就是叶子节点的预测值。
- R m j R_{mj} Rmj 表示第 m m m 棵树的第 j j j 个叶子节点区域。
- L ( y i , f ( x i ) ) L(y_i, f(x_i)) L(yi,f(xi)) 是损失函数,衡量样本 x i x_i xi 的真实值 y i y_i yi 和当前模型预测值 f ( x i ) f(x_i) f(xi) 之间的误差。
- f m − 1 ( x i ) f_{m-1}(x_i) fm−1(xi) 是前 m − 1 m-1 m−1 轮构建的模型在 x i x_i xi 处的预测值。
1. 公式(11-15)的背景
在GBDT中,每一棵树的任务是对当前模型的误差进行拟合和修正。我们通过新增一棵树 T ( x ; Θ m ) T(x; \Theta_m) T(x;Θm) 来改善当前模型的预测能力。
这棵新树的结构(分裂方式)已经确定,它会将输入样本分配到不同的叶子节点区域 R m j R_{mj} Rmj。接下来,我们需要为每个叶子节点分配一个值,使得这棵树能够最好地拟合该节点区域内的样本误差。
2. 目标是最小化损失
在叶子节点 R m j R_{mj} Rmj 上,我们希望找到一个最优的输出值 c m j c_{mj} cmj,使得它能够最小化该节点区域内所有样本的损失。具体来说,给定前 m − 1 m-1 m−1 棵树的预测值 f m − 1 ( x i ) f_{m-1}(x_i) fm−1(xi),新树的叶子节点值 c m j c_{mj} cmj 需要满足以下优化目标:
c m j = arg min c ∑ x i ∈ R m j L ( y i , f m − 1 ( x i ) + c ) c_{mj} = \arg \min_c \sum_{x_i \in R_{mj}} L(y_i, f_{m-1}(x_i) + c) cmj=argcminxi∈Rmj∑L(yi,fm−1(xi)+c)
这意味着,我们为每个叶子节点选择一个最优的常数 c c c,使得该节点区域内所有样本的损失之和最小。
3. 损失函数 L ( y i , f m − 1 ( x i ) + c ) L(y_i, f_{m-1}(x_i) + c) L(yi,fm−1(xi)+c)
在GBDT算法中,不同的损失函数 L ( y , f ( x ) ) L(y, f(x)) L(y,f(x)) 会影响叶子节点的最优输出值计算方式。常见的损失函数包括:
- 平方损失:用于回归任务。
- 对数损失:用于二分类任务。
不同的损失函数会导致不同的叶子节点值计算方式。下面以平方损失为例来说明如何求解。
例子:平方损失
假设损失函数是平方损失:
L ( y i , f ( x i ) ) = 1 2 ( y i − f ( x i ) ) 2 L(y_i, f(x_i)) = \frac{1}{2} (y_i - f(x_i))^2 L(yi,f(xi))=21(yi−f(xi))2
代入公式(11-15)中的 L ( y i , f m − 1 ( x i ) + c ) L(y_i, f_{m-1}(x_i) + c) L(yi,fm−1(xi)+c):
c m j = arg min c ∑ x i ∈ R m j 1 2 ( y i − ( f m − 1 ( x i ) + c ) ) 2 c_{mj} = \arg \min_c \sum_{x_i \in R_{mj}} \frac{1}{2} (y_i - (f_{m-1}(x_i) + c))^2 cmj=argcminxi∈Rmj∑21(yi−(fm−1(xi)+c))2
我们对 c c c 求导,并让导数等于零,以找到最优的 c c c:
∂ ∂ c ∑ x i ∈ R m j 1 2 ( y i − f m − 1 ( x i ) − c ) 2 = 0 \frac{\partial}{\partial c} \sum_{x_i \in R_{mj}} \frac{1}{2} (y_i - f_{m-1}(x_i) - c)^2 = 0 ∂c∂xi∈Rmj∑21(yi−fm−1(xi)−c)2=0
这等价于:
∑ x i ∈ R m j ( y i − f m − 1 ( x i ) − c ) = 0 \sum_{x_i \in R_{mj}} (y_i - f_{m-1}(x_i) - c) = 0 xi∈Rmj∑(yi−fm−1(xi)−c)=0
解这个方程,可以得到:
c = ∑ x i ∈ R m j ( y i − f m − 1 ( x i ) ) ∣ R m j ∣ c = \frac{\sum_{x_i \in R_{mj}} (y_i - f_{m-1}(x_i))}{|R_{mj}|} c=∣Rmj∣∑xi∈Rmj(yi−fm−1(xi))
即:
c m j = ∑ x i ∈ R m j ( y i − f m − 1 ( x i ) ) ∣ R m j ∣ c_{mj} = \frac{\sum_{x_i \in R_{mj}} (y_i - f_{m-1}(x_i))}{|R_{mj}|} cmj=∣Rmj∣∑xi∈Rmj(yi−fm−1(xi))
这表明,在平方损失的情况下,叶子节点的输出值 c m j c_{mj} cmj 是该叶子节点区域内所有样本残差( y i − f m − 1 ( x i ) y_i - f_{m-1}(x_i) yi−fm−1(xi))的平均值。
总结
公式(11-15)表示了GBDT算法中如何确定每棵树的叶子节点值。通过最小化叶子节点区域内的损失,可以找到一个最优的输出值 c m j c_{mj} cmj,使得该节点区域的样本预测误差最小化。在不同的损失函数下,最优值 c m j c_{mj} cmj 的计算方式可能有所不同,但原理都是基于最小化损失来确定最佳的叶子节点值。