当前位置: 首页 > news >正文

Hunuan-DiT代码阅读

一 整体架构

该模型是以SD为基础的文生图模型,具体扩散模型原理参考https://zhouyifan.net/2023/07/07/20230330-diffusion-model/,代码地址https://github.com/Tencent/HunyuanDiT,这里介绍 Full-parameter Training

二 输入数据处理

这里主要包括图像和文本数据输入处理

2.1 图像处理

这里代码参考 hydit/data_loader/arrow_load_stream.py,生成1024*1024的图片,对于输入图片进行random_crop,之后包括随机水平翻转,转tensor,以及Normalize(减均值0.5, 除以标准差0.5,为什么是这个,是因为通过PIL Image读图之后转到tensor范围是0-1之间,不是opencv读出来像素值在0-255之间),得到最终image( B ∗ 3 ∗ 1024 ∗ 1024 B*3*1024*1024 B310241024

2.2 文本处理

输入的文本,通过BertTokenizer,进行映射,同时补齐长度到77,不够的补0,同时生成相应的attention_mask;同时还有T5TokenizerFast,对于T5的输入,会随机小于uncond_p_t5(目前给出的设置uncond_p_t5=5),输入为空,否则为文本输入,补齐长度256,同时生成相应的attention_mask

2.3 图像编码

对于输入图像,采用VAE encoder 进行编码,生成隐空间特征latents( B ∗ 4 ∗ 128 ∗ 128 B*4*128*128 B4128128,就是输入8倍下采样,计算过程latents = vae.encode(image).latent_dist.sample().mul_(vae_scaling_factor),具体VAE相关后续补充)

2.4 文本编码

包括两个部分,一个是CLIP的text编码,采用bert layer,生成encoder_hidden_states( B ∗ 77 ∗ 1024 B*77*1024 B771024);第二部分是mT5的text编码,生成encoder_hidden_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B2562048

2.5 位置编码

这里是采用根据预设的分辨率,提前生成好的位置编码,这里采用ROPE,生成cos_cis_img, sin_cis_img (分别都是 4096 ∗ 88 4096*88 409688)

最终生成图像编码latents,文本编码(encoder_hidden_states以及对应的attention_mask,encoder_hidden_states_t5以及对应的attention_mask),以及位置编码cos_cis_img, sin_cis_img

三 DIT模型

3.1 add noise过程

  • 根据上一步的输出latents,作为x_start,随机选取一个time step,根据q_sample,得到增加噪声之后的输出x_t(具体公式参考如下,x0对应x_start,xt对应x_t)
    在这里插入图片描述

3.2 HunYuanDiT模型训练过程

  • 对于输入的文本编码,包括text_states( B ∗ 77 ∗ 1024 B*77*1024 B771024),text_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B2562048)以及相应的attention_mask,对于text_states_t5通过Linear+Silu+Linear,转成 B ∗ 256 ∗ 1024 B*256*1024 B2561024,然后对着两个进行concat,得到text_states( B ∗ 333 ∗ 1024 B*333*1024 B3331024),对于attention_mask也concat得到clip_t5_mask( B ∗ 333 B*333 B333);这里会生成一个可学习的text_embedding_padding特征( B ∗ 333 ∗ 1024 B*333*1024 B3331024),对于clip_t5_mask中通过补0得到的特征全部替换成text_embedding_padding特征
  • 对于输入time step 先走timestep_embedding(就是sinusoidal编码),然后通过Linear+Silu+Linear得到最终t ( B ∗ 1408 B*1408 B1408)
  • 对于输入x(就是上一步的x_t),通过PatchEmbed(就是VIT前面对图像进行patch),得到x( B ∗ 4096 ∗ 1408 , 4096 是 64 ∗ 64 B*4096*1408,4096是64*64 B4096140840966464
  • 对于text_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B2562048),添加一个AttentionPool模块,就是对于输入在256维度上,进行mean,当成query,然后将输入和query concat一起得到257维,作为key和value,(其中query,key,value都添加位置编码)做multi_head_attention,得到最终输出extra_vec( B ∗ 1024 B*1024 B1024
  • 对于extra_vec 通过Linear+Silu+Linear得到( B ∗ 1408 B*1408 B1408),然后与通过time step得到的t相加,得到c( B ∗ 1408 B*1408 B1408,作为所有extra_vectors)

3.2.1 进入Dit Block

一共40个block,前面0到18个block的生成输入,中间19,20作为middle block,剩余的block会增加一个前面19个block输出的结果作为skip

3.2.1.1 前面0到18共19个block
  • 前面一共19个block的过程,输入x( B ∗ 4096 ∗ 1408 B*4096*1408 B40961408),c( B ∗ 1408 B*1408 B1408),text_states( B ∗ 333 ∗ 1024 B*333*1024 B3331024),位置编码freqs_cis_img (cos_cis_img, sin_cis_img,分别都是 B ∗ 4096 ∗ 88 B*4096*88 B409688
HunYuanDiTBlock((norm1): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)(attn1): FlashSelfMHAModified((Wqkv): Linear(in_features=1408, out_features=4224, bias=True)(q_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(k_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(inner_attn): FlashSelfAttention((drop): Dropout(p=0.0, inplace=False))(out_proj): Linear(in_features=1408, out_features=1408, bias=True)(proj_drop): Dropout(p=0.0, inplace=False))(norm2): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=1408, out_features=6144, bias=True)(act): GELU(approximate='tanh')(drop1): Dropout(p=0, inplace=False)(norm): Identity()(fc2): Linear(in_features=6144, out_features=1408, bias=True)(drop2): Dropout(p=0, inplace=False))(default_modulation): Sequential((0): FP32_SiLU()(1): Linear(in_features=1408, out_features=1408, bias=True))(attn2): FlashCrossMHAModified((q_proj): Linear(in_features=1408, out_features=1408, bias=True)(kv_proj): Linear(in_features=1024, out_features=2816, bias=True)(q_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(k_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(inner_attn): FlashCrossAttention((drop): Dropout(p=0.0, inplace=False))(out_proj): Linear(in_features=1408, out_features=1408, bias=True)(proj_drop): Dropout(p=0.0, inplace=False))(norm3): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)
)
  • 对于c 通过default_modulation,得到shift_msa( B ∗ 4096 ∗ 1408 B*4096*1408 B40961408),与经过norm1之后的x进行相加作为attn1的输入(就是Flash Self Attention)
  • 将attn1的输出与原始的x进行残差相加,在经过norm3,与text_states一起作为attn2的输入(就是Flash Cross Attention)
  • 在将经过残差相加之后的x与attn2的输出在进行残差相加,作为输入,走FFN,即先经过norm2,在经过mlp,之后与输入残差相加
3.2.1.2 第19和20 middle block
  • 中间第19 和 20 两个block作为middle block,方式和上面一样
3.2.1.3 后面21到39共19个block
  • 从第21个block开始,增加一个输入,例如第21个block,会将第18个block的输出作为输入
  (skip_norm): FP32_Layernorm((2816,), eps=1e-06, elementwise_affine=True)(skip_linear): Linear(in_features=2816, out_features=1408, bias=True)
  • 就是对于新的输入skip,将skip与x进行concat之后,经过skip norm,然后在经过skip linear,得到输出x,剩余步骤与前面一样

3.2.2 最后FInal layer处理

  • 输入x和c,x是上面所有dit block的输出,c是上面的extra_vectors;对于c先进行SILU+Linear,得到( B ∗ 2816 B*2816 B2816),并彩分成shift 和 scale(分别为 B ∗ 1408 B*1408 B1408),最终通过x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1),然后通过Linear,得到最终输出x( B ∗ 4096 ∗ 32 B*4096*32 B409632),然后通过转换得到输出imgs ( B ∗ 8 ∗ 128 ∗ 128 B*8*128*128 B8128128

http://www.mrgr.cn/news/47354.html

相关文章:

  • 下载huggingface模型到本地
  • CDC和RDC分别适用于哪些场景?
  • 第十九章 基于逻辑回归的信用卡欺诈检测
  • Python数据分析-数据预处理、统计与分析
  • vue3数字滚动插件vue3-count-to
  • 基于SpringBoot+Vue+Uniapp警务辅助人员管理小程序系统的设计与实现
  • 嵌入式面试——FreeRTOS篇(四) 信号量
  • 升序 Asc、降序 Desc 极简理解
  • kali在git外网的代理
  • 【图论】(一)图论理论基础与岛屿问题
  • C#开发基础之使用 Mutex 控制应用程序的单实例启动
  • Linux
  • 【常用的安装破解版指令】MAC安装破解版软件显示文件损坏时
  • 一文掌握Prompt大模型提示词技巧:从战略到战术
  • PolarCTF靶场[web]file、ezphp WP
  • 目标检测:yolov9训练自己的数据集,新手小白也能学会训练模型,一看就会
  • JavaScript进阶--作用域-函数进阶
  • 第二十一章 基于随机森林气温预测
  • qiankun 主项目和子项目都是 vue2,部署在不同的服务器上,nginx 配置
  • 240604 模板进阶