当前位置: 首页 > news >正文

【Classifier Guidance/Classifier-free Guidance】理论推导与代码实现

Classifier Guidance论文链接:Diffusion Models Beat GANs on Image Synthesis
Classifier-free Guidance论文链接:Classifier-Free Diffusion Guidance
原理讲解:一个视频看懂什么是classifier guidance/classifier-free guidance

前言

扩散模型在提出之后是有两大优势的,第一是它生成效果比较好,保真度比较高;其次一点是它生成的这个图片的多样性要明显好于其他一些模型。但扩散模型在保真度的指标上,是一直没有超过Gan的,Gan模型在扩散模型提出之前一直是统治着图像生成领域,保真度用FID来表示,FID越小就说明这张图的保真度越高。直到OpenAI这篇论文的问世,扩散模型才真正的在图像生成领域击败了Gan。

在前边的几篇文章中我们已经学习了 DDPM 以及分别对其训练和采样过程进行改进的工作,不过这些方法都只能进行无条件生成,而无法对生成过程进行控制。我们这次学习的不再是无条件生成,而是通过一定方式对生成过程进行控制,比较常见的有两种:Classifier Guidance 与 Classifier-Free Guidance。

首先说功能上的区别:

classifier guidanceclassifier free guidance
是否需要重训模型不需要,拿训好的 diffusion 就行需要,得用这方法从头训 diffusion
是否需要训别的模型需要,得用加噪的图像训个分类模型相当于不需要,文生图用 clip 就行
最终效果可以控制生成的类别。分类模型能分多少类,用这方法就能控制生成多少类。任何条件都可以控制

Classifier Guidance论文工作

  1. 优化网络结构(OpenAI在 Improved DDPM 的基础上继续进行了一些改进)

    • 在模型的尺寸基本不变的前提下,提升模型的深度与宽度之比,相当于使用更深的模型;
    • 增加多头注意力中 head 的数量;
    • 使用多分辨率 attention,即 32x32、16x16 和 8x8,而不是只在 16x16 的尺度计算 attention;
    • 使用 BigGAN 的残差模块来进行上下采样;
    • 将残差连接的权重改为 1 2 \frac{1}{\sqrt{2}} 2 1
      ~  
      经过一系列改进,DDPM 的性能超过了 GAN,文章把改进后的模型称为 Ablated Diffusion Model(ADM)。
  2. Classifier Guidance

    • 上边的工程改进并不是本文要讨论的重点,我们言归正传来讲 Classifier Guidance。顾名思义,这种可控生成的方式引入了一个额外的分类器,具体来说,是使用分类器的梯度对生成的过程进行引导。
    • 这个方法现在已经被广泛应用于stable diffusion这些图像生成的大模型,当然现在可能对于这些大模型来说,相比于Classifier guidance更多的是用到Classifier free guidance,因为原理会更简单一些。

一、问题分析

网上有很多讲这个Classifier Guidance的文章直接说引入了一个分类器,然后通过这个分类器的梯度,对采样方法进行偏移,从而得到更符合条件的一个图像,因为这本身是一种逆向思维,不太好理解。为什么会想到引用这个分类器,第一次看到的时候会比较难以理解,所以这里进行一个正向的推导。

首先我们还是先来分析下问题,现在扩散模型有两个问题:

  1. FID值偏高(图像的保真度相比于GAN来说,在指标上是比较偏高的)
  2. 多样性比较高(扩散模型多样性比较好,导致结果不可控)

说到这里肯定会想到条件生成, P ( X ) → P ( X ∣ y ) P(X) →P(X | y) P(X)P(Xy) y指的是class,即类别信息(在GAN的研究中以及证明条件生成会使FID指标更好),所以现在的问题从优化DDPM模型变成如何构造Conditional-DDPM。下面的推导将证明要实现CDDPM只要引入一个分类器,用这个分类器的梯度在生成时进行监督,所以就有了Classifier Guidance。

二、公式推导

Classifier Guidance本身是一个采样过程,直接使用已经训练好的DDPM模型,通过Classifier Guidance采样方式,进行优化,就能够得到更好,保证度更高的图片,并且这个生成结果是可控的。它有点像DDIM,这就是Classifier Guidance的一个优势,但Classifier-free Guidance是不一样的,Classifier-free Guidance是要重新进行训练的。

DDPM的逆向过程为 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt1xt)现在我们希望给它变成带类别的,即
q ^ ( x t − 1 ∣ x t , y ) \hat{q}(x_{t-1}|x_t,y) q^(xt1xt,y)通过贝叶斯可以得到
q ^ ( x t − 1 ∣ x t , y ) = q ^ ( x t − 1 ∣ x t ) q ^ ( y ∣ x t − 1 , x t ) q ^ ( y ∣ x t ) (1) \hat{q}(x_{t-1}|x_t,y) = \frac{\hat{q}(x_{t-1}|x_t) \hat{q}(y|x_{t-1},x_t)}{\hat{q}(y|x_t)}\tag1 q^(xt1xt,y)=q^(yxt)q^(xt1xt)q^(yxt1,xt)(1)这个式子怎么来的呢,是因为
[ q ^ ( x t − 1 ∣ y ) = q ^ ( x t − 1 ) q ^ ( y ∣ x t − 1 ) q ^ ( y ) ] [\hat{q}(x_{t-1}|y) = \frac{\hat{q}(x_{t-1}) \hat{q}(y|x_{t-1})}{\hat{q}(y)}] [q^(xt1y)=q^(y)q^(xt1)q^(yxt1)]然后给这个式子左右两边加 x t x_t xt即可得到式子1

现在我们的目标变成了求式子1,假设 q ^ \hat{q} q^服从 q ^ ( x t − 1 ∣ x t , y ) ∽ N ( μ ^ , σ 2 ^ ) \hat{q}(x_{t-1}|x_t,y)∽N(\hat{\mu},\hat{\sigma^2}) q^(xt1xt,y)N(μ^,σ2^),这样就可以通过之前的采样方式来生成图像

式子1中我们要求右侧的三项,因为这项 q ^ ( y ∣ x t ) \hat{q}(y|x_t) q^(yxt)中不含 x t − 1 x_{t-1} xt1,所以这一项是一个常数,不需要求。前面我们提到还是要使用DDPM训练的模型,即我们虽然不知道 p ^ \hat{p} p^的逆向过程,但是知道正向过程。

1. 已知项

因为我们需要把这个方法用在已经训练好的模型上,所以必须使得正向过程一致。

