当前位置: 首页 > news >正文

研究线性模型训练中损失变化的规律和最优学习率的影响

探究一维线性模型训练中,测试损失随训练步数变化的缩放定律及其最优学习率影响,并研究多维线性模型训练的缩放定律,确定参数以符合特定损失衰减模式。

研究大模型的缩放定律对减少其训练开销至关重要,即最终的测试损失如何随着训练步数和模型大小的变化而变化?本题中,我们研究了训练线性模型时的缩放定律。

  1. 在本小问中,考虑使用梯度下降学习一个一维线性模型的情况。
  • 定义数据分布 D \mathcal{D} D为一个 R 2 \mathbb{R}^2 R2上的分布,每个数据是一个数对 ( x , y ) (x, y) (x,y),分别代表输入和输出,并服从分布 x ∼ N ( 0 , 1 ) , y ∼ N ( 3 x , 1 ) x\sim N(0, 1),y\sim N(3x, 1) xN(0,1),yN(3x,1)

  • 用梯度下降算法学习线性模型 f w ( x ) = w ⋅ x f_{w}(x)=w \cdot x fw(x)=wx,其中 w , x ∈ R w, x\in\mathbb{R} w,xR。初始化 ω 0 = 0 ω_0=0 ω0=0并进行多步迭代。每次迭代时,从 D \mathcal{D} D中采样 ( x t , y t ) (x_t,y_t) (xt,yt),然后更新 w t w_t wt w t + 1 ← w t − η ∇ l t ( w t ) w_{t+1}\leftarrow w_t-\eta\nabla l_t(w_t) wt+1wtηlt(wt),其中 l t ( w ) = 1 2 ( f w ( x t ) − y t ) 2 l_t(w)=\frac{1}{2}(f_w(x_t)-y_t)^2 lt(w)=21(fw(xt)yt)2是平方损失函数, η > 0 \eta>0 η>0是学习率。

设学习率 η ∈ ( 0 , 1 3 ] \eta\in(0,\frac{1}{3}] η(0,31],那么 T ≥ 0 T≥0 T0步迭代之后的测试损失的期望

L ‾ η , T = E w T E ( x , y ) ∼ D [ 1 2 ( f w T ( x ) − y ) 2 ] \overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{w_T}\mathbb{E}_{(x,y)\sim D}[\frac{1}{2}(f_{w_T}(x)-y)^2] Lη,T=EwTE(x,y)D[21(fwT(x)y)2]

是多少?

  1. 现在我们在第一小问的设定下,考虑学习率 η \eta η被调到最优的情况,求函数 g ( T ) g(T) g(T),使得当 T → + ∞ T\rightarrow+\infty T+时,以下条件成立:

∣ inf ⁡ η ∈ ( 0 , 1 3 ] I n , T − g ( T ) ∣ = O ( ( log ⁡ T ) 2 T 2 ) \left|\underset{η\in(0,\frac{1}{3}]}{\inf}\mathcal{I}_{n,T}-g(T)\right|=O(\frac{(\log T)^2}{T^2}) η(0,31]infIn,Tg(T) =O(T2(logT)2)

  1. 一个常常被观测到的实验现象是大语言模型的预训练过程大致遵循Chinchilla缩放定律:

L ‾ N , T ≈ A N α + B T β + C \overline{\mathcal{L}}_{N,T}≈\frac{A}{N^\alpha}+\frac{B}{T^\beta}+C LN,TNαA+TβB+C

