当前位置: 首页 > news >正文

【多模态】CLIP模型技术学习

前言

最近多模态太火了,学了MiniCPM-V、ViT和transformer,现在开始学CLIP,笔记记录一下,如果有理解不到位的欢迎批评指正。

不需要下游任务微调,图-文对比学习训练的模型就能胜任下游任务?!

  • CLIP出自OpenAI发表在ICML 2021的论文Learning Transferable Visual Models From Natural Language Supervision
  • 文中提出了一种图文跨模态对齐的方法(基于对比学习),并且发现训练的模型可以很好地泛化到新的任务
    在这里插入图片描述

CLIP模型架构

CLIP包含一个文本编码器和一个图像编码器

  • 文本编码器:Transformer
  • 图像编码器:ViT,或者ResNet
    在这里插入图片描述

CLIP模型代码结构

  • CLIP类的代码,forward()里获取图片和文本特征
  • 文本特征提取与bert类似,取了[EOS]对应的向量,self.text_projection是把文本投影到图-文隐空间维度,这里没有加激活函数使用非线性变换
  • 提取图片特征后也是会乘一个proj矩阵,让文本和图片特征统一到同一个维度
def encode_image(self, image):return self.visual(image.type(self.dtype))
def encode_text(self, text):x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]x = x + self.positional_embedding.type(self.dtype)x = x.permute(1, 0, 2)  # NLD -> LND,batch_first的改一下维度,在之前ViT的代码中可以看到x = self.transformer(x)x = x.permute(1, 0, 2)  # LND -> NLDx = self.ln_final(x).type(self.dtype)# x.shape = [batch_size, n_ctx, transformer.width]# take features from the eot embedding (eot_token is the highest number in each sequence)x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projectionreturn x
def forward(self, image, text):image_features = self.encode_image(image)text_features = self.encode_text(text)return image_features, text_features

CLIP模型代码结构——ResNet图片特征

在CLIP中如果使用Resnet作为图片编码器,相比于原始的ResNet文中做了一点修改,在layer4中后面没有使用avg_pool计算均值,而是使用一个AttentionPool2d层计算加权均值
在这里插入图片描述
AttentionPool2d

  • CLIP类的代码中,AttentionPool2d里面的query向量是输入特征的均值
  • 在多头注意力模块MHA中,输入query大小为[1,N,C],key=value大小为[(HW+1),N,C]
  • MHA模块的输出为和输入query是一样的大小为[1,N,C],最终返回的为[N,C]
# AttentionPool2d
def forward(self, x):x = x.flatten(start_dim=2).permute(2, 0, 1)  # [N,C,H,W] -> [(HW),N,C]x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # [(HW+1),N,C]x = x + self.positional_embedding[:, None, :].to(x.dtype)  # [(HW+1),N,C]x, _ = F.multi_head_attention_forward(query=x[:1], key=x, value=x,out_proj_weight=self.c_proj.weight)return x.squeeze(0)

CLIP 模型代码结构——inference

推理过程

  • 如果要做分类任务,先需要写一个prompt,label是分类任务的标签,例如cifar-100分类,需要写This is a photo of a cat / This is a photo of a dog …一百个prompt作为文本描述
  • 然后@运算符进行矩阵乘法,计算要推理的图和文本的相似度
  • Label可以是训练集中没有的,实现zero-shot
text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()
with torch.no_grad():image_features = model.encode_image(image_input).float()text_features = model.encode_text(text_tokens).float()text_features /= text_features.norm(dim=-1, keepdim=True)text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)

CLIP 模型代码结构——train

  • 与推理过程类似,计算图文向量的相似度
  • 训练过程中恰好相似度矩阵的对角线就是图文匹配的label
image_features = model.encode_image(image_input)
text_features = model.encode_text(text_tokens)image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()labels = torch.arange(len(logits_per_image)).to(logits_per_image.device)
image_loss = F.cross_entropy(logits_per_image, labels)
text_loss  = F.cross_entropy(logits_per_text, labels)
loss = (image_loss + text_loss) / 2

下一篇

下一篇应该可以看看SAM、MiniCPM-V里面压缩图片编码时采用的类似Q-former结构、SigLip或者MiniCPM-V的文本LLM了


http://www.mrgr.cn/news/53648.html

相关文章:

  • Java应用程序的测试覆盖率之设计与实现(三)-- jacoco cli 客户端
  • Flink 大数据实战演练02 实现篇
  • 16天自制CppServer-day01
  • Android 从0搭建初始化MVVM项目框架(二):添加版本依赖管理、分包分模块、组件化Aroute
  • YOLOv8模型改进 第十三讲 添加卷积和注意力融合模块(CAFM) 提升小目标和遮挡检测
  • PSPICE FOR TI笔记记录1
  • 2024批量下载公众号文章内容/话题/图片/封面/视频/音频,导出excel和pdf,文章数据包含阅读数/点赞数/分享数/留言数
  • 普通java web项目转为maven项目
  • 原地移除数组中所有的元素val 含源码
  • 如何快速学会盲打
  • 2024.09.27校招 实习 内推 面经
  • 5步轻松上手!零基础也能掌握Go语言编程
  • 明日周刊-第23期
  • 性能测试中性能调优的基本原则有哪些
  • 大模型(LLM)推理体系全览
  • SFT、RLHF、DPO、IFT —— LLM 微调的进化之路_如何搭建自己的dpo
  • Cesium for UE-04-一些说明
  • Docker本地镜像发布到阿里云镜像服务的简易指南
  • 从 PDF 表到见解:在 RAG 中解析 PDF 的另一种方法
  • 基于51单片机的电子时钟数码管显示proteus仿真
  • 正则化-权重衰减
  • Vue Google 广告的配置
  • 数据库原理与应用(基于MySQL):实验六数据查询
  • rpm 命令
  • PPT自动化:如何判断PPT中的shape类型(python-pptx中常见shape类型及其代码速查表)
  • 【学习笔记】理解 C++ 中 reinterpret_cast 和 C 风格类型转换的区别