目前已知3个等式
q ^ ( x t ∣ x t − 1 , y ) = q ( x t ∣ x t − 1 ) ( 1 ) q ^ ( x 0 ) = q ( x 0 ) ( 2 ) q ^ ( x 1 : T ∣ x 0 , y ) = ∏ t = 1 n q ^ ( x t ∣ x t − 1 , y ) ∽ M a r k o v ( 3 ) \hat{q}(x_t|x_{t-1},y) = q(x_t|x_{t-1})~~(1) \\ \hat{q}(x_0) = q(x_0)~~(2) \\ \hat{q}(x_{1:T}|x_0,y) = \prod_{t=1}^{n}\hat{q}(x_t|x_{t-1},y) ∽ Markov~~(3) q^(xtxt1,y)=q(xtxt1)  (1)q^(x0)=q(x0)  (2)q^(x1:Tx0,y)=t=1nq^(xtxt1,y)Markov  (3)

接下来我们已经解决的项将用绿色表示 q ^ ( y ∣ x t ) \color{green}{\hat{q}(y|x_t)} q^(yxt),正在解决的用红色表示 正在解决 \color{red}{\text{正在解决}} 正在解决

2. 求 q ^ ( x t − 1 ∣ x t ) \hat{q}(x_{t-1}|x_t) q^(xt1xt)

q ^ ( x t − 1 ∣ x t , y ) = q ^ ( x t − 1 ∣ x t ) q ^ ( y ∣ x t , x t − 1 ) q ^ ( y ∣ x t ) (1) \hat{q}(x_{t-1}|x_t,y) = \frac{{\color{red}\hat{q}(x_{t-1}|x_t)} \hat{q}(y|x_{t},x_{t-1})}{\color{green}\hat{q}(y|x_t)}\tag1 q^(xt1xt,y)=q^(yxt)q^(xt1xt)q^(yxt,xt1)(1)
使用贝叶斯得
q ^ ( x t − 1 ∣ x t ) = q ^ ( x t ∣ x t − 1 ) q ^ ( x t − 1 ) q ^ ( x t ) (2) {\color{red}\hat{q}(x_{t-1}|x_t)} = \frac{{\color{red}\hat{q}(x_t|x_{t-1})} \hat{q}(x_{t-1})}{\hat{q}(x_t)}\tag2 q^(xt1xt)=q^(xt)q^(xtxt1)q^(xt1)(2)

步骤一:
先求 q ^ ( x t ∣ x t − 1 ) {\color{red}\hat{q}(x_t|x_{t-1})} q^(xtxt1),在已知(1)中,已经知道了有y的等式,我们想要得到在上一个状态 𝑥 𝑡 − 1 𝑥_{𝑡−1} xt1的条件下转移到当前状态 𝑥 𝑡 𝑥_𝑡 xt的无条件概率 q ^ ( x t ∣ x t − 1 ) {\color{red}\hat{q}(x_t|x_{t-1})} q^(xtxt1)。为了消除类别标签 𝑦 对结果的影响,可以利用全概率公式对 𝑦 积分:
q ^ ( x t ∣ x t − 1 ) = ∫ y q ^ ( x t , y ∣ x t − 1 ) d y = ∫ y q ^ ( x t ∣ y , x t − 1 ) q ^ ( y ∣ x t − 1 ) d y = ∫ y q ( x t ∣ x t − 1 ) q ^ ( y ∣ x t − 1 ) d y = q ( x t ∣ x t − 1 ) ∫ y q ^ ( y ∣ x t − 1 ) d y = q ( x t ∣ x t − 1 ) {\color{red}\hat{q}(x_t|x_{t-1})} = \int_{y} \hat{q}(x_t,y|x_{t-1})dy \\ = \int_{y}{\hat{q}(x_t|y,x_{t-1})}{\hat{q}(y|x_{t-1})}dy \\ = \int_y{q(x_t|x_{t-1})}{\hat{q}(y|x_{t-1})}dy \\ = q(x_t|x_{t-1}) \int_{y} {\hat{q}(y|x_{t-1})}dy \\ = q(x_t|x_{t-1}) q^(xtxt1)=yq^(xt,yxt1)dy=yq^(xty,xt1)q^(yxt1)dy=yq(xtxt1)q^(yxt1)dy=q(xtxt1)yq^(yxt1)dy=q(xtxt1)

  • 第一行中y在变量里,我们已知的y在条件里,用条件概率的定义式转化为第二行;
  • 第二行到第三行变换用了已知(1)的等式;
  • 在第三行中,由于是对y积分,前一项就是常数项,可以提出来,后面一项因为变量只含y,不管条件是什么,全概率公式加起来为1,即 ∫ y q ^ ( y ∣ x t − 1 ) d y = 1 \int_{y} {\hat{q}(y|x_{t-1})}dy=1 yq^(yxt1)dy=1

等式(2)现在可以表示为
q ^ ( x t − 1 ∣ x t ) = q ^ ( x t ∣ x t − 1 ) q ^ ( x t − 1 ) q ^ ( x t ) (2) {\color{red}\hat{q}(x_{t-1}|x_t)} = \frac{{\color{green}\hat{q}(x_t|x_{t-1})} {\color{red}\hat{q}(x_{t-1})}}{\color{red}\hat{q}(x_t)}\tag2 q^(xt1xt)=q^(xt)q^(xtxt1)q^(xt1)(2)

目前已知:
q ^ ( x t ∣ x t − 1 , y ) = q ( x t ∣ x t − 1 ) ( 1 ) q ^ ( x 0 ) = q ( x 0 ) ( 2 ) q ^ ( x 1 : T ∣ x 0 , y ) = ∏ t = 1 n q ^ ( x t ∣ x t − 1 , y ) ∽ M a r k o v ( 3 ) q ^ ( x t ∣ x t − 1 ) = q ( x t ∣ x t − 1 ) ( 4 ) \hat{q}(x_t|x_{t-1},y) = q(x_t|x_{t-1})~~(1) \\ \hat{q}(x_0) = q(x_0)~~(2) \\ \hat{q}(x_{1:T}|x_0,y) = \prod_{t=1}^{n}\hat{q}(x_t|x_{t-1},y) ∽ Markov~~(3) \\ \hat{q}(x_t|x_{t-1}) = q(x_t|x_{t-1})~~(4) q^(xtxt1,y)=q(xtxt1)  (1)q^(x0)=q(x0)  (2)q^(x1:Tx0,y)=t=1nq^(xtxt1,y)Markov  (3)q^(xtxt1)=q(xtxt1)  (4)

