unet中的attn_processor的修改(用于设计新的注意力模块)
参考资料
文章目录
- unet中的一些变量的数据情况
- attn_processor
- unet.config
- unet_sd
- 自己定义自己的attn Processor ,对原始的attn Processor进行修改
IP-adapter中设置attn的方法
参考的代码: 腾讯ailabipadapter 的官方训练代码
unet中的一些变量的数据情况
# init adapter modules#用来存储自己重构后的注意力处理器字典attn_procs = {}unet_sd = unet.state_dict()for name in unet.attn_processors.keys():#如果是自注意力注意力attn1,那么设置为空,否则设置为交叉注意力的维度cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim#这里记录此时这个快的通道式if name.startswith("mid_block"):#'block_out_channels', [320, 640, 1280, 1280]hidden_size = unet.config.block_out_channels[-1]elif name.startswith("up_blocks"):#name中的,up_block.的后一个位置就是表示是第几个上块block_id = int(name[len("up_blocks.")])hidden_size = list(reversed(unet.config.block_out_channels))[block_id]elif name.startswith("down_blocks"):block_id = int(name[len("down_blocks.")])hidden_size = unet.config.block_out_channels[block_id]if cross_attention_dim is None:attn_procs[name] = AttnProcessor()else:layer_name = name.split(".processor")[0]weights = {#这里是从unet_sd当中把这个交叉注意力层的原始kv权重拷贝一份出来"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],}#然后这里将新构建的字典里面的attn_processor给替换为自己定义的IPAttnProcessorattn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)#这里将新构建的attn模型的权重初始化为原来的SD的uent中的crossattn的权重attn_procs[name].load_state_dict(weights)#最后这里将unet的注意力处理器设置为自己重构后的注意力字典unet.set_attn_processor(attn_procs)
attn_processor
unet中的unet.state_dict()存储了所有attn_processor的字典
我们要做修改的话,重构一个类似的字典,然后把其中我们需要修改的模块的attn_processor的类型进行替换
我们来看一下unet.attn_processors是什么样子的
unet.attn_processors是一个字典,包含32个元素
它的 key 是每个处理类所在位置,并结合unet的结构以及其中中crossattn块的个数(总共2,2,2,1,3,3,3(16个块)(每个块分别有一个自注意力和一个交叉注意力模块,所以总共有32个注意力块)),
我们知道了每块的名称的命名的含义:
比如:
'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor'
down_blocks.0.(可以是0,1,2,)(有3个下块)
代表第一个下块attentions.0.(可以是0,1)(每个下块有2个transformer块)
代表第一个下块中的第一个transformer块transformer_blocks.0.
这里都是0attn1.processor(每个transformer块有2和注意快,一个交叉注意力,一个自注意力)
代表是自注意力还是交叉注意力(attn2.代表交叉注意力层,attn1代表自注意力层)
unet.config
unet.config 是unet配置的参数
FrozenDict([('sample_size', 64),('in_channels', 4),
('out_channels', 4),
('center_input_sample', False), ('flip_sin_to_cos', True), ('freq_shift', 0), ('down_block_types', ['CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D']), ('mid_block_type', 'UNetMidBlock2DCrossAttn'), ('up_block_types', ['UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D']),
('only_cross_attention', False),
('block_out_channels', [320, 640, 1280, 1280]),('layers_per_block', 2), ('downsample_padding', 1), ('mid_block_scale_factor', 1), ('dropout', 0.0), ('act_fn', 'silu'), ('norm_num_groups', 32), ('norm_eps', 1e-05), ('cross_attention_dim', 768), ('transformer_layers_per_block', 1), ('reverse_transformer_layers_per_block', None),('encoder_hid_dim', None), ('encoder_hid_dim_type', None), ('attention_head_dim', 8), ('num_attention_heads', None), ('dual_cross_attention', False), ('use_linear_projection', False), ('class_embed_type', None), ('addition_embed_type', None), ('addition_time_embed_dim', None), ('num_class_embeds', None), ('upcast_attention', False), ('resnet_time_scale_shift', 'default'), ('resnet_skip_time_act', False), ('resnet_out_scale_factor', 1.0), ('time_embedding_type', 'positional'), ('time_embedding_dim', None), ('time_embedding_act_fn', None), ('timestep_post_act', None), ('time_cond_proj_dim', None), ('conv_in_kernel', 3), ('conv_out_kernel', 3), ('projection_class_embeddings_input_dim', None), ('attention_type', 'default'), ('class_embeddings_concat', False), ('mid_block_only_cross_attention', None), ('cross_attention_norm', None), ('addition_embed_type_num_heads', 64), ('_use_default_values', ['addition_embed_type', 'encoder_hid_dim', 'transformer_layers_per_block', 'addition_embed_type_num_heads', 'upcast_attention', 'conv_in_kernel', 'attention_type', 'resnet_out_scale_factor', 'time_embedding_dim', 'time_embedding_act_fn', 'conv_out_kernel', 'reverse_transformer_layers_per_block', 'mid_block_type', 'class_embeddings_concat', 'time_embedding_type', 'use_linear_projection', 'class_embed_type', 'only_cross_attention', 'resnet_time_scale_shift', 'encoder_hid_dim_type', 'projection_class_embeddings_input_dim', 'dual_cross_attention', 'addition_time_embed_dim', 'cross_attention_norm', 'dropout', 'timestep_post_act', 'resnet_skip_time_act', 'num_attention_heads', 'time_cond_proj_dim', 'mid_block_only_cross_attention', 'num_class_embeds']), ('_class_name', 'UNet2DConditionModel'), ('_diffusers_version', '0.6.0'), ('_name_or_path', '/media/dell/DATA/RK/pretrained_model/stable-diffusion-v1-5')])
unet_sd
这里面是一个字典,包含了所有层的各个小模块的权重
这里是从unet_sd当中把这个交叉注意力层的原始kv权重拷贝一份出来,用于初始化自己设计的注意力处理器
weights = {"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],}
查看修改后unet的attn_processors
这里将unet.attn_processors的所有values()转化为list
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
这里是IPadapter替换后的attn processor 的情况
ModuleList((0): AttnProcessor2_0()(1): IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=320, bias=False)(to_v_ip): Linear(in_features=768, out_features=320, bias=False))(2): AttnProcessor2_0()(3): IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=320, bias=False)(to_v_ip): Linear(in_features=768, out_features=320, bias=False))(4): AttnProcessor2_0()(5): IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=640, bias=False)(to_v_ip): Linear(in_features=768, out_features=640, bias=False))(6): AttnProcessor2_0()(7): IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=640, bias=False)(to_v_ip): Linear(in_features=768, out_features=640, bias=False))(8): AttnProcessor2_0()(9): IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=1280, bias=False)(to_v_ip): Linear(in_features=768, out_features=1280, bias=False))(10): AttnProcessor2_0()(11): IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=1280, bias=False)(to_v_ip): Linear(in_features=768, out_features=1280, bias=False))(12): AttnProcessor2_0()(13): IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=1280, bias=False)(to_v_ip): Linear(in_features=768, out_features=1280, bias=False))(14): AttnProcessor2_0()(15): IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=1280, bias=False)(to_v_ip): Linear(in_features=768, out_features=1280, bias=False))(16): AttnProcessor2_0()(17): IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=1280, bias=False)(to_v_ip): Linear(in_features=768, out_features=1280, bias=False))(18): AttnProcessor2_0()(19): IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=640, bias=False)(to_v_ip): Linear(in_features=768, out_features=640, bias=False))(20): AttnProcessor2_0()(21): IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=640, bias=False)(to_v_ip): Linear(in_features=768, out_features=640, bias=False))(22): AttnProcessor2_0()(23): IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=640, bias=False)(to_v_ip): Linear(in_features=768, out_features=640, bias=False))(24): AttnProcessor2_0()(25): IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=320, bias=False)(to_v_ip): Linear(in_features=768, out_features=320, bias=False))(26): AttnProcessor2_0()(27): IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=320, bias=False)(to_v_ip): Linear(in_features=768, out_features=320, bias=False))(28): AttnProcessor2_0()(29): IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=320, bias=False)(to_v_ip): Linear(in_features=768, out_features=320, bias=False))(30): AttnProcessor2_0()(31): IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=1280, bias=False)(to_v_ip): Linear(in_features=768, out_features=1280, bias=False))
)
自己定义自己的attn Processor ,对原始的attn Processor进行修改
在原始的attention_processor.py 文件中定义新的attn processor类
原始的attention_processor中的attn processor
class AttnProcessor(nn.Module):r"""Default processor for performing attention-related computations."""def __init__(self,hidden_size=None,cross_attention_dim=None,):super().__init__()def __call__(self,attn,hidden_states,encoder_hidden_states=None,attention_mask=None,temb=None,*args,**kwargs,):residual = hidden_statesif attn.spatial_norm is not None:hidden_states = attn.spatial_norm(hidden_states, temb)input_ndim = hidden_states.ndimif input_ndim == 4:batch_size, channel, height, width = hidden_states.shapehidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)batch_size, sequence_length, _ = (hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape)attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)if attn.group_norm is not None:hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)query = attn.to_q(hidden_states)if encoder_hidden_states is None:encoder_hidden_states = hidden_stateselif attn.norm_cross:encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)key = attn.to_k(encoder_hidden_states)value = attn.to_v(encoder_hidden_states)query = attn.head_to_batch_dim(query)key = attn.head_to_batch_dim(key)value = attn.head_to_batch_dim(value)attention_probs = attn.get_attention_scores(query, key, attention_mask)hidden_states = torch.bmm(attention_probs, value)hidden_states = attn.batch_to_head_dim(hidden_states)# linear projhidden_states = attn.to_out[0](hidden_states)# dropouthidden_states = attn.to_out[1](hidden_states)if input_ndim == 4:hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)if attn.residual_connection:hidden_states = hidden_states + residualhidden_states = hidden_states / attn.rescale_output_factorreturn hidden_states#3 ipadapter 新定义的
class IPAttnProcessor(nn.Module):r"""Attention processor for IP-Adapater.Args:hidden_size (`int`):The hidden size of the attention layer.cross_attention_dim (`int`):The number of channels in the `encoder_hidden_states`.scale (`float`, defaults to 1.0):the weight scale of image prompt.num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):The context length of the image features."""def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):super().__init__()self.hidden_size = hidden_sizeself.cross_attention_dim = cross_attention_dimself.scale = scaleself.num_tokens = num_tokensself.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)def __call__(self,attn,hidden_states,encoder_hidden_states=None,attention_mask=None,temb=None,*args,**kwargs,):residual = hidden_statesif attn.spatial_norm is not None:hidden_states = attn.spatial_norm(hidden_states, temb)input_ndim = hidden_states.ndimif input_ndim == 4:batch_size, channel, height, width = hidden_states.shapehidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)batch_size, sequence_length, _ = (hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape)attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)if attn.group_norm is not None:hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)query = attn.to_q(hidden_states)if encoder_hidden_states is None:encoder_hidden_states = hidden_stateselse:# get encoder_hidden_states, ip_hidden_statesend_pos = encoder_hidden_states.shape[1] - self.num_tokensencoder_hidden_states, ip_hidden_states = (encoder_hidden_states[:, :end_pos, :],encoder_hidden_states[:, end_pos:, :],)if attn.norm_cross:encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)key = attn.to_k(encoder_hidden_states)value = attn.to_v(encoder_hidden_states)query = attn.head_to_batch_dim(query)key = attn.head_to_batch_dim(key)value = attn.head_to_batch_dim(value)attention_probs = attn.get_attention_scores(query, key, attention_mask)hidden_states = torch.bmm(attention_probs, value)hidden_states = attn.batch_to_head_dim(hidden_states)# for ip-adapterip_key = self.to_k_ip(ip_hidden_states)ip_value = self.to_v_ip(ip_hidden_states)ip_key = attn.head_to_batch_dim(ip_key)ip_value = attn.head_to_batch_dim(ip_value)ip_attention_probs = attn.get_attention_scores(query, ip_key, None)self.attn_map = ip_attention_probsip_hidden_states = torch.bmm(ip_attention_probs, ip_value)ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)hidden_states = hidden_states + self.scale * ip_hidden_states# linear projhidden_states = attn.to_out[0](hidden_states)# dropouthidden_states = attn.to_out[1](hidden_states)if input_ndim == 4:hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)if attn.residual_connection:hidden_states = hidden_states + residualhidden_states = hidden_states / attn.rescale_output_factorreturn hidden_states
自己定义两个新的,然后也放如这个文件里面
class StyleAttnProcessor(nn.Module):r"""Default processor for performing attention-related computations."""def __init__(self,hidden_size=None,cross_attention_dim=None,):super().__init__()def __call__(self,attn,hidden_states,encoder_hidden_states=None,attention_mask=None,temb=None,*args,**kwargs,):residual = hidden_statesif attn.spatial_norm is not None:hidden_states = attn.spatial_norm(hidden_states, temb)input_ndim = hidden_states.ndimif input_ndim == 4:batch_size, channel, height, width = hidden_states.shapehidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)batch_size, sequence_length, _ = (hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape)attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)if attn.group_norm is not None:hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)query = attn.to_q(hidden_states)if encoder_hidden_states is None:encoder_hidden_states = hidden_stateselif attn.norm_cross:encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)key = attn.to_k(encoder_hidden_states)value = attn.to_v(encoder_hidden_states)query = attn.head_to_batch_dim(query)key = attn.head_to_batch_dim(key)value = attn.head_to_batch_dim(value)attention_probs = attn.get_attention_scores(query, key, attention_mask)hidden_states = torch.bmm(attention_probs, value)hidden_states = attn.batch_to_head_dim(hidden_states)# linear projhidden_states = attn.to_out[0](hidden_states)# dropouthidden_states = attn.to_out[1](hidden_states)if input_ndim == 4:hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)if attn.residual_connection:hidden_states = hidden_states + residualhidden_states = hidden_states / attn.rescale_output_factorreturn hidden_states
class LayoutAttnProcessor(nn.Module):r"""Default processor for performing attention-related computations."""def __init__(self,hidden_size=None,cross_attention_dim=None,):super().__init__()def __call__(self,attn,hidden_states,encoder_hidden_states=None,attention_mask=None,temb=None,*args,**kwargs,):residual = hidden_statesif attn.spatial_norm is not None:hidden_states = attn.spatial_norm(hidden_states, temb)input_ndim = hidden_states.ndimif input_ndim == 4:batch_size, channel, height, width = hidden_states.shapehidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)batch_size, sequence_length, _ = (hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape)attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)if attn.group_norm is not None:hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)query = attn.to_q(hidden_states)if encoder_hidden_states is None:encoder_hidden_states = hidden_stateselif attn.norm_cross:encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)key = attn.to_k(encoder_hidden_states)value = attn.to_v(encoder_hidden_states)query = attn.head_to_batch_dim(query)key = attn.head_to_batch_dim(key)value = attn.head_to_batch_dim(value)attention_probs = attn.get_attention_scores(query, key, attention_mask)hidden_states = torch.bmm(attention_probs, value)hidden_states = attn.batch_to_head_dim(hidden_states)# linear projhidden_states = attn.to_out[0](hidden_states)# dropouthidden_states = attn.to_out[1](hidden_states)if input_ndim == 4:hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)if attn.residual_connection:hidden_states = hidden_states + residualhidden_states = hidden_states / attn.rescale_output_factorreturn hidden_states
然后导入这两个attn processor
from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor, \LayoutAttnProcessor, StyleAttnProcessor
替换后的结果如下
这里是将第三个下块,和第1个上块分别替换为layout attn 和 style attn
for name in unet.attn_processors.keys():cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dimif name.startswith("mid_block"):hidden_size = unet.config.block_out_channels[-1]elif name.startswith("up_blocks"):block_id = int(name[len("up_blocks.")])hidden_size = list(reversed(unet.config.block_out_channels))[block_id]elif name.startswith("down_blocks"):block_id = int(name[len("down_blocks.")])hidden_size = unet.config.block_out_channels[block_id]if cross_attention_dim is None:attn_procs[name] = AttnProcessor()# 第三个下块的名称开头是这个elif name.startswith("down_blocks.2.attentions"):attn_procs[name] = LayoutAttnProcessor()#第一个上块的名称开头是这个elif name.startswith("up_blocks.1.attentions"):attn_procs[name] = StyleAttnProcessor()else:layer_name = name.split(".processor")[0]weights = {"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],}attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)attn_procs[name].load_state_dict(weights)
修改后 attn_processors 如下
{
'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=320, bias=False)(to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=320, bias=False)(to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=640, bias=False)(to_v_ip): Linear(in_features=768, out_features=640, bias=False)
), 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=640, bias=False)(to_v_ip): Linear(in_features=768, out_features=640, bias=False)
),## 可以看到,这里的attn替换为了我们自己定义的layout attn'down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor': LayoutAttnProcessor(), 'down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor': LayoutAttnProcessor(), ## 可以看到,这里的attn替换为了我们自己定义的style attn
'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': StyleAttnProcessor(), 'up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': StyleAttnProcessor(), 'up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor': StyleAttnProcessor(), 'up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=640, bias=False)(to_v_ip): Linear(in_features=768, out_features=640, bias=False)
), 'up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=640, bias=False)(to_v_ip): Linear(in_features=768, out_features=640, bias=False)
), 'up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=640, bias=False)(to_v_ip): Linear(in_features=768, out_features=640, bias=False)
), 'up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=320, bias=False)(to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=320, bias=False)(to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=320, bias=False)(to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'mid_block.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'mid_block.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0((to_k_ip): Linear(in_features=768, out_features=1280, bias=False)(to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
)
}