从扩散模型开始的生成模型范式演变--SDE(1)
SDE是在分数生成模型的基础上,将加噪过程扩展时连续、无限状态,使得扩散模型的正向、逆向过程通过SDE表示。在前文讲解DDPM后,本文主要讲解SDE扩散模型原理。本文内容主要来自B站Up主deep_thoughts分享视频Score Diffusion Model分数扩散模型理论与完整PyTorch代码详细解读,在其讲解的英文文稿上做了翻译和适当的个人理解调整,感兴趣读者可以将本文作为该视频的参考读物结合学习。
文章目录
- 分数匹配网络
- 何为分数
- 基于朗之万动力学采样
- 分数匹配
- 去噪分数匹配
- Noise Conditonal Score Networks
- NCSN定义
- 通过分数匹配学习NCSN
- 基于退火朗之万动力学进行NCSN推理
- Stochastic Differential Equations--SDEs
- SDE下的扰动数据
- 用于采样的逆SDE过程
- 去噪分数匹配估计逆SDE过程所需分数
- 分数网络设计建议
分数匹配网络
何为分数
假设我们的数据集由来自一个未知数据分布 p d a t a ( x ) p_{data}(x) pdata(x)的 N N N个独立同分布样本 { x i ∈ R D } i = 1 N \{x_i \in R^D\}^N_{i=1} {xi∈RD}i=1N组成。定义概率密度 p ( x ) p(x) p(x)的 s c o r e score score为 ∇ x log p ( x ) \nabla_x \log p(x) ∇xlogp(x),即一个概率密度的对数似然关于变量 x x x的梯度就是该概率密度的分数。分数网络 s θ : R D → R D s_{\theta}: R^D \to R^D sθ:RD→RD是一个用 θ \theta θ表示参数的神经网络,其被训练用于去近似 p d a t a ( x ) p_{data}(x) pdata(x)的分数。生成建模的目标是基于数据集训练可以生成符合真实数据分布 p d a t a ( x ) p_{data}(x) pdata(x)新样本的模型。分数生成建模框架有两个核心要素:分数匹配和朗之万动力学;分数匹配的作用就是训练一个分数网络可以预测出真实数据分布的分数,而朗之万动力学则可以基于分数采样出新样本。
基于朗之万动力学采样
朗之万动力学能仅使用概率密度 p ( x ) p(x) p(x)的分数函数 ∇ x log p ( x ) \nabla_x \log p(x) ∇xlogp(x)生成符合 p ( x ) p(x) p(x)的样本。给定一个固定步长 ϵ > 0 \epsilon > 0 ϵ>0,一个初始值 x ~ 0 ∼ π ( x ) \tilde{x}_0 \sim \pi(x) x~0∼π(x), π \pi π是一个先验分布,朗之万动力学基于以下公式递归采样:
x ~ t = x ~ t − 1 + ϵ 2 ∇ x log p ( x ~ t − 1 ) + ϵ z t , (1) \tilde{x}_t = \tilde{x}_{t-1} + \frac{\epsilon}{2}\nabla_x \log p(\tilde{x}_{t-1})+\sqrt{\epsilon}z_t,\tag1 x~t=x~t−1+2ϵ∇xlogp(x~t−1)+ϵzt,(1)
其中 z t ∼ N ( 0 , I ) z_t \sim N(0,I) zt∼N(0,I)。当 ϵ → 0 \epsilon \to 0 ϵ→0, T → ∞ T \to \infty T→∞, x ~ \tilde{x} x~的分布等价于 p ( x ) p(x) p(x),在一些正则条件下 x ~ T \tilde{x}_T x~T成为 p ( x ) p(x) p(x)的一个样本。当 ϵ > 0 \epsilon > 0 ϵ>0, T < ∞ T < \infty T<∞,需要通过Metropolis-Hastings更新去纠正公式(1)中的错误,但实际上该过程进本被忽略。在SDE工作中,当 ϵ \epsilon ϵ很小, T T T很大时,也假设公式(1)存在的问题可以被忽略,即基于公式(1)采样出的样本时符合 p ( x ) p(x) p(x)。
公式(1)仅需要 ∇ x log p ( x ) \nabla_x \log p(x) ∇xlogp(x)就能采样,因此为了获得 p d a t a ( x ) p_{data}(x) pdata(x)的样本,首先要训练一个分数网络 s θ ≈ ∇ x log p ( x ) s_{\theta} \approx \nabla_x \log p(x) sθ≈∇xlogp(x),然后基于朗之万动力学使用公式(1)生成新样本。这是分数生成建模框架的关键。
分数匹配
分数匹配初始被设计用于从一个未知数据分布的独立同分布样本中学习非归一化统计模型,在SDE中将其重新应用于分数估计。使用分数匹配,能直接训练一个分数网络 s θ s_{\theta} sθ去估计 ∇ x log p d a t a ( x ) \nabla_x \log p_{data}(x) ∇xlogpdata(x),而不需要训练一个模型去估计 p d a t a ( x ) p_{data}(x) pdata(x)。与电影的分数匹配使用方法不同,SDE中不使用一个能量模型的梯度作为分数网络,可以避免由于高阶梯度导致额外计算。优化目标是 1 2 E p d a t a [ ∣ ∣ s θ − ∇ x log p d a t a ( x ) ∣ ∣ 2 2 ] \frac{1}{2}E_{p_{data}}[||s_{\theta}-\nabla_x \log p_{data}(x)||^2_2] 21Epdata[∣∣sθ−∇xlogpdata(x)∣∣22],其与以下公式等价
E p d a t a [ t r ( ∇ x s θ ( x ) ) + 1 2 ∣ ∣ s θ ( x ) ∣ ∣ 2 2 ] (2) E_{p_{data}}[tr(\nabla_x s_{\theta}(x))+\frac{1}{2}||s_{\theta}(x)||^2_2]\tag2 Epdata[tr(∇xsθ(x))+21∣∣sθ(x)∣∣22](2)
其中 ∇ x s θ ( x ) \nabla_x s_{\theta}(x) ∇xsθ(x)表示 s θ ( x ) s_{\theta}(x) sθ(x)的雅可比行列式,但深度模型和高i为数据时会导致此雅可比行列式计算困难,在大范围分数匹配场景下的有两种绕过雅可比行列式计算的方法,常用的一种就是去噪分数匹配。
去噪分数匹配
去噪分数匹配是一种完全规避求解 t r ( ∇ x s θ ( x ) ) tr(\nabla_x s_{\theta}(x)) tr(∇xsθ(x))的分数匹配变体。去噪分数匹配先用一个预定义的噪声扰动数据点 x x x,得到扰动后的数据分数 q σ ( x ~ ∣ x ) q_{\sigma}(\tilde{x}|x) qσ(x~∣x);然后使用分数匹配去估计扰动后数据分布 q σ ( x ~ ) = Δ ∫ q σ ( x ~ ∣ x ) p d a t a ( x ) d x q_{\sigma}(\tilde{x}) \overset{\Delta}{=} \int q_{\sigma}(\tilde{x}|x) p_{data}(x) dx qσ(x~)=Δ∫qσ(x~∣x)pdata(x)dx的分数。优化目标等价于以下公式:
1 2 E q σ ( x ~ ∣ x ) p d a t a ( x ) [ ∣ ∣ s θ ( x ~ ) − ∇ x ~ log q σ ( x ~ ∣ x ) ∣ ∣ 2 2 ] , (3) \frac{1}{2}E_{q_{\sigma}(\tilde{x}|x)p_{data}(x)}[||s_{\theta}(\tilde{x})-\nabla_{\tilde{x}} \log q_{\sigma}(\tilde{x}|x)||^2_2],\tag3 21Eqσ(x~∣x)pdata(x)[∣∣sθ(x~)−∇x~logqσ(x~∣x)∣∣22],(3)
有理由相信基于上述优化目标训练后的最优模型 s θ ∗ ( x ) = ∇ x log p σ ( x ) s_{\theta^*}(x) = \nabla_x \log p_{\sigma}(x) sθ∗(x)=∇xlogpσ(x)。那么当噪声非常小,即 q σ ( x ) ≈ p d a t a ( x ) q_{\sigma}(x) \approx p_{data}(x) qσ(x)≈pdata(x)时,有 s θ ∗ ( x ) = ∇ x log p σ ( x ) ≈ ∇ x log p d a t a ( x ) s_{\theta^*}(x)=\nabla_x \log p_{\sigma}(x) \approx \nabla_x \log p_{data}(x) sθ∗(x)=∇xlogpσ(x)≈∇xlogpdata(x)
此处去噪分数匹配给数据加噪与下文中的NCSN中的添加噪声不同,其不是条件,而是为了规避计算雅可比行列式的一种手段。
Noise Conditonal Score Networks
去噪分数匹配存在一个问题,即在数据密度较低的区域分数估计不准确(如下图所示),进而后续的采样结果也不会准确。为了解决低密度分数估计不准的问题,引入了噪声条件分数网络,即NSCN。
NSCN是带条件的分数网络,条件就是噪声。在上图数据基础上加噪后(下图所示),整体数据密度变大,低密度区域变得很少,就不再存在分数估计不准的问题。
那么应该如何加噪呢?在上述去噪分数匹配中提到,当加入的噪声很小时,才能基于公式(3)训练得到目标分数网络,但是当噪声很小时,加噪后的数据密度增幅不大,就会存在分数估计不准的问题;当加入的噪声很大时,虽然分数估计变准,但是加噪后的数据与原始数据分布区别就会很大,就违背了去噪分数匹配中的假设。为了兼顾两边,应该使用不同量级的噪声进行加噪,然后使用同一个分数网络去估计不同噪声量级下的分数。当分数网络训练完毕后,首先生成噪声量级大下分布的样本,然后逐渐降低噪声量级,最终基于朗之万采样平滑地生成符合目标分数地样本。
NCSN定义
假设有一组正等比数列 { σ i } i = 1 L \{\sigma_i\}^L_{i=1} {σi}i=1L,满足 σ 1 σ 2 = ⋯ = σ L − 1 σ L > 1 \frac{\sigma_1}{\sigma_2}= \cdots = \frac{\sigma_{L-1}}{\sigma_L} > 1 σ2σ1=⋯=σLσL−1>1。 q σ ( x ) = Δ ∫ p d a r a ( t ) N ( x ∣ t , σ 2 I ) d t q_{\sigma}(x) \overset{\Delta}{=} \int p_{dara}(t) N(x|t,\sigma^2I)dt qσ(x)=Δ∫pdara(t)N(x∣t,σ2I)dt表示扰动后的数据分布。 σ i \sigma_i σi就表征不同的噪声等级, σ 1 \sigma_1 σ1足够大,使其能修复数据密度低导致的分数估计不准的问题; σ L \sigma_L σL足够小,可认为扰动后的数据分布等价于原始数据分布。目标是使用一个条件分数网络去同时估计所有不同噪声等级扰动分布的分数,即 ∀ σ ∈ { σ i } i = 1 L , s θ ( x , σ ) ≈ ∇ x q σ ( x ) \forall \sigma \in \{\sigma_i\}^L_{i=1},s_{\theta}(x,\sigma) \approx \nabla_xq_{\sigma}(x) ∀σ∈{σi}i=1L,sθ(x,σ)≈∇xqσ(x)。注意,当 x ∈ R D , s θ ( x , σ ) ∈ R D x\in R^D,s_{\theta}(x,\sigma) \in R^D x∈RD,sθ(x,σ)∈RD, s θ ( x , σ ) s_{\theta}(x,\sigma) sθ(x,σ)即为噪声条件分数网络,NCSN。
通过分数匹配学习NCSN
假设噪声扰动后的数据分布为 q σ ( x ~ ∣ x ) = N ( x ~ ∣ x , σ 2 I ) q_{\sigma}(\tilde{x}|x)=N(\tilde{x}|x,\sigma^2I) qσ(x~∣x)=N(x~∣x,σ2I),则该分布的分数函数为 ∇ x ~ log q σ ( x ~ ∣ x ) = − x ~ − x σ 2 \nabla_{\tilde{x}}\log q_{\sigma}(\tilde{x}|x)=-\frac{\tilde{x}-x}{\sigma^2} ∇x~logqσ(x~∣x)=−σ2x~−x。对于给定 σ \sigma σ,去噪分数匹配优化目标如下:
l ( θ ; σ ) = Δ 1 2 E p d a t a E x ~ ∼ N ( x , σ 2 I ) [ ∣ ∣ s θ ( x , σ ) + x ~ − x σ 2 ∣ ∣ 2 2 ] , (4) l(\theta;\sigma) \overset{\Delta}{=} \frac{1}{2}E_{p_{data}}E_{\tilde{x} \sim N(x,\sigma^2I)}[||s_{\theta}(x,\sigma)+\frac{\tilde{x}-x}{\sigma^2}||^2_2],\tag4 l(θ;σ)=Δ21EpdataEx~∼N(x,σ2I)[∣∣sθ(x,σ)+σ2x~−x∣∣22],(4)
利用公式(4)对多有噪声等级计算后取平均值得到最终目标
L ( θ ; { σ i } i = 1 L ) = Δ 1 L ∑ i = 1 L λ ( σ i ) l ( θ ; σ i ) , (5) L(\theta;\{\sigma_i\}^L_{i=1}) \overset{\Delta}{=} \frac{1}{L}\sum_{i=1}^L \lambda(\sigma_i)l(\theta;\sigma_i),\tag5 L(θ;{σi}i=1L)=ΔL1i=1∑Lλ(σi)l(θ;σi),(5)
其中 λ ( σ i ) > 0 \lambda(\sigma_i)>0 λ(σi)>0是一个基于 σ i \sigma_i σi的系数函数。假设 s θ ( x , σ ) s_{\theta}(x,\sigma) sθ(x,σ)有足够的能力,训练完成之后得到的最后分数网络对于所有噪声等级都有, s θ ∗ ( x , σ i ) = ∇ x log q σ i ( x ) s_{\theta^*}(x,\sigma_i)=\nabla_x \log q_{\sigma_i}(x) sθ∗(x,σi)=∇xlogqσi(x),因为公式(5)是 L L L个去噪分数匹配目标的组合。
公式(5)中不同噪声等级的 λ ( ⋅ ) \lambda(\cdot) λ(⋅)设置有很多方法。理想情况下,希望对于所有噪声等级的 λ ( σ i ) l ( θ ; σ i ) \lambda(\sigma_i)l(\theta;\sigma_i) λ(σi)l(θ;σi)能大致在相同的数量等级。从经验上观察得到当分数网络训练到最优时, ∣ ∣ s θ ( x , σ ) ∣ ∣ 2 ∝ 1 σ ||s_{\theta}(x,\sigma)||_2 \varpropto \frac{1}{\sigma} ∣∣sθ(x,σ)∣∣2∝σ1(分数网络的二阶范数正比于噪声的倒数),这促使将 λ \lambda λ设置为 λ ( σ ) = σ 2 \lambda(\sigma)=\sigma^2 λ(σ)=σ2。因为在这个设置下,有 λ ( σ ) l ( θ ; σ ) = σ 2 l ( θ ; σ ) = 1 2 E p d a t a E x ~ ∼ N ( x , σ 2 I ) [ ∣ ∣ σ s θ ( x , σ ) + x ~ − x σ ∣ ∣ 2 2 ] \lambda(\sigma)l(\theta;\sigma)=\sigma^2l(\theta;\sigma)=\frac{1}{2}E_{p_{data}}E_{\tilde{x} \sim N(x,\sigma^2I)}[||\sigma s_{\theta}(x,\sigma)+\frac{\tilde{x}-x}{\sigma}||^2_2] λ(σ)l(θ;σ)=σ2l(θ;σ)=21EpdataEx~∼N(x,σ2I)[∣∣σsθ(x,σ)+σx~−x∣∣22],因为 x ~ − x σ ∼ N ( 0 , I ) \frac{\tilde{x}-x}{\sigma} \sim N(0,I) σx~−x∼N(0,I)并且 ∣ ∣ σ s θ ( x , σ ) ∣ ∣ 2 ∝ 1 ||\sigma s_{\theta}(x,\sigma)||_2 \varpropto 1 ∣∣σsθ(x,σ)∣∣2∝1,可以看到 λ ( σ ) l ( θ ; σ ) \lambda(\sigma)l(\theta;\sigma) λ(σ)l(θ;σ)的量级与 σ \sigma σ无关,即所有噪声等级的数量级一致。
基于退火朗之万动力学进行NCSN推理
当NCSN s θ ( x , σ ) s_{\theta}(x,\sigma) sθ(x,σ)被训练后可使用退火朗之万进行采样,如下表所示。采样从一些固定先验分布初始化实例开始,如从均匀分布中初始化 x ~ 0 \tilde{x}_0 x~0;然后使用步长 α i \alpha_i αi进行朗之万动力学从分布 q σ 1 ( x ) q_{\sigma_1}(x) qσ1(x)中采样。在每个噪声等级下均采样 T T T次,当进入到下一个噪声等级时,用上一个噪声等级的最后一个采样样本作为作为当前噪声等级的起始样本,按此规律迭代采样下去。最后,使用朗之万动力学从 q σ L ( x ) q_{\sigma_L(x)} qσL(x)中采样,即当 σ L ≈ 0 \sigma_L \approx 0 σL≈0时, q σ L ( x ) q_{\sigma_L}(x) qσL(x)近似于 p d a t a ( x ) p_{data}(x) pdata(x)。
将 σ i \sigma_i σi设置为正比于 σ i 2 \sigma_i^2 σi2,动机是固定朗之万动力学中信噪比 α i s θ ( x , σ i ) 2 α i z \frac{\alpha_is_{\theta}(x,\sigma_i)}{2\sqrt{\alpha_i}z} 2αizαisθ(x,σi)的数量级。
为了证明退火朗之万动力学的有效性,通过一个小型实验进行验证,实验设置 { σ i } i = 1 L \{\sigma_i\}^L_{i=1} {σi}i=1L为等比数列, L = 10 , σ 1 = 10 , σ 10 = 0.1 L=10, \sigma_1=10,\sigma_{10}=0.1 L=10,σ1=10,σ10=0.1,训练时使用EMA更新参数更稳定。
Stochastic Differential Equations–SDEs
上述NCSN虽然在多个噪声等级下进行了训练,但本质还是有限步骤加噪,如果将噪声等级推过到有限步骤,或者连续情况,就能使用随机微分方程/SDE对分数生成模型建模进行统一表征。
SDE下的扰动数据
为了基于分数模型生成样本,需要使用一个扩散过程逐渐缓慢加噪将原始数据转换为一个随机噪声,然在逆扩散过程通过估计分布的分数进行样本采样。一个扩散过程是与布朗运行相似的随机过程。假设 { x ( t ) ∈ R d } t = 0 T \{x(t) \in R^d\}^T_{t=0} {x(t)∈Rd}t=0T是一个扩散过程,有连续的时间变量 t ∈ [ 0 , T ] t \in [0,T] t∈[0,T]引导,该扩散过程可以用一个SDE表达,用以下公式表示
d x = f ( x , t ) d t + g ( t ) d w , (6) dx=f(x,t)dt+g(t)dw,\tag6 dx=f(x,t)dt+g(t)dw,(6)
随机微分方程是指含有随机参数或随机过程或随机初始值或随机边界值的微分方程;公式(6)中的 w w w表示标准布朗运动,使得SDE成立;布朗运动具有增量独立性、增量服从高斯分布、轨迹连续。公式(6)中 f ( ⋅ , t ) : R d → R d f(\cdot,t):R^d \to R^d f(⋅,t):Rd→Rd称为SDE的漂移系数, g ( t ) ∈ R g(t) \in R g(t)∈R成为扩散系数。可以将 SDE 理解为常微分方程 (ODE) 的随机推广。后续使用 p t ( x ) p_t(x) pt(x)表示 x ( t ) x(t) x(t)的分布。
基于分数的生成建模中一般有两个边界,即 x ( 0 ) ∼ p 0 x(0) \sim p_0 x(0)∼p0是已有数据集中独立同分布样本组成的原始数据分布, x ( T ) ∼ p T x(T) \sim p_T x(T)∼pT是一个有解析解且易采样的先验分布。整个扩散过程应该足够大,使得在噪声扰动后, p T p_T pT不再依赖 p 0 p_0 p0。
用于采样的逆SDE过程
从一个先验分布 p T p_T pT中样本开始,将SDE过程逆向进行,能逐渐获得一个符合原始数据分布 p 0 p_0 p0的样本,此过程就是逆SDE采样。至关重要的是,逆过程是时间倒流的扩散过程。它由以下逆SDE 公式表达
d x = [ f ( x , t ) − g 2 ( t ) ∇ x log p t ( x ) ] d t + g ( t ) d w ˉ , (7) dx=[f(x,t)-g^2(t)\nabla_x \log p_t(x)]dt+g(t)d\bar{w},\tag7 dx=[f(x,t)−g2(t)∇xlogpt(x)]dt+g(t)dwˉ,(7)
其中 w ˉ \bar{w} wˉ是布朗运动的逆向过程, d t dt dt就表示逆时方向上的时间微分。一旦知道了正向SDE过程中漂移系数、扩散系数和时间区间 t ∈ [ 0 , T ] t \in [0,T] t∈[0,T]内数据分布 p t ( x ) p_t(x) pt(x)的分数,该逆向SDE过程就能计算。
基于SDE的正向过程和逆向过程如下图所示,其实与DDPM中加噪过程是一致的,只是加噪过程变成了连续过程,数据变化开始用微分 d x dx dx表示。
去噪分数匹配估计逆SDE过程所需分数
基于公式(7),需要使用依赖时间分数函数 ∇ x log p t ( x ) \nabla_x \log p_t(x) ∇xlogpt(x)去执行逆向SDE过程,然后可使用数值解法从先验分布 p T p_T pT样本迭代计算得到符合 p 0 p_0 p0的样本。故需要训练一个时间相关的分数模型 s θ ( x , t ) s_{\theta}(x,t) sθ(x,t)去近似分数函数 ∇ x log p t ( x ) \nabla_x \log p_t(x) ∇xlogpt(x)。基于去噪分数匹配中公式(3),可以推导出时间相关分数模型的优化目标,如下所示
m i n θ E t ∼ U ( 0 , T ) [ λ ( t ) E x ( 0 ) ∼ p 0 E x ( t ) ∼ p 0 t ( x ( t ) ∣ x ( 0 ) ) [ ∣ ∣ s θ ( x ( t ) , t ) − ∇ x ( t ) log p 0 t ( x ( t ) ∣ x ( 0 ) ) ∣ ∣ 2 2 ] ] , (8) \underset{\theta}{min}E_{t \sim \mathcal{U}(0,T)}[\lambda(t)E_{x(0) \sim p_0}E_{x(t) \sim p_{0t}(x(t)|x(0))}[||s_{\theta}(x(t),t)-\nabla_{x(t)} \log p_{0t}(x(t)|x(0))||^2_2]],\tag8 θminEt∼U(0,T)[λ(t)Ex(0)∼p0Ex(t)∼p0t(x(t)∣x(0))[∣∣sθ(x(t),t)−∇x(t)logp0t(x(t)∣x(0))∣∣22]],(8)
其中 U ( 0 , T ) \mathcal{U}(0,T) U(0,T)是在区间 [ 0 , T ] [0,T] [0,T]上的均匀分布,实际使用时会进行归一化,即是区间 [ 0 , 1 ] [0,1] [0,1]上的均匀分布,分布中的数值是一个连续量; p 0 t ( x ( t ) ∣ x ( 0 ) ) p_{0t}(x(t)|x(0)) p0t(x(t)∣x(0))表示从 x ( 0 ) x(0) x(0)到 x ( t ) x(t) x(t)的转移概率, λ ( t ) \lambda(t) λ(t)表示一个正数权重。
在上述目标中, x ( 0 ) x(0) x(0)的期望可以通过 p 0 p0 p0的数据样本的经验平均值来估计。当SDE中的漂移系数 f ( x , t ) f(x,t) f(x,t)是一个仿射变换,即可通过重参数采样就能从 x ( 0 ) x(0) x(0)和时间 t t t采样出样本 x ( 0 ) x(0) x(0),即 x ( t ) x(t) x(t)的期望可通过从 p 0 t ( x ( t ) ∣ x ( 0 ) ) p_{0t}(x(t)|x(0)) p0t(x(t)∣x(0))采样来估计。为了匹配不同时间点上的分数匹配损失的数量级,权重函数 λ ( t ) \lambda(t) λ(t)一般设置为 1 E [ ∣ ∣ ∇ x ( t ) log p 0 t ( x ( t ) ∣ x ( 0 ) ) ∣ ∣ 2 2 ] \frac{1}{E[||\nabla_{x(t)} \log p_{0t}(x(t)|x(0))||^2_2]} E[∣∣∇x(t)logp0t(x(t)∣x(0))∣∣22]1,与上述NCSN中一致,如果噪声符合高斯分布, 1 E [ ∣ ∣ ∇ x ( t ) log p 0 t ( x ( t ) ∣ x ( 0 ) ) ∣ ∣ 2 2 ] \frac{1}{E[||\nabla_{x(t)} \log p_{0t}(x(t)|x(0))||^2_2]} E[∣∣∇x(t)logp0t(x(t)∣x(0))∣∣22]1就等于方差 σ t 2 \sigma^2_t σt2。
分数网络设计建议
时间相关的分数模型的网络架构没有任何限制,除了它们的输出应该与输入具有相同的维度,并且它们应该以时间为条件。以下实在结构设计上的几条有用建议
- 使用Unet作为分数网络 s θ ( x , t ) s_{\theta}(x,t) sθ(x,t)的主干通常效果更好
- 对于时间信息可以不只是使用简单的embedding隐射到分数网络中,而是使用高斯随机特征。具体做法是先从分布 N ( 0 , s 2 I ) N(0,s^2I) N(0,s2I)中采样一个 ω \omega ω,该分数是固定的,即 s s s是固定的,训练过程中不用学习。对于一个时间步 t t t,对应的高斯随机特征定义为 [ sin ( 2 π ω t ) ; cos ( 2 π ω t ) ] [\sin(2\pi \omega t);\cos(2\pi \omega t)] [sin(2πωt);cos(2πωt)], [ a ⃗ ; b ⃗ ] [\vec{a} ; \vec{b}] [a;b]表示 a ⃗ \vec{a} a和 b ⃗ \vec{b} b的拼接。该高斯随机特征可以用作时间步长 t 的编码,以便分数网络可以通过合并该编码来以 t 为条件。
- 可以对Unet的输出用一个系数 1 E [ ∣ ∣ ∇ x ( t ) log p 0 t ( x ( t ) ∣ x ( 0 ) ) ∣ ∣ 2 2 ] \frac{1}{\sqrt{E[||\nabla_{x(t)} \log p_{0t}(x(t)|x(0))||^2_2]}} E[∣∣∇x(t)logp0t(x(t)∣x(0))∣∣22]1进行缩放。这是为了使最优的分数网络 s θ ( x , t ) s_{\theta}(x,t) sθ(x,t)的二阶范数近似于 E [ ∣ ∣ ∇ x ( t ) log p 0 t ( x ( t ) ∣ x ( 0 ) ) ∣ ∣ 2 ] E[||\nabla_{x(t)} \log p_{0t}(x(t)|x(0))||_2] E[∣∣∇x(t)logp0t(x(t)∣x(0))∣∣2],此缩放过程有助于捕捉真实分数的范数。
- 训练时使用EMA更新模型权重,采用时使用EMA权重,此方法采样的样本质量更高。