[万字长文]stable diffusion代码阅读笔记
stable diffusion代码阅读笔记
获得更好的阅读体验可以转到我的博客y0k1n0的小破站
本文参考的配置文件信息:
AutoencoderKL
:stable-diffusion\configs\autoencoder\autoencoder_kl_32x32x4.yaml
latent-diffusion
:stable-diffusion\configs\latent-diffusion\lsun_churches-ldm-kl-8.yaml
ldm
modules
diffusionmodules
model.py
Nromalize
函数
def Normalize(in_channels, num_groups=32):"""创建GroupNorm层Args:in_channels: 输入通道数num_groups: 分组数量. Defaults to 32.Returns:返回一个 torch.nn.GroupNorm 层的实例""" return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
这个方法定义了一个归一化层的方式,使用群归一化有利于提高训练速度和模型稳定性
ResnetBlock
类
这个类定义了使用的残差块的模型,前向传播模型如下图所示
注释代码如下:
class ResnetBlock(nn.Module):def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,dropout, temb_channels=512):"""Resnet模块实现Args:in_channels: 输入通道数dropout: Dropout率out_channels: 输出通道数. Defaults to None.conv_shortcut: 是否使用卷积快速链接. Defaults to False.temb_channels: 时间嵌入通道数. Defaults to 512.""" super().__init__()self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channelsself.out_channels = out_channelsself.use_conv_shortcut = conv_shortcutself.norm1 = Normalize(in_channels)self.conv1 = torch.nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1)if temb_channels > 0:self.temb_proj = torch.nn.Linear(temb_channels,out_channels)self.norm2 = Normalize(out_channels)self.dropout = torch.nn.Dropout(dropout)self.conv2 = torch.nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1)if self.in_channels != self.out_channels:if self.use_conv_shortcut:self.conv_shortcut = torch.nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1)else:self.nin_shortcut = torch.nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=1,padding=0)def forward(self, x, temb):"""前线传播方法,用于计算输入张量x经过Resnet block后的输出Args:x: 输入张量temb: 时间嵌入Returns:残差块的输出""" h = xh = self.norm1(h)h = nonlinearity(h)h = self.conv1(h)if temb is not None:h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] # 拓展temp为四维h = self.norm2(h)h = nonlinearity(h)h = self.dropout(h)h = self.conv2(h)if self.in_channels != self.out_channels:if self.use_conv_shortcut:x = self.conv_shortcut(x)else:x = self.nin_shortcut(x)return x+h
其中定义的norm1
和norm2
来自torch.nn.GroupNorm
,为一个群归一化层
AttnBlock
类
这个类定义了经典的自注意力机制,其前向传播过程的模型如下
详细代码如下图所示
class AttnBlock(nn.Module):def __init__(self, in_channels):"""经典自注意力模块Args:in_channels: 输入通道数""" super().__init__()self.in_channels = in_channelsself.norm = Normalize(in_channels)self.q = torch.nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,padding=0) # 对通道层做的线性变换self.k = torch.nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,padding=0)self.v = torch.nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,padding=0)self.proj_out = torch.nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,padding=0) # 用于将经过注意力计算后的输出重新投影到输入维度的卷积层def forward(self, x):"""前向传播,计算输入x的自注意力Args:x: 输入向量Returns:_description_""" h_ = xh_ = self.norm(h_)q = self.q(h_)k = self.k(h_)v = self.v(h_)# compute attentionb,c,h,w = q.shapeq = q.reshape(b,c,h*w)q = q.permute(0,2,1) # b,hw,ck = k.reshape(b,c,h*w) # b,c,hww_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 计算注意力权重,每个位置之间的关系w_ = w_ * (int(c)**(-0.5)) # 对注意力权重进行缩放,保证数值稳定w_ = torch.nn.functional.softmax(w_, dim=2)# attend to valuesv = v.reshape(b,c,h*w)w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]h_ = h_.reshape(b,c,h,w)h_ = self.proj_out(h_)return x+h_
LinearAttension
类
LinearAttension类实现了一个优化过的自注意力算法,具体而言他通过改变矩阵的计算次序,将时间复杂度从 O ( N 2 ) O(N^2) O(N2)降低到 O ( N ) O(N) O(N)
详细代码如下:
class LinearAttention(nn.Module):def __init__(self, dim, heads=4, dim_head=32): """实现了一个线性注意力机制加速注意力计算,实现方式与AttnBlock类似,但比AttnBlock快\n先计算v和softmax(k)的乘积在与q乘法,复杂度从O(N^2)到O(N)Args:dim: 输入特征维度heads: 注意力头数量. Defaults to 4.dim_head: 每个注意力头维度. Defaults to 32.""" super().__init__()self.heads = headshidden_dim = dim_head * heads # 隐藏层维度self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) #1x1卷积层: q, k, v每个向量的维度都是hidden_dimself.to_out = nn.Conv2d(hidden_dim, dim, 1) # 1x1卷积层:重新投影回初始维度def forward(self, x):b, c, h, w = x.shapeqkv = self.to_qkv(x)q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v)out = torch.einsum('bhde,bhdn->bhen', context, q)out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)return self.to_out(out)
LinAttnBlock
类
LinAttnBlock
类继承自LinearAttension
,设定了输入特征数为输入通道数,注意力头的数量为1
详细代码:
class LinAttnBlock(LinearAttention):"""to match AttnBlock usage"""def __init__(self, in_channels):"""继承自LinearAttention,是一个注意力头为1个的特殊的线性注意力机制Args:in_channels: 输入通道数""" super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
make_attn
函数
def make_attn(in_channels, attn_type="vanilla"):"""注意力模块选择函数Args:in_channels: 输入通道数attn_type: 注意力模块. Defaults to "vanilla".Returns:返回所选择的注意力模块实例""" assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'print(f"making attention of type '{attn_type}' with {in_channels} in_channels")if attn_type == "vanilla":return AttnBlock(in_channels)elif attn_type == "none":return nn.Identity(in_channels) # 输入是什么输出就是什么else:return LinAttnBlock(in_channels)
make_attn
函数指定了注意力模块的种类,根据attn_type
的不同取值提供了如下三种注意力模块
vanilla
: 经典自注意力模块,详见AttnBlock类
linear
: 优化的自注意力模块,时间复杂度降低到O(N),详见LinearAttension
类none
: 线性层,即什么也不做,输入是什么,输出就是什么
Downsample
类
Downsample类实现了图像的下采样操作,他提供了两种图像下采样方法
- 平均池化
- 卷积
通过with_conv
来判断使用哪一种下采样方式实现下采样
注释代码如下:
class Downsample(nn.Module):def __init__(self, in_channels, with_conv):"""图像下采样模块Args:in_channels: 输入通道数with_conv: 是否使用卷积下采样""" super().__init__()self.with_conv = with_convif self.with_conv:# no asymmetric padding in torch conv, must do it ourselvesself.conv = torch.nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=2,padding=0) # 使用卷积层将图像尺寸减小为原来的一半def forward(self, x):if self.with_conv:pad = (0,1,0,1) # 手动进行非对称填充,右面和底面填充1个像素x = torch.nn.functional.pad(x, pad, mode="constant", value=0)x = self.conv(x)else:x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) # 平均池化,张量尺寸减半return x
Encoder
类
Encoder类实现了对于源输入的编码过程,从模型结构上来说使用的是Unet
结构的下采样和中间层部分.
模型的前向传播过程如图所示:
根据模型的配置文件参数,AutoEncoderKL
在下采样过程中没有用到AttnBlock
,他的目的是将输入图像编码为潜在变量Z的分布的描述,包括均值和方差.
Encoder类代码如下:
class Encoder(nn.Module):def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",**ignore_kwargs):"""为AutoEncoderKL的编码器部分Args:ch: 初始通道数,用于第一层卷积out_ch: 最终输出的通道数num_res_blocks: 每个分辨率层中的残差块数量attn_resolutions: 在哪些分辨率下使用注意力机制in_channels: 输入图像的通道数resolution: 输入图像的分辨率z_channels: 最终潜在空间的维度数ch_mult: 通道数的倍增系数,每一层的通道数是初始通道数乘以一个倍增系数. Defaults to (1,2,4,8).dropout: 用于控制ResnetBlock中的丢弃率. Defaults to 0.0.resamp_with_conv: 下采样时是否使用卷积操作. Defaults to True.double_z: 控制输出的通道数是否加倍,用于生成均值和标准差. Defaults to True.use_linear_attn: 是否使用线性注意力代替标准注意力. Defaults to False.attn_type: 使用的注意力类型. Defaults to "vanilla".""" super().__init__()if use_linear_attn: attn_type = "linear"self.ch = chself.temb_ch = 0 # 时间嵌入的通道数self.num_resolutions = len(ch_mult)self.num_res_blocks = num_res_blocksself.resolution = resolutionself.in_channels = in_channels# downsamplingself.conv_in = torch.nn.Conv2d(in_channels,self.ch,kernel_size=3,stride=1,padding=1) # 图像大小保持不变curr_res = resolutionin_ch_mult = (1,)+tuple(ch_mult) # (1, 1, 2, 4, 8)self.in_ch_mult = in_ch_multself.down = nn.ModuleList()for i_level in range(self.num_resolutions): # i_level初值为1block = nn.ModuleList()attn = nn.ModuleList()block_in = ch*in_ch_mult[i_level]block_out = ch*ch_mult[i_level]for i_block in range(self.num_res_blocks):block.append(ResnetBlock(in_channels=block_in,out_channels=block_out,temb_channels=self.temb_ch,dropout=dropout))block_in = block_outif curr_res in attn_resolutions:attn.append(make_attn(block_in, attn_type=attn_type))down = nn.Module()down.block = blockdown.attn = attnif i_level != self.num_resolutions-1:down.downsample = Downsample(block_in, resamp_with_conv)curr_res = curr_res // 2self.down.append(down)# middleself.mid = nn.Module()self.mid.block_1 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout)self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)self.mid.block_2 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout)# endself.norm_out = Normalize(block_in)self.conv_out = torch.nn.Conv2d(block_in,2*z_channels if double_z else z_channels,kernel_size=3,stride=1,padding=1)def forward(self, x):"""前向传播方法,经过下采样和中间层得到潜在变量zArgs:x: 输入特征图Returns:潜在变量z,维度为z_channels或2*z_channels,包括均值和方差""" # timestep embeddingtemb = None# downsamplinghs = [self.conv_in(x)]for i_level in range(self.num_resolutions):for i_block in range(self.num_res_blocks):h = self.down[i_level].block[i_block](hs[-1], temb)if len(self.down[i_level].attn) > 0:h = self.down[i_level].attn[i_block](h)hs.append(h)if i_level != self.num_resolutions-1:hs.append(self.down[i_level].downsample(hs[-1]))# middleh = hs[-1]h = self.mid.block_1(h, temb)h = self.mid.attn_1(h)h = self.mid.block_2(h, temb)# endh = self.norm_out(h)h = nonlinearity(h)h = self.conv_out(h)return h
Decoder
类
Decoder类实现了对于潜在变量z的解码,将潜在变量z解码为生成图像h,从模型上来说使用的是Unet的右半部和上采样部分
模型的前向传播过程如图所示:
Decoder类代码如下:
class Decoder(nn.Module):def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,attn_type="vanilla", **ignorekwargs):"""解码器,将潜在变量z转换为生成图像Args:ch: 初始通道数,控制网络中的通道数out_ch: 最终输出的通道数num_res_blocks: 每一层中 Resnet Block 的数量attn_resolutions: 决定在哪些分辨率层应用注意力机制in_channels: 输入通道数resolution: 原始输入的分辨率z_channels: 潜在空间的通道数,即编码后的特征图大小ch_mult: 通道倍增系数,用于控制每层的通道数变化. Defaults to (1,2,4,8).dropout: Dropout 的概率. Defaults to 0.0.resamp_with_conv: 是否使用卷积进行上采样. Defaults to True.give_pre_end: 如果为 True, 返回最终卷积之前的特征图. Defaults to False.tanh_out: 如果为 True, 使用 tanh 函数将输出值范围限制在 [-1, 1]. Defaults to False.use_linear_attn: 是否使用线性注意力. Defaults to False.attn_type: 注意力类型. Defaults to "vanilla".""" super().__init__()if use_linear_attn: attn_type = "linear"self.ch = chself.temb_ch = 0self.num_resolutions = len(ch_mult)self.num_res_blocks = num_res_blocksself.resolution = resolutionself.in_channels = in_channelsself.give_pre_end = give_pre_endself.tanh_out = tanh_out# compute in_ch_mult, block_in and curr_res at lowest resin_ch_mult = (1,)+tuple(ch_mult) # (1, 1, 2, 4, 8)block_in = ch*ch_mult[self.num_resolutions-1]curr_res = resolution // 2**(self.num_resolutions-1)self.z_shape = (1,z_channels,curr_res,curr_res)print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))# z to block_inself.conv_in = torch.nn.Conv2d(z_channels,block_in,kernel_size=3,stride=1,padding=1)# middleself.mid = nn.Module()self.mid.block_1 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout)self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)self.mid.block_2 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout)# upsamplingself.up = nn.ModuleList()for i_level in reversed(range(self.num_resolutions)):block = nn.ModuleList()attn = nn.ModuleList()block_out = ch*ch_mult[i_level]for i_block in range(self.num_res_blocks+1):block.append(ResnetBlock(in_channels=block_in,out_channels=block_out,temb_channels=self.temb_ch,dropout=dropout))block_in = block_outif curr_res in attn_resolutions:attn.append(make_attn(block_in, attn_type=attn_type))up = nn.Module()up.block = blockup.attn = attnif i_level != 0:up.upsample = Upsample(block_in, resamp_with_conv)curr_res = curr_res * 2self.up.insert(0, up) # 将up插入到self.up列表的开头# endself.norm_out = Normalize(block_in)self.conv_out = torch.nn.Conv2d(block_in,out_ch,kernel_size=3,stride=1,padding=1)def forward(self, z):"""前向传播方法,从最初的潜在变量z解码得到生成图像Args:z: 潜在变量zReturns:解码得到的生成图像""" #assert z.shape[1:] == self.z_shape[1:]self.last_z_shape = z.shape# timestep embeddingtemb = None# z to block_inh = self.conv_in(z)# middleh = self.mid.block_1(h, temb)h = self.mid.attn_1(h)h = self.mid.block_2(h, temb)# upsamplingfor i_level in reversed(range(self.num_resolutions)):for i_block in range(self.num_res_blocks+1):h = self.up[i_level].block[i_block](h, temb)if len(self.up[i_level].attn) > 0:h = self.up[i_level].attn[i_block](h)if i_level != 0:h = self.up[i_level].upsample(h)# endif self.give_pre_end:return hh = self.norm_out(h)h = nonlinearity(h)h = self.conv_out(h)if self.tanh_out:h = torch.tanh(h)return h
distributions
distributions.py
DiagonalGaussianDistribution
类
对角高斯分布类使用编码器Encoder
对输入特征x
的编码得到的潜在变量z
,根据z
中含有的均值方差等信息建立了对角高斯分布,提供了计算均值方差、采样、计算KL散度、计算负对数似然等方法
__init__
方法
def __init__(self, parameters, deterministic=False):"""对角高斯分布,存储对角高斯分布的均值方差等参数,并提供了采样方式Args:parameters: 潜在变量zdeterministic: 参数是否为确定性分布. Defaults to False.""" self.parameters = parametersself.mean, self.logvar = torch.chunk(parameters, 2, dim=1)self.logvar = torch.clamp(self.logvar, -30.0, 20.0) # 防止方差过大过小self.deterministic = deterministicself.std = torch.exp(0.5 * self.logvar) # 标准差self.var = torch.exp(self.logvar) # 方差if self.deterministic:self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) # 确定性分布方差为0
构造函数中根据潜在变量z
确定了对角高斯分布的均值和方差信息,如果deterministic
为真,则使方差为0,让高斯分布退化为一个确定的分布
sample
方法
def sample(self):"""从对角高斯分布中采样\n x=μ+σ⋅ϵ\nϵ为高斯白噪声Returns:返回采样得到的变量""" x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)return x
sample
方法使用如下公式计算采样得到的分布
x = μ + σ ϵ , ϵ ∼ N ( 0 , I ) x=\mu+\sigma\epsilon,\quad\epsilon \sim N(0, I) x=μ+σϵ,ϵ∼N(0,I)
kl
方法
def kl(self, other=None):"""计算KL散度Args:other: 与哪一个分布计算KL散度,默认与正态分布计算. Defaults to None.Returns:_description_""" if self.deterministic:return torch.Tensor([0.])else:if other is None:return 0.5 * torch.sum(torch.pow(self.mean, 2)+ self.var - 1.0 - self.logvar,dim=[1, 2, 3])else:return 0.5 * torch.sum(torch.pow(self.mean - other.mean, 2) / other.var+ self.var / other.var - 1.0 - self.logvar + other.logvar,dim=[1, 2, 3])
-
KL 散度 用于衡量两个分布之间的差异
当other
为None
时,表示计算与标准正态分布(均值为 0,方差为 1)的 KL 散度,公式如下:
D K L ( q ∣ ∣ p ) = 0.5 ⋅ ∑ ( μ 2 + σ 2 − 1 − log ( σ 2 ) ) D_{KL}(q || p) = 0.5 \cdot \sum \left( \mu^2 + \sigma^2 - 1 - \log(\sigma^2) \right) DKL(q∣∣p)=0.5⋅∑(μ2+σ2−1−log(σ2)) -
当
other
不为None
时,表示计算与另一个对角高斯分布的 KL 散度:
D K L ( q ∣ ∣ p ) = 0.5 ⋅ ∑ ( ( μ q − μ p ) 2 σ p 2 + σ q 2 σ p 2 − 1 − log σ q 2 σ p 2 ) D_{KL}(q || p) = 0.5 \cdot \sum \left( \frac{(\mu_q - \mu_p)^2}{\sigma_p^2} + \frac{\sigma_q^2}{\sigma_p^2} - 1 - \log \frac{\sigma_q^2}{\sigma_p^2} \right) DKL(q∣∣p)=0.5⋅∑(σp2(μq−μp)2+σp2σq2−1−logσp2σq2)
nll
方法
def nll(self, sample, dims=[1,2,3]):"""计算负对数似然Args:sample: 真实样本dims: 维度信息. Defaults to [1,2,3].Returns:_description_""" if self.deterministic:return torch.Tensor([0.])logtwopi = np.log(2.0 * np.pi)return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,dim=dims)
- 负对数似然(NLL) 是一种衡量数据点与分布拟合程度的指标。公式为:
NLL = 0.5 ⋅ ∑ ( log ( 2 π ) + log ( σ 2 ) + ( x − μ ) 2 σ 2 ) \text{NLL} = 0.5 \cdot \sum \left( \log(2\pi) + \log(\sigma^2) + \frac{(x - \mu)^2}{\sigma^2} \right) NLL=0.5⋅∑(log(2π)+log(σ2)+σ2(x−μ)2)
其中, x x x 是真实样本, μ \mu μ 是均值, σ 2 \sigma^2 σ2 是方差。
mode
方法
def mode(self):"""众数,高斯分布的众数即均值Returns:返回高斯分布的众数(均值)""" return self.mean
- 众数(
mode
)即分布的均值,因为高斯分布的众数就是其均值。
discriminator
model.py
NLayerDiscriminator
类
这个函数实现了一个GAN判别器,用于判断输入图像的局部区域是否是真是图像,模型通过不同通道数的卷积和激活函数提取生成图像并判别真假,模型结构如下:
详细注释代码如下:
class NLayerDiscriminator(nn.Module):"""Defines a PatchGAN discriminator as in Pix2Pix--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py"""def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):"""PatchGAN 判别器,用于判断输入图像的局部区域是否为真实图像。它通过多层卷积逐步提取特征,并输出一个单通道的特征图,表示每个局部区域的真实性。这个结构中的层数和滤波器数量可以根据需求调整。Args:input_nc: 输入图像的通道数. Defaults to 3.ndf: 第一层卷积层的输出通道数. Defaults to 64.n_layers: 卷积层的层数. Defaults to 3.use_actnorm: 是否使用激活归一化层. Defaults to False."""super(NLayerDiscriminator, self).__init__()if not use_actnorm:norm_layer = nn.BatchNorm2delse:norm_layer = ActNormif type(norm_layer) == functools.partial: # BatchNorm2d 自带仿射变换(即有偏置和缩放参数)use_bias = norm_layer.func != nn.BatchNorm2delse:use_bias = norm_layer != nn.BatchNorm2dkw = 4 # 卷积核大小padw = 1 # 填充大小sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]nf_mult = 1nf_mult_prev = 1for n in range(1, n_layers): # gradually increase the number of filtersnf_mult_prev = nf_multnf_mult = min(2 ** n, 8)sequence += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),norm_layer(ndf * nf_mult),nn.LeakyReLU(0.2, True)]nf_mult_prev = nf_multnf_mult = min(2 ** n_layers, 8)sequence += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),norm_layer(ndf * nf_mult),nn.LeakyReLU(0.2, True)]sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction mapself.main = nn.Sequential(*sequence)def forward(self, input):"""PatchGAN 判别器,用于判断输入图像的局部区域是否为真实图像。它通过多层卷积逐步提取特征,并输出一个单通道的特征图,表示每个局部区域的真实性。这个结构中的层数和滤波器数量可以根据需求调整。Args:input: 输入图像Returns:通道数为1的卷积,用于判断图像真实性""" """Standard forward."""return self.main(input)
losses
vqperceptual.py
hinge_d_loss
函数
- 对 真实样本,我们希望判别器输出的分数尽可能大于 1(越大越好),因此
1. - logits_real
会惩罚得分小于 1 的情况。 - 对 生成样本,我们希望判别器输出的分数尽可能小于 -1,
1. + logits_fake
会惩罚得分高于 -1 的情况。
注释代码:
def hinge_d_loss(logits_real, logits_fake):"""GAN判别器损失函数Args:logits_real: 判别器对真实样本的输出logits_fake: 判别器对生成样本的输出Returns:最终的判别器损失""" loss_real = torch.mean(F.relu(1. - logits_real)) # 计算真实样本的损失, 希望 logits_real 尽可能大于 1loss_fake = torch.mean(F.relu(1. + logits_fake)) # 计算生成样本的损失, 希望 logits_fake 尽可能小于 -1d_loss = 0.5 * (loss_real + loss_fake) # 求均值return d_loss
lpips.py
vgg16
类
VGG16
使用的是固定的预训练权重参数,通过将网络整体分为五个部分,存储每个部分的输出及其对应的标签作为前向传播的整体输出.
注释代码如下:
class vgg16(torch.nn.Module):def __init__(self, requires_grad=False, pretrained=True):"""预训练的VGG16网络Args:requires_grad: 是否需要梯度信息. Defaults to False.pretrained: 是否使用预训练权重. Defaults to True.""" super(vgg16, self).__init__()vgg_pretrained_features = models.vgg16(pretrained=pretrained).featuresself.slice1 = torch.nn.Sequential()self.slice2 = torch.nn.Sequential()self.slice3 = torch.nn.Sequential()self.slice4 = torch.nn.Sequential()self.slice5 = torch.nn.Sequential()self.N_slices = 5for x in range(4):self.slice1.add_module(str(x), vgg_pretrained_features[x])for x in range(4, 9):self.slice2.add_module(str(x), vgg_pretrained_features[x])for x in range(9, 16):self.slice3.add_module(str(x), vgg_pretrained_features[x])for x in range(16, 23):self.slice4.add_module(str(x), vgg_pretrained_features[x])for x in range(23, 30):self.slice5.add_module(str(x), vgg_pretrained_features[x])if not requires_grad:for param in self.parameters():param.requires_grad = Falsedef forward(self, X):"""将整个网络分为五个部分,记录每个部分的输出并返回Args:X: 输入特征xReturns:包含网络中五个部分的输出特征的字典""" h = self.slice1(X)h_relu1_2 = hh = self.slice2(h)h_relu2_2 = hh = self.slice3(h)h_relu3_3 = hh = self.slice4(h)h_relu4_3 = hh = self.slice5(h)h_relu5_3 = hvgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)return out
NetLinlayer
类
该类实现了一个简单的1x1卷积神经网络,用于修改通道数
详细代码如下:
class NetLinLayer(nn.Module):def __init__(self, chn_in, chn_out=1, use_dropout=False):"""通过1x1卷积层将VGG16网络的输出映射到通道数为1的特征向量Args:chn_in: 输入通道数chn_out: 输出通道数. Defaults to 1.use_dropout: 是否使用dropout. Defaults to False.""" super(NetLinLayer, self).__init__()layers = [nn.Dropout(), ] if (use_dropout) else []layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]self.model = nn.Sequential(*layers)
ScalingLayer
类
class ScalingLayer(nn.Module):def __init__(self):"""缩放层,对输入的张量标准化处理""" super(ScalingLayer, self).__init__()self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])def forward(self, inp):"""前向传播,标准化输入张量Args:inp: 输入张量Returns:标准化后的张量""" return (inp - self.shift) / self.scale
LPIPS
类
LPIPS
类计算的是两个输入图像之间的感知损失,模型如下图所示
注释代码如下:
class LPIPS(nn.Module):# Learned perceptual metricdef __init__(self, use_dropout=True):"""计算感知损失,通过预训练的VGG16网络衡量两张图像之间的视觉相似性Args:use_dropout: 用于控制是否在 NetLinLayer 中使用 dropout 层. Defaults to True.""" super().__init__()self.scaling_layer = ScalingLayer()self.chns = [64, 128, 256, 512, 512] # VGG16 网络中提取的不同特征层的通道数self.net = vgg16(pretrained=True, requires_grad=False) # 预训练的VGG16网络self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)self.load_from_pretrained()for param in self.parameters():param.requires_grad = Falsedef load_from_pretrained(self, name="vgg_lpips"):ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)print("loaded pretrained LPIPS loss from {}".format(ckpt))@classmethoddef from_pretrained(cls, name="vgg_lpips"):if name != "vgg_lpips":raise NotImplementedErrormodel = cls()ckpt = get_ckpt_path(name)model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)return modeldef forward(self, input, target):"""计算两个特征图之间的像素差异度(感知差异度)Args:input: 输入的特征图target: 与输入特征图比较差异度的特征图Returns:感知差异""" in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) # 标准化缩放outs0, outs1 = self.net(in0_input), self.net(in1_input)feats0, feats1, diffs = {}, {}, {}lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]for kk in range(len(self.chns)):feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) # 标准化处理diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 # 求每个像素之间差异的平方res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] # 映射到一个通道上面去求均值[c, 1, 1]val = res[0]for l in range(1, len(self.chns)):val += res[l] # 累加获得最终的感知差异度return val
contperceptual.py
LPIPSWithDiscriminator
类
该类用于计算并更新生成器和判别器
- 更新生成器:
- 计算重构损失和感知损失
- 根据重构损失和感知损失得到负对数似然损失
- 计算KL散度(与标准正态分布)
- 计算判别器损失
- 总损失函数=负对数似然损失+KL散度+判别器损失
- 更新判别器
- 计算真实图像和重建图像判别结果
- 计算对抗损失
- 总损失函数=判别器对抗损失函数
详细代码如下:
class LPIPSWithDiscriminator(nn.Module):def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,disc_loss="hinge"):"""损失函数类,包括感知损失和判别器损失Args:disc_start: 判别器开始工作的时间点logvar_init: 初始化对数方差的初始值. Defaults to 0.0.kl_weight: KL散度的权重。. Defaults to 1.0.pixelloss_weight: 像素级损失的权重. Defaults to 1.0.disc_num_layers: 判别器中的层数. Defaults to 3.disc_in_channels: 判别器输入的通道数. Defaults to 3.disc_factor: 判别器的损失因子. Defaults to 1.0.disc_weight: 自适应判别器权重. Defaults to 1.0.perceptual_weight: 感知损失的权重. Defaults to 1.0.use_actnorm: 是否在判别器中使用 actnorm 层. Defaults to False.disc_conditional: 判别器是否为条件 GAN. Defaults to False.disc_loss: 判别器使用的损失函数. Defaults to "hinge".""" super().__init__()assert disc_loss in ["hinge", "vanilla"]self.kl_weight = kl_weight # 0.000001self.pixel_weight = pixelloss_weightself.perceptual_loss = LPIPS().eval()self.perceptual_weight = perceptual_weight# output log varianceself.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,n_layers=disc_num_layers,use_actnorm=use_actnorm).apply(weights_init)self.discriminator_iter_start = disc_start # 50001self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_lossself.disc_factor = disc_factorself.discriminator_weight = disc_weight # 0.5self.disc_conditional = disc_conditionaldef calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):if last_layer is not None:nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]else:nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()d_weight = d_weight * self.discriminator_weightreturn d_weightdef forward(self, inputs, reconstructions, posteriors, optimizer_idx,global_step, last_layer=None, cond=None, split="train",weights=None):"""AutoEncoderKL参数损失函数计算Args:inputs: 原始输入图像reconstructions: 模型重建的图像posteriors: 用于计算 KL 散度的后验分布optimizer_idx: 用于区分是在更新生成器(=1)还是判别器(=0)global_step: 当前的训练步数last_layer: 用于自适应权重计算的最后一层. Defaults to None.cond: 条件 GAN 的输入,如果存在,生成器和判别器都会将其作为输入的一部分. Defaults to None.split: 训练模式. Defaults to "train".weights: 用于加权重建损失. Defaults to None.Returns:最终的损失函数,日志文件""" rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) # 重构损失:inputs和resconstructions之差的绝对值if self.perceptual_weight > 0:p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) # 计算感知损失,通过VGG网络计算rec_loss = rec_loss + self.perceptual_weight * p_loss # 更新损失为重构损失+感知损失nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar # 负对数似然损失weighted_nll_loss = nll_lossif weights is not None:weighted_nll_loss = weights*nll_lossweighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] # 计算每个样本平均的负对数似然损失nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]kl_loss = posteriors.kl() # 计算KL损失kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]# now the GAN partif optimizer_idx == 0:# 生成器更新if cond is None:assert not self.disc_conditionallogits_fake = self.discriminator(reconstructions.contiguous()) # 计算判别器对于重建图像的预测else:assert self.disc_conditionallogits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))g_loss = -torch.mean(logits_fake) # 反转损失函数, 优化最小化g_loss等价于最大化判别器对于重建图像的预测,即最大化判别器认为重建图像真实性if self.disc_factor > 0.0:try:d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) # 计算自适应权重except RuntimeError:assert not self.trainingd_weight = torch.tensor(0.0)else:d_weight = torch.tensor(0.0)disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) # 根据时间步判断是否使用判别器损失loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss # 加权后的重建损失+加权后的KL散度+加权后的判别器损失log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),"{}/rec_loss".format(split): rec_loss.detach().mean(),"{}/d_weight".format(split): d_weight.detach(),"{}/disc_factor".format(split): torch.tensor(disc_factor),"{}/g_loss".format(split): g_loss.detach().mean(),}return loss, logif optimizer_idx == 1:# 判别器更新if cond is None:logits_real = self.discriminator(inputs.contiguous().detach()) # 计算真实图像损失logits_fake = self.discriminator(reconstructions.contiguous().detach()) # 计算重建图像损失else:logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) # 判断是否计算判别器损失d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) # 计算对抗损失log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),"{}/logits_real".format(split): logits_real.detach().mean(),"{}/logits_fake".format(split): logits_fake.detach().mean()}return d_loss, log
models
autoencoder.py
AutoencoderKL
类
这个类实现的是第一阶段的训练任务
encode
方法
def encode(self, x):"""编码器Args:x: 输入的特征图Returns:先验分布""" h = self.encoder(x) # 潜在变量zmoments = self.quant_conv(h) # 嵌入向量posterior = DiagonalGaussianDistribution(moments) # 实例化为对角高斯分布作为先验分布return posterior
该函数将输入特征图转变为潜在变量z后经过嵌入层,最终实例化为对角高斯分布
模型结构如下:
decode
方法
def decode(self, z):"""解码器Args:z: 采样得到的嵌入向量Returns:解码得到的输出特征图""" z = self.post_quant_conv(z)dec = self.decoder(z)return dec
解码器则是将嵌入层变量z
先通过卷积映射到潜在变量z
的维度上,然后使用解码器进行解码得到目的特征图
模型结构如下:
forward
方法
def forward(self, input, sample_posterior=True):"""前向传播方法,计算输入特征图经过encoder得到的先验分布,并从中采样经过解码器解码得到输出图像Args:input: 输入特征图sample_posterior: 是否使用采样. Defaults to True.Returns:解码得到的图片和先验分布""" posterior = self.encode(input)if sample_posterior:z = posterior.sample()else:z = posterior.mode()dec = self.decode(z)return dec, posterior
模型如下图所示
ddpm.py
DiffusionWrapper
类
这个类实现了一个包装器,通过处理不同情况的条件输入,将条件输入和输入图像一同送进模型
注释代码如下:
class DiffusionWrapper(pl.LightningModule):def __init__(self, diff_model_config, conditioning_key):"""一个用于扩散模型的包装器,提供了一种灵活的方式来处理不同的条件输入类型Args:diff_model_config: 一个配置字典,用于创建扩散模型的配置conditioning_key: 决定如何将条件信息与扩散模型结合""" super().__init__()self.diffusion_model = instantiate_from_config(diff_model_config)self.conditioning_key = conditioning_keyassert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):"""处理条件输入,将条件输入和输入图像结合,并通过模型Args:x: 输入图像或噪声t: 扩散过程中的时间步数c_concat: 在 concat 和 hybrid 模式下使用,条件信息会与输入图像拼接在一起. Defaults to None.c_crossattn: 在 crossattn、hybrid 和 adm 模式下使用,作为上下文传递给模型. Defaults to None.Raises:NotImplementedError: _description_Returns:_description_""" if self.conditioning_key is None:out = self.diffusion_model(x, t) # 直接将输入图像和时间步数传给扩散模型,不使用条件信息elif self.conditioning_key == 'concat':xc = torch.cat([x] + c_concat, dim=1) # 将输入图像 x 和条件信息 c_concat 拼接在一起,然后传给扩散模型out = self.diffusion_model(xc, t)elif self.conditioning_key == 'crossattn':cc = torch.cat(c_crossattn, 1) # 将条件信息 c_crossattn 拼接在一起,作为上下文信息传给扩散模型out = self.diffusion_model(x, t, context=cc)elif self.conditioning_key == 'hybrid': # 同时使用拼接和上下文信息,输入图像和条件信息 c_concat 拼接后传给模型,同时将条件信息 c_crossattn 作为上下文传递xc = torch.cat([x] + c_concat, dim=1)cc = torch.cat(c_crossattn, 1)out = self.diffusion_model(xc, t, context=cc)elif self.conditioning_key == 'adm': # 使用 ADM 特定的条件信息,将 c_crossattn[0] 作为 y 传给模型cc = c_crossattn[0]out = self.diffusion_model(x, t, y=cc)else:raise NotImplementedError()return out
ddpm
类
DDPM前向过程
KaTeX parse error: No such environment: eqnarray at position 8: \begin{̲e̲q̲n̲a̲r̲r̲a̲y̲}̲ x_t&=&\sqrt{\a…
据此可以用重参数化技巧写成:
x t ∼ p ( x t ∣ x t − 1 ) = N ( x t ; α t x t − 1 , ( 1 − α t ) I ) x t ∼ p ( x t ∣ x 0 ) = N ( x t ; α t ˉ x 0 , ( 1 − α t ˉ ) I ) x_t \sim p(x_t\mid x_{t-1})=\mathcal{N}(x_t; \sqrt{\alpha_t}x_{t-1}, (1-\alpha_t)I)\\ x_t \sim p(x_t\mid x_{0})=\mathcal{N}(x_t; \sqrt{\bar{\alpha_t}}x_{0}, (1-\bar{\alpha_t})I)\\ xt∼p(xt∣xt−1)=N(xt;αtxt−1,(1−αt)I)xt∼p(xt∣x0)=N(xt;αtˉx0,(1−αtˉ)I)
DDPM反向过程
根据贝叶斯定理有
p ( x t − 1 ∣ x t ) = p ( x t ∣ x t − 1 ) p ( x t − 1 ) p ( x t ) p(x_{t-1}\mid x_t)=\frac{p(x_t\mid x_{t-1})p(x_{t-1})}{p(x_t)} p(xt−1∣xt)=p(xt)p(xt∣xt−1)p(xt−1)
可以在给定 x 0 x_0 x0条件下使用贝叶斯定理:
p ( x t − 1 ∣ x t , x 0 ) = p ( x t ∣ x t − 1 , x 0 ) p ( x t − 1 ∣ x 0 ) p ( x t ∣ x 0 ) p(x_{t-1}\mid x_t, x_0)=\frac{p(x_t\mid x_{t-1}, x_0)p(x_{t-1} \mid x_0)}{p(x_t\mid x_0)} p(xt−1∣xt,x0)=p(xt∣x0)p(xt∣xt−1,x0)p(xt−1∣x0)
带入并整理有
p ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 ( 1 − α t ) 1 − α ˉ t x 0 , ( 1 − α t 1 − α ˉ t − 1 1 − α ˉ t ) 2 ) p(x_{t-1}\mid x_t, x_0)=\mathcal{N}\left( x_{t-1}; \frac{\sqrt{\alpha_t(1-\bar{\alpha}_{t-1})}}{1-\bar{\alpha}_{t}}x_t+\frac{\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)}{1-\bar{\alpha}_t}x_0, \left( \frac{\sqrt{1-\alpha_t}\sqrt{1-\bar{\alpha}_{t-1}}}{\sqrt{1-\bar{\alpha}_t}} \right)^2 \right) p(xt−1∣xt,x0)=N(xt−1;1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1(1−αt)x0,(1−αˉt1−αt1−αˉt−1)2)
使用 x 0 = x t − 1 − α ˉ t ϵ α ˉ t x_0=\frac{x_t-\sqrt{1-\bar{\alpha}_t}\epsilon}{\sqrt{\bar{\alpha}_t}} x0=αˉtxt−1−αˉtϵ替换到公式中的 x 0 x_0 x0可得
KaTeX parse error: No such environment: eqnarray at position 8: \begin{̲e̲q̲n̲a̲r̲r̲a̲y̲}̲ p(x_{t-1}\mid …
其中 ϵ \epsilon ϵ为Unet识别的向神经网络中添加的噪声
q_mean_variance
方法
扩散过程 q ( x t ∣ x 0 ) q(x_t\mid x_{0}) q(xt∣x0)的参数可以通过如下方式计算:
x t ∼ q ( x t ∣ x 0 ) = N ( x t ; α t ˉ x 0 , ( 1 − α t ˉ ) I ) x_t \sim q(x_t\mid x_{0})=\mathcal{N}(x_t; \sqrt{\bar{\alpha_t}}x_{0}, (1-\bar{\alpha_t})I) xt∼q(xt∣x0)=N(xt;αtˉx0,(1−αtˉ)I)
- 均值 μ \mu μ: α t x 0 \sqrt{\alpha_t}x_0 αtx0
- 方差 σ 2 \sigma^2 σ2: 1 − α t ˉ 1-\bar{\alpha_t} 1−αtˉ
- 对数方差 l o g ( σ 2 ) log(\sigma^2) log(σ2): l o g ( 1 − α t ˉ ) log(1-\bar{\alpha_t}) log(1−αtˉ)
注释代码如下:
def q_mean_variance(self, x_start, t):"""用于计算扩散过程中的分布x_t ~ q(x_t | x_0)的均值和方差\nx_t ~ q(x_t | x_0)=N(x_t; sqrt_alphas_cumprod_t * x_0, (1 - alphas_cumprod_t)I)Args:x_start: 一个形状为 [N x C x ...] 的张量,表示无噪声输入数据t: 扩散步骤数(从 0 开始计数)Returns:均值,方差,对数方差""" mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)return mean, variance, log_variance
predict_start_from_noise
方法
从噪声推导原图可以通过以下公式计算:
x 0 = x t − 1 − α ˉ t ϵ α ˉ t x_0=\frac{x_t-\sqrt{1-\bar{\alpha}_t}\epsilon}{\sqrt{\bar{\alpha}_t}} x0=αˉtxt−1−αˉtϵ
其中 ϵ \epsilon ϵ为模型预测的噪声
注释代码如下:
def predict_start_from_noise(self, x_t, t, noise):"""从扩散过程某个时间步t的图像x_t和噪声ε逆推原始图像x_0Args:x_t: 扩散过程在时间步t 时的图像。t: 扩散的时间步索引noise: 噪声εReturns:返回推导得到的原始图像""" return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise)
q_posterior
方法
从 x t x_{t} xt逆向推导 x t − 1 x_{t-1} xt−1的公式如下:
$$
\begin{eqnarray}
p(x_{t-1}\mid x_t, x_0)
&=&\mathcal{N}\left(
x_{t-1};
\frac{\sqrt{\alpha_t(1-\bar{\alpha}{t-1})}}{1-\bar{\alpha}{t}}x_t+\frac{\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)}{1-\bar{\alpha}t}
\times
x_0,
\left(
\frac{\sqrt{1-\alpha_t}\sqrt{1-\bar{\alpha}{t-1}}}{\sqrt{1-\bar{\alpha}_t}}
\right)^2
\right)\
\end{eqnarray}
$$
-
均值 μ = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 ( 1 − α t ) 1 − α ˉ t × x 0 \mu=\frac{\sqrt{\alpha_t(1-\bar{\alpha}_{t-1})}}{1-\bar{\alpha}_{t}}x_t+\frac{\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)}{1-\bar{\alpha}_t}\times x_0 μ=1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1(1−αt)×x0
-
方差 σ 2 = ( 1 − α t 1 − α ˉ t − 1 1 − α ˉ t ) 2 \sigma^2=\left( \frac{\sqrt{1-\alpha_t}\sqrt{1-\bar{\alpha}_{t-1}}}{\sqrt{1-\bar{\alpha}_t}} \right)^2 σ2=(1−αˉt1−αt1−αˉt−1)2
-
对数方差 l o g ( σ 2 ) = m a x ( 1 e − 20 , σ 2 ) log(\sigma^2)=max(1e-20, \sigma^2) log(σ2)=max(1e−20,σ2)
注释代码如下:
def q_posterior(self, x_start, x_t, t):"""函数计算的是在时间步t时,给定初始图象x_t和扩散过程的图像x_t,逆向扩散过程q(x_{t-1}|x_t, x_0)的后验分布的均值和方差Args:x_start: 扩散过程的初始图像x_t: 扩散过程中时间步t的图像t: 当前的时间步索引Returns:均值, 方差, 对数方差(裁剪处理后,避免方差过小不稳定)""" posterior_mean = (extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t)posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)return posterior_mean, posterior_variance, posterior_log_variance_clipped
p_mean_variance
方法
函数将当前时间步的图像数据和时间信息送入模型,得到预测的噪声;在根据预测得到的噪声预测初始图象,并借助初始图象来预测 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}\mid x_t, x_0) q(xt−1∣xt,x0)的均值方差和对数方差
注释代码如下:
def p_mean_variance(self, x, t, clip_denoised: bool):"""计算并返回模型的均值、后验方差和后验对数方差Args:x: 当前时间步的图像数据t: 时间步clip_denoised: 布尔值,指示是否将去噪后的结果裁剪到一个指定的范围内Returns:_description_""" model_out = self.model(x, t) # 预测得到的噪声if self.parameterization == "eps":x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) # 直接从噪声预测初始图像elif self.parameterization == "x0":x_recon = model_out # 预测原图像if clip_denoised:x_recon.clamp_(-1., 1.)model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)return model_mean, posterior_variance, posterior_log_variance
p_sample
方法
函数通过给定的图像信息和时间步,计算 p ( x t − 1 ∣ x t , x 0 ) p(x_{t-1}|x_t, x_0) p(xt−1∣xt,x0),并据此预测x_1步的图像信息
详细注释代码
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):"""函数通过给定的图像信息和时间步,计算p(x_{t-1}|x_t, x_0),并据此预测x_1步的图像信息Args:x: 输入图像或特征图t: 当前时间步或噪声水平clip_denoised: 是否在去噪后裁剪图像. Defaults to True.是否重复使用相同的噪声: _description_. Defaults to False.Returns:返回x_{t-1}去噪的图像""" b, *_, device = *x.shape, x.devicemodel_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)noise = noise_like(x.shape, device, repeat_noise) # 从标准正态分布采样# no noise when t == 0nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) # 确保时间步为0的时候不引入噪声return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
p_sample_loop
方法
用于在扩散模型中进行逐步采样,逐渐将噪声图像还原为清晰的图像,同时还会根据return_intermediates
参数决定是否返回中间的预测结果
详细代码:
def p_sample_loop(self, shape, return_intermediates=False):"""用于在扩散模型中进行逐步采样,逐渐将噪声图像还原为清晰的图像Args:shape: 生成的图像的形状,[b, c, h, w]return_intermediates: 指示是否返回每个时间步的中间结果. Defaults to False.Returns:预测的x_0图像信息""" device = self.betas.deviceb = shape[0]img = torch.randn(shape, device=device) # 生成初始噪声图像intermediates = [img]for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),clip_denoised=self.clip_denoised)if i % self.log_every_t == 0 or i == self.num_timesteps - 1:intermediates.append(img)if return_intermediates:return img, intermediatesreturn img
q_sample
方法
函数实现了从 x 0 x_0 x0添加噪声直接得到 x t x_t xt,参考公式如下:
x t = α t ˉ x 0 + 1 − α t ˉ ϵ t ˉ , N ∼ ( 0 , I ) x_t=\sqrt{\bar{\alpha_t}}x_{0}+\sqrt{1-\bar{\alpha_t}}\bar{\epsilon_t},\quad \mathcal{N}\sim (0, I) xt=αtˉx0+1−αtˉϵtˉ,N∼(0,I)
def q_sample(self, x_start, t, noise=None):"""从x_0添加噪声得到x_tArgs:x_start: 初始的无噪声图像t: 时间步或噪声水平noise: 噪声张量. Defaults to None.Returns:_description_""" noise = default(noise, lambda: torch.randn_like(x_start)) # 如果没有传递噪声张量,则初始化为与x_start同形状的高斯噪声return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
get_loss
方法
计算预测值和目标值之间的损失,根据self.loss_type
选择计算l1损失还是l2损失,并返回最终的损失值
def get_loss(self, pred, target, mean=True):"""计算预测值和目标值之间的损失Args:pred: 模型的预测输出target: 真实图像mean: 是否对损失值进行平均并返回标量损失. Defaults to True.Returns:损失值""" if self.loss_type == 'l1':loss = (target - pred).abs() # l1损失if mean:loss = loss.mean()elif self.loss_type == 'l2':if mean:loss = torch.nn.functional.mse_loss(target, pred) # l2损失else:loss = torch.nn.functional.mse_loss(target, pred, reduction='none')else:raise NotImplementedError("unknown loss type '{loss_type}'")return loss
p_losses
方法
这个函数用于计算真实噪声和预测噪声之间的差值
def p_losses(self, x_start, t, noise=None):"""计算真实噪声和预测噪声之间的差值Args:x_start: 输入图像t: 最大时间步noise: 噪声张量. Defaults to None.Returns:噪声重建损失+变分损失""" noise = default(noise, lambda: torch.randn_like(x_start))x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # 加噪后的图像model_out = self.model(x_noisy, t)loss_dict = {}if self.parameterization == "eps":target = noiseelif self.parameterization == "x0":target = x_startelse:raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])log_prefix = 'train' if self.training else 'val'loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})loss_simple = loss.mean() * self.l_simple_weightloss_vlb = (self.lvlb_weights[t] * loss).mean()loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})loss = loss_simple + self.original_elbo_weight * loss_vlbloss_dict.update({f'{log_prefix}/loss': loss})return loss, loss_dict
forward
方法
def forward(self, x, *args, **kwargs):"""ddpm显示随即生成了batch_size大小的从0到num_timesteps值不等的时间步,并在每个时间步上计算损失Args:x: 输入真实图像Returns:总损失, 日志信息""" # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()return self.p_losses(x, t, *args, **kwargs)
shared_step
方法
def shared_step(self, batch):"""函数从输入batch中取得图像信息,并通过前向传播计算损失和损失日志Args:batch: 一批量的数据Returns:损失,损失日志""" x = self.get_input(batch, self.first_stage_key)loss, loss_dict = self(x)return loss, loss_dict
training_step
方法
def training_step(self, batch, batch_idx):"""执行训练步骤并返回损失Args:batch: 当前批次的数据batch_idx: 当前批次的索引Returns:总损失""" loss, loss_dict = self.shared_step(batch) # 获得损失和损失日志self.log_dict(loss_dict, prog_bar=True,logger=True, on_step=True, on_epoch=True)self.log("global_step", self.global_step,prog_bar=True, logger=True, on_step=True, on_epoch=False)if self.use_scheduler: # 使用调度器监控修改学习率lr = self.optimizers().param_groups[0]['lr']self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)return loss # 返回损失
UNetModel
类
对于UnetModel
类,重点关注模型的各个部分的构成,这里不给出具体的代码分析
时间嵌入
时间嵌入的模型结构如下图所示
ResBlock
ResBlock
的模型结构如下图所示
其中,当不使用上\下采样
时,此处的模块会被一个torch.nn.Identity
替代
AttentionBlock
AttentionBlock
类的模型结构如下:
输出层
输出层
的模型结构如下图所示
下采样层
下采样层的模型结构如下图所示:
中间层
中间层的模型结构如下图所示:
上采样层
上采样层的模型结构如下图所示:
整体模型结构
整体模型结构如下:
LatentDiffusion
类
这个类的大部分方法都类似于DDPM类,因此不详细解释
instantiate_first_stage
方法
该函数用于从配置文件中实例化第一阶段的模型并冻结模型参数
def instantiate_first_stage(self, config):"""用于根据给定的配置实例化第一阶段模型Args:config: 配置信息""" model = instantiate_from_config(config)self.first_stage_model = model.eval() # 设置为评估模式self.first_stage_model.train = disabled_train # 禁用模型训练for param in self.first_stage_model.parameters(): # 冻结模型参数param.requires_grad = False
instantiate_cond_stage
方法
该函数用于从配置文件中实例化条件生成模型并根据参数决定是否冻结模型参数
def instantiate_cond_stage(self, config):"""用于实例化条件生成模型Args:config: 条件模型配置文件""" if not self.cond_stage_trainable: # 不可训练模型会设置为评估模式并冻结参数if config == "__is_first_stage__": # 使用第一阶段的模式作为条件模型print("Using first stage also as cond stage.")self.cond_stage_model = self.first_stage_modelelif config == "__is_unconditional__": # 不适用条件模型print(f"Training {self.__class__.__name__} as an unconditional model.")self.cond_stage_model = None# self.be_unconditional = Trueelse: # 从配置文件中加载条件模型model = instantiate_from_config(config)self.cond_stage_model = model.eval()self.cond_stage_model.train = disabled_trainfor param in self.cond_stage_model.parameters():param.requires_grad = Falseelse:assert config != '__is_first_stage__'assert config != '__is_unconditional__'model = instantiate_from_config(config)self.cond_stage_model = model
__init__
方法
def __init__(self,first_stage_config,cond_stage_config,num_timesteps_cond=None,cond_stage_key="image",cond_stage_trainable=False,concat_mode=True,cond_stage_forward=None,conditioning_key=None,scale_factor=1.0,scale_by_std=False,*args, **kwargs):"""LatentDiffusion,实现了潜在空间上的扩散模型Args:first_stage_config: 自动编码器配置cond_stage_config: 条件编码器配置num_timesteps_cond: 用于控制时间步数的条件. Defaults to None.cond_stage_key: 条件阶段的输入数据类型. Defaults to "image".cond_stage_trainable: 条件阶段是是否训练. Defaults to False.concat_mode: _descri定义条件如何与输入拼接ption_. Defaults to True.cond_stage_forward: 规定条件阶段的前向传播方式. Defaults to None.conditioning_key: 指定如何进行条件处理. Defaults to None.scale_factor: 输入输出缩放因子. Defaults to 1.0.scale_by_std: 是否按照标准差缩放. Defaults to False.""" self.num_timesteps_cond = default(num_timesteps_cond, 1) # 1self.scale_by_std = scale_by_std # trueassert self.num_timesteps_cond <= kwargs['timesteps']# for backwards compatibility after implementation of DiffusionWrapperif conditioning_key is None:conditioning_key = 'concat' if concat_mode else 'crossattn'if cond_stage_config == '__is_unconditional__':conditioning_key = Noneckpt_path = kwargs.pop("ckpt_path", None)ignore_keys = kwargs.pop("ignore_keys", [])super().__init__(conditioning_key=conditioning_key, *args, **kwargs)self.concat_mode = concat_mode # falseself.cond_stage_trainable = cond_stage_trainable # falseself.cond_stage_key = cond_stage_key # imagetry:self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 # 下采样层数except:self.num_downs = 0if not scale_by_std:self.scale_factor = scale_factorelse:self.register_buffer('scale_factor', torch.tensor(scale_factor))self.instantiate_first_stage(first_stage_config)self.instantiate_cond_stage(cond_stage_config)self.cond_stage_forward = cond_stage_forward # Noneself.clip_denoised = Falseself.bbox_tokenizer = None self.restarted_from_ckpt = Falseif ckpt_path is not None:self.init_from_ckpt(ckpt_path, ignore_keys)self.restarted_from_ckpt = True
encode_first_stage
方法
该函数调用了AutoencoderKL
的encode
函数,实现了对于输入向量的编码
def encode_first_stage(self, x):"""调用第一阶段编码器模型Args:x: 输入张量Returns:返回输入张量经编码器的结果""" if hasattr(self, "split_input_params"): # 没有split_input_paramsif self.split_input_params["patch_distributed_vq"]:ks = self.split_input_params["ks"] # eg. (128, 128)stride = self.split_input_params["stride"] # eg. (64, 64)df = self.split_input_params["vqf"]self.split_input_params['original_image_size'] = x.shape[-2:]bs, nc, h, w = x.shapeif ks[0] > h or ks[1] > w:ks = (min(ks[0], h), min(ks[1], w))print("reducing Kernel")if stride[0] > h or stride[1] > w:stride = (min(stride[0], h), min(stride[1], w))print("reducing stride")fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)z = unfold(x) # (bn, nc * prod(**ks), L)# Reshape to img shapez = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )output_list = [self.first_stage_model.encode(z[:, :, :, :, i])for i in range(z.shape[-1])]o = torch.stack(output_list, axis=-1)o = o * weighting# Reverse reshape to img shapeo = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)# stitch crops togetherdecoded = fold(o)decoded = decoded / normalizationreturn decodedelse:return self.first_stage_model.encode(x)else:return self.first_stage_model.encode(x)
get_first_stage_encoding
方法
根据encoder结果的对象类型选择合适的采样方式并缩放
def get_first_stage_encoding(self, encoder_posterior):"""根据encoder结果的对象类型选择合适的采样方式并缩放Args:encoder_posterior: encoder返回的编码的潜在变量Returns:缩放后的采样向量""" if isinstance(encoder_posterior, DiagonalGaussianDistribution): # 如果为高斯分布则采样z = encoder_posterior.sample()elif isinstance(encoder_posterior, torch.Tensor): # 如果是张量则直接返回z = encoder_posteriorelse:raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")return self.scale_factor * z
on_train_batch_start
方法
这个函数在训练的每个批次开始的时候被调用,用于根据潜在变量的维度设置缩放因子
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):"""在训练每个批次的开始时被调用Args:batch: 一批次的数据batch_idx: 批次的iddataloader_idx: _description_""" # only for very first batchif self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: # 确保以下操作只在第一个 epoch、第一个 global step、第一个 batch中执行assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'# set rescale weight to 1./std of encodingsprint("### USING STD-RESCALING ###")x = super().get_input(batch, self.first_stage_key)x = x.to(self.device)encoder_posterior = self.encode_first_stage(x) # 使用第一阶段编码器对数据进行编码,返回编码后的后验分布z = self.get_first_stage_encoding(encoder_posterior).detach() # 采样后得到的潜在变量del self.scale_factorself.register_buffer('scale_factor', 1. / z.flatten().std())print(f"setting self.scale_factor to {self.scale_factor}")print("### USING STD-RESCALING ###")
_get_denoise_row_from_list
方法
该方法用于从给定的样本列表中解码图像,并将他按照网格格式组织并可视化
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):"""从给定的样本列表中解码图像,并将其组织为网格格式以便于可视化Args:samples: 输入样本列表desc: 进度条的描述. Defaults to ''.force_no_decoder_quantization: 是否强制使用量化. Defaults to False.Returns:_description_""" denoise_row = []for zd in tqdm(samples, desc=desc):denoise_row.append(self.decode_first_stage(zd.to(self.device),force_not_quantize=force_no_decoder_quantization))n_imgs_per_row = len(denoise_row)denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, Wdenoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)return denoise_grid
get_input
方法
get_input
方法用于从给定的批量数据中提取输入,并进行条件编码,可返回的信息包括但不限于原输入、原输入x
的潜在变量编码、潜在变量的解码结果、源条件输入、条件编码输出
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,cond_key=None, return_original_cond=False, bs=None):"""从给定的批量数据中提取输入,并进行条件编码Args:batch: 输入的批量数据k: 获取输入的关键字return_first_stage_outputs: 是否返回第一阶段的输出. Defaults to False.force_c_encode: 强制条件编码的标志. Defaults to False.cond_key: 条件输入的关键字. Defaults to None.return_original_cond: 是否返回原始条件信息. Defaults to False.bs: 批量大小. Defaults to None.Returns:原输入、原输入`x`的潜在变量编码、潜在变量的解码结果、源条件输入、条件编码输出""" x = super().get_input(batch, k)if bs is not None:x = x[:bs]x = x.to(self.device)encoder_posterior = self.encode_first_stage(x) # 编码第一阶段的输入z = self.get_first_stage_encoding(encoder_posterior).detach() # 禁用梯度计算if self.model.conditioning_key is not None: # 检查是否有条件输入# 提取相应的条件数据if cond_key is None:cond_key = self.cond_stage_keyif cond_key != self.first_stage_key:if cond_key in ['caption', 'coordinates_bbox']:xc = batch[cond_key]elif cond_key == 'class_label':xc = batchelse:xc = super().get_input(batch, cond_key).to(self.device)else:xc = xif not self.cond_stage_trainable or force_c_encode:if isinstance(xc, dict) or isinstance(xc, list):# import pudb; pudb.set_trace()c = self.get_learned_conditioning(xc) # 获取条件编码else:c = self.get_learned_conditioning(xc.to(self.device))else:c = xcif bs is not None:c = c[:bs]if self.use_positional_encodings: # 添加位置编码信息pos_x, pos_y = self.compute_latent_shifts(batch)ckey = __conditioning_keys__[self.model.conditioning_key]c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}else:c = Nonexc = Noneif self.use_positional_encodings:pos_x, pos_y = self.compute_latent_shifts(batch)c = {'pos_x': pos_x, 'pos_y': pos_y}out = [z, c] # 潜在变量, 条件编码信息if return_first_stage_outputs:xrec = self.decode_first_stage(z)out.extend([x, xrec]) # 源输入, decoder解码信息if return_original_cond:out.append(xc) # 源条件输入信息return out
decode_first_stage
方法
将编码后的表示z
解码为图像,在不使用split_input_params
的情况下,不需要关注if hasattr(self, "split_input_params"):
这部分代码
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):if predict_cids:if z.dim() == 4:z = torch.argmax(z.exp(), dim=1).long()z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)z = rearrange(z, 'b h w c -> b c h w').contiguous()z = 1. / self.scale_factor * zif hasattr(self, "split_input_params"):if self.split_input_params["patch_distributed_vq"]:ks = self.split_input_params["ks"] # eg. (128, 128)stride = self.split_input_params["stride"] # eg. (64, 64)uf = self.split_input_params["vqf"]bs, nc, h, w = z.shapeif ks[0] > h or ks[1] > w:ks = (min(ks[0], h), min(ks[1], w))print("reducing Kernel")if stride[0] > h or stride[1] > w:stride = (min(stride[0], h), min(stride[1], w))print("reducing stride")fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)z = unfold(z) # (bn, nc * prod(**ks), L)# 1. Reshape to img shapez = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )# 2. apply model loop over last dimif isinstance(self.first_stage_model, VQModelInterface):output_list = [self.first_stage_model.decode(z[:, :, :, :, i],force_not_quantize=predict_cids or force_not_quantize)for i in range(z.shape[-1])]else:output_list = [self.first_stage_model.decode(z[:, :, :, :, i])for i in range(z.shape[-1])]o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)o = o * weighting# Reverse 1. reshape to img shapeo = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)# stitch crops togetherdecoded = fold(o)decoded = decoded / normalization # norm is shape (1, 1, h, w)return decodedelse:if isinstance(self.first_stage_model, VQModelInterface):return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)else:return self.first_stage_model.decode(z)else:if isinstance(self.first_stage_model, VQModelInterface):return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)else:return self.first_stage_model.decode(z)
encode_first_stage
方法
该方法主要是在调用第一阶段的编码器模型得到潜在变量的后验分布,在不使用split_input_params
的情况下,不需要关注if hasattr(self, "split_input_params"):
这部分代码
def encode_first_stage(self, x):"""调用第一阶段编码器模型Args:x: 输入张量Returns:返回输入张量经编码器的posterior""" if hasattr(self, "split_input_params"):if self.split_input_params["patch_distributed_vq"]:ks = self.split_input_params["ks"] # eg. (128, 128)stride = self.split_input_params["stride"] # eg. (64, 64)df = self.split_input_params["vqf"]self.split_input_params['original_image_size'] = x.shape[-2:]bs, nc, h, w = x.shapeif ks[0] > h or ks[1] > w:ks = (min(ks[0], h), min(ks[1], w))print("reducing Kernel")if stride[0] > h or stride[1] > w:stride = (min(stride[0], h), min(stride[1], w))print("reducing stride")fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)z = unfold(x) # (bn, nc * prod(**ks), L)# Reshape to img shapez = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )output_list = [self.first_stage_model.encode(z[:, :, :, :, i])for i in range(z.shape[-1])]o = torch.stack(output_list, axis=-1)o = o * weighting# Reverse reshape to img shapeo = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)# stitch crops togetherdecoded = fold(o)decoded = decoded / normalizationreturn decodedelse:return self.first_stage_model.encode(x)else:return self.first_stage_model.encode(x) # posterior
shared_step
方法
这个方法主要是根据潜在变量和条件编码信息,去计算给定条件c
下的损失函数值
def shared_step(self, batch, **kwargs):"""根据潜在变量和条件编码信息,计算在给定条件下的损失函数数值Args:batch: 批次号Returns:给定条件下的损失函数值""" x, c = self.get_input(batch, self.first_stage_key) # 获取潜在变量z和条件编码信息loss = self(x, c) # 调用前向传播,计算在给定条件下的损失函数值return loss
apply_model
方法
这个方法主要是调用模型,得到重构后的图像,同样,在不使用split_input_params
的情况下,不需要关注if hasattr(self, "split_input_params"):
这部分代码
def apply_model(self, x_noisy, t, cond, return_ids=False):if isinstance(cond, dict):# hybrid case, cond is exptected to be a dictpasselse:if not isinstance(cond, list):cond = [cond]key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'cond = {key: cond}if hasattr(self, "split_input_params"):assert len(cond) == 1 # todo can only deal with one conditioning atmassert not return_ids ks = self.split_input_params["ks"] # eg. (128, 128)stride = self.split_input_params["stride"] # eg. (64, 64)h, w = x_noisy.shape[-2:]fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)z = unfold(x_noisy) # (bn, nc * prod(**ks), L)# Reshape to img shapez = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]if self.cond_stage_key in ["image", "LR_image", "segmentation",'bbox_img'] and self.model.conditioning_key: # todo check for completenessc_key = next(iter(cond.keys())) # get keyc = next(iter(cond.values())) # get valueassert (len(c) == 1) # todo extend to list with more than one elemc = c[0] # get elementc = unfold(c)c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]elif self.cond_stage_key == 'coordinates_bbox':assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'# assuming padding of unfold is always 0 and its dilation is always 1n_patches_per_row = int((w - ks[0]) / stride[0] + 1)full_img_h, full_img_w = self.split_input_params['original_image_size']# as we are operating on latents, we need the factor from the original image size to the# spatial latent size to properly rescale the crops for regenerating the bbox annotationsnum_downs = self.first_stage_model.encoder.num_resolutions - 1rescale_latent = 2 ** (num_downs)# get top left postions of patches as conforming for the bbbox tokenizer, therefore we# need to rescale the tl patch coordinates to be in between (0,1)tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)for patch_nr in range(z.shape[-1])]# patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)patch_limits = [(x_tl, y_tl,rescale_latent * ks[0] / full_img_w,rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]# patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]# tokenize crop coordinates for the bounding boxes of the respective patchespatch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)for bbox in patch_limits] # list of length l with tensors of shape (1, 2)print(patch_limits_tknzd[0].shape)# cut tknzd crop position from conditioningassert isinstance(cond, dict), 'cond must be dict to be fed into model'cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)print(cut_cond.shape)adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')print(adapted_cond.shape)adapted_cond = self.get_learned_conditioning(adapted_cond)print(adapted_cond.shape)adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])print(adapted_cond.shape)cond_list = [{'c_crossattn': [e]} for e in adapted_cond]else:cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient# apply model by loop over cropsoutput_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]assert not isinstance(output_list[0],tuple) # todo cant deal with multiple model outputs check this never happenso = torch.stack(output_list, axis=-1)o = o * weighting# Reverse reshape to img shapeo = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)# stitch crops togetherx_recon = fold(o) / normalizationelse:x_recon = self.model(x_noisy, t, **cond) # 重建图像if isinstance(x_recon, tuple) and not return_ids: # 如果重建图像为元组并没有指定return x_recon[0]else:return x_recon
p_losses
方法
p_losses
方法是ddpm
在条件输入上的拓展,同样也是计算预测噪声和初始噪声的损失,并在对损失进行调整和与变分损失叠加作为最终的损失
def p_losses(self, x_start, cond, t, noise=None):noise = default(noise, lambda: torch.randn_like(x_start)) # 初始噪声x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # 加噪后的图像model_output = self.apply_model(x_noisy, t, cond) # 输出的重建图像loss_dict = {}prefix = 'train' if self.training else 'val'if self.parameterization == "x0":target = x_startelif self.parameterization == "eps":target = noiseelse:raise NotImplementedError()loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) # 计算得到的损失loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})logvar_t = self.logvar[t].to(self.device)loss = loss_simple / torch.exp(logvar_t) + logvar_t # 对初始损失的调整# loss = loss_simple / torch.exp(self.logvar) + self.logvarif self.learn_logvar:loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})loss_dict.update({'logvar': self.logvar.data.mean()})loss = self.l_simple_weight * loss.mean()loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})loss += (self.original_elbo_weight * loss_vlb) # 添加变分损失loss_dict.update({f'{prefix}/loss': loss})return loss, loss_dict
forward
方法
forward
方法用于获取在条件输入的情况下,输入图像的真实噪声和预测噪声之间的损失
def forward(self, x, c, *args, **kwargs):t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()if self.model.conditioning_key is not None:assert c is not Noneif self.cond_stage_trainable: # 获取条件编译输出c = self.get_learned_conditioning(c)if self.shorten_cond_schedule: # TODO: drop this optiontc = self.cond_ids[t].to(self.device)c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))return self.p_losses(x, c, t, *args, **kwargs)
p_sample_loop
方法
函数
progressive_denoising
与这个方法类似,因此不再赘述
这个方法用于逐步生成图像的采样循环,实现了从纯噪声开始逐步去噪,直到生成最终的图像(与论文当中的图片最贴切的一集)
def p_sample_loop(self, cond, shape, return_intermediates=False,x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,mask=None, x0=None, img_callback=None, start_T=None,log_every_t=None):"""用于逐步生成图像的采样循环,实现了从纯噪声开始逐步去噪,直到生成最终的图像Args:cond: 条件信息,用于指导生成图像,通常与输入图像相关联shape: 生成图像的形状return_intermediates: 是否返回中间的去噪结果. Defaults to False.x_T: 初始的随机噪声图像,如果为 None,则从标准正态分布中采样噪声. Defaults to None.verbose: 是否显示进度条. Defaults to True.callback: 每一步迭代时的回调函数,可用于监控生成过程. Defaults to None.timesteps: 生成过程中的时间步数。如果未指定,将使用默认的时间步数. Defaults to None.quantize_denoised: 是否对去噪后的图像进行量化. Defaults to False.mask: 可选的掩码,用于在生成时部分保留原图像. Defaults to None.x0: 在有 mask 的情况下,表示被掩盖的部分图像. Defaults to None.img_callback: _description_. Defaults to None.start_T: 开始的时间步,控制从哪一步开始生成. Defaults to None.log_every_t: 设置记录中间结果的步数间隔. Defaults to None.Returns:_description_""" if not log_every_t:log_every_t = self.log_every_tdevice = self.betas.deviceb = shape[0]if x_T is None:img = torch.randn(shape, device=device)else:img = x_Tintermediates = [img]if timesteps is None:timesteps = self.num_timestepsif start_T is not None:timesteps = min(timesteps, start_T)iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(range(0, timesteps))if mask is not None:assert x0 is not Noneassert x0.shape[2:3] == mask.shape[2:3] # spatial size has to matchfor i in iterator:ts = torch.full((b,), i, device=device, dtype=torch.long)if self.shorten_cond_schedule:assert self.model.conditioning_key != 'hybrid'tc = self.cond_ids[ts].to(cond.device)cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))img = self.p_sample(img, cond, ts,clip_denoised=self.clip_denoised,quantize_denoised=quantize_denoised)if mask is not None:img_orig = self.q_sample(x0, ts)img = img_orig * mask + (1. - mask) * imgif i % log_every_t == 0 or i == timesteps - 1:intermediates.append(img)if callback: callback(i)if img_callback: img_callback(img, i)if return_intermediates:return img, intermediatesreturn img
其中在每个iterator
中,img
都要经过p_sample
方法得到前一步预测的图像,逐步预测知道得到最初的初始图象 x 0 x_0 x0
sample
方法
这个方法是对p_sample_loop
方法的一个细化,处理了可能的条件信息并将条件信息作为输入调用p_sample_loop
方法完成采样过程.
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,verbose=True, timesteps=None, quantize_denoised=False,mask=None, x0=None, shape=None,**kwargs):if shape is None:shape = (batch_size, self.channels, self.image_size, self.image_size)if cond is not None:if isinstance(cond, dict):cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) elselist(map(lambda x: x[:batch_size], cond[key])) for key in cond}else:cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]return self.p_sample_loop(cond,shape,return_intermediates=return_intermediates, x_T=x_T,verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,mask=mask, x0=x0)