其中 L ‾ N , T \overline{\mathcal{L}}_{N,T} LN,T是在经过 T T T步训练后具有 N N N个参数的模型的测试损失的期望, A A A B B B a a a β β β C C C是常数。现在我们举一个训练多维线性模型的例子,使其也遵循类似的缩放定律。

  • 固定 a > 0 , b ≥ 1 a>0,b≥1 a>0,b1,每个数据 ( x ⋅ , y ) (x_{\cdot},y) (x,y)由一个输入和输出组成,其中输入 x ⋅ x_{\cdot} x是一个无限维向量(可看作一个序列),输出 y y y满足 y ∈ R y\in\mathbb{R} yR。定义数据分布 D \mathcal{D} D如下。首先,从Zipf分布中采样 k k k Pr ⁡ [ k = i ] ∝ i − ( a + 1 ) ( i ≥ 1 ) \Pr[k=i]\propto i^{-(a+1)}\quad(i\geq 1) Pr[k=i]i(a+1)(i1)。令 j : = [ k b ] j:=[k^b] j:=[kb],然后,从 m a t h c a l N ( 0 , 1 ) mathcal{N}(0,1) mathcalN(0,1)中采样得到 x ⋅ x_{\cdot} x的第 j j j个坐标 x j x_j xj,并令其余坐标为0。最后, y ∼ N ( 3 x j , 1 ) y\sim N(3x_j,1) yN(3xj,1)。这样得到的 ( x ⋅ , y ) (x_{\cdot},y) (x,y)的分即数据分布 D \mathcal{D} D

  • 我们研究一个仅关注前 N N N个输入坐标的线性模型。定义函数 ϕ N ( x x ⋅ ) = ( x 1 , . . . , x N ) \phi_N(xx_{\cdot})=(x_1,...,x_N) ϕN(xx)=(x1,...,xN)。我们研究的线性模型具有参数 w ∈ R N \mathbf{w}\in\mathbb{R}^N wRN,输出为 f w ( x ) = ( w , ϕ N ( x ⋅ ) ) f_{\mathbf{w}}(x)=(\mathbf{w},\phi_N(x_{\cdot})) fw(x)=(w,ϕN(x))

  • 我们使用梯度下降算法学习该线性模型。初始化 w 0 = 0 \mathbf{w}_0=0 w0=0并进行多步迭代。每次迭代时,从 D \mathcal{D} D中采样 ( x t , ⋅ , y t ) (x_{t,\cdot},y_t) (xt,,yt),然后更新 w t \mathbf{w}_t wt w t + 1 ← w t − η ∇ l t ( w t ) \mathbf{w}_{t+1}\gets \mathbf{w}_t-\eta\nabla l_t(\mathbf{w}_t) wt+1wtηlt(wt),其中 l t ( w ) = 1 2 ( f w ( x t , ⋅ ) − y t ) 2 l_t(\mathbf{w})=\frac{1}{2}(f_\mathbf{w}(x_{t,\cdot})-y_t)^2 lt(w)=21(fw(xt,)yt)2

L ‾ η , T = E w T E ( x , y ) ∼ D [ 1 2 ( f w T ( x ) − y ) 2 ] \overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{\mathbf{w}_T}\mathbb{E}_{(x,y)\sim D}[\frac{1}{2}(f_{\mathbf{w}_T}(x)-y)^2] Lη,T=EwTE(x,y)D[21(fwT(x)y)2]为以学习率 η ∈ ( 0 , 1 3 ] \eta\in(0,\frac{1}{3}] η(0,31]对其有N个参数的线性模型进行 T ≥ 0 T≥0 T0步训练后的测试损失的期望。

请求出 α α α β β β C C C,使得 ∀ γ > 0 , ∀ c > 0 \forall\gamma>0,\forall c>0 γ>0,c>0,当 T = N c + o ( 1 ) T=N^{c+o(1)} T=Nc+o(1) N N N足够大时,以下条件成立:

ϵ ( N , T ) : = inf ⁡ η ∈ ( 0 , 1 3 ] L ‾ N , T − C A N α + B T β \epsilon(N,T):=\frac{\inf_{\eta\in(0,\frac{1}{3}]}{\overline{\mathcal{L}}_{N,T}}-C}{\frac{A}{N^\alpha}+\frac{B}{T^\beta}} ϵ(N,T):=NαA+TβBinfη(0,31]LN,TC

( log ⁡ N + log ⁡ T ) − γ ≤ ϵ ( N , T ) ≤ ( log ⁡ N + log ⁡ T ) γ (\log N+\log T)^{-γ}\leq \epsilon(N,T)\leq(\log N+\log T)^γ (logN+logT)γϵ(N,T)(logN+logT)γ。即 inf ⁡ η ∈ ( 0 , 1 3 ] L ‾ N , T = Θ ~ ( N − α + T − β ) + C \inf_{\eta\in(0,\frac{1}{3}]}{\overline{\mathcal{L}}_{N,T}}=\tilde{\Theta}(N^{-\alpha}+T^{-\beta})+C infη(0,31]LN,T=Θ~(Nα+Tβ)+C,其中 Θ ~ \tilde{\Theta} Θ~表示忽略任何关于 log ⁡ N \log N logN log ⁡ T \log T logT的多项式。