步骤二:
q ^ ( x t ) {\color{red}\hat{q}(x_t)} q^(xt),继续用全概率公式
q ^ ( x t ) = ∫ x 0 : t − 1 q ^ ( x 0 : t ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ^ ( x 0 ) q ^ ( x 1 : t ∣ x 0 ) d x 0 : t − 1 (3) {\color{red}\hat{q}(x_t)} = \int_{x_0:t-1} {\hat{q}(x_{0:t})}dx_{0:t-1} \\ = \int_{x_0:t-1} {\hat{q}(x_{0})\hat{q}(x_{1:t}|x_0)}dx_{0:t-1} \tag3 q^(xt)=x0:t1q^(x0:t)dx0:t1=x0:t1q^(x0)q^(x1:tx0)dx0:t1(3)

  • x 0 : t − 1 ⇨ x 0 , x 1 . . . x t − 1 x_{0:t-1}⇨ x_0, x_1... x_{t-1} x0:t1x0,x1...xt1这是一个联合的随机变量;
  • 根据已知(3), x 0 x_0 x0在条件的位置,所以使用条件概率变换为第二行;

用全概率公式继续求 q ^ ( x 1 : t ∣ x 0 ) \color{red}\hat{q}(x_{1:t}|x_0) q^(x1:tx0)
q ^ ( x 1 : t ∣ x 0 ) = ∫ y q ^ ( x 1 : t , y ∣ x 0 ) d y = ∫ y q ^ ( x 1 : t ∣ y , x 0 ) q ^ ( y ∣ x 0 ) d y = ∫ y q ^ ( y ∣ x 0 ) ∏ 1 t q ^ ( x t ∣ x t − 1 , y ) d y = ∫ y q ^ ( y ∣ x 0 ) ∏ 1 t q ( x t ∣ x t − 1 ) d y = ∫ y q ^ ( y ∣ x 0 ) q ( x 1 : t ∣ x 0 ) d y = q ( x 1 : t ∣ x 0 ) {\color{red}\hat{q}(x_{1:t}|x_0)} = \int_{y}{\hat{q}(x_{1:t},y| x_0)dy} \\ = \int_{y}{\hat{q}(x_{1:t}|y, x_0)\hat{q}(y|x_0)dy} \\ = \int_{y}\hat{q}(y|x_0) \prod_{1}^{t}\hat{q}(x_t|x_{t-1},y)dy \\ = \int_{y}\hat{q}(y|x_0) \prod_{1}^{t} q(x_t|x_{t-1})dy \\ = \int_{y}\hat{q}(y|x_0) q(x_{1:t}|x_0)dy \\ = q(x_{1:t}|x_0) q^(x1:tx0)=yq^(x1:t,yx0)dy=yq^(x1:ty,x0)q^(yx0)dy=yq^(yx0)1tq^(xtxt1,y)dy=yq^(yx0)1tq(xtxt1)dy=yq^(yx0)q(x1:tx0)dy=q(x1:tx0)

  • 根据已知(3),对y做积分并将y移至条件的位置,并做替换得到第三行;
  • 根据已知(1),做替换得到第四行;
  • ∏ 1 t q ( x t ∣ x t − 1 ) \prod_{1}^{t}q(x_t|x_{t-1}) 1tq(xtxt1)的连乘就是 q ( x 1 : t ∣ x 0 ) q(x_{1:t}|x_0) q(x1:tx0),变换得第五行;
  • 根据全概率公式得 ∫ y q ^ ( y ∣ x 0 ) = 1 \int_{y}\hat{q}(y|x_0) = 1 yq^(yx0)=1

现在等式(3)如下:
q ^ ( x t ) = ∫ x 0 : t − 1 q ^ ( x 0 : t ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ^ ( x 0 ) q ^ ( x 1 : t ∣ x 0 ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ( x 0 ) q ( x 1 : t ∣ x 0 ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ( x 0 : t ) d x 0 : t − 1 = q ( x t ) {\color{red}\hat{q}(x_t)} = \int_{x_0:t-1} {\hat{q}(x_{0:t})}dx_{0:t-1} \\ = \int_{x_0:t-1} {\hat{q}(x_{0})\hat{q}(x_{1:t}|x_0)}dx_{0:t-1} \\ = \int_{x_0:t-1} {q(x_{0}) q(x_{1:t}|x_0)}dx_{0:t-1} \\ = \int_{x_0:t-1} {q(x_{0:t})}dx_{0:t-1} = q(x_t) q^(xt)=x0:t1q^(x0:t)dx0:t1=x0:t1q^(x0)q^(x1:tx0)dx0:t1=x0:t1q(x0)q(x1:tx0)dx0:t1=x0:t1q(x0:t)dx0:t1=q(xt)
至此我们解决了等式(2),即
q ^ ( x t − 1 ∣ x t , y ) = q ^ ( x t − 1 ∣ x t ) q ^ ( y ∣ x t , x t − 1 ) q ^ ( y ∣ x t ) \hat{q}(x_{t-1}|x_t,y) = \frac{{\color{green}\hat{q}(x_{t-1}|x_t)} \hat{q}(y|x_{t},x_{t-1})}{\color{green}\hat{q}(y|x_t)} q^(xt1xt,y)=q^(yxt)q^(xt1xt)q^(yxt,xt1)

3. 求 q ^ ( y ∣ x t , x t − 1 ) \hat{q}(y|x_{t},x_{t-1}) q^(yxt,xt1)

q ^ ( x t − 1 ∣ x t , y ) = q ^ ( x t − 1 ∣ x t ) q ^ ( y ∣ x t , x t − 1 ) q ^ ( y ∣ x t ) \hat{q}(x_{t-1}|x_t,y) = \frac{{\color{green}\hat{q}(x_{t-1}|x_t)} \color{red}\hat{q}(y|x_{t},x_{t-1})}{\color{green}\hat{q}(y|x_t)} q^(xt1xt,y)=q^(yxt)q^(xt1xt)q^(yxt,xt1)

目前已知:
q ^ ( x t ∣ x t − 1 , y ) = q ( x t ∣ x t − 1 ) ( 1 ) q ^ ( x 0 ) = q ( x 0 ) ( 2 ) q ^ ( x 1 : T ∣ x 0 , y ) = ∏ t = 1 n q ^ ( x t ∣ x t − 1 , y ) ∽ M a r k o v ( 3 ) q ^ ( x t ∣ x t − 1 ) = q ( x t ∣ x t − 1 ) ( 4 ) q ^ ( x t − 1 ∣ x t ) = q ( x t − 1 ∣ x t ) ( 5 ) \hat{q}(x_t|x_{t-1},y) = q(x_t|x_{t-1})~~(1) \\ \hat{q}(x_0) = q(x_0)~~(2) \\ \hat{q}(x_{1:T}|x_0,y) = \prod_{t=1}^{n}\hat{q}(x_t|x_{t-1},y) ∽ Markov~~(3) \\ \hat{q}(x_t|x_{t-1}) = q(x_t|x_{t-1})~~(4) \\ \hat{q}(x_{t-1}|x_{t}) = q(x_{t-1}|x_{t})~~(5) q^(xtxt1,y)=q(xtxt1)  (1)q^(x0)=q(x0)  (2)q^(x1:Tx0,y)=t=1nq^(xtxt1,y)Markov  (3)q^(xtxt1)=q(xtxt1)  (4)q^(xt1xt)=q(xt1xt)  (5)

用贝叶斯得:
q ^ ( y ∣ x t , x t − 1 ) = q ^ ( x t ∣ y , x t − 1 ) q ^ ( y ∣ x t − 1 ) q ^ ( x t ∣ x t − 1 ) = q ^ ( x t ∣ y , x t − 1 ) q ^ ( y ∣ x t − 1 ) q ^ ( x t ∣ x t − 1 ) = q ( x t ∣ x t − 1 ) q ^ ( y ∣ x t − 1 ) q ^ ( x t ∣ x t − 1 ) = q ^ ( x t ∣ x t − 1 ) q ^ ( y ∣ x t − 1 ) q ^ ( x t ∣ x t − 1 ) = q ^ ( y ∣ x t − 1 ) {\color{red}\hat{q}(y|x_{t},x_{t-1})} = \frac{\hat{q}(x_t|y,x_{t-1})\hat{q}(y|x_{t-1})}{\hat{q}(x_t|x_{t-1})} \\ = \hat{q}(x_t|y,x_{t-1}) \frac{\hat{q}(y|x_{t-1})}{\hat{q}(x_t|x_{t-1})} \\ = q(x_t|x_{t-1}) \frac{\hat{q}(y|x_{t-1})}{\hat{q}(x_t|x_{t-1})} \\ = \hat{q}(x_t|x_{t-1}) \frac{\hat{q}(y|x_{t-1})}{\hat{q}(x_t|x_{t-1})} \\ = \hat{q}(y|x_{t-1}) q^(yxt,xt1)=q^(xtxt1)q^(xty,xt1)q^(yxt1)=q^(xty,xt1)q^(xtxt1)q^(yxt1)=q(xtxt1)q^(xtxt1)q^(yxt1)=q^(xtxt1)q^(xtxt1)q^(yxt1)=q^(yxt1)

  • 根据已知(1)得第三行;
  • 根据已知(4)得第四行,化简得第五行。

4. 整合

在等式中:
q ^ ( x t − 1 ∣ x t , y ) = q ^ ( x t − 1 ∣ x t ) q ^ ( y ∣ x t , x t − 1 ) q ^ ( y ∣ x t ) = Z q ( x t − 1 ∣ x t ) q ^ ( y ∣ x t − 1 ) \hat{q}(x_{t-1}|x_t,y) = \frac{{\color{green}\hat{q}(x_{t-1}|x_t)} \color{red}\hat{q}(y|x_{t},x_{t-1})}{\color{green}\hat{q}(y|x_t)} \\ = Z {\color{blue}q(x_{t-1}|x_t)} {\color{blue}\hat{q}(y|x_{t-1}}) q^(xt1xt,y)=q^(yxt)q^(xt1xt)q^(yxt,xt1)=Zq(xt1xt)q^(yxt1)

  • q ^ ( y ∣ x t ) {\color{green}\hat{q}(y|x_t)} q^(yxt)是常数,用Z来表示;
  • q ^ ( x t − 1 ∣ x t ) = q ( x t − 1 ∣ x t ) {\color{green}\hat{q}(x_{t-1}|x_t)} = q(x_{t-1}|x_t) q^(xt1xt)=q(xt1xt),即DDPM的采样过程;
  • q ^ ( y ∣ x t − 1 ) \hat{q}(y|x_{t-1}) q^(yxt1)就是一个分类器的分类结果;
  • Z作为normalize factor,是一个正则化的因子,为了保证 q ^ ( x t − 1 ∣ x t , y ) \hat{q}(x_{t-1}|x_t,y) q^(xt1xt,y)符合概率分布。

三、采样方式

现在需要用两个模型来分别预测右边的两项
q ^ ( x t − 1 ∣ x t , y ) = Z q ( x t − 1 ∣ x t ) q ^ ( y ∣ x t − 1 ) = Z P θ ( x t ∣ x t + 1 ) P φ ( y ∣ x t ) (4) \hat{q}(x_{t-1}|x_t,y) = Z {q(x_{t-1}|x_t)} {\hat{q}(y|x_{t-1})}=Z P_θ(x_t|x_{t+1})P_φ(y|x_{t})\tag4 q^(xt1xt,y)=Zq(xt1xt)q^(yxt1)=ZPθ(xtxt+1)Pφ(yxt)(4)

  • P θ ( x t ∣ x t + 1 ) → q ( x t − 1 ∣ x t ) P_θ(x_t|x_{t+1})→q(x_{t-1}|x_t) Pθ(xtxt+1)q(xt1xt),即DDPM的模型,服从高斯分布 N ( μ , Σ 2 ) N(\mu,Σ^2) N(μ,Σ2)
  • P φ ( y ∣ x t ) → q ^ ( y ∣ x t − 1 ) P_φ(y|x_{t})→\hat{q}(y|x_{t-1}) Pφ(yxt)q^(yxt1),即分类器模型

我们有了这两个模型之后,怎么让这两个模型之间有一定的联系,然后得到最终的结果 。假设此时刻是t+1,而这个时刻是没有办法得到 x t x_t xt的,没有 x t x_t xt就没有办法得到 P φ ( y ∣ x t ) P_φ(y|x_{t}) Pφ(yxt)确切的值,所以很难通过数学计算得到 q ^ ( x t − 1 ∣ x t , y ) \hat{q}(x_{t-1}|x_t,y) q^(xt1xt,y),因此只能通过预估得到 P φ ( y ∣ x t ) P_φ(y|x_{t}) Pφ(yxt)

1. 计算 P θ ( x t ∣ x t + 1 ) P_θ(x_t|x_{t+1}) Pθ(xtxt+1)

取对数 l o g P θ ( x t ∣ x t + 1 ) = − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + C 1 (5) logP_θ(x_t|x_{t+1}) = -\frac{1}{2}(x_t-\mu)^TΣ^{-1}(x_t-\mu)+C_1\tag5 logPθ(xtxt+1)=21(xtμ)TΣ1(xtμ)+C1(5)

公式如下:
在这里插入图片描述

2. 计算 P φ ( y ∣ x t ) P_φ(y|x_{t}) Pφ(yxt)

x t x_{t} xt是不知道的,需要 x t x_{t} xt进行预估。但我们知道,扩散模型的方差 Σ Σ Σ一般都很小,在图像上表示为,其他位置都趋近于0,只有在 x = μ x=\mu x=μ附近概率比较高。那么就可以认为 x t ≈ μ x_t≈\mu xtμ
在这里插入图片描述
取对数 l o g P φ ( y ∣ x t ) logP_φ(y|x_{t}) logPφ(yxt),并在 x t ≈ μ x_t≈\mu xtμ处进行泰勒展开得
l o g P φ ( y ∣ x t ) = l o g P φ ( y ∣ x t ) ∣ x t = μ + ( x t − μ ) T ▽ x t l o g P φ ( y ∣ x t ) ∣ x t = μ + o ( ∣ ∣ x t − μ ∣ ∣ 2 ) ≈ ( x t − μ ) T ▽ x t l o g P φ ( y ∣ x t ) ∣ x t = μ + C 2 (6) logP_φ(y|x_{t}) = logP_φ(y|x_{t})|_{x_{t=\mu}}+(x_t-\mu)^T\triangledown x_t logP_φ(y|x_{t})|_{x_{t=\mu}} +o(||x_t-\mu||^2) \\ ≈ (x_t-\mu)^T\triangledown x_t logP_φ(y|x_{t})|_{x_{t=\mu}}+C_2\tag6 logPφ(yxt)=logPφ(yxt)xt=μ+(xtμ)TxtlogPφ(yxt)xt=μ+o(∣∣xtμ2)(xtμ)TxtlogPφ(yxt)xt=μ+C2(6)
解释:
在多维空间中,泰勒展开和梯度的定义是自然延伸的。以下是关于泰勒展开为这种形式的原因,以及梯度表示为何采用倒三角(即 ∇)符号的解释。
在这里插入图片描述

在泰勒展开中出现 𝑇 次方(即转置)是因为我们要确保中间项是一个标量,而不是一个向量。让我们更深入地理解这个问题。
在这里插入图片描述

l o g P θ ( x t ∣ x t + 1 ) + l o g P φ ( y ∣ x t ) = − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + ( x t − μ ) T ▽ x t l o g P φ ( y ∣ x t ) ∣ x t = μ + C = − 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + 1 2 g T Σ g + C ( ▽ = g ) = − 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + C ′ = l o g P ( x t − 1 ∣ x t , y ) (7) logP_θ(x_t|x_{t+1}) + logP_φ(y|x_{t}) =-\frac{1}{2}(x_t-\mu)^TΣ^{-1}(x_t-\mu)+(x_t-\mu)^T\triangledown x_t logP_φ(y|x_{t})|_{x_{t=\mu}}+C \\ = -\frac{1}{2}(x_t-\mu-Σg)^TΣ^{-1}(x_t-\mu-Σg)+\frac{1}{2}g^TΣg+C~~~ (\triangledown=g) \\ = -\frac{1}{2}(x_t-\mu-Σg)^TΣ^{-1}(x_t-\mu-Σg)+C' \\ = logP(x_{t-1}|x_t,y) \tag7 logPθ(xtxt+1)+logPφ(yxt)=21(xtμ)TΣ1(xtμ)+(xtμ)TxtlogPφ(yxt)xt=μ+C=21(xtμΣg)TΣ1(xtμΣg)+21gTΣg+C   (=g)=21(xtμΣg)TΣ1(xtμΣg)+C=logP(xt1xt,y)(7)

  • 因为这项 1 2 g T Σ g \frac{1}{2}g^TΣg 21gTΣg中没有 x t x_t xt,所以是这项是常数,化简为第三行;
  • P θ ( x t ∣ x t + 1 ) P φ ( y ∣ x t ) = P ( x t − 1 ∣ x t , y ) P_θ(x_t|x_{t+1})P_φ(y|x_{t})=P(x_{t-1}|x_t,y) Pθ(xtxt+1)Pφ(yxt)=P(xt1xt,y)得第四行

3. 终极目标

l o g P ( x t − 1 ∣ x t , y ) = − 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + C ′ (8) logP(x_{t-1}|x_t,y)= -\frac{1}{2}(x_t-\mu-Σg)^TΣ^{-1}(x_t-\mu-Σg)+C' \tag8 logP(xt1xt,y)=21(xtμΣg)TΣ1(xtμΣg)+C(8)
观察等式(8)和等式(5),可以发现这两个式子是一样的,故可知,等式(8)也服从高斯分布 N ∽ ( μ + Σ g , Σ 2 ) N∽(\mu+Σg,Σ^2) N(μ+Σg,Σ2),即

P ( x t − 1 ∣ x t , y ) ∽ N ( μ + Σ g , Σ 2 ) (9) P(x_{t-1}|x_t,y)∽N(\mu+Σg,Σ^2)\tag9 P(xt1xt,y)N(μ+Σg,Σ2)(9)
所以 x t = μ + Σ g + Σ ε x_t = \mu+Σg +Σε xt=μ+Σg+Σε

  • μ \mu μ需要用预测的 ε ε ε来求
  • g = ▽ = ▽ x t l o g P φ ( y ∣ x t ) ∣ x t = μ g=\triangledown=\triangledown x_t logP_φ(y|x_{t})|_{x_{t=\mu}} g==xtlogPφ(yxt)xt=μ,这个东西实际上就是分类器的梯度

至此,就可以得到条件生成的结果,这个条件包含在了分类器里

四、泛用采样

这步的作用是为了摆脱对 Σ Σ Σ的依赖,因为 Σ Σ Σ有等于0的情况。所以就借用了score的概念。

在这个分布中,直接求解并不容易,但可以使用 score-based models 的方式进行求解(对 score-based models 不熟悉的读者可以先阅读一个视频看懂score-based模型的底层原理和[生成模型新方向]: score-based generative models 这两个讲解作为前置知识),也就是利用 score function。
这里简单描述一下:

  • score的定义: 一个概率分布 P ( x ) P(x) P(x),score就是 l o g P ( x ) logP(x) logP(x)的梯度 ▽ x l o g P ( x ) \triangledown xlogP(x) xlogP(x)
  • score-based generative models的论文中,提到需要训练一个模型来预测score。score based这些论文认为扩散模型实际上是在预测score
    P ( x t − 1 ∣ x t , y ) ∽ N ( μ + Σ g , Σ 2 ) (9) P(x_{t-1}|x_t,y)∽N(\mu+Σg,Σ^2)\tag9 P(xt1xt,y)N(μ+Σg,Σ2)(9)

在DDPM中,有这样一个分布
P ( x t ∣ x 0 ) ∽ N ( a ‾ t x 0 , 1 − a ‾ t ) P(x_t|x_0)∽\mathcal{N}(\sqrt{\overline{a}_t}x_{0}, 1 - \overline{a}_t) P(xtx0)N(at x0,1at)
取log:
l o g P ( x t ∣ x 0 ) = − ( x t − a ‾ t ⋅ x 0 ) 2 2 ( 1 − a ‾ t ) (10) logP(x_t | x_0) = -\frac{(x_t - \sqrt{\overline{a}_t} \cdot x_0)^2}{2(1 -\overline{a}_t)}\tag{10} logP(xtx0)=2(1at)(xtat x0)2(10)
x t x_t xt求梯度得:
▽ x t l o g P ( x t ∣ x 0 ) = − x t − a ‾ t ⋅ x 0 1 − a ‾ t (11) \triangledown_{x_t}logP(x_t | x_0) = -\frac{x_t - \sqrt{\overline{a}_t} \cdot x_0}{1 -\overline{a}_t}\tag{11} xtlogP(xtx0)=1atxtat x0(11)
根据公式: x t = 1 − a ‾ t × E + a ‾ t x 0 x_t =\sqrt{1 - \overline{a}_t} × \mathcal{E} + \sqrt{\overline{a}_t}x_{0} xt=1at ×E+at x0
E = x t − a ‾ t x 0 1 − a ‾ t \mathcal{E} = \frac{x_t - \sqrt{\overline{a}_t}x_{0}}{\sqrt{1 - \overline{a}_t}} E=1at xtat x0
故等式(11)可以改写为:
▽ x t l o g P ( x t ∣ x 0 ) = − x t − a ‾ t ⋅ x 0 1 − a ‾ t = − E 1 − a ‾ t (12) \triangledown_{x_t}logP(x_t | x_0) = -\frac{x_t - \sqrt{\overline{a}_t} \cdot x_0}{1 -\overline{a}_t} = - \frac{\mathcal{E}}{\sqrt{1 - \overline{a}_t}}\tag{12} xtlogP(xtx0)=1atxtat x0=1at E(12)
上面这个等式的 E \mathcal{E} E是抽样得到的,现在我们把它改为预测的
▽ x t l o g P θ ( x t ) = − E θ ^ ( x t ) 1 − a ‾ t (13) \triangledown_{x_t}logP_θ(x_t) = - \frac{\hat{\mathcal{E}_θ}(x_t)}{\sqrt{1 - \overline{a}_t}}\tag{13} xtlogPθ(xt)=1at Eθ^(xt)(13)
因为模型实际上就是给定 x t , t x_t,t xt,t,然后用Unet预测 E \mathcal{E} E,所以 ▽ x t l o g P θ ( x t ) \triangledown_{x_t}logP_θ(x_t) xtlogPθ(xt)实际上就是预测的score

根据等式(4)(13)得:
▽ x t l o g P θ ( x t ) P φ ( y ∣ x t ) = ▽ l o g P θ ( x t ) + ▽ l o g P φ ( y ∣ x t ) = − E θ ^ ( x t ) 1 − a ‾ t + g (14) \triangledown_{x_t}logP_θ(x_t)P_φ(y|x_{t}) = \triangledown logP_θ(x_t) + \triangledown logP_φ(y|x_{t}) \\ = - \frac{\hat{\mathcal{E}_θ}(x_t)}{\sqrt{1 - \overline{a}_t}} +g\tag{14} xtlogPθ(xt)Pφ(yxt)=logPθ(xt)+logPφ(yxt)=1at Eθ^(xt)+g(14)
等式右边也可以解释为:第一项是无条件生成的 score function,第二项是分类器的梯度,这个梯度表示的是从噪声指向条件 y 的方向,把这个方向加到无条件生成的 score 上,就可以让降噪的方向也指向 y 的方向。

▽ x t l o g P θ ( x t ) P φ ( y ∣ x t ) \triangledown_{x_t}logP_θ(x_t)P_φ(y|x_{t}) xtlogPθ(xt)Pφ(yxt)看作一个整体可知
▽ x t l o g P θ ( x t ) P φ ( y ∣ x t ) = − E ^ ′ 1 − a ‾ t (15) \triangledown_{x_t}logP_θ(x_t)P_φ(y|x_{t}) = - \frac{\hat{\mathcal{E}}'}{\sqrt{1 - \overline{a}_t}}\tag{15} xtlogPθ(xt)Pφ(yxt)=1at E^(15)
根据(14)(15)可知
E ^ ′ = E θ ^ − 1 − a ‾ t g (15) \hat{\mathcal{E}}' = \hat{\mathcal{E}_θ} - {\sqrt{1 - \overline{a}_t}}g\tag{15} E^=Eθ^1at g(15)
这个式子表示:

  • E θ ^ \hat{\mathcal{E}_θ} Eθ^是DDPM模型预测的噪音
  • 此时已经不需要用 μ + Σ g \mu+Σg μ+Σg,就摆脱了对 Σ Σ Σ的依赖
  • 只需要用预测出来的噪音+一个非常小的扰动,就可以得到一个全新的 E \mathcal{E} E
  • 然后使用DDPM\DDIM的采样方式就可以得到最终的生成结果

五、伪代码分析

在这里插入图片描述
Algorithm 1 2 都是采样过程

  1. Algorithm 1 就是前面将的 + Σ g +Σg +Σg形式

    • 首先从(0,1)高斯分布中提取一个 x T x_T xT
    • 然后一步一步迭代;
    • 通过预测得到 μ , Σ \mu,Σ μ,Σ
    • 用分类器计算分类器的梯度,并计算 + Σ g +Σg +Σg
    • s是调节系数,作者发现s大的时候,模型会更关注分类器,这样得到的图片质量越高,但是多样性越低,因为更多的趋向于给定的标签y;
    • 最后采样
  2. Algorithm 2 就是改变 E ^ \hat{\mathcal{E}} E^

    • 使用的是DDIM的采样方式

再看一个分类引导的伪代码加深理解

classifier_model = ...  # 加载一个训好的图像分类模型
y = 1  # 生成类别为 1 的图像,假设类别 1 对应“狗”这个类
guidance_scale = 7.5  # 控制类别引导的强弱,越大越强
input = get_noise(...)  # 从高斯分布随机取一个跟输出图像一样 shape 的噪声图for t in tqdm(scheduler.timesteps):# 用 unet 推理,预测噪声with torch.no_grad():noise_pred = unet(input, t).sample# 用 input 和预测出的 noise_pred 和 x_t 计算得到 x_t-1input = scheduler.step(noise_pred, t, input).prev_sample# classifier guidance 步骤class_guidance = classifier_model.get_class_guidance(input, y)input += class_guidance * guidance_scals  # 把梯度加上去

这里的关键是倒数第二行:

    class_guidance = classifier_model.get_class_guidance(input, y)

这一行做的事情是,把当前生成的图 input 和我们想要的类别 y 一起喂给分类模型。

分类模型分出来的类别不一定是 y,但我们想让它尽量靠近 y,所以计算预测值和 y 的 loss。

然后,这里像训练分类模型时的梯度反向传播一样,计算梯度。不同的是训分类模型时要得到权重参数的梯度,方便梯度更新,而这里只需要保留对 input 的梯度就好。

最后把计算出来的梯度根据 guidance_scale 加到图像上。

六、分类器训练

  1. 训练数据:和扩散模型用的数据一样
  2. 分类器训练的数据是由噪音的,噪音强度时随机的,目的是为了保证分类器在任何时刻都能有作用,可以给出一个好的梯度方向

七、代码实现

虽然推导看起来依然很复杂,但需要改动的代码其实非常少,获得梯度之后再用梯度更新一下就可以了。这里给出一些关键的代码片段。

1. 获取分类器梯度

获取分类器对 x t \mathbf{x}_t xt的梯度其实也比较直接,可以直接使用 Pytorch 的自动求导工具。先让 x \mathbf{x} x带上梯度,然后输入分类器获取概率分布,最后再提取出 y 对应的一项计算梯度。这里有一个比较神奇的点,就是一般来说分类模型的输入都是不计算梯度的,不过这里的输入也是带梯度的,感觉类似于 DETR 里的 learnable query:

import torch
import torch.nn.functional as Fdef classifier_guidance(x: torch.Tensor,t: torch.Tensor,y: torch.Tensor,classifier: torch.nn.Module
):# 开启自动求导上下文,确保计算梯度with torch.enable_grad():# 分离 x 的梯度信息并启用对其的梯度计算,得到新的张量 x_with_grad# 这样可以在后续步骤中计算 x_with_grad 的梯度,而不会影响 x 本身x_with_grad = x.detach().requires_grad_(True)# 使用分类器对带有梯度的 x 和时间步 t 进行前向传播,得到未归一化的分数 (logits)logits = classifier(x_with_grad, t)# 对 logits 应用 log softmax 函数,得到 log 概率分布# 这样可以提高数值稳定性,并便于后续的梯度计算log_prob = F.log_softmax(logits, dim=-1)# 从 log 概率分布中选取目标类别 y 对应的 log 概率值# 使用 y.view(-1) 将 y 转换为一维,以便作为索引selected = log_prob[range(len(logits)), y.view(-1)]# 计算目标类别 y 的 log 概率总和对 x_with_grad 的梯度# 该梯度将作为指导信号,用于引导生成模型朝向目标类别 y 的方向return torch.autograd.grad(selected.sum(), x_with_grad)[0]

这一部分也就相当于 ▽ x t l o g P ( y ∣ x t ) \triangledown_{x_t}logP(y | x_t) xtlogP(yxt)这一项,这在上一章的两种解释中都是相通的。而如何使用得到的梯度对采样过程进行引导,会根据推导不同有两种实现方式。

2. 第一种引导的实现

这种方法相对比较好理解,就是用梯度朝着指向 y 的方向对生成结果进行一个修正:

for timestep in tqdm(scheduler.timesteps):# 在当前时间步中预测噪声分量with torch.no_grad():noise_pred = unet(images, timestep).sample  # 使用 U-Net 预测噪声,禁用梯度计算以节省内存# 根据预测出的噪声和当前时间步,计算出去噪后的 x_{t-1}(即生成过程的前一状态)images = scheduler.step(noise_pred, timestep, images).prev_sample# 计算分类器的梯度指导,帮助生成模型生成符合目标类别的样本guidance = classifier_guidance(images, timestep, y, classifier)# 将分类器梯度(guidance)加到去噪后的样本上,以引导生成过程更加符合目标类别images += guidance_scale * guidance

3. 第二种引导的实现

这种实现方式和 openai 的官方实现相同,也就是直接按照原论文的 x t ∽ N ( μ + s Σ g , Σ ) x_t ∽ N(\mu+sΣg,Σ) xtNμ+sΣg,Σ得到结果:

# 先预测均值和方差
mean, variance = p_mean_var['mean'], p_mean_var['variance']  # 提取预测出的均值和方差,用于更新生成样本# 计算分类器的梯度,用于引导生成过程
guidance = classifier_guidance(images, timestep, y, classifier)# 根据原始的均值和方差,以及分类器提供的梯度指导信号,计算出新的均值
# 新的均值通过加入 guidance 来调整生成样本的方向,使其更符合目标类别
new_mean = mean.float() + guidance_scale * variance * guidance.float()

在这份代码中,p_mean_var 就是模型预测出的均值和方差。因为官方实现基于 Improved DDPM 修改,所以方差也是可学习的。根据公式可以计算出新的均值,得到新的均值和方差后,再从对应的高斯分布中进行采样即可。

Classifier-Free Guidance论文工作

Classifier Guidance 只能用分类模型控制生成的类别。如果分类模型是分 80 类,那么 Classifier Guidance 也只能引导 diffusion 模型生成固定的 80 类,多一类都不好使。

Classifier Free Guidance 这个方法就厉害了,虽然得重训 diffusion 模型,但训好了以后可以直接起飞,没有类别数量的限制。

一、 伪代码分析

1. 训练

在这里插入图片描述

  • puncond:无条件的y;
  • 从数据集中采样一张图片和对应的分类;
  • 有puncond的概率把c复赋值为空(非条件),1-puncond的概率保留c,不赋值的话就是条件;通过这种形式把条件和非条件放到一起训练;(这是唯一多的一步,其他基本都一样)
  • 取一个时间样本λ;
  • 抽样随机噪音;
  • 通过均值标准差计算 z λ z_λ zλ
  • 算误差。

2. 采样

在这里插入图片描述

  • w是条件因子,c可以给定,也可以不给定
  • 步骤3,只有当有条件c时,才会条件生成,否则就是非条件生成

以文本条件为例(也就是文生图)

clip_model = ...  # 加载一个官方的 clip 模型text = "一只狗"  # 输入文本
text_embeddings = clip_model.text_encode(text)  # 编码条件文本
empty_embeddings = clip_model.text_encode("")  # 编码空文本
text_embeddings = torch.cat(empty_embeddings, text_embeddings)  # 把它俩 concate 到一起作为条件input = get_noise(...)  # 从高斯分布随机取一个跟输出图像一样 shape 的噪声图for t in tqdm(scheduler.timesteps):# 用 unet 推理,预测噪声with torch.no_grad():# 这里同时预测出了有文本的和空文本的图像噪声noise_pred = unet(input, t, encoder_hidden_states=text_embeddings).sample# Classifier-Free Guidance 引导noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)  # 拆成无条件和有条件的噪声# 把【“无条件噪声”指向“有条件噪声”】看做一个向量,根据 guidance_scale 的值放大这个向量# (当 guidance_scale = 1 时,下面这个式子退化成 noise_pred_text)noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)# 用预测出的 noise_pred 和 x_t 计算得到 x_t-1input = scheduler.step(noise_pred, t, input).prev_sample

画个图,大约是下面这么个意思:
在这里插入图片描述

上图中的两个圈就表示我们用 unet 算出的 noise_pred_uncond(无条件噪声) 和 noise_pred_text(“一只狗”文本条件噪声) 。

红色的箭头表示从“无条件”到“‘一只狗’条件”的向量,给它乘上 guidance_scale,通过调节 guidance_scale 的数值大小,我们就能控制文本条件噪声贴近文本语义的程度。

如果我们想让生成的图更遵循“一只狗”这个文本语义,就把 guidance_scale 设大一点,生成的图像会更贴近“一只狗”的文本语义,但是多样性也会降低。反之如果我们想让生成的图像更多样丰富一些,就把 guidance_scale 设小一点。通常来讲这个值被设为 7.5 比较合适。

总结一下,Classifier-Free Guidance 需要在训练过程中同时训练模型的两个能力,一个是有条件生成,一个是无条件生成。

个人的理解:无条件生成是有条件生成的基础,生成的质量和多样性是由无条件生成的分数保证的,如果只有有条件生成而没有无条件生成,那么生成效果可能不佳。

显然 Classifier-Free Guidance 效果更好些,即能生成无穷多的图像类别,又不需要重新训练一个基于噪声的分类模型。所以现在最常见的都是 Classifier-Free Guidance。

二、U-Net模型如何融入语义信息

1. CrossAttention

首先通过embedding layer或者是clip等模型将文本转换为文本特征嵌入,即text embedding过程。

之后text embedding和原本模型中的image进行融合。最常见的方式是利用CrossAttention(stable diffusion采用的就是这个方法)。

具体来说是把text embedding作为注意力机制中的key和value,把原始图片表征作为query。相当于计算每张图片和对应句子中单词的一个相似度得分,把得分转换成单词的权重,[权重乘以单词的embedding]加和作为最终的特征。
在这里插入图片描述

import torch 
import torch.nn as nn
from einops import rearrangeclass SpatialCrossAttention(nn.Module):def __init__(self,dim,context_dim,heads=4,dim_head=32):super(SpatialCrossAttention,self).__init__()self.scale = dim_head**-0.5self.heads = headshidden_dim = dim_head*headsself.proj_in = nn.Conv2d(dim,context_dim,kernel_size=1,stride=1,padding=0)self.to_q = nn.Linear(context_dim, hidden_dim, bias=False)self.to_k = nn.Linear(context_dim, hidden_dim, bias=False)self.to_v = nn.Linear(context_dim, hidden_dim, bias=False)self.to_out = nn.Conv2d(hidden_dim, dim, 1)def forward(self,x,context=None):x_q = self.proj_in(x)b,c,h,w = x_q.shapex_q = rearrange(x_q,"b c h w -> b (h w) c")if context is None:context = x_qif context.ndim == 2:context = rearrange(context,"b c -> b () c")q = self.to_q(x_q)k = self.to_k(context)v = self.to_v(context)q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=self.heads), (q, k, v))sim = torch.einsum('b i d, b j d -> b i j',q,k)*self.scaleattn = sim.softmax(dim=-1)out = torch.einsum('b i j, b j d -> b i d',attn,v)out = rearrange(out, '(b h) n d -> b n (h d)',h=self.heads)out = rearrange(out,'b (h w) c -> b c h w', h=h, w=w)out = self.to_out(out)return outCrossAttn = SpatialCrossAttention(dim=32,context_dim=1024)
x = torch.randn(8,32,256,256)
context = torch.randn(8,1024)
out = CrossAttn(x,context)

2. Channel-wise attention

融入方式与time-embedding的融入方式相同。基于channel-wise的融入粒度没有CrossAttention细,一般使用类别数量有限的特征融入,如时间embedding、类别embedding。语义信息的融入更推荐使用CrossAttention。
在这里插入图片描述

 #如果参数为“default”,则时间融合方式为相加#如果参数为“scale_shift”,则时间融合方式为scale_shift方法if temb_channels is not None:if self.time_embedding_norm == "default":self.time_emb_proj = linear_cls(temb_channels, out_channels)elif self.time_embedding_norm == "scale_shift":self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)else:raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)if self.time_embedding_norm == "default":if temb is not None:hidden_states = hidden_states + tembhidden_states = self.norm2(hidden_states)
elif self.time_embedding_norm == "scale_shift":if temb is None:raise ValueError(f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}")time_scale, time_shift = torch.chunk(temb, 2, dim=1)hidden_states = self.norm2(hidden_states)hidden_states = hidden_states * (1 + time_scale) + time_shift
else:hidden_states = self.norm2(hidden_states)

三、采样过程的代码实现

类似于 Improved DDPM。一个示意性的代码如下:

for timestep in tqdm(scheduler.timesteps):# 预测噪声with torch.no_grad():noise_pred = unet(images, timestep, condition).sample# Classifier-Free Guidancenoise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)noise_pred = (1.0 - guidance_scale) * noise_pred_uncond + guidance_scale * noise_pred_condimages = scheduler.step(noise_pred, timestep, images).prev_sample

参考资料:
https://lichtung612.github.io/posts/3-diffusion-models/
https://zhuanlan.zhihu.com/p/660518657


http://www.mrgr.cn/news/66296.html

相关文章:

  • 冰雪奇缘!中科院一区算法+双向深度学习+注意力机制!SAO-BiTCN-BiGRU-Attention雪消融算法优化回归预测
  • 【Android】使用productFlavors构建多个变体
  • git clone,用https还是ssh
  • 【含开题报告+文档+源码】基于Web的房地产销售网站的设计与实现
  • Prometheus套装部署到K8S+Dashboard部署详解
  • 动态规划—目标和
  • 『VUE』19. scope避免组件之间样式互相覆盖(详细图文注释)
  • MATLAB - ROS 2 分析器
  • 欢迎使用Markdown编辑器
  • GaussDB高智能--库内AI引擎:模型管理数据集管理
  • 省级-社会保障水平数据(2007-2022年)
  • 视频编辑学习笔记
  • “大跳水”的全新奥迪A3,精准狙击年轻人的心
  • 【NOIP普及组】明明的随机数
  • 华为HarmonyOS借助AR引擎帮助应用实现虚拟与现实交互的能力3-获取设备位姿
  • 腾讯混元宣布大语言模型和3D模型正式开源
  • 外包干了6年,技术退步明显.......
  • 小张求职记五
  • C++【string类,模拟实现string类】
  • 数码管驱动电路音响LED驱动芯片VK1640
  • 【通俗理解】自由能与熵的关系是怎样的? ——从热力学第二定律看自由能最小化与熵最大化的趋势
  • C++ <string> 标头文件详解
  • 多线程--模拟实现定时器--Java
  • 了解分布式数据库系统中的CAP定理
  • 【初阶数据结构与算法】复杂度分析练习之轮转数组(多种方法)
  • 华为HarmonyOS借助AR引擎帮助应用实现虚拟与现实交互的能力2-管理AR会话