【Classifier Guidance/Classifier-free Guidance】理论推导与代码实现
Classifier Guidance论文链接:Diffusion Models Beat GANs on Image Synthesis
Classifier-free Guidance论文链接:Classifier-Free Diffusion Guidance
原理讲解:一个视频看懂什么是classifier guidance/classifier-free guidance
前言
扩散模型在提出之后是有两大优势的,第一是它生成效果比较好,保真度比较高;其次一点是它生成的这个图片的多样性要明显好于其他一些模型。但扩散模型在保真度的指标上,是一直没有超过Gan的,Gan模型在扩散模型提出之前一直是统治着图像生成领域,保真度用FID来表示,FID越小就说明这张图的保真度越高。直到OpenAI这篇论文的问世,扩散模型才真正的在图像生成领域击败了Gan。
在前边的几篇文章中我们已经学习了 DDPM 以及分别对其训练和采样过程进行改进的工作,不过这些方法都只能进行无条件生成,而无法对生成过程进行控制。我们这次学习的不再是无条件生成,而是通过一定方式对生成过程进行控制,比较常见的有两种:Classifier Guidance 与 Classifier-Free Guidance。
首先说功能上的区别:
classifier guidance | classifier free guidance | |
---|---|---|
是否需要重训模型 | 不需要,拿训好的 diffusion 就行 | 需要,得用这方法从头训 diffusion |
是否需要训别的模型 | 需要,得用加噪的图像训个分类模型 | 相当于不需要,文生图用 clip 就行 |
最终效果 | 可以控制生成的类别。分类模型能分多少类,用这方法就能控制生成多少类。 | 任何条件都可以控制 |
Classifier Guidance论文工作
优化网络结构(OpenAI在 Improved DDPM 的基础上继续进行了一些改进)
- 在模型的尺寸基本不变的前提下,提升模型的深度与宽度之比,相当于使用更深的模型;
- 增加多头注意力中 head 的数量;
- 使用多分辨率 attention,即 32x32、16x16 和 8x8,而不是只在 16x16 的尺度计算 attention;
- 使用 BigGAN 的残差模块来进行上下采样;
- 将残差连接的权重改为 1 2 \frac{1}{\sqrt{2}} 21。
~
经过一系列改进,DDPM 的性能超过了 GAN,文章把改进后的模型称为 Ablated Diffusion Model(ADM)。Classifier Guidance
- 上边的工程改进并不是本文要讨论的重点,我们言归正传来讲 Classifier Guidance。顾名思义,这种可控生成的方式引入了一个额外的分类器,具体来说,是使用分类器的梯度对生成的过程进行引导。
- 这个方法现在已经被广泛应用于stable diffusion这些图像生成的大模型,当然现在可能对于这些大模型来说,相比于Classifier guidance更多的是用到Classifier free guidance,因为原理会更简单一些。
一、问题分析
网上有很多讲这个Classifier Guidance的文章直接说引入了一个分类器,然后通过这个分类器的梯度,对采样方法进行偏移,从而得到更符合条件的一个图像,因为这本身是一种逆向思维,不太好理解。为什么会想到引用这个分类器,第一次看到的时候会比较难以理解,所以这里进行一个正向的推导。
首先我们还是先来分析下问题,现在扩散模型有两个问题:
- FID值偏高(图像的保真度相比于GAN来说,在指标上是比较偏高的)
- 多样性比较高(扩散模型多样性比较好,导致结果不可控)
说到这里肯定会想到条件生成, P ( X ) → P ( X ∣ y ) P(X) →P(X | y) P(X)→P(X∣y) y指的是class,即类别信息(在GAN的研究中以及证明条件生成会使FID指标更好),所以现在的问题从优化DDPM模型变成如何构造Conditional-DDPM。下面的推导将证明要实现CDDPM只要引入一个分类器,用这个分类器的梯度在生成时进行监督,所以就有了Classifier Guidance。
二、公式推导
Classifier Guidance本身是一个采样过程,直接使用已经训练好的DDPM模型,通过Classifier Guidance采样方式,进行优化,就能够得到更好,保证度更高的图片,并且这个生成结果是可控的。它有点像DDIM,这就是Classifier Guidance的一个优势,但Classifier-free Guidance是不一样的,Classifier-free Guidance是要重新进行训练的。
DDPM的逆向过程为 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt−1∣xt)现在我们希望给它变成带类别的,即
q ^ ( x t − 1 ∣ x t , y ) \hat{q}(x_{t-1}|x_t,y) q^(xt−1∣xt,y)通过贝叶斯可以得到
q ^ ( x t − 1 ∣ x t , y ) = q ^ ( x t − 1 ∣ x t ) q ^ ( y ∣ x t − 1 , x t ) q ^ ( y ∣ x t ) (1) \hat{q}(x_{t-1}|x_t,y) = \frac{\hat{q}(x_{t-1}|x_t) \hat{q}(y|x_{t-1},x_t)}{\hat{q}(y|x_t)}\tag1 q^(xt−1∣xt,y)=q^(y∣xt)q^(xt−1∣xt)q^(y∣xt−1,xt)(1)这个式子怎么来的呢,是因为
[ q ^ ( x t − 1 ∣ y ) = q ^ ( x t − 1 ) q ^ ( y ∣ x t − 1 ) q ^ ( y ) ] [\hat{q}(x_{t-1}|y) = \frac{\hat{q}(x_{t-1}) \hat{q}(y|x_{t-1})}{\hat{q}(y)}] [q^(xt−1∣y)=q^(y)q^(xt−1)q^(y∣xt−1)]然后给这个式子左右两边加 x t x_t xt即可得到式子1
现在我们的目标变成了求式子1,假设 q ^ \hat{q} q^服从 q ^ ( x t − 1 ∣ x t , y ) ∽ N ( μ ^ , σ 2 ^ ) \hat{q}(x_{t-1}|x_t,y)∽N(\hat{\mu},\hat{\sigma^2}) q^(xt−1∣xt,y)∽N(μ^,σ2^),这样就可以通过之前的采样方式来生成图像
式子1中我们要求右侧的三项,因为这项 q ^ ( y ∣ x t ) \hat{q}(y|x_t) q^(y∣xt)中不含 x t − 1 x_{t-1} xt−1,所以这一项是一个常数,不需要求。前面我们提到还是要使用DDPM训练的模型,即我们虽然不知道 p ^ \hat{p} p^的逆向过程,但是知道正向过程。
1. 已知项
因为我们需要把这个方法用在已经训练好的模型上,所以必须使得正向过程一致。
目前已知3个等式
q ^ ( x t ∣ x t − 1 , y ) = q ( x t ∣ x t − 1 ) ( 1 ) q ^ ( x 0 ) = q ( x 0 ) ( 2 ) q ^ ( x 1 : T ∣ x 0 , y ) = ∏ t = 1 n q ^ ( x t ∣ x t − 1 , y ) ∽ M a r k o v ( 3 ) \hat{q}(x_t|x_{t-1},y) = q(x_t|x_{t-1})~~(1) \\ \hat{q}(x_0) = q(x_0)~~(2) \\ \hat{q}(x_{1:T}|x_0,y) = \prod_{t=1}^{n}\hat{q}(x_t|x_{t-1},y) ∽ Markov~~(3) q^(xt∣xt−1,y)=q(xt∣xt−1) (1)q^(x0)=q(x0) (2)q^(x1:T∣x0,y)=t=1∏nq^(xt∣xt−1,y)∽Markov (3)
接下来我们已经解决的项将用绿色表示 q ^ ( y ∣ x t ) \color{green}{\hat{q}(y|x_t)} q^(y∣xt),正在解决的用红色表示 正在解决 \color{red}{\text{正在解决}} 正在解决
2. 求 q ^ ( x t − 1 ∣ x t ) \hat{q}(x_{t-1}|x_t) q^(xt−1∣xt)
q ^ ( x t − 1 ∣ x t , y ) = q ^ ( x t − 1 ∣ x t ) q ^ ( y ∣ x t , x t − 1 ) q ^ ( y ∣ x t ) (1) \hat{q}(x_{t-1}|x_t,y) = \frac{{\color{red}\hat{q}(x_{t-1}|x_t)} \hat{q}(y|x_{t},x_{t-1})}{\color{green}\hat{q}(y|x_t)}\tag1 q^(xt−1∣xt,y)=q^(y∣xt)q^(xt−1∣xt)q^(y∣xt,xt−1)(1)
使用贝叶斯得
q ^ ( x t − 1 ∣ x t ) = q ^ ( x t ∣ x t − 1 ) q ^ ( x t − 1 ) q ^ ( x t ) (2) {\color{red}\hat{q}(x_{t-1}|x_t)} = \frac{{\color{red}\hat{q}(x_t|x_{t-1})} \hat{q}(x_{t-1})}{\hat{q}(x_t)}\tag2 q^(xt−1∣xt)=q^(xt)q^(xt∣xt−1)q^(xt−1)(2)
步骤一:
先求 q ^ ( x t ∣ x t − 1 ) {\color{red}\hat{q}(x_t|x_{t-1})} q^(xt∣xt−1),在已知(1)中,已经知道了有y的等式,我们想要得到在上一个状态 𝑥 𝑡 − 1 𝑥_{𝑡−1} xt−1的条件下转移到当前状态 𝑥 𝑡 𝑥_𝑡 xt的无条件概率 q ^ ( x t ∣ x t − 1 ) {\color{red}\hat{q}(x_t|x_{t-1})} q^(xt∣xt−1)。为了消除类别标签 𝑦 对结果的影响,可以利用全概率公式对 𝑦 积分:
q ^ ( x t ∣ x t − 1 ) = ∫ y q ^ ( x t , y ∣ x t − 1 ) d y = ∫ y q ^ ( x t ∣ y , x t − 1 ) q ^ ( y ∣ x t − 1 ) d y = ∫ y q ( x t ∣ x t − 1 ) q ^ ( y ∣ x t − 1 ) d y = q ( x t ∣ x t − 1 ) ∫ y q ^ ( y ∣ x t − 1 ) d y = q ( x t ∣ x t − 1 ) {\color{red}\hat{q}(x_t|x_{t-1})} = \int_{y} \hat{q}(x_t,y|x_{t-1})dy \\ = \int_{y}{\hat{q}(x_t|y,x_{t-1})}{\hat{q}(y|x_{t-1})}dy \\ = \int_y{q(x_t|x_{t-1})}{\hat{q}(y|x_{t-1})}dy \\ = q(x_t|x_{t-1}) \int_{y} {\hat{q}(y|x_{t-1})}dy \\ = q(x_t|x_{t-1}) q^(xt∣xt−1)=∫yq^(xt,y∣xt−1)dy=∫yq^(xt∣y,xt−1)q^(y∣xt−1)dy=∫yq(xt∣xt−1)q^(y∣xt−1)dy=q(xt∣xt−1)∫yq^(y∣xt−1)dy=q(xt∣xt−1)
- 第一行中y在变量里,我们已知的y在条件里,用条件概率的定义式转化为第二行;
- 第二行到第三行变换用了已知(1)的等式;
- 在第三行中,由于是对y积分,前一项就是常数项,可以提出来,后面一项因为变量只含y,不管条件是什么,全概率公式加起来为1,即 ∫ y q ^ ( y ∣ x t − 1 ) d y = 1 \int_{y} {\hat{q}(y|x_{t-1})}dy=1 ∫yq^(y∣xt−1)dy=1
等式(2)现在可以表示为
q ^ ( x t − 1 ∣ x t ) = q ^ ( x t ∣ x t − 1 ) q ^ ( x t − 1 ) q ^ ( x t ) (2) {\color{red}\hat{q}(x_{t-1}|x_t)} = \frac{{\color{green}\hat{q}(x_t|x_{t-1})} {\color{red}\hat{q}(x_{t-1})}}{\color{red}\hat{q}(x_t)}\tag2 q^(xt−1∣xt)=q^(xt)q^(xt∣xt−1)q^(xt−1)(2)
目前已知:
q ^ ( x t ∣ x t − 1 , y ) = q ( x t ∣ x t − 1 ) ( 1 ) q ^ ( x 0 ) = q ( x 0 ) ( 2 ) q ^ ( x 1 : T ∣ x 0 , y ) = ∏ t = 1 n q ^ ( x t ∣ x t − 1 , y ) ∽ M a r k o v ( 3 ) q ^ ( x t ∣ x t − 1 ) = q ( x t ∣ x t − 1 ) ( 4 ) \hat{q}(x_t|x_{t-1},y) = q(x_t|x_{t-1})~~(1) \\ \hat{q}(x_0) = q(x_0)~~(2) \\ \hat{q}(x_{1:T}|x_0,y) = \prod_{t=1}^{n}\hat{q}(x_t|x_{t-1},y) ∽ Markov~~(3) \\ \hat{q}(x_t|x_{t-1}) = q(x_t|x_{t-1})~~(4) q^(xt∣xt−1,y)=q(xt∣xt−1) (1)q^(x0)=q(x0) (2)q^(x1:T∣x0,y)=t=1∏nq^(xt∣xt−1,y)∽Markov (3)q^(xt∣xt−1)=q(xt∣xt−1) (4)
步骤二:
求 q ^ ( x t ) {\color{red}\hat{q}(x_t)} q^(xt),继续用全概率公式
q ^ ( x t ) = ∫ x 0 : t − 1 q ^ ( x 0 : t ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ^ ( x 0 ) q ^ ( x 1 : t ∣ x 0 ) d x 0 : t − 1 (3) {\color{red}\hat{q}(x_t)} = \int_{x_0:t-1} {\hat{q}(x_{0:t})}dx_{0:t-1} \\ = \int_{x_0:t-1} {\hat{q}(x_{0})\hat{q}(x_{1:t}|x_0)}dx_{0:t-1} \tag3 q^(xt)=∫x0:t−1q^(x0:t)dx0:t−1=∫x0:t−1q^(x0)q^(x1:t∣x0)dx0:t−1(3)
- x 0 : t − 1 ⇨ x 0 , x 1 . . . x t − 1 x_{0:t-1}⇨ x_0, x_1... x_{t-1} x0:t−1⇨x0,x1...xt−1这是一个联合的随机变量;
- 根据已知(3), x 0 x_0 x0在条件的位置,所以使用条件概率变换为第二行;
用全概率公式继续求 q ^ ( x 1 : t ∣ x 0 ) \color{red}\hat{q}(x_{1:t}|x_0) q^(x1:t∣x0):
q ^ ( x 1 : t ∣ x 0 ) = ∫ y q ^ ( x 1 : t , y ∣ x 0 ) d y = ∫ y q ^ ( x 1 : t ∣ y , x 0 ) q ^ ( y ∣ x 0 ) d y = ∫ y q ^ ( y ∣ x 0 ) ∏ 1 t q ^ ( x t ∣ x t − 1 , y ) d y = ∫ y q ^ ( y ∣ x 0 ) ∏ 1 t q ( x t ∣ x t − 1 ) d y = ∫ y q ^ ( y ∣ x 0 ) q ( x 1 : t ∣ x 0 ) d y = q ( x 1 : t ∣ x 0 ) {\color{red}\hat{q}(x_{1:t}|x_0)} = \int_{y}{\hat{q}(x_{1:t},y| x_0)dy} \\ = \int_{y}{\hat{q}(x_{1:t}|y, x_0)\hat{q}(y|x_0)dy} \\ = \int_{y}\hat{q}(y|x_0) \prod_{1}^{t}\hat{q}(x_t|x_{t-1},y)dy \\ = \int_{y}\hat{q}(y|x_0) \prod_{1}^{t} q(x_t|x_{t-1})dy \\ = \int_{y}\hat{q}(y|x_0) q(x_{1:t}|x_0)dy \\ = q(x_{1:t}|x_0) q^(x1:t∣x0)=∫yq^(x1:t,y∣x0)dy=∫yq^(x1:t∣y,x0)q^(y∣x0)dy=∫yq^(y∣x0)1∏tq^(xt∣xt−1,y)dy=∫yq^(y∣x0)1∏tq(xt∣xt−1)dy=∫yq^(y∣x0)q(x1:t∣x0)dy=q(x1:t∣x0)
- 根据已知(3),对y做积分并将y移至条件的位置,并做替换得到第三行;
- 根据已知(1),做替换得到第四行;
- ∏ 1 t q ( x t ∣ x t − 1 ) \prod_{1}^{t}q(x_t|x_{t-1}) ∏1tq(xt∣xt−1)的连乘就是 q ( x 1 : t ∣ x 0 ) q(x_{1:t}|x_0) q(x1:t∣x0),变换得第五行;
- 根据全概率公式得 ∫ y q ^ ( y ∣ x 0 ) = 1 \int_{y}\hat{q}(y|x_0) = 1 ∫yq^(y∣x0)=1
现在等式(3)如下:
q ^ ( x t ) = ∫ x 0 : t − 1 q ^ ( x 0 : t ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ^ ( x 0 ) q ^ ( x 1 : t ∣ x 0 ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ( x 0 ) q ( x 1 : t ∣ x 0 ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ( x 0 : t ) d x 0 : t − 1 = q ( x t ) {\color{red}\hat{q}(x_t)} = \int_{x_0:t-1} {\hat{q}(x_{0:t})}dx_{0:t-1} \\ = \int_{x_0:t-1} {\hat{q}(x_{0})\hat{q}(x_{1:t}|x_0)}dx_{0:t-1} \\ = \int_{x_0:t-1} {q(x_{0}) q(x_{1:t}|x_0)}dx_{0:t-1} \\ = \int_{x_0:t-1} {q(x_{0:t})}dx_{0:t-1} = q(x_t) q^(xt)=∫x0:t−1q^(x0:t)dx0:t−1=∫x0:t−1q^(x0)q^(x1:t∣x0)dx0:t−1=∫x0:t−1q(x0)q(x1:t∣x0)dx0:t−1=∫x0:t−1q(x0:t)dx0:t−1=q(xt)
至此我们解决了等式(2),即
q ^ ( x t − 1 ∣ x t , y ) = q ^ ( x t − 1 ∣ x t ) q ^ ( y ∣ x t , x t − 1 ) q ^ ( y ∣ x t ) \hat{q}(x_{t-1}|x_t,y) = \frac{{\color{green}\hat{q}(x_{t-1}|x_t)} \hat{q}(y|x_{t},x_{t-1})}{\color{green}\hat{q}(y|x_t)} q^(xt−1∣xt,y)=q^(y∣xt)q^(xt−1∣xt)q^(y∣xt,xt−1)
3. 求 q ^ ( y ∣ x t , x t − 1 ) \hat{q}(y|x_{t},x_{t-1}) q^(y∣xt,xt−1)
q ^ ( x t − 1 ∣ x t , y ) = q ^ ( x t − 1 ∣ x t ) q ^ ( y ∣ x t , x t − 1 ) q ^ ( y ∣ x t ) \hat{q}(x_{t-1}|x_t,y) = \frac{{\color{green}\hat{q}(x_{t-1}|x_t)} \color{red}\hat{q}(y|x_{t},x_{t-1})}{\color{green}\hat{q}(y|x_t)} q^(xt−1∣xt,y)=q^(y∣xt)q^(xt−1∣xt)q^(y∣xt,xt−1)
目前已知:
q ^ ( x t ∣ x t − 1 , y ) = q ( x t ∣ x t − 1 ) ( 1 ) q ^ ( x 0 ) = q ( x 0 ) ( 2 ) q ^ ( x 1 : T ∣ x 0 , y ) = ∏ t = 1 n q ^ ( x t ∣ x t − 1 , y ) ∽ M a r k o v ( 3 ) q ^ ( x t ∣ x t − 1 ) = q ( x t ∣ x t − 1 ) ( 4 ) q ^ ( x t − 1 ∣ x t ) = q ( x t − 1 ∣ x t ) ( 5 ) \hat{q}(x_t|x_{t-1},y) = q(x_t|x_{t-1})~~(1) \\ \hat{q}(x_0) = q(x_0)~~(2) \\ \hat{q}(x_{1:T}|x_0,y) = \prod_{t=1}^{n}\hat{q}(x_t|x_{t-1},y) ∽ Markov~~(3) \\ \hat{q}(x_t|x_{t-1}) = q(x_t|x_{t-1})~~(4) \\ \hat{q}(x_{t-1}|x_{t}) = q(x_{t-1}|x_{t})~~(5) q^(xt∣xt−1,y)=q(xt∣xt−1) (1)q^(x0)=q(x0) (2)q^(x1:T∣x0,y)=t=1∏nq^(xt∣xt−1,y)∽Markov (3)q^(xt∣xt−1)=q(xt∣xt−1) (4)q^(xt−1∣xt)=q(xt−1∣xt) (5)
用贝叶斯得:
q ^ ( y ∣ x t , x t − 1 ) = q ^ ( x t ∣ y , x t − 1 ) q ^ ( y ∣ x t − 1 ) q ^ ( x t ∣ x t − 1 ) = q ^ ( x t ∣ y , x t − 1 ) q ^ ( y ∣ x t − 1 ) q ^ ( x t ∣ x t − 1 ) = q ( x t ∣ x t − 1 ) q ^ ( y ∣ x t − 1 ) q ^ ( x t ∣ x t − 1 ) = q ^ ( x t ∣ x t − 1 ) q ^ ( y ∣ x t − 1 ) q ^ ( x t ∣ x t − 1 ) = q ^ ( y ∣ x t − 1 ) {\color{red}\hat{q}(y|x_{t},x_{t-1})} = \frac{\hat{q}(x_t|y,x_{t-1})\hat{q}(y|x_{t-1})}{\hat{q}(x_t|x_{t-1})} \\ = \hat{q}(x_t|y,x_{t-1}) \frac{\hat{q}(y|x_{t-1})}{\hat{q}(x_t|x_{t-1})} \\ = q(x_t|x_{t-1}) \frac{\hat{q}(y|x_{t-1})}{\hat{q}(x_t|x_{t-1})} \\ = \hat{q}(x_t|x_{t-1}) \frac{\hat{q}(y|x_{t-1})}{\hat{q}(x_t|x_{t-1})} \\ = \hat{q}(y|x_{t-1}) q^(y∣xt,xt−1)=q^(xt∣xt−1)q^(xt∣y,xt−1)q^(y∣xt−1)=q^(xt∣y,xt−1)q^(xt∣xt−1)q^(y∣xt−1)=q(xt∣xt−1)q^(xt∣xt−1)q^(y∣xt−1)=q^(xt∣xt−1)q^(xt∣xt−1)q^(y∣xt−1)=q^(y∣xt−1)
- 根据已知(1)得第三行;
- 根据已知(4)得第四行,化简得第五行。
4. 整合
在等式中:
q ^ ( x t − 1 ∣ x t , y ) = q ^ ( x t − 1 ∣ x t ) q ^ ( y ∣ x t , x t − 1 ) q ^ ( y ∣ x t ) = Z q ( x t − 1 ∣ x t ) q ^ ( y ∣ x t − 1 ) \hat{q}(x_{t-1}|x_t,y) = \frac{{\color{green}\hat{q}(x_{t-1}|x_t)} \color{red}\hat{q}(y|x_{t},x_{t-1})}{\color{green}\hat{q}(y|x_t)} \\ = Z {\color{blue}q(x_{t-1}|x_t)} {\color{blue}\hat{q}(y|x_{t-1}}) q^(xt−1∣xt,y)=q^(y∣xt)q^(xt−1∣xt)q^(y∣xt,xt−1)=Zq(xt−1∣xt)q^(y∣xt−1)
- q ^ ( y ∣ x t ) {\color{green}\hat{q}(y|x_t)} q^(y∣xt)是常数,用Z来表示;
- q ^ ( x t − 1 ∣ x t ) = q ( x t − 1 ∣ x t ) {\color{green}\hat{q}(x_{t-1}|x_t)} = q(x_{t-1}|x_t) q^(xt−1∣xt)=q(xt−1∣xt),即DDPM的采样过程;
- q ^ ( y ∣ x t − 1 ) \hat{q}(y|x_{t-1}) q^(y∣xt−1)就是一个分类器的分类结果;
- Z作为normalize factor,是一个正则化的因子,为了保证 q ^ ( x t − 1 ∣ x t , y ) \hat{q}(x_{t-1}|x_t,y) q^(xt−1∣xt,y)符合概率分布。
三、采样方式
现在需要用两个模型来分别预测右边的两项
q ^ ( x t − 1 ∣ x t , y ) = Z q ( x t − 1 ∣ x t ) q ^ ( y ∣ x t − 1 ) = Z P θ ( x t ∣ x t + 1 ) P φ ( y ∣ x t ) (4) \hat{q}(x_{t-1}|x_t,y) = Z {q(x_{t-1}|x_t)} {\hat{q}(y|x_{t-1})}=Z P_θ(x_t|x_{t+1})P_φ(y|x_{t})\tag4 q^(xt−1∣xt,y)=Zq(xt−1∣xt)q^(y∣xt−1)=ZPθ(xt∣xt+1)Pφ(y∣xt)(4)
- P θ ( x t ∣ x t + 1 ) → q ( x t − 1 ∣ x t ) P_θ(x_t|x_{t+1})→q(x_{t-1}|x_t) Pθ(xt∣xt+1)→q(xt−1∣xt),即DDPM的模型,服从高斯分布 N ( μ , Σ 2 ) N(\mu,Σ^2) N(μ,Σ2)
- P φ ( y ∣ x t ) → q ^ ( y ∣ x t − 1 ) P_φ(y|x_{t})→\hat{q}(y|x_{t-1}) Pφ(y∣xt)→q^(y∣xt−1),即分类器模型
我们有了这两个模型之后,怎么让这两个模型之间有一定的联系,然后得到最终的结果 。假设此时刻是t+1,而这个时刻是没有办法得到 x t x_t xt的,没有 x t x_t xt就没有办法得到 P φ ( y ∣ x t ) P_φ(y|x_{t}) Pφ(y∣xt)确切的值,所以很难通过数学计算得到 q ^ ( x t − 1 ∣ x t , y ) \hat{q}(x_{t-1}|x_t,y) q^(xt−1∣xt,y),因此只能通过预估得到 P φ ( y ∣ x t ) P_φ(y|x_{t}) Pφ(y∣xt)
1. 计算 P θ ( x t ∣ x t + 1 ) P_θ(x_t|x_{t+1}) Pθ(xt∣xt+1)
取对数 l o g P θ ( x t ∣ x t + 1 ) = − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + C 1 (5) logP_θ(x_t|x_{t+1}) = -\frac{1}{2}(x_t-\mu)^TΣ^{-1}(x_t-\mu)+C_1\tag5 logPθ(xt∣xt+1)=−21(xt−μ)TΣ−1(xt−μ)+C1(5)
公式如下:
2. 计算 P φ ( y ∣ x t ) P_φ(y|x_{t}) Pφ(y∣xt)
x t x_{t} xt是不知道的,需要 x t x_{t} xt进行预估。但我们知道,扩散模型的方差 Σ Σ Σ一般都很小,在图像上表示为,其他位置都趋近于0,只有在 x = μ x=\mu x=μ附近概率比较高。那么就可以认为 x t ≈ μ x_t≈\mu xt≈μ,
取对数 l o g P φ ( y ∣ x t ) logP_φ(y|x_{t}) logPφ(y∣xt),并在 x t ≈ μ x_t≈\mu xt≈μ处进行泰勒展开得
l o g P φ ( y ∣ x t ) = l o g P φ ( y ∣ x t ) ∣ x t = μ + ( x t − μ ) T ▽ x t l o g P φ ( y ∣ x t ) ∣ x t = μ + o ( ∣ ∣ x t − μ ∣ ∣ 2 ) ≈ ( x t − μ ) T ▽ x t l o g P φ ( y ∣ x t ) ∣ x t = μ + C 2 (6) logP_φ(y|x_{t}) = logP_φ(y|x_{t})|_{x_{t=\mu}}+(x_t-\mu)^T\triangledown x_t logP_φ(y|x_{t})|_{x_{t=\mu}} +o(||x_t-\mu||^2) \\ ≈ (x_t-\mu)^T\triangledown x_t logP_φ(y|x_{t})|_{x_{t=\mu}}+C_2\tag6 logPφ(y∣xt)=logPφ(y∣xt)∣xt=μ+(xt−μ)T▽xtlogPφ(y∣xt)∣xt=μ+o(∣∣xt−μ∣∣2)≈(xt−μ)T▽xtlogPφ(y∣xt)∣xt=μ+C2(6)
解释:
在多维空间中,泰勒展开和梯度的定义是自然延伸的。以下是关于泰勒展开为这种形式的原因,以及梯度表示为何采用倒三角(即 ∇)符号的解释。
在泰勒展开中出现 𝑇 次方(即转置)是因为我们要确保中间项是一个标量,而不是一个向量。让我们更深入地理解这个问题。
故
l o g P θ ( x t ∣ x t + 1 ) + l o g P φ ( y ∣ x t ) = − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + ( x t − μ ) T ▽ x t l o g P φ ( y ∣ x t ) ∣ x t = μ + C = − 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + 1 2 g T Σ g + C ( ▽ = g ) = − 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + C ′ = l o g P ( x t − 1 ∣ x t , y ) (7) logP_θ(x_t|x_{t+1}) + logP_φ(y|x_{t}) =-\frac{1}{2}(x_t-\mu)^TΣ^{-1}(x_t-\mu)+(x_t-\mu)^T\triangledown x_t logP_φ(y|x_{t})|_{x_{t=\mu}}+C \\ = -\frac{1}{2}(x_t-\mu-Σg)^TΣ^{-1}(x_t-\mu-Σg)+\frac{1}{2}g^TΣg+C~~~ (\triangledown=g) \\ = -\frac{1}{2}(x_t-\mu-Σg)^TΣ^{-1}(x_t-\mu-Σg)+C' \\ = logP(x_{t-1}|x_t,y) \tag7 logPθ(xt∣xt+1)+logPφ(y∣xt)=−21(xt−μ)TΣ−1(xt−μ)+(xt−μ)T▽xtlogPφ(y∣xt)∣xt=μ+C=−21(xt−μ−Σg)TΣ−1(xt−μ−Σg)+21gTΣg+C (▽=g)=−21(xt−μ−Σg)TΣ−1(xt−μ−Σg)+C′=logP(xt−1∣xt,y)(7)
- 因为这项 1 2 g T Σ g \frac{1}{2}g^TΣg 21gTΣg中没有 x t x_t xt,所以是这项是常数,化简为第三行;
- 令 P θ ( x t ∣ x t + 1 ) P φ ( y ∣ x t ) = P ( x t − 1 ∣ x t , y ) P_θ(x_t|x_{t+1})P_φ(y|x_{t})=P(x_{t-1}|x_t,y) Pθ(xt∣xt+1)Pφ(y∣xt)=P(xt−1∣xt,y)得第四行
3. 终极目标
l o g P ( x t − 1 ∣ x t , y ) = − 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + C ′ (8) logP(x_{t-1}|x_t,y)= -\frac{1}{2}(x_t-\mu-Σg)^TΣ^{-1}(x_t-\mu-Σg)+C' \tag8 logP(xt−1∣xt,y)=−21(xt−μ−Σg)TΣ−1(xt−μ−Σg)+C′(8)
观察等式(8)和等式(5),可以发现这两个式子是一样的,故可知,等式(8)也服从高斯分布 N ∽ ( μ + Σ g , Σ 2 ) N∽(\mu+Σg,Σ^2) N∽(μ+Σg,Σ2),即
P ( x t − 1 ∣ x t , y ) ∽ N ( μ + Σ g , Σ 2 ) (9) P(x_{t-1}|x_t,y)∽N(\mu+Σg,Σ^2)\tag9 P(xt−1∣xt,y)∽N(μ+Σg,Σ2)(9)
所以 x t = μ + Σ g + Σ ε x_t = \mu+Σg +Σε xt=μ+Σg+Σε
- μ \mu μ需要用预测的 ε ε ε来求
- g = ▽ = ▽ x t l o g P φ ( y ∣ x t ) ∣ x t = μ g=\triangledown=\triangledown x_t logP_φ(y|x_{t})|_{x_{t=\mu}} g=▽=▽xtlogPφ(y∣xt)∣xt=μ,这个东西实际上就是分类器的梯度
至此,就可以得到条件生成的结果,这个条件包含在了分类器里
四、泛用采样
这步的作用是为了摆脱对 Σ Σ Σ的依赖,因为 Σ Σ Σ有等于0的情况。所以就借用了score的概念。
在这个分布中,直接求解并不容易,但可以使用 score-based models 的方式进行求解(对 score-based models 不熟悉的读者可以先阅读一个视频看懂score-based模型的底层原理和[生成模型新方向]: score-based generative models 这两个讲解作为前置知识),也就是利用 score function。
这里简单描述一下:
- score的定义: 一个概率分布 P ( x ) P(x) P(x),score就是 l o g P ( x ) logP(x) logP(x)的梯度 ▽ x l o g P ( x ) \triangledown xlogP(x) ▽xlogP(x)
- score-based generative models的论文中,提到需要训练一个模型来预测score。score based这些论文认为扩散模型实际上是在预测score
P ( x t − 1 ∣ x t , y ) ∽ N ( μ + Σ g , Σ 2 ) (9) P(x_{t-1}|x_t,y)∽N(\mu+Σg,Σ^2)\tag9 P(xt−1∣xt,y)∽N(μ+Σg,Σ2)(9)
在DDPM中,有这样一个分布
P ( x t ∣ x 0 ) ∽ N ( a ‾ t x 0 , 1 − a ‾ t ) P(x_t|x_0)∽\mathcal{N}(\sqrt{\overline{a}_t}x_{0}, 1 - \overline{a}_t) P(xt∣x0)∽N(atx0,1−at)
取log:
l o g P ( x t ∣ x 0 ) = − ( x t − a ‾ t ⋅ x 0 ) 2 2 ( 1 − a ‾ t ) (10) logP(x_t | x_0) = -\frac{(x_t - \sqrt{\overline{a}_t} \cdot x_0)^2}{2(1 -\overline{a}_t)}\tag{10} logP(xt∣x0)=−2(1−at)(xt−at⋅x0)2(10)
对 x t x_t xt求梯度得:
▽ x t l o g P ( x t ∣ x 0 ) = − x t − a ‾ t ⋅ x 0 1 − a ‾ t (11) \triangledown_{x_t}logP(x_t | x_0) = -\frac{x_t - \sqrt{\overline{a}_t} \cdot x_0}{1 -\overline{a}_t}\tag{11} ▽xtlogP(xt∣x0)=−1−atxt−at⋅x0(11)
根据公式: x t = 1 − a ‾ t × E + a ‾ t x 0 x_t =\sqrt{1 - \overline{a}_t} × \mathcal{E} + \sqrt{\overline{a}_t}x_{0} xt=1−at×E+atx0得
E = x t − a ‾ t x 0 1 − a ‾ t \mathcal{E} = \frac{x_t - \sqrt{\overline{a}_t}x_{0}}{\sqrt{1 - \overline{a}_t}} E=1−atxt−atx0
故等式(11)可以改写为:
▽ x t l o g P ( x t ∣ x 0 ) = − x t − a ‾ t ⋅ x 0 1 − a ‾ t = − E 1 − a ‾ t (12) \triangledown_{x_t}logP(x_t | x_0) = -\frac{x_t - \sqrt{\overline{a}_t} \cdot x_0}{1 -\overline{a}_t} = - \frac{\mathcal{E}}{\sqrt{1 - \overline{a}_t}}\tag{12} ▽xtlogP(xt∣x0)=−1−atxt−at⋅x0=−1−atE(12)
上面这个等式的 E \mathcal{E} E是抽样得到的,现在我们把它改为预测的
▽ x t l o g P θ ( x t ) = − E θ ^ ( x t ) 1 − a ‾ t (13) \triangledown_{x_t}logP_θ(x_t) = - \frac{\hat{\mathcal{E}_θ}(x_t)}{\sqrt{1 - \overline{a}_t}}\tag{13} ▽xtlogPθ(xt)=−1−atEθ^(xt)(13)
因为模型实际上就是给定 x t , t x_t,t xt,t,然后用Unet预测 E \mathcal{E} E,所以 ▽ x t l o g P θ ( x t ) \triangledown_{x_t}logP_θ(x_t) ▽xtlogPθ(xt)实际上就是预测的score
根据等式(4)(13)得:
▽ x t l o g P θ ( x t ) P φ ( y ∣ x t ) = ▽ l o g P θ ( x t ) + ▽ l o g P φ ( y ∣ x t ) = − E θ ^ ( x t ) 1 − a ‾ t + g (14) \triangledown_{x_t}logP_θ(x_t)P_φ(y|x_{t}) = \triangledown logP_θ(x_t) + \triangledown logP_φ(y|x_{t}) \\ = - \frac{\hat{\mathcal{E}_θ}(x_t)}{\sqrt{1 - \overline{a}_t}} +g\tag{14} ▽xtlogPθ(xt)Pφ(y∣xt)=▽logPθ(xt)+▽logPφ(y∣xt)=−1−atEθ^(xt)+g(14)
等式右边也可以解释为:第一项是无条件生成的 score function,第二项是分类器的梯度,这个梯度表示的是从噪声指向条件 y 的方向,把这个方向加到无条件生成的 score 上,就可以让降噪的方向也指向 y 的方向。
将 ▽ x t l o g P θ ( x t ) P φ ( y ∣ x t ) \triangledown_{x_t}logP_θ(x_t)P_φ(y|x_{t}) ▽xtlogPθ(xt)Pφ(y∣xt)看作一个整体可知
▽ x t l o g P θ ( x t ) P φ ( y ∣ x t ) = − E ^ ′ 1 − a ‾ t (15) \triangledown_{x_t}logP_θ(x_t)P_φ(y|x_{t}) = - \frac{\hat{\mathcal{E}}'}{\sqrt{1 - \overline{a}_t}}\tag{15} ▽xtlogPθ(xt)Pφ(y∣xt)=−1−atE^′(15)
根据(14)(15)可知
E ^ ′ = E θ ^ − 1 − a ‾ t g (15) \hat{\mathcal{E}}' = \hat{\mathcal{E}_θ} - {\sqrt{1 - \overline{a}_t}}g\tag{15} E^′=Eθ^−1−atg(15)
这个式子表示:
- E θ ^ \hat{\mathcal{E}_θ} Eθ^是DDPM模型预测的噪音
- 此时已经不需要用 μ + Σ g \mu+Σg μ+Σg,就摆脱了对 Σ Σ Σ的依赖
- 只需要用预测出来的噪音+一个非常小的扰动,就可以得到一个全新的 E \mathcal{E} E
- 然后使用DDPM\DDIM的采样方式就可以得到最终的生成结果
五、伪代码分析
Algorithm 1 2 都是采样过程
-
Algorithm 1 就是前面将的 + Σ g +Σg +Σg形式
- 首先从(0,1)高斯分布中提取一个 x T x_T xT;
- 然后一步一步迭代;
- 通过预测得到 μ , Σ \mu,Σ μ,Σ;
- 用分类器计算分类器的梯度,并计算 + Σ g +Σg +Σg;
- s是调节系数,作者发现s大的时候,模型会更关注分类器,这样得到的图片质量越高,但是多样性越低,因为更多的趋向于给定的标签y;
- 最后采样
-
Algorithm 2 就是改变 E ^ \hat{\mathcal{E}} E^
- 使用的是DDIM的采样方式
再看一个分类引导的伪代码加深理解
classifier_model = ... # 加载一个训好的图像分类模型
y = 1 # 生成类别为 1 的图像,假设类别 1 对应“狗”这个类
guidance_scale = 7.5 # 控制类别引导的强弱,越大越强
input = get_noise(...) # 从高斯分布随机取一个跟输出图像一样 shape 的噪声图for t in tqdm(scheduler.timesteps):# 用 unet 推理,预测噪声with torch.no_grad():noise_pred = unet(input, t).sample# 用 input 和预测出的 noise_pred 和 x_t 计算得到 x_t-1input = scheduler.step(noise_pred, t, input).prev_sample# classifier guidance 步骤class_guidance = classifier_model.get_class_guidance(input, y)input += class_guidance * guidance_scals # 把梯度加上去
这里的关键是倒数第二行:
class_guidance = classifier_model.get_class_guidance(input, y)
这一行做的事情是,把当前生成的图 input 和我们想要的类别 y 一起喂给分类模型。
分类模型分出来的类别不一定是 y,但我们想让它尽量靠近 y,所以计算预测值和 y 的 loss。
然后,这里像训练分类模型时的梯度反向传播一样,计算梯度。不同的是训分类模型时要得到权重参数的梯度,方便梯度更新,而这里只需要保留对 input 的梯度就好。
最后把计算出来的梯度根据 guidance_scale 加到图像上。
六、分类器训练
- 训练数据:和扩散模型用的数据一样
- 分类器训练的数据是由噪音的,噪音强度时随机的,目的是为了保证分类器在任何时刻都能有作用,可以给出一个好的梯度方向
七、代码实现
虽然推导看起来依然很复杂,但需要改动的代码其实非常少,获得梯度之后再用梯度更新一下就可以了。这里给出一些关键的代码片段。
1. 获取分类器梯度
获取分类器对 x t \mathbf{x}_t xt的梯度其实也比较直接,可以直接使用 Pytorch 的自动求导工具。先让 x \mathbf{x} x带上梯度,然后输入分类器获取概率分布,最后再提取出 y 对应的一项计算梯度。这里有一个比较神奇的点,就是一般来说分类模型的输入都是不计算梯度的,不过这里的输入也是带梯度的,感觉类似于 DETR 里的 learnable query:
import torch
import torch.nn.functional as Fdef classifier_guidance(x: torch.Tensor,t: torch.Tensor,y: torch.Tensor,classifier: torch.nn.Module
):# 开启自动求导上下文,确保计算梯度with torch.enable_grad():# 分离 x 的梯度信息并启用对其的梯度计算,得到新的张量 x_with_grad# 这样可以在后续步骤中计算 x_with_grad 的梯度,而不会影响 x 本身x_with_grad = x.detach().requires_grad_(True)# 使用分类器对带有梯度的 x 和时间步 t 进行前向传播,得到未归一化的分数 (logits)logits = classifier(x_with_grad, t)# 对 logits 应用 log softmax 函数,得到 log 概率分布# 这样可以提高数值稳定性,并便于后续的梯度计算log_prob = F.log_softmax(logits, dim=-1)# 从 log 概率分布中选取目标类别 y 对应的 log 概率值# 使用 y.view(-1) 将 y 转换为一维,以便作为索引selected = log_prob[range(len(logits)), y.view(-1)]# 计算目标类别 y 的 log 概率总和对 x_with_grad 的梯度# 该梯度将作为指导信号,用于引导生成模型朝向目标类别 y 的方向return torch.autograd.grad(selected.sum(), x_with_grad)[0]
这一部分也就相当于 ▽ x t l o g P ( y ∣ x t ) \triangledown_{x_t}logP(y | x_t) ▽xtlogP(y∣xt)这一项,这在上一章的两种解释中都是相通的。而如何使用得到的梯度对采样过程进行引导,会根据推导不同有两种实现方式。
2. 第一种引导的实现
这种方法相对比较好理解,就是用梯度朝着指向 y 的方向对生成结果进行一个修正:
for timestep in tqdm(scheduler.timesteps):# 在当前时间步中预测噪声分量with torch.no_grad():noise_pred = unet(images, timestep).sample # 使用 U-Net 预测噪声,禁用梯度计算以节省内存# 根据预测出的噪声和当前时间步,计算出去噪后的 x_{t-1}(即生成过程的前一状态)images = scheduler.step(noise_pred, timestep, images).prev_sample# 计算分类器的梯度指导,帮助生成模型生成符合目标类别的样本guidance = classifier_guidance(images, timestep, y, classifier)# 将分类器梯度(guidance)加到去噪后的样本上,以引导生成过程更加符合目标类别images += guidance_scale * guidance
3. 第二种引导的实现
这种实现方式和 openai 的官方实现相同,也就是直接按照原论文的 x t ∽ N ( μ + s Σ g , Σ ) x_t ∽ N(\mu+sΣg,Σ) xt∽N(μ+sΣg,Σ)得到结果:
# 先预测均值和方差
mean, variance = p_mean_var['mean'], p_mean_var['variance'] # 提取预测出的均值和方差,用于更新生成样本# 计算分类器的梯度,用于引导生成过程
guidance = classifier_guidance(images, timestep, y, classifier)# 根据原始的均值和方差,以及分类器提供的梯度指导信号,计算出新的均值
# 新的均值通过加入 guidance 来调整生成样本的方向,使其更符合目标类别
new_mean = mean.float() + guidance_scale * variance * guidance.float()
在这份代码中,p_mean_var 就是模型预测出的均值和方差。因为官方实现基于 Improved DDPM 修改,所以方差也是可学习的。根据公式可以计算出新的均值,得到新的均值和方差后,再从对应的高斯分布中进行采样即可。
Classifier-Free Guidance论文工作
Classifier Guidance 只能用分类模型控制生成的类别。如果分类模型是分 80 类,那么 Classifier Guidance 也只能引导 diffusion 模型生成固定的 80 类,多一类都不好使。
Classifier Free Guidance 这个方法就厉害了,虽然得重训 diffusion 模型,但训好了以后可以直接起飞,没有类别数量的限制。
一、 伪代码分析
1. 训练
- puncond:无条件的y;
- 从数据集中采样一张图片和对应的分类;
有puncond的概率把c复赋值为空(非条件),1-puncond的概率保留c,不赋值的话就是条件;通过这种形式把条件和非条件放到一起训练;
(这是唯一多的一步,其他基本都一样)- 取一个时间样本λ;
- 抽样随机噪音;
- 通过均值标准差计算 z λ z_λ zλ;
- 算误差。
2. 采样
- w是条件因子,c可以给定,也可以不给定
- 步骤3,只有当有条件c时,才会条件生成,否则就是非条件生成
以文本条件为例(也就是文生图)
clip_model = ... # 加载一个官方的 clip 模型text = "一只狗" # 输入文本
text_embeddings = clip_model.text_encode(text) # 编码条件文本
empty_embeddings = clip_model.text_encode("") # 编码空文本
text_embeddings = torch.cat(empty_embeddings, text_embeddings) # 把它俩 concate 到一起作为条件input = get_noise(...) # 从高斯分布随机取一个跟输出图像一样 shape 的噪声图for t in tqdm(scheduler.timesteps):# 用 unet 推理,预测噪声with torch.no_grad():# 这里同时预测出了有文本的和空文本的图像噪声noise_pred = unet(input, t, encoder_hidden_states=text_embeddings).sample# Classifier-Free Guidance 引导noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) # 拆成无条件和有条件的噪声# 把【“无条件噪声”指向“有条件噪声”】看做一个向量,根据 guidance_scale 的值放大这个向量# (当 guidance_scale = 1 时,下面这个式子退化成 noise_pred_text)noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)# 用预测出的 noise_pred 和 x_t 计算得到 x_t-1input = scheduler.step(noise_pred, t, input).prev_sample
画个图,大约是下面这么个意思:
上图中的两个圈就表示我们用 unet 算出的 noise_pred_uncond(无条件噪声) 和 noise_pred_text(“一只狗”文本条件噪声) 。
红色的箭头表示从“无条件”到“‘一只狗’条件”的向量,给它乘上 guidance_scale,通过调节 guidance_scale 的数值大小,我们就能控制文本条件噪声贴近文本语义的程度。
如果我们想让生成的图更遵循“一只狗”这个文本语义,就把 guidance_scale 设大一点,生成的图像会更贴近“一只狗”的文本语义,但是多样性也会降低。反之如果我们想让生成的图像更多样丰富一些,就把 guidance_scale 设小一点。通常来讲这个值被设为 7.5 比较合适。
总结一下,Classifier-Free Guidance 需要在训练过程中同时训练模型的两个能力,一个是有条件生成,一个是无条件生成。
个人的理解:无条件生成是有条件生成的基础,生成的质量和多样性是由无条件生成的分数保证的,如果只有有条件生成而没有无条件生成,那么生成效果可能不佳。
显然 Classifier-Free Guidance 效果更好些,即能生成无穷多的图像类别,又不需要重新训练一个基于噪声的分类模型。所以现在最常见的都是 Classifier-Free Guidance。
二、U-Net模型如何融入语义信息
1. CrossAttention
首先通过embedding layer或者是clip等模型将文本转换为文本特征嵌入,即text embedding过程。
之后text embedding和原本模型中的image进行融合。最常见的方式是利用CrossAttention(stable diffusion采用的就是这个方法)。
具体来说是把text embedding作为注意力机制中的key和value,把原始图片表征作为query。相当于计算每张图片和对应句子中单词的一个相似度得分,把得分转换成单词的权重,[权重乘以单词的embedding]加和作为最终的特征。
import torch
import torch.nn as nn
from einops import rearrangeclass SpatialCrossAttention(nn.Module):def __init__(self,dim,context_dim,heads=4,dim_head=32):super(SpatialCrossAttention,self).__init__()self.scale = dim_head**-0.5self.heads = headshidden_dim = dim_head*headsself.proj_in = nn.Conv2d(dim,context_dim,kernel_size=1,stride=1,padding=0)self.to_q = nn.Linear(context_dim, hidden_dim, bias=False)self.to_k = nn.Linear(context_dim, hidden_dim, bias=False)self.to_v = nn.Linear(context_dim, hidden_dim, bias=False)self.to_out = nn.Conv2d(hidden_dim, dim, 1)def forward(self,x,context=None):x_q = self.proj_in(x)b,c,h,w = x_q.shapex_q = rearrange(x_q,"b c h w -> b (h w) c")if context is None:context = x_qif context.ndim == 2:context = rearrange(context,"b c -> b () c")q = self.to_q(x_q)k = self.to_k(context)v = self.to_v(context)q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=self.heads), (q, k, v))sim = torch.einsum('b i d, b j d -> b i j',q,k)*self.scaleattn = sim.softmax(dim=-1)out = torch.einsum('b i j, b j d -> b i d',attn,v)out = rearrange(out, '(b h) n d -> b n (h d)',h=self.heads)out = rearrange(out,'b (h w) c -> b c h w', h=h, w=w)out = self.to_out(out)return outCrossAttn = SpatialCrossAttention(dim=32,context_dim=1024)
x = torch.randn(8,32,256,256)
context = torch.randn(8,1024)
out = CrossAttn(x,context)
2. Channel-wise attention
融入方式与time-embedding的融入方式相同。基于channel-wise的融入粒度没有CrossAttention细,一般使用类别数量有限的特征融入,如时间embedding、类别embedding。语义信息的融入更推荐使用CrossAttention。
#如果参数为“default”,则时间融合方式为相加#如果参数为“scale_shift”,则时间融合方式为scale_shift方法if temb_channels is not None:if self.time_embedding_norm == "default":self.time_emb_proj = linear_cls(temb_channels, out_channels)elif self.time_embedding_norm == "scale_shift":self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)else:raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)if self.time_embedding_norm == "default":if temb is not None:hidden_states = hidden_states + tembhidden_states = self.norm2(hidden_states)
elif self.time_embedding_norm == "scale_shift":if temb is None:raise ValueError(f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}")time_scale, time_shift = torch.chunk(temb, 2, dim=1)hidden_states = self.norm2(hidden_states)hidden_states = hidden_states * (1 + time_scale) + time_shift
else:hidden_states = self.norm2(hidden_states)
三、采样过程的代码实现
类似于 Improved DDPM。一个示意性的代码如下:
for timestep in tqdm(scheduler.timesteps):# 预测噪声with torch.no_grad():noise_pred = unet(images, timestep, condition).sample# Classifier-Free Guidancenoise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)noise_pred = (1.0 - guidance_scale) * noise_pred_uncond + guidance_scale * noise_pred_condimages = scheduler.step(noise_pred, timestep, images).prev_sample
参考资料:
https://lichtung612.github.io/posts/3-diffusion-models/
https://zhuanlan.zhihu.com/p/660518657