解:

  1. 首先,我们来计算测试损失的期望 L ‾ η , T \overline{\mathcal{L}}_{\eta,T} Lη,T

由于 x x x y y y是独立的随机变量,且 y y y的条件分布是 N ( 3 x , 1 ) N(3x, 1) N(3x,1),我们可以写出测试损失的期望为:

L ‾ η , T = E ( x , y ) ∼ D [ 1 2 ( w T x − y ) 2 ] \overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{(x,y)\sim D}[\frac{1}{2}(w_T x - y)^2] Lη,T=E(x,y)D[21(wTxy)2]

由于 y = 3 x + ϵ y=3x+\epsilon y=3x+ϵ,其中 ϵ ∼ N ( 0 , 1 ) \epsilon\sim N(0, 1) ϵN(0,1)且独立于 x x x,我们可以将 y y y替换为 3 x + ϵ 3x+\epsilon 3x+ϵ

L ‾ η , T = E x , ϵ [ 1 2 ( w T x − ( 3 x + ϵ ) ) 2 ] \overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{x,\epsilon}[\frac{1}{2}(w_T x - (3x+\epsilon))^2] Lη,T=Ex,ϵ[21(wTx(3x+ϵ))2]

展开并利用 E [ ϵ 2 ] = 1 \mathbb{E}[\epsilon^2]=1 E[ϵ2]=1 E [ x 2 ] = 1 \mathbb{E}[x^2]=1 E[x2]=1(因为 x ∼ N ( 0 , 1 ) x\sim N(0, 1) xN(0,1)):

L ‾ η , T = E x [ 1 2 ( w T 2 x 2 − 6 w T x 2 + 9 x 2 + ϵ 2 − 6 w T x ϵ + 3 w T 2 x 2 ) ] \overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_x[\frac{1}{2}(w_T^2 x^2 - 6w_T x^2 + 9x^2 + \epsilon^2 - 6w_T x \epsilon + 3w_T^2 x^2)] Lη,T=Ex[21(wT2x26wTx2+9x2+ϵ26wTxϵ+3wT2x2)]

由于 ϵ \epsilon ϵ x x x是独立的,我们可以分别计算期望:

L ‾ η , T = 1 2 ( w T 2 − 6 w T + 9 ) E [ x 2 ] + 1 2 E [ ϵ 2 ] \overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}(w_T^2 - 6w_T + 9)\mathbb{E}[x^2] + \frac{1}{2}\mathbb{E}[\epsilon^2] Lη,T=21(wT26wT+9)E[x2]+21E[ϵ2]

L ‾ η , T = 1 2 ( w T 2 − 6 w T + 9 ) + 1 2 \overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}(w_T^2 - 6w_T + 9) + \frac{1}{2} Lη,T=21(wT26wT+9)+21

现在我们需要计算 w T w_T wT的期望值。由于 w t w_t wt的更新规则是 w t + 1 = w t − η ∇ l t ( w t ) w_{t+1}=w_t-\eta\nabla l_t(w_t) wt+1=wtηlt(wt),我们有:

∇ l t ( w t ) = w t x t − y t = w t x t − ( 3 x t + ϵ ) \nabla l_t(w_t) = w_t x_t - y_t = w_t x_t - (3x_t + \epsilon) lt(wt)=wtxtyt=wtxt(3xt+ϵ)

因此,更新规则变为:

w t + 1 = w t − η ( w t x t − 3 x t − ϵ ) w_{t+1} = w_t - \eta(w_t x_t - 3x_t - \epsilon) wt+1=wtη(wtxt3xtϵ)

取期望并利用 E [ x t ] = 0 \mathbb{E}[x_t]=0 E[xt]=0 E [ ϵ ] = 0 \mathbb{E}[\epsilon]=0 E[ϵ]=0

