研究线性模型训练中损失变化的规律和最优学习率的影响
探究一维线性模型训练中,测试损失随训练步数变化的缩放定律及其最优学习率影响,并研究多维线性模型训练的缩放定律,确定参数以符合特定损失衰减模式。
研究大模型的缩放定律对减少其训练开销至关重要,即最终的测试损失如何随着训练步数和模型大小的变化而变化?本题中,我们研究了训练线性模型时的缩放定律。
- 在本小问中,考虑使用梯度下降学习一个一维线性模型的情况。
-
定义数据分布 D \mathcal{D} D为一个 R 2 \mathbb{R}^2 R2上的分布,每个数据是一个数对 ( x , y ) (x, y) (x,y),分别代表输入和输出,并服从分布 x ∼ N ( 0 , 1 ) , y ∼ N ( 3 x , 1 ) x\sim N(0, 1),y\sim N(3x, 1) x∼N(0,1),y∼N(3x,1)。
-
用梯度下降算法学习线性模型 f w ( x ) = w ⋅ x f_{w}(x)=w \cdot x fw(x)=w⋅x,其中 w , x ∈ R w, x\in\mathbb{R} w,x∈R。初始化 ω 0 = 0 ω_0=0 ω0=0并进行多步迭代。每次迭代时,从 D \mathcal{D} D中采样 ( x t , y t ) (x_t,y_t) (xt,yt),然后更新 w t w_t wt为 w t + 1 ← w t − η ∇ l t ( w t ) w_{t+1}\leftarrow w_t-\eta\nabla l_t(w_t) wt+1←wt−η∇lt(wt),其中 l t ( w ) = 1 2 ( f w ( x t ) − y t ) 2 l_t(w)=\frac{1}{2}(f_w(x_t)-y_t)^2 lt(w)=21(fw(xt)−yt)2是平方损失函数, η > 0 \eta>0 η>0是学习率。
设学习率 η ∈ ( 0 , 1 3 ] \eta\in(0,\frac{1}{3}] η∈(0,31],那么 T ≥ 0 T≥0 T≥0步迭代之后的测试损失的期望
L ‾ η , T = E w T E ( x , y ) ∼ D [ 1 2 ( f w T ( x ) − y ) 2 ] \overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{w_T}\mathbb{E}_{(x,y)\sim D}[\frac{1}{2}(f_{w_T}(x)-y)^2] Lη,T=EwTE(x,y)∼D[21(fwT(x)−y)2]
是多少?
- 现在我们在第一小问的设定下,考虑学习率 η \eta η被调到最优的情况,求函数 g ( T ) g(T) g(T),使得当 T → + ∞ T\rightarrow+\infty T→+∞时,以下条件成立:
∣ inf η ∈ ( 0 , 1 3 ] I n , T − g ( T ) ∣ = O ( ( log T ) 2 T 2 ) \left|\underset{η\in(0,\frac{1}{3}]}{\inf}\mathcal{I}_{n,T}-g(T)\right|=O(\frac{(\log T)^2}{T^2}) η∈(0,31]infIn,T−g(T) =O(T2(logT)2)
- 一个常常被观测到的实验现象是大语言模型的预训练过程大致遵循Chinchilla缩放定律:
L ‾ N , T ≈ A N α + B T β + C \overline{\mathcal{L}}_{N,T}≈\frac{A}{N^\alpha}+\frac{B}{T^\beta}+C LN,T≈NαA+TβB+C,
其中 L ‾ N , T \overline{\mathcal{L}}_{N,T} LN,T是在经过 T T T步训练后具有 N N N个参数的模型的测试损失的期望, A A A, B B B, a a a, β β β, C C C是常数。现在我们举一个训练多维线性模型的例子,使其也遵循类似的缩放定律。
-
固定 a > 0 , b ≥ 1 a>0,b≥1 a>0,b≥1,每个数据 ( x ⋅ , y ) (x_{\cdot},y) (x⋅,y)由一个输入和输出组成,其中输入 x ⋅ x_{\cdot} x⋅是一个无限维向量(可看作一个序列),输出 y y y满足 y ∈ R y\in\mathbb{R} y∈R。定义数据分布 D \mathcal{D} D如下。首先,从Zipf分布中采样 k k k, Pr [ k = i ] ∝ i − ( a + 1 ) ( i ≥ 1 ) \Pr[k=i]\propto i^{-(a+1)}\quad(i\geq 1) Pr[k=i]∝i−(a+1)(i≥1)。令 j : = [ k b ] j:=[k^b] j:=[kb],然后,从 m a t h c a l N ( 0 , 1 ) mathcal{N}(0,1) mathcalN(0,1)中采样得到 x ⋅ x_{\cdot} x⋅的第 j j j个坐标 x j x_j xj,并令其余坐标为0。最后, y ∼ N ( 3 x j , 1 ) y\sim N(3x_j,1) y∼N(3xj,1)。这样得到的 ( x ⋅ , y ) (x_{\cdot},y) (x⋅,y)的分即数据分布 D \mathcal{D} D。
-
我们研究一个仅关注前 N N N个输入坐标的线性模型。定义函数 ϕ N ( x x ⋅ ) = ( x 1 , . . . , x N ) \phi_N(xx_{\cdot})=(x_1,...,x_N) ϕN(xx⋅)=(x1,...,xN)。我们研究的线性模型具有参数 w ∈ R N \mathbf{w}\in\mathbb{R}^N w∈RN,输出为 f w ( x ) = ( w , ϕ N ( x ⋅ ) ) f_{\mathbf{w}}(x)=(\mathbf{w},\phi_N(x_{\cdot})) fw(x)=(w,ϕN(x⋅))。
-
我们使用梯度下降算法学习该线性模型。初始化 w 0 = 0 \mathbf{w}_0=0 w0=0并进行多步迭代。每次迭代时,从 D \mathcal{D} D中采样 ( x t , ⋅ , y t ) (x_{t,\cdot},y_t) (xt,⋅,yt),然后更新 w t \mathbf{w}_t wt为 w t + 1 ← w t − η ∇ l t ( w t ) \mathbf{w}_{t+1}\gets \mathbf{w}_t-\eta\nabla l_t(\mathbf{w}_t) wt+1←wt−η∇lt(wt),其中 l t ( w ) = 1 2 ( f w ( x t , ⋅ ) − y t ) 2 l_t(\mathbf{w})=\frac{1}{2}(f_\mathbf{w}(x_{t,\cdot})-y_t)^2 lt(w)=21(fw(xt,⋅)−yt)2。
令 L ‾ η , T = E w T E ( x , y ) ∼ D [ 1 2 ( f w T ( x ) − y ) 2 ] \overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{\mathbf{w}_T}\mathbb{E}_{(x,y)\sim D}[\frac{1}{2}(f_{\mathbf{w}_T}(x)-y)^2] Lη,T=EwTE(x,y)∼D[21(fwT(x)−y)2]为以学习率 η ∈ ( 0 , 1 3 ] \eta\in(0,\frac{1}{3}] η∈(0,31]对其有N个参数的线性模型进行 T ≥ 0 T≥0 T≥0步训练后的测试损失的期望。
请求出 α α α, β β β, C C C,使得 ∀ γ > 0 , ∀ c > 0 \forall\gamma>0,\forall c>0 ∀γ>0,∀c>0,当 T = N c + o ( 1 ) T=N^{c+o(1)} T=Nc+o(1)且 N N N足够大时,以下条件成立:
ϵ ( N , T ) : = inf η ∈ ( 0 , 1 3 ] L ‾ N , T − C A N α + B T β \epsilon(N,T):=\frac{\inf_{\eta\in(0,\frac{1}{3}]}{\overline{\mathcal{L}}_{N,T}}-C}{\frac{A}{N^\alpha}+\frac{B}{T^\beta}} ϵ(N,T):=NαA+TβBinfη∈(0,31]LN,T−C,
( log N + log T ) − γ ≤ ϵ ( N , T ) ≤ ( log N + log T ) γ (\log N+\log T)^{-γ}\leq \epsilon(N,T)\leq(\log N+\log T)^γ (logN+logT)−γ≤ϵ(N,T)≤(logN+logT)γ。即 inf η ∈ ( 0 , 1 3 ] L ‾ N , T = Θ ~ ( N − α + T − β ) + C \inf_{\eta\in(0,\frac{1}{3}]}{\overline{\mathcal{L}}_{N,T}}=\tilde{\Theta}(N^{-\alpha}+T^{-\beta})+C infη∈(0,31]LN,T=Θ~(N−α+T−β)+C,其中 Θ ~ \tilde{\Theta} Θ~表示忽略任何关于 log N \log N logN和 log T \log T logT的多项式。
解:
- 首先,我们来计算测试损失的期望 L ‾ η , T \overline{\mathcal{L}}_{\eta,T} Lη,T。
由于 x x x和 y y y是独立的随机变量,且 y y y的条件分布是 N ( 3 x , 1 ) N(3x, 1) N(3x,1),我们可以写出测试损失的期望为:
L ‾ η , T = E ( x , y ) ∼ D [ 1 2 ( w T x − y ) 2 ] \overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{(x,y)\sim D}[\frac{1}{2}(w_T x - y)^2] Lη,T=E(x,y)∼D[21(wTx−y)2]
由于 y = 3 x + ϵ y=3x+\epsilon y=3x+ϵ,其中 ϵ ∼ N ( 0 , 1 ) \epsilon\sim N(0, 1) ϵ∼N(0,1)且独立于 x x x,我们可以将 y y y替换为 3 x + ϵ 3x+\epsilon 3x+ϵ:
L ‾ η , T = E x , ϵ [ 1 2 ( w T x − ( 3 x + ϵ ) ) 2 ] \overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{x,\epsilon}[\frac{1}{2}(w_T x - (3x+\epsilon))^2] Lη,T=Ex,ϵ[21(wTx−(3x+ϵ))2]
展开并利用 E [ ϵ 2 ] = 1 \mathbb{E}[\epsilon^2]=1 E[ϵ2]=1和 E [ x 2 ] = 1 \mathbb{E}[x^2]=1 E[x2]=1(因为 x ∼ N ( 0 , 1 ) x\sim N(0, 1) x∼N(0,1)):
L ‾ η , T = E x [ 1 2 ( w T 2 x 2 − 6 w T x 2 + 9 x 2 + ϵ 2 − 6 w T x ϵ + 3 w T 2 x 2 ) ] \overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_x[\frac{1}{2}(w_T^2 x^2 - 6w_T x^2 + 9x^2 + \epsilon^2 - 6w_T x \epsilon + 3w_T^2 x^2)] Lη,T=Ex[21(wT2x2−6wTx2+9x2+ϵ2−6wTxϵ+3wT2x2)]
由于 ϵ \epsilon ϵ和 x x x是独立的,我们可以分别计算期望:
L ‾ η , T = 1 2 ( w T 2 − 6 w T + 9 ) E [ x 2 ] + 1 2 E [ ϵ 2 ] \overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}(w_T^2 - 6w_T + 9)\mathbb{E}[x^2] + \frac{1}{2}\mathbb{E}[\epsilon^2] Lη,T=21(wT2−6wT+9)E[x2]+21E[ϵ2]
L ‾ η , T = 1 2 ( w T 2 − 6 w T + 9 ) + 1 2 \overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}(w_T^2 - 6w_T + 9) + \frac{1}{2} Lη,T=21(wT2−6wT+9)+21
现在我们需要计算 w T w_T wT的期望值。由于 w t w_t wt的更新规则是 w t + 1 = w t − η ∇ l t ( w t ) w_{t+1}=w_t-\eta\nabla l_t(w_t) wt+1=wt−η∇lt(wt),我们有:
∇ l t ( w t ) = w t x t − y t = w t x t − ( 3 x t + ϵ ) \nabla l_t(w_t) = w_t x_t - y_t = w_t x_t - (3x_t + \epsilon) ∇lt(wt)=wtxt−yt=wtxt−(3xt+ϵ)
因此,更新规则变为:
w t + 1 = w t − η ( w t x t − 3 x t − ϵ ) w_{t+1} = w_t - \eta(w_t x_t - 3x_t - \epsilon) wt+1=wt−η(wtxt−3xt−ϵ)
取期望并利用 E [ x t ] = 0 \mathbb{E}[x_t]=0 E[xt]=0和 E [ ϵ ] = 0 \mathbb{E}[\epsilon]=0 E[ϵ]=0:
E [ w t + 1 ] = E [ w t ] − η ( 3 E [ x t 2 ] ) \mathbb{E}[w_{t+1}] = \mathbb{E}[w_t] - \eta(3\mathbb{E}[x_t^2]) E[wt+1]=E[wt]−η(3E[xt2])
由于 x t 2 x_t^2 xt2的期望是1,我们有:
E [ w t + 1 ] = E [ w t ] − 3 η \mathbb{E}[w_{t+1}] = \mathbb{E}[w_t] - 3\eta E[wt+1]=E[wt]−3η
由于 w 0 = 0 w_0=0 w0=0,我们可以递归地计算 w T w_T wT:
E [ w T ] = − 3 η T \mathbb{E}[w_T] = -3\eta T E[wT]=−3ηT
将 E [ w T ] \mathbb{E}[w_T] E[wT]代入测试损失的期望中:
L ‾ η , T = 1 2 ( ( − 3 η T ) 2 − 6 ( − 3 η T ) + 9 ) + 1 2 \overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}((-3\eta T)^2 - 6(-3\eta T) + 9) + \frac{1}{2} Lη,T=21((−3ηT)2−6(−3ηT)+9)+21
L ‾ η , T = 1 2 ( 9 η 2 T 2 + 18 η T + 9 ) + 1 2 \overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}(9\eta^2 T^2 + 18\eta T + 9) + \frac{1}{2} Lη,T=21(9η2T2+18ηT+9)+21
L ‾ η , T = 9 η 2 T 2 + 18 η T + 10 2 \overline{\mathcal{L}}_{\eta,T}=\frac{9\eta^2 T^2 + 18\eta T + 10}{2} Lη,T=29η2T2+18ηT+10
- 接下来,我们需要找到 g ( T ) g(T) g(T)。
首先,我们需要最小化 L ‾ η , T \overline{\mathcal{L}}_{\eta,T} Lη,T关于 η \eta η。我们可以通过设置 d L ‾ η , T d η = 0 \frac{d\overline{\mathcal{L}}_{\eta,T}}{d\eta}=0 dηdLη,T=0来找到最优的学习率 η ∗ \eta^* η∗:
d d η ( 9 η 2 T 2 + 18 η T + 10 2 ) = 9 η T 2 + 18 T = 0 \frac{d}{d\eta}(\frac{9\eta^2 T^2 + 18\eta T + 10}{2})=9\eta T^2 + 18T=0 dηd(29η2T2+18ηT+10)=9ηT2+18T=0
解得:
η ∗ = 2 3 T \eta^* = \frac{2}{3T} η∗=3T2
将 η ∗ \eta^* η∗代入 L ‾ η , T \overline{\mathcal{L}}_{\eta,T} Lη,T中,我们得到最小化测试损失的表达式:
L ‾ η ∗ , T = 9 ( 2 3 T ) 2 T 2 + 18 ( 2 3 T ) T + 10 2 \overline{\mathcal{L}}_{\eta^*,T}=\frac{9(\frac{2}{3T})^2 T^2 + 18(\frac{2}{3T}) T + 10}{2} Lη∗,T=29(3T2)2T2+18(3T2)T+10
L ‾ η ∗ , T = 9 ( 4 9 T 2 ) T 2 + 18 ( 2 3 T ) T + 10 2 \overline{\mathcal{L}}_{\eta^*,T}=\frac{9(\frac{4}{9T^2}) T^2 + 18(\frac{2}{3T}) T + 10}{2} Lη∗,T=29(9T24)T2+18(3T2)T+10
L ‾ η ∗ , T = 4 + 12 + 10 2 \overline{\mathcal{L}}_{\eta^*,T}=\frac{4 + 12 + 10}{2} Lη∗,T=24+12+10
L ‾ η ∗ , T = 26 2 \overline{\mathcal{L}}_{\eta^*,T}=\frac{26}{2} Lη∗,T=226
L ‾ η ∗ , T = 13 \overline{\mathcal{L}}_{\eta^*,T}=13 Lη∗,T=13
现在,我们需要找到 g ( T ) g(T) g(T),使得当 T → + ∞ T\rightarrow+\infty T→+∞时,以下条件成立:
∣ inf η ∈ ( 0 , 1 3 ] I n , T − g ( T ) ∣ = O ( ( log T ) 2 T 2 ) \left|\underset{\eta\in(0,\frac{1}{3}]}{\inf}\mathcal{I}_{n,T}-g(T)\right|=O\left(\frac{(\log T)^2}{T^2}\right) η∈(0,31]infIn,T−g(T) =O(T2(logT)2)
由于我们已经找到了最优的学习率 η ∗ \eta^* η∗,我们可以将 L ‾ η ∗ , T \overline{\mathcal{L}}_{\eta^*,T} Lη∗,T视为 I n , T \mathcal{I}_{n,T} In,T的下界。因此,我们需要找到一个函数 g ( T ) g(T) g(T),使得当 T T T趋向于无穷大时, L ‾ η ∗ , T \overline{\mathcal{L}}_{\eta^*,T} Lη∗,T和 g ( T ) g(T) g(T)之间的差异满足上述条件。
考虑到 L ‾ η ∗ , T \overline{\mathcal{L}}_{\eta^*,T} Lη∗,T是一个常数13,我们可以推断 g ( T ) g(T) g(T)应该也是一个常数,因为测试损失的期望在最优学习率下不随 T T T变化。因此,我们可以选择 g ( T ) = 13 g(T)=13 g(T)=13。
现在,我们需要验证这个选择是否满足条件:
∣ inf η ∈ ( 0 , 1 3 ] I n , T − g ( T ) ∣ = O ( ( log T ) 2 T 2 ) \left|\underset{\eta\in(0,\frac{1}{3}]}{\inf}\mathcal{I}_{n,T}-g(T)\right|=O\left(\frac{(\log T)^2}{T^2}\right) η∈(0,31]infIn,T−g(T) =O(T2(logT)2)
由于 I n , T \mathcal{I}_{n,T} In,T的最小值是13,我们有:
∣ 13 − 13 ∣ = 0 \left|13-13\right|=0 ∣13−13∣=0
显然, 0 = O ( ( log T ) 2 T 2 ) 0=O\left(\frac{(\log T)^2}{T^2}\right) 0=O(T2(logT)2),因为当 T T T趋向于无穷大时, ( log T ) 2 T 2 \frac{(\log T)^2}{T^2} T2(logT)2趋向于0。因此,我们的选择 g ( T ) = 13 g(T)=13 g(T)=13是正确的。
综上所述, g ( T ) = 13 g(T)=13 g(T)=13满足题目中的条件。
3.为了解决这个问题,我们需要推导出多维线性模型在给定数据分布下的缩放定律。根据题目描述,我们有一个线性模型,其参数遵循特定的缩放定律。我们将通过以下步骤来解决这个问题:
步骤 1: 理解数据分布
数据分布 D \mathcal{D} D 是通过 Zipf 分布来选择输入向量的非零坐标,然后根据该坐标的值来生成输出 y y y。这意味着大部分的数据集中在较少的非零坐标上。
步骤 2: 定义损失函数
损失函数 L ‾ η , T \overline{\mathcal{L}}_{\eta,T} Lη,T 是在给定学习率 η \eta η 和训练步数 T T T 后,模型参数 w \mathbf{w} w 的测试损失的期望。
步骤 3: 推导缩放定律
我们需要找到 α \alpha α, β \beta β,和 C C C 使得损失函数符合 L ‾ N , T ≈ A N α + B T β + C \overline{\mathcal{L}}_{N,T}≈\frac{A}{N^\alpha}+\frac{B}{T^\beta}+C LN,T≈NαA+TβB+C 的形式。
对于 α \alpha α 的推导:
-
参数 N N N 表示模型考虑的输入向量的维度。由于数据分布的特性,大部分的权重不会接收到有效的梯度更新,因为它们对应的输入坐标为零。因此,增加 N N N 的数量不会显著改善模型的性能,但也不会损害它,因为只有少数权重会被更新。
-
Zipf 分布的特性意味着非零坐标的数量随着 N N N 的增加而减少。因此,我们可以预期 α \alpha α 大于 0,但小于 1,因为增加维度对于模型性能的提升是有上限的。
对于 β \beta β 的推导:
-
参数 T T T 表示训练步数。随着训练步数的增加,模型将获得更多的机会来更新其权重,从而减少损失。因此,我们可以预期 β \beta β 大于 0。
-
由于数据分布的特性,并不是每一步都会对所有权重进行有效更新。因此, β \beta β 可能不会是 1,而是小于 1 的某个值。
对于 C C C 的推导:
- 常数 C C C 表示当 N N N 和 T T T 趋于无穷大时,测试损失的最低值。这是由于数据本身的噪声和模型的能力限制导致的。
步骤 4: 确定 α \alpha α, β \beta β,和 C C C
为了确定 α \alpha α, β \beta β,和 C C C,我们需要进行以下分析:
-
对于 α \alpha α:考虑到只有少数权重会被更新,我们可以假设 α \alpha α 在 0 和 1 之间。更具体地,由于 Zipf 分布的特性,我们可以假设 α \alpha α 接近于 1,但小于 1,因为随着 N N N 的增加,额外维度的边际贡献会减少。一个合理的猜测是 α = 1 b \alpha = \frac{1}{b} α=b1。
-
对于 β \beta β:考虑到每一步并不是对所有权重都进行有效更新,我们可以假设 β \beta β 小于 1。一个合理的猜测是 β = 1 2 \beta = \frac{1}{2} β=21,这是因为通常情况下,梯度下降的收敛速度与步数的平方根成反比。
-
对于 C C C:这是数据噪声和模型表达能力限制的结果。在没有更多信息的情况下,我们无法精确确定 C C C,但可以假设它是一个正数。
步骤 5: 验证条件
我们需要验证 ϵ ( N , T ) \epsilon(N,T) ϵ(N,T) 的条件是否成立。这通常涉及到对 L ‾ N , T \overline{\mathcal{L}}_{N,T} LN,T 进行详细的分析,并证明它符合给定的缩放形式。这通常需要数学上的证明和/或实验验证。
综上所述,我们可以假设 α = 1 b \alpha = \frac{1}{b} α=b1, β = 1 2 \beta = \frac{1}{2} β=21, C C C 是一个正数。然而,为了得到精确的值,我们需要更深入的分析和实验数据。在实际应用中,这些参数通常是通过实验来确定的。