OmniGen: Unified Image Generation(代码的学习)
文章目录
- OmniGen项目介绍
- 模型的整体结构
- 对输入文本的处理
- 对输入进LLM的数据的处理
- 对于输图像信息,会转化为1024的长度
- 将输入的图像信息嵌入文本嵌入
- forward函数
- 总结
github项目地址
OmniGen项目介绍
一个通用的图像生成模型
大型语言模型(llm)的出现实现了统一的语言生成任务,并彻底改变了人机交互。然而,在图像生成领域,一个能够在单个框架内处理各种任务的统一模型在很大程度上仍未被探索。在这项工作中,我们引入了综合性的,一个新的扩散模型的统一图像生成。与流行的扩散模型(例如,稳定扩散)不同,通用技术不再需要额外的模块,如控制网或ip适配器来处理不同的控制条件。
OmniGen具有以下特点:
1)统一:它不仅具有文本到图像的生成功能,而且还天生支持各种下游任务,如图像编辑、主题驱动的生成和视觉条件生成。此外,通用综合技术还可以通过将经典的计算机视觉任务转换为图像生成任务来处理这些任务,如边缘检测和人体姿态识别。
2)简单性:通用集成系统的架构高度简化,消除了对额外的文本编码器的需要。综合性的是高度简化的,不需要额外的文本编码器。此外,与现有的扩散模型相比,它更为用户友好,使得复杂的任务可以通过指令完成,而不需要额外的预处理步骤(例如,人体姿态估计),从而大大简化了图像生成的工作流程。
3)知识转移:受益于统一格式的学习,综合性能有效地在不同的任务之间转移知识,管理看不见的任务和领域,并展示出新的能力。我们还探讨了该模型的推理能力和思维链机制的潜在应用。
这项工作代表了第一次尝试一个通用的图像生成模型,仍然有几个未解决的问题。
可以实现的任务
文生图
混合模态的提示
比如可以实现文本编辑和风格迁移
可以实现图像超分,图像增亮
可以实现上下文理解能力,给他参考的处理方式,他能学习到任务的需求
所以总的来讲,该模型的理解能力很强大, 其效果也是非常不错的。
模型的整体结构
这里对输入数据的处理方式需要学习:
对输入文本的处理
这里用OmniGenProcessor对输入文本进行处理
class OmniGenProcessor:def __init__(self, text_tokenizer, max_image_size: int=1024):self.text_tokenizer = text_tokenizerself.max_image_size = max_image_sizeself.image_transform = transforms.Compose([transforms.Lambda(lambda pil_image: crop_arr(pil_image, max_image_size)),transforms.ToTensor(),transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)])self.collator = OmniGenCollator()self.separate_collator = OmniGenSeparateCollator()@classmethoddef from_pretrained(cls, model_name):if not os.path.exists(model_name):cache_folder = os.getenv('HF_HUB_CACHE')model_name = snapshot_download(repo_id=model_name,cache_dir=cache_folder,allow_patterns="*.json")text_tokenizer = AutoTokenizer.from_pretrained(model_name)return cls(text_tokenizer)def process_image(self, image):image = Image.open(image).convert('RGB')return self.image_transform(image)def process_multi_modal_prompt(self, text, input_images):text = self.add_prefix_instruction(text)if input_images is None or len(input_images) == 0:model_inputs = self.text_tokenizer(text)return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}pattern = r"<\|image_\d+\|>"prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)] for i in range(1, len(prompt_chunks)):if prompt_chunks[i][0] == 1:prompt_chunks[i] = prompt_chunks[i][1:]image_tags = re.findall(pattern, text) image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]unique_image_ids = sorted(list(set(image_ids)))assert unique_image_ids == list(range(1, len(unique_image_ids)+1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"# total images must be the same as the number of image tagsassert len(unique_image_ids) == len(input_images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"input_images = [input_images[x-1] for x in image_ids]all_input_ids = []img_inx = []idx = 0for i in range(len(prompt_chunks)):all_input_ids.extend(prompt_chunks[i])if i != len(prompt_chunks) -1:start_inx = len(all_input_ids)size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16img_inx.append([start_inx, start_inx+size])all_input_ids.extend([0]*size)return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}def add_prefix_instruction(self, prompt):user_prompt = '<|user|>\n'generation_prompt = 'Generate an image according to the following instructions\n'assistant_prompt = '<|assistant|>\n<|diffusion|>'prompt_suffix = "<|end|>\n"prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"return promptdef __call__(self, instructions: List[str], input_images: List[List[str]] = None,height: int = 1024,width: int = 1024,negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",use_img_cfg: bool = True,separate_cfg_input: bool = False,) -> Dict:if input_images is None:use_img_cfg = Falseif isinstance(instructions, str):instructions = [instructions]input_images = [input_images]input_data = []for i in range(len(instructions)):cur_instruction = instructions[i]cur_input_images = None if input_images is None else input_images[i]if cur_input_images is not None and len(cur_input_images) > 0:cur_input_images = [self.process_image(x) for x in cur_input_images]else:cur_input_images = Noneassert "<img><|image_1|></img>" not in cur_instructionmllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)neg_mllm_input, img_cfg_mllm_input = None, Noneneg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)if use_img_cfg:if cur_input_images is not None and len(cur_input_images) >= 1:img_cfg_prompt = [f"<img><|image_{i+1}|></img>" for i in range(len(cur_input_images))]img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)else:img_cfg_mllm_input = neg_mllm_inputinput_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))if separate_cfg_input:return self.separate_collator(input_data)return self.collator(input_data)class OmniGenCollator:def __init__(self, pad_token_id=2, hidden_size=3072):self.pad_token_id = pad_token_idself.hidden_size = hidden_sizedef create_position(self, attention_mask, num_tokens_for_output_images):position_ids = []text_length = attention_mask.size(-1)img_length = max(num_tokens_for_output_images) for mask in attention_mask:temp_l = torch.sum(mask)temp_position = [0]*(text_length-temp_l) + [i for i in range(temp_l+img_length+1)] # we add a time embedding into the sequence, so add one more tokenposition_ids.append(temp_position)return torch.LongTensor(position_ids)def create_mask(self, attention_mask, num_tokens_for_output_images):extended_mask = []padding_images = []text_length = attention_mask.size(-1)img_length = max(num_tokens_for_output_images)seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more tokeninx = 0for mask in attention_mask:temp_l = torch.sum(mask)pad_l = text_length - temp_ltemp_mask = torch.tril(torch.ones(size=(temp_l+1, temp_l+1)))image_mask = torch.zeros(size=(temp_l+1, img_length))temp_mask = torch.cat([temp_mask, image_mask], dim=-1)image_mask = torch.ones(size=(img_length, temp_l+img_length+1))temp_mask = torch.cat([temp_mask, image_mask], dim=0)if pad_l > 0:pad_mask = torch.zeros(size=(temp_l+1+img_length, pad_l))temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)pad_mask = torch.ones(size=(pad_l, seq_len))temp_mask = torch.cat([pad_mask, temp_mask], dim=0)true_img_length = num_tokens_for_output_images[inx]pad_img_length = img_length - true_img_lengthif pad_img_length > 0:temp_mask[:, -pad_img_length:] = 0temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))else:temp_padding_imgs = Noneextended_mask.append(temp_mask.unsqueeze(0))padding_images.append(temp_padding_imgs)inx += 1return torch.cat(extended_mask, dim=0), padding_imagesdef adjust_attention_for_input_images(self, attention_mask, image_sizes):for b_inx in image_sizes.keys():for start_inx, end_inx in image_sizes[b_inx]:attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1return attention_maskdef pad_input_ids(self, input_ids, image_sizes):max_l = max([len(x) for x in input_ids])padded_ids = []attention_mask = []new_image_sizes = []for i in range(len(input_ids)):temp_ids = input_ids[i]temp_l = len(temp_ids)pad_l = max_l - temp_lif pad_l == 0:attention_mask.append([1]*max_l)padded_ids.append(temp_ids)else:attention_mask.append([0]*pad_l+[1]*temp_l)padded_ids.append([self.pad_token_id]*pad_l+temp_ids)if i in image_sizes:new_inx = []for old_inx in image_sizes[i]:new_inx.append([x+pad_l for x in old_inx])image_sizes[i] = new_inxreturn torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizesdef process_mllm_input(self, mllm_inputs, target_img_size):num_tokens_for_output_images = []for img_size in target_img_size:num_tokens_for_output_images.append(img_size[0]*img_size[1]//16//16)pixel_values, image_sizes = [], {}b_inx = 0for x in mllm_inputs:if x['pixel_values'] is not None:pixel_values.extend(x['pixel_values'])for size in x['image_sizes']:if b_inx not in image_sizes:image_sizes[b_inx] = [size]else:image_sizes[b_inx].append(size)b_inx += 1 pixel_values = [x.unsqueeze(0) for x in pixel_values]input_ids = [x['input_ids'] for x in mllm_inputs]padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)position_ids = self.create_position(attention_mask, num_tokens_for_output_images)attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizesdef __call__(self, features):mllm_inputs = [f[0] for f in features]cfg_mllm_inputs = [f[1] for f in features]img_cfg_mllm_input = [f[2] for f in features]target_img_size = [f[3] for f in features]if img_cfg_mllm_input[0] is not None:mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_inputtarget_img_size = target_img_size + target_img_size + target_img_sizeelse:mllm_inputs = mllm_inputs + cfg_mllm_inputstarget_img_size = target_img_size + target_img_sizeall_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)data = {"input_ids": all_padded_input_ids,"attention_mask": all_attention_mask,"position_ids": all_position_ids,"input_pixel_values": all_pixel_values,"input_image_sizes": all_image_sizes,"padding_images": all_padding_images,}return dataclass OmniGenSeparateCollator(OmniGenCollator):def __call__(self, features):mllm_inputs = [f[0] for f in features]cfg_mllm_inputs = [f[1] for f in features]img_cfg_mllm_input = [f[2] for f in features]target_img_size = [f[3] for f in features]all_padded_input_ids, all_attention_mask, all_position_ids, all_pixel_values, all_image_sizes, all_padding_images = [], [], [], [], [], []padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)all_padded_input_ids.append(padded_input_ids)all_attention_mask.append(attention_mask)all_position_ids.append(position_ids)all_pixel_values.append(pixel_values)all_image_sizes.append(image_sizes)all_padding_images.append(padding_images)if cfg_mllm_inputs[0] is not None:padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(cfg_mllm_inputs, target_img_size)all_padded_input_ids.append(padded_input_ids)all_attention_mask.append(attention_mask)all_position_ids.append(position_ids)all_pixel_values.append(pixel_values)all_image_sizes.append(image_sizes)all_padding_images.append(padding_images)if img_cfg_mllm_input[0] is not None:padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(img_cfg_mllm_input, target_img_size)all_padded_input_ids.append(padded_input_ids)all_attention_mask.append(attention_mask)all_position_ids.append(position_ids)all_pixel_values.append(pixel_values)all_image_sizes.append(image_sizes)all_padding_images.append(padding_images)data = {"input_ids": all_padded_input_ids,"attention_mask": all_attention_mask,"position_ids": all_position_ids,"input_pixel_values": all_pixel_values,"input_image_sizes": all_image_sizes,"padding_images": all_padding_images,}return data
例如刚输入的文本是:
prompt= "Make <img><|image_1|></img> has the same style of <img><|image_2|></img>.Maintain the consistency of the content in <img><|image_1|></img> and ensure that there is only the style of <img><|image_2|></img>, without its content.",
首先对于输入的文本进行填充,增加前后缀,增加后的结果如下
<|user|>
Generate an image according to the following instructions
Make <img><|image_1|></img> has the same style of <img><|image_2|></img>.Maintain the consistency of the content in <img><|image_1|></img> and ensure that there is only the style of <img><|image_2|></img>, without its content.<|end|>
<|assistant|>
<|diffusion|>
然后将text按照出现图像的位置进行划分
#按照这个正则表达式,对整个text进行划分
pattern = r"<\|image_\d+\|>"
prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
将整个text拆分为5段
然后将出去第一段的起始字符去掉
for i in range(1, len(prompt_chunks)):if prompt_chunks[i][0] == 1:prompt_chunks[i] = prompt_chunks[i][1:]
然后对text中图像的标签和id进行识别
可以看到我们有4个图像的标签,然后实际用的图像只有2个
然后将图像和文本的信息融合在一起
all_input_ids = []img_inx = []idx = 0for i in range(len(prompt_chunks)):all_input_ids.extend(prompt_chunks[i])if i != len(prompt_chunks) -1:start_inx = len(all_input_ids)size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16img_inx.append([start_inx, start_inx+size])all_input_ids.extend([0]*size)
对输入进LLM的数据的处理
对于输图像信息,会转化为1024的长度
用这个卷积吗模块Conv2d(4, 3072, kernel_size=(2, 2), stride=(2, 2))对1,4,64,64的输入数据转化为1,1024,3972
def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):if isinstance(latents, list):return_list = Falseif padding_latent is None:padding_latent = [None] * len(latents)return_list = Truepatched_latents, num_tokens, shapes = [], [], []for latent, padding in zip(latents, padding_latent):height, width = latent.shape[-2:]if is_input_images:#利用这个卷积将输入的4,64,64的图像变为,1,1024,3072latent = self.input_x_embedder(latent)else:latent = self.x_embedder(latent)pos_embed = self.cropped_pos_embed(height, width) latent = latent + pos_embedif padding is not None:latent = torch.cat([latent, padding], dim=-2)patched_latents.append(latent)num_tokens.append(pos_embed.size(1))shapes.append([height, width])if not return_list:latents = torch.cat(patched_latents, dim=0)else:latents = patched_latentselse:height, width = latents.shape[-2:]if is_input_images:latents = self.input_x_embedder(latents)else:latents = self.x_embedder(latents)pos_embed = self.cropped_pos_embed(height, width) latents = latents + pos_embednum_tokens = latents.size(1)shapes = [height, width]return latents, num_tokens, shapes
将输入的图像信息嵌入文本嵌入
整个Omnigen
class OmniGen(nn.Module, PeftAdapterMixin):"""Diffusion model with a Transformer backbone."""def __init__(self,transformer_config: Phi3Config,patch_size=2,in_channels=4,pe_interpolation: float = 1.0,pos_embed_max_size: int = 192,):super().__init__()self.in_channels = in_channelsself.out_channels = in_channelsself.patch_size = patch_sizeself.pos_embed_max_size = pos_embed_max_sizehidden_size = transformer_config.hidden_sizeself.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)self.time_token = TimestepEmbedder(hidden_size)self.t_embedder = TimestepEmbedder(hidden_size)self.pe_interpolation = pe_interpolationpos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)self.initialize_weights()self.llm = Phi3Transformer(config=transformer_config)self.llm.config.use_cache = False@classmethoddef from_pretrained(cls, model_name):if not os.path.exists(model_name):cache_folder = os.getenv('HF_HUB_CACHE')model_name = snapshot_download(repo_id=model_name,cache_dir=cache_folder,ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])config = Phi3Config.from_pretrained(model_name)model = cls(config)if os.path.exists(os.path.join(model_name, 'model.safetensors')):print("Loading safetensors")ckpt = load_file(os.path.join(model_name, 'model.safetensors'))else:ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')model.load_state_dict(ckpt)return modeldef initialize_weights(self):assert not hasattr(self, "llama")# Initialize transformer layers:def _basic_init(module):if isinstance(module, nn.Linear):torch.nn.init.xavier_uniform_(module.weight)if module.bias is not None:nn.init.constant_(module.bias, 0)self.apply(_basic_init)# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):w = self.x_embedder.proj.weight.datann.init.xavier_uniform_(w.view([w.shape[0], -1]))nn.init.constant_(self.x_embedder.proj.bias, 0)w = self.input_x_embedder.proj.weight.datann.init.xavier_uniform_(w.view([w.shape[0], -1]))nn.init.constant_(self.x_embedder.proj.bias, 0)# Initialize timestep embedding MLP:nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)# Zero-out output layers:nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)nn.init.constant_(self.final_layer.linear.weight, 0)nn.init.constant_(self.final_layer.linear.bias, 0)def unpatchify(self, x, h, w):"""x: (N, T, patch_size**2 * C)imgs: (N, H, W, C)"""c = self.out_channelsx = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))x = torch.einsum('nhwpqc->nchpwq', x)imgs = x.reshape(shape=(x.shape[0], c, h, w))return imgsdef cropped_pos_embed(self, height, width):"""Crops positional embeddings for SD3 compatibility."""if self.pos_embed_max_size is None:raise ValueError("`pos_embed_max_size` must be set for cropping.")height = height // self.patch_sizewidth = width // self.patch_sizeif height > self.pos_embed_max_size:raise ValueError(f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}.")if width > self.pos_embed_max_size:raise ValueError(f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}.")top = (self.pos_embed_max_size - height) // 2left = (self.pos_embed_max_size - width) // 2spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]# print(top, top + height, left, left + width, spatial_pos_embed.size())spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])return spatial_pos_embeddef patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):if isinstance(latents, list):return_list = Falseif padding_latent is None:padding_latent = [None] * len(latents)return_list = Truepatched_latents, num_tokens, shapes = [], [], []for latent, padding in zip(latents, padding_latent):height, width = latent.shape[-2:]if is_input_images:latent = self.input_x_embedder(latent)else:latent = self.x_embedder(latent)pos_embed = self.cropped_pos_embed(height, width) latent = latent + pos_embedif padding is not None:latent = torch.cat([latent, padding], dim=-2)patched_latents.append(latent)num_tokens.append(pos_embed.size(1))shapes.append([height, width])if not return_list:latents = torch.cat(patched_latents, dim=0)else:latents = patched_latentselse:height, width = latents.shape[-2:]if is_input_images:latents = self.input_x_embedder(latents)else:latents = self.x_embedder(latents)pos_embed = self.cropped_pos_embed(height, width) latents = latents + pos_embednum_tokens = latents.size(1)shapes = [height, width]return latents, num_tokens, shapesdef forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True):""""""input_is_list = isinstance(x, list)x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1) if input_img_latents is not None:input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)if input_ids is not None:# wen ben tiao jian qian rucondition_embeds = self.llm.embed_tokens(input_ids).clone()input_img_inx = 0for b_inx in input_image_sizes.keys():for start_inx, end_inx in input_image_sizes[b_inx]:condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]input_img_inx += 1if input_img_latents is not None:assert input_img_inx == len(input_latents) # wo men zhi qian zai mei ge tu xian de wei zhi yu liu le 1024 de jian xi ,zhe li hui ba tu xiang de xingxi fang ru qi zhong zuo wei tiaojianinput_emb = torch.cat([condition_embeds, time_token, x], dim=1)else:input_emb = torch.cat([time_token, x], dim=1)# ran hou jiang de dao de suoyou qian ru wenben ,shijian .noist *.3072 fang ru LLMoutput = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values)output, past_key_values = output.last_hidden_state, output.past_key_valuesif input_is_list:image_embedding = output[:, -max(num_tokens):]time_emb = self.t_embedder(timestep, dtype=x.dtype)x = self.final_layer(image_embedding, time_emb)latents = []for i in range(x.size(0)):latent = x[i:i+1, :num_tokens[i]]latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])latents.append(latent)else:image_embedding = output[:, -num_tokens:]time_emb = self.t_embedder(timestep, dtype=x.dtype)x = self.final_layer(image_embedding, time_emb)latents = self.unpatchify(x, shapes[0], shapes[1])if return_past_key_values:return latents, past_key_valuesreturn latents@torch.no_grad()def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache):"""Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.""" self.llm.config.use_cache = use_kv_cachemodel_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True)if use_img_cfg:cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)model_out = [cond, cond, cond]else:cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0)cond = uncond + cfg_scale * (cond - uncond)model_out = [cond, cond]return torch.cat(model_out, dim=0), past_key_values@torch.no_grad()def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, return_past_key_values=True):"""Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.""" self.llm.config.use_cache = use_kv_cacheif past_key_values is None:past_key_values = [None] * len(attention_mask)x = torch.split(x, len(x) // len(attention_mask), dim=0)timestep = timestep.to(x[0].dtype)timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)model_out, pask_key_values = [], []for i in range(len(input_ids)):temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values[i])model_out.append(temp_out)pask_key_values.append(temp_pask_key_values)if len(model_out) == 3:cond, uncond, img_cond = model_outcond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)model_out = [cond, cond, cond]elif len(model_out) == 2:cond, uncond = model_outcond = uncond + cfg_scale * (cond - uncond)model_out = [cond, cond]else:return model_out[0]return torch.cat(model_out, dim=0), pask_key_values
将输入的数据都映射为3072的维度
有引导图像的话就会有3组噪声,分别是文本引导,无条件引导,图像引导
然后三个类别的信息按照,文本条件,时间嵌入,噪声的顺序concatenate在一起放入LLM作为条件输入来预测噪声。
这里的文本条件由于之前我们会将其中的图像位置用1024个进行占位,后续会把这个占位的地方替换为对应的图像嵌入。
forward函数
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True):""""""input_is_list = isinstance(x, list)x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1) if input_img_latents is not None:input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)if input_ids is not None:# wen ben tiao jian qian rucondition_embeds = self.llm.embed_tokens(input_ids).clone()input_img_inx = 0for b_inx in input_image_sizes.keys():# 这里会将之前记录的图像的位置for start_inx, end_inx in input_image_sizes[b_inx]:condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]input_img_inx += 1if input_img_latents is not None:assert input_img_inx == len(input_latents) # wo men zhi qian zai mei ge tu xian de wei zhi yu liu le 1024 de jian xi ,zhe li hui ba tu xiang de xingxi fang ru qi zhong zuo wei tiaojianinput_emb = torch.cat([condition_embeds, time_token, x], dim=1)else:input_emb = torch.cat([time_token, x], dim=1)# ran hou jiang de dao de suoyou qian ru wenben ,shijian .noist *.3072 fang ru LLMoutput = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values)output, past_key_values = output.last_hidden_state, output.past_key_valuesif input_is_list:image_embedding = output[:, -max(num_tokens):]time_emb = self.t_embedder(timestep, dtype=x.dtype)x = self.final_layer(image_embedding, time_emb)latents = []for i in range(x.size(0)):latent = x[i:i+1, :num_tokens[i]]latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])latents.append(latent)else:image_embedding = output[:, -num_tokens:]time_emb = self.t_embedder(timestep, dtype=x.dtype)x = self.final_layer(image_embedding, time_emb)latents = self.unpatchify(x, shapes[0], shapes[1])if return_past_key_values:return latents, past_key_valuesreturn latents
最后将输出的x选择其最后的noise,1,1024,3072,然后将其采用线性映射映射为1,1024,16。最后在改变形状变为1,4,64,64
#最后将输出的x选择其最后的noise,1024
image_embedding = output[:, -num_tokens:]
time_emb = self.t_embedder(timestep, dtype=x.dtype)
#这一步手x转化为1,1024,16
x = self.final_layer(image_embedding, time_emb)
#在最后将1,1024,16转化为1,4,64,64
latents = self.unpatchify(x, shapes[0], shapes[1])
class FinalLayer(nn.Module):"""The final layer of DiT."""def __init__(self, hidden_size, patch_size, out_channels):super().__init__()self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)self.adaLN_modulation = nn.Sequential(nn.SiLU(),nn.Linear(hidden_size, 2 * hidden_size, bias=True))def forward(self, x, c):#把输入的时间嵌入转化为两份 1,3072shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)#将x利用t进行线性变换 x:1,1024,3072x = modulate(self.norm_final(x), shift, scale)#将1,1024,3072 转化为1,1024,16x = self.linear(x)return x
modulate模块:进行线性变换
def modulate(x, shift, scale):return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
FinalLayer((norm_final): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)(linear): Linear(in_features=3072, out_features=16, bias=True)(adaLN_modulation): Sequential((0): SiLU()(1): Linear(in_features=3072, out_features=6144, bias=True))
)
在最后将1,1024,16转化为1,4,64,64
def unpatchify(self, x, h, w):"""x: (N, T, patch_size**2 * C)imgs: (N, H, W, C)"""c = self.out_channelsx = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))x = torch.einsum('nhwpqc->nchpwq', x)imgs = x.reshape(shape=(x.shape[0], c, h, w))return imgs
总结
OmniGen整个模型采用了类似DiT的架构
对于输入的文本信息,图像信息:
文本信息只进行分词,变为一个一个tokenid,图像信息用vae编码为4,64,64后映射为1024个token
然后按照文本tokenid,时间步嵌入,图像token,以及noise_token的顺序concatenate在一起得到我们的总的输入条件
然后调用一个phi3模型来进行噪声的预测(里面是Dit架构,只有selfattn)
因为有三种条件输入,用户的提示词,消极提示词,图像提示词。所以会构成3个条件
所以,对于以上过程需要重复三次。最后将三者得到的噪声预测按照比例进行加权,得到最终的噪声。
然后迭代这个过程50步得到最后的图像
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)