E [ w t + 1 ] = E [ w t ] − η ( 3 E [ x t 2 ] ) \mathbb{E}[w_{t+1}] = \mathbb{E}[w_t] - \eta(3\mathbb{E}[x_t^2]) E[wt+1]=E[wt]η(3E[xt2])

由于 x t 2 x_t^2 xt2的期望是1,我们有:

E [ w t + 1 ] = E [ w t ] − 3 η \mathbb{E}[w_{t+1}] = \mathbb{E}[w_t] - 3\eta E[wt+1]=E[wt]3η

由于 w 0 = 0 w_0=0 w0=0,我们可以递归地计算 w T w_T wT

E [ w T ] = − 3 η T \mathbb{E}[w_T] = -3\eta T E[wT]=3ηT

E [ w T ] \mathbb{E}[w_T] E[wT]代入测试损失的期望中:

L ‾ η , T = 1 2 ( ( − 3 η T ) 2 − 6 ( − 3 η T ) + 9 ) + 1 2 \overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}((-3\eta T)^2 - 6(-3\eta T) + 9) + \frac{1}{2} Lη,T=21((3ηT)26(3ηT)+9)+21

L ‾ η , T = 1 2 ( 9 η 2 T 2 + 18 η T + 9 ) + 1 2 \overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}(9\eta^2 T^2 + 18\eta T + 9) + \frac{1}{2} Lη,T=21(9η2T2+18ηT+9)+21

L ‾ η , T = 9 η 2 T 2 + 18 η T + 10 2 \overline{\mathcal{L}}_{\eta,T}=\frac{9\eta^2 T^2 + 18\eta T + 10}{2} Lη,T=29η2T2+18ηT+10

  1. 接下来,我们需要找到 g ( T ) g(T) g(T)

首先,我们需要最小化 L ‾ η , T \overline{\mathcal{L}}_{\eta,T} Lη,T关于 η \eta η。我们可以通过设置 d L ‾ η , T d η = 0 \frac{d\overline{\mathcal{L}}_{\eta,T}}{d\eta}=0 dηdLη,T=0来找到最优的学习率 η ∗ \eta^* η

d d η ( 9 η 2 T 2 + 18 η T + 10 2 ) = 9 η T 2 + 18 T = 0 \frac{d}{d\eta}(\frac{9\eta^2 T^2 + 18\eta T + 10}{2})=9\eta T^2 + 18T=0 dηd(29η2T2+18ηT+10)=9ηT2+18T=0

解得:

η ∗ = 2 3 T \eta^* = \frac{2}{3T} η=3T2

η ∗ \eta^* η代入 L ‾ η , T \overline{\mathcal{L}}_{\eta,T} Lη,T中,我们得到最小化测试损失的表达式:

L ‾ η ∗ , T = 9 ( 2 3 T ) 2 T 2 + 18 ( 2 3 T ) T + 10 2 \overline{\mathcal{L}}_{\eta^*,T}=\frac{9(\frac{2}{3T})^2 T^2 + 18(\frac{2}{3T}) T + 10}{2} Lη,T=29(3T2)2T2+18(3T2)T+10

L ‾ η ∗ , T = 9 ( 4 9 T 2 ) T 2 + 18 ( 2 3 T ) T + 10 2 \overline{\mathcal{L}}_{\eta^*,T}=\frac{9(\frac{4}{9T^2}) T^2 + 18(\frac{2}{3T}) T + 10}{2} Lη,T=29(9T24)T2+18(3T2)T+10

L ‾ η ∗ , T = 4 + 12 + 10 2 \overline{\mathcal{L}}_{\eta^*,T}=\frac{4 + 12 + 10}{2} Lη,T=24+12+10

L ‾ η ∗ , T = 26 2 \overline{\mathcal{L}}_{\eta^*,T}=\frac{26}{2} Lη,T=226

L ‾ η ∗ , T = 13 \overline{\mathcal{L}}_{\eta^*,T}=13 Lη,T=13

现在,我们需要找到 g ( T ) g(T) g(T),使得当 T → + ∞ T\rightarrow+\infty T+时,以下条件成立:

∣ inf ⁡ η ∈ ( 0 , 1 3 ] I n , T − g ( T ) ∣ = O ( ( log ⁡ T ) 2 T 2 ) \left|\underset{\eta\in(0,\frac{1}{3}]}{\inf}\mathcal{I}_{n,T}-g(T)\right|=O\left(\frac{(\log T)^2}{T^2}\right) η(0,31]infIn,Tg(T) =O(T2(logT)2)

由于我们已经找到了最优的学习率 η ∗ \eta^* η,我们可以将 L ‾ η ∗ , T \overline{\mathcal{L}}_{\eta^*,T} Lη,T视为 I n , T \mathcal{I}_{n,T} In,T的下界。因此,我们需要找到一个函数 g ( T ) g(T) g(T),使得当 T T T趋向于无穷大时, L ‾ η ∗ , T \overline{\mathcal{L}}_{\eta^*,T} Lη,T g ( T ) g(T) g(T)之间的差异满足上述条件。

考虑到 L ‾ η ∗ , T \overline{\mathcal{L}}_{\eta^*,T} Lη,T是一个常数13,我们可以推断 g ( T ) g(T) g(T)应该也是一个常数,因为测试损失的期望在最优学习率下不随 T T T变化。因此,我们可以选择 g ( T ) = 13 g(T)=13 g(T)=13

现在,我们需要验证这个选择是否满足条件:

∣ inf ⁡ η ∈ ( 0 , 1 3 ] I n , T − g ( T ) ∣ = O ( ( log ⁡ T ) 2 T 2 ) \left|\underset{\eta\in(0,\frac{1}{3}]}{\inf}\mathcal{I}_{n,T}-g(T)\right|=O\left(\frac{(\log T)^2}{T^2}\right) η(0,31]infIn,Tg(T) =O(T2(logT)2)

由于 I n , T \mathcal{I}_{n,T} In,T的最小值是13,我们有:

∣ 13 − 13 ∣ = 0 \left|13-13\right|=0 1313=0

显然, 0 = O ( ( log ⁡ T ) 2 T 2 ) 0=O\left(\frac{(\log T)^2}{T^2}\right) 0=O(T2(logT)2),因为当 T T T趋向于无穷大时, ( log ⁡ T ) 2 T 2 \frac{(\log T)^2}{T^2} T2(logT)2趋向于0。因此,我们的选择 g ( T ) = 13 g(T)=13 g(T)=13是正确的。

综上所述, g ( T ) = 13 g(T)=13 g(T)=13满足题目中的条件。

3.为了解决这个问题,我们需要推导出多维线性模型在给定数据分布下的缩放定律。根据题目描述,我们有一个线性模型,其参数遵循特定的缩放定律。我们将通过以下步骤来解决这个问题:

步骤 1: 理解数据分布

数据分布 D \mathcal{D} D 是通过 Zipf 分布来选择输入向量的非零坐标,然后根据该坐标的值来生成输出 y y y。这意味着大部分的数据集中在较少的非零坐标上。

步骤 2: 定义损失函数

损失函数 L ‾ η , T \overline{\mathcal{L}}_{\eta,T} Lη,T 是在给定学习率 η \eta η 和训练步数 T T T 后,模型参数 w \mathbf{w} w 的测试损失的期望。

步骤 3: 推导缩放定律

我们需要找到 α \alpha α β \beta β,和 C C C 使得损失函数符合 L ‾ N , T ≈ A N α + B T β + C \overline{\mathcal{L}}_{N,T}≈\frac{A}{N^\alpha}+\frac{B}{T^\beta}+C LN,TNαA+TβB+C 的形式。

对于 α \alpha α 的推导:
  • 参数 N N N 表示模型考虑的输入向量的维度。由于数据分布的特性,大部分的权重不会接收到有效的梯度更新,因为它们对应的输入坐标为零。因此,增加 N N N 的数量不会显著改善模型的性能,但也不会损害它,因为只有少数权重会被更新。

  • Zipf 分布的特性意味着非零坐标的数量随着 N N N 的增加而减少。因此,我们可以预期 α \alpha α 大于 0,但小于 1,因为增加维度对于模型性能的提升是有上限的。

对于 β \beta β 的推导:
  • 参数 T T T 表示训练步数。随着训练步数的增加,模型将获得更多的机会来更新其权重,从而减少损失。因此,我们可以预期 β \beta β 大于 0。

  • 由于数据分布的特性,并不是每一步都会对所有权重进行有效更新。因此, β \beta β 可能不会是 1,而是小于 1 的某个值。

对于 C C C 的推导:
  • 常数 C C C 表示当 N N N T T T 趋于无穷大时,测试损失的最低值。这是由于数据本身的噪声和模型的能力限制导致的。

步骤 4: 确定 α \alpha α β \beta β,和 C C C

为了确定 α \alpha α β \beta β,和 C C C,我们需要进行以下分析:

  • 对于 α \alpha α:考虑到只有少数权重会被更新,我们可以假设 α \alpha α 在 0 和 1 之间。更具体地,由于 Zipf 分布的特性,我们可以假设 α \alpha α 接近于 1,但小于 1,因为随着 N N N 的增加,额外维度的边际贡献会减少。一个合理的猜测是 α = 1 b \alpha = \frac{1}{b} α=b1

  • 对于 β \beta β:考虑到每一步并不是对所有权重都进行有效更新,我们可以假设 β \beta β 小于 1。一个合理的猜测是 β = 1 2 \beta = \frac{1}{2} β=21,这是因为通常情况下,梯度下降的收敛速度与步数的平方根成反比。

  • 对于 C C C:这是数据噪声和模型表达能力限制的结果。在没有更多信息的情况下,我们无法精确确定 C C C,但可以假设它是一个正数。

步骤 5: 验证条件

我们需要验证 ϵ ( N , T ) \epsilon(N,T) ϵ(N,T) 的条件是否成立。这通常涉及到对 L ‾ N , T \overline{\mathcal{L}}_{N,T} LN,T 进行详细的分析,并证明它符合给定的缩放形式。这通常需要数学上的证明和/或实验验证。

综上所述,我们可以假设 α = 1 b \alpha = \frac{1}{b} α=b1 β = 1 2 \beta = \frac{1}{2} β=21 C C C 是一个正数。然而,为了得到精确的值,我们需要更深入的分析和实验数据。在实际应用中,这些参数通常是通过实验来确定的。


http://www.mrgr.cn/news/61764.html

相关文章:

  • EG2133 (三相独立半桥驱动芯片)的功能介绍
  • R语言装环境Gcc报错以及scater包的安装
  • 单例模式-如何保证全局唯一性?
  • 【大数据】Apache Superset:可视化开源架构
  • TypeScript Jest 单元测试 搭建
  • Spring实现通过工具类统一输出日志(不改变日志类信息)
  • 2024 Rust现代实用教程:1.3获取rust的库国内源以及windows下的操作
  • Infinity-MM数据集:一个包含 4000 万个样本的开源视觉语言模型的大规模多模态指令数据集。
  • 【征程 6 工具链性能分析与优化-1】编译器预估 perf 解读与性能分析
  • 矩阵压缩格式转换:COO转换CSC(C++)
  • Python世界:自动化办公Word之批量替换文本生成副本
  • nginx[新手用][模块化][高效]配置
  • 使用命令行上传 ipa 到 App Store(iTMSTransporter 3.3)
  • [JAVAEE] 面试题(二) - CAS 和 原子类
  • 计算机组成原理之高级语言程序与机器级代码之间的对应、高级语言和机器级代码的具体示例
  • 优化云成本,打造卓越体验,他们有话说
  • 微信小程序 - 获取汉字拼音首字母(汉字英文首字母)根据汉字查拼音,实现汉字拼音首字母获取,在小程序上实现汉字的拼音提取首字母!
  • [专有网络VPC]管理VPC配额
  • 智慧园区 | 数智引领,让智慧触手可及
  • String的长度有限,而我对你的思念却无限延伸
  • IDEA 打包首个java项目为jar包
  • 开箱即用!智能文档处理“百宝箱”
  • Faces in Things数据集: 由麻省理工学院、微软等联合发布,探索人类视觉错觉的新里程碑
  • Ollama运行本地LLM大模型简单教程:大显存很重要
  • 【Golang】Golang的数组和slice切片的区别
  • 数据集(Dataset)是指为特定目的而收集、整理、存储的数据集合