当前位置: 首页 > news >正文

SAM核心代码注释总结

最近看sam2,顺便注释了下代码,方便回顾和分享。
PS: tensor的维度都基于默认参数配置。

SAM

_build_sam

sam模块包含三个部分,ImageEncoderViT、PromptEncoder和MaskDecoder:

def _build_sam(encoder_embed_dim,encoder_depth,encoder_num_heads,encoder_global_attn_indexes,checkpoint=None,
):prompt_embed_dim = 256image_size = 1024vit_patch_size = 16image_embedding_size = image_size // vit_patch_sizesam = Sam(# 普通的VIT模型, 对image进行encodingimage_encoder=ImageEncoderViT(depth=encoder_depth,embed_dim=encoder_embed_dim,img_size=image_size,mlp_ratio=4,norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),num_heads=encoder_num_heads,patch_size=vit_patch_size,qkv_bias=True,# 是否使用PEuse_rel_pos=True,# global_attn_indexes和window_size搭配# 如果当前block索引不在global_attn_indexes则使用window_size的局部attnglobal_attn_indexes=encoder_global_attn_indexes,window_size=14,out_chans=prompt_embed_dim,),# prompt,包括point, box, mask # point支持多个,需要对应的label(1:fg   0:bg)prompt_encoder=PromptEncoder(embed_dim=prompt_embed_dim,image_embedding_size=(image_embedding_size, image_embedding_size),input_image_size=(image_size, image_size),mask_in_chans=16,),# mask解码器mask_decoder=MaskDecoder(# 输出mask个数。默认为3.解决prompt-ambiguous。num_multimask_outputs=3,# image-to-prompt和prompt-to-image的cross-attntransformer=TwoWayTransformer(depth=2,embedding_dim=prompt_embed_dim,mlp_dim=2048,num_heads=8,),transformer_dim=prompt_embed_dim,iou_head_depth=3,iou_head_hidden_dim=256,),pixel_mean=[123.675, 116.28, 103.53],pixel_std=[58.395, 57.12, 57.375],)sam.eval()if checkpoint is not None:with open(checkpoint, "rb") as f:state_dict = torch.load(f)sam.load_state_dict(state_dict)return sam

ImageEncoderViT

就是一个传统的vit结构,默认参数配置下[B, 3, 1024, 1024] - > [B, 256, 64, 64]。

PromptEncoder

对prompt(点,框,mask)进行embeding。最终的输出维度如下:

    # dense_embeddings   Bx256x64x64# sparse_embeddings  Bx(N+1)x256, N为点的个数
class PromptEncoder(nn.Module):def __init__(self,embed_dim: int,image_embedding_size: Tuple[int, int],input_image_size: Tuple[int, int],mask_in_chans: int,activation: Type[nn.Module] = nn.GELU,) -> None:"""Encodes prompts for input to SAM's mask decoder.Arguments:embed_dim (int): The prompts' embedding dimensionimage_embedding_size (tuple(int, int)): The spatial size of theimage embedding, as (H, W).input_image_size (int): The padded size of the image as inputto the image encoder, as (H, W).mask_in_chans (int): The number of hidden channels used forencoding input masks.activation (nn.Module): The activation to use when encodinginput masks."""super().__init__()self.embed_dim = embed_dim# 用于坐标归一化self.input_image_size = input_image_size# 图像经patch处理之后的width、height,容易和embeding_dim混淆self.image_embedding_size = image_embedding_size# PE,sin-cosself.pe_layer = PositionEmbeddingRandom(embed_dim // 2)# 统一point和box的embeding向量self.num_point_embeddings: int = 4  # pos/neg point + 2 box corners# 对point是否有效进行embeding,然后加到PE上point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]self.point_embeddings = nn.ModuleList(point_embeddings)# 如果point为无效的(比如pad的),则用下面这个self.not_a_point_embed = nn.Embedding(1, embed_dim)self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])# mask下采样模块self.mask_downscaling = nn.Sequential(nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),LayerNorm2d(mask_in_chans // 4),activation(),nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),LayerNorm2d(mask_in_chans),activation(),nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),)self.no_mask_embed = nn.Embedding(1, embed_dim)def get_dense_pe(self) -> torch.Tensor:"""Returns the positional encoding used to encode point prompts,applied to a dense set of points the shape of the image encoding.Returns:torch.Tensor: Positional encoding with shape1x(embed_dim)x(embedding_h)x(embedding_w)"""return self.pe_layer(self.image_embedding_size).unsqueeze(0)def _embed_points(self,points: torch.Tensor,labels: torch.Tensor,pad: bool,) -> torch.Tensor:"""Embeds point prompts."""points = points + 0.5  # Shift to center of pixel# pad,保持和box统一if pad:padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)points = torch.cat([points, padding_point], dim=1)labels = torch.cat([labels, padding_label], dim=1)## PEpoint_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)## 叠加label标记point_embedding[labels == -1] = 0.0point_embedding[labels == -1] += self.not_a_point_embed.weightpoint_embedding[labels == 0] += self.point_embeddings[0].weightpoint_embedding[labels == 1] += self.point_embeddings[1].weightreturn point_embeddingdef _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:"""Embeds box prompts."""## 整体思路同_embed_pointsboxes = boxes + 0.5  # Shift to center of pixelcoords = boxes.reshape(-1, 2, 2)corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)corner_embedding[:, 0, :] += self.point_embeddings[2].weightcorner_embedding[:, 1, :] += self.point_embeddings[3].weightreturn corner_embeddingdef _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:"""Embeds mask inputs."""mask_embedding = self.mask_downscaling(masks)return mask_embeddingdef _get_batch_size(self,points: Optional[Tuple[torch.Tensor, torch.Tensor]],boxes: Optional[torch.Tensor],masks: Optional[torch.Tensor],) -> int:"""Gets the batch size of the output given the batch size of the input prompts."""if points is not None:return points[0].shape[0]elif boxes is not None:return boxes.shape[0]elif masks is not None:return masks.shape[0]else:return 1def _get_device(self) -> torch.device:return self.point_embeddings[0].weight.devicedef forward(self,points: Optional[Tuple[torch.Tensor, torch.Tensor]],boxes: Optional[torch.Tensor],masks: Optional[torch.Tensor],) -> Tuple[torch.Tensor, torch.Tensor]:"""Embeds different types of prompts, returning both sparse and denseembeddings.Arguments:points (tuple(torch.Tensor, torch.Tensor) or none): point coordinatesand labels to embed.boxes (torch.Tensor or none): boxes to embedmasks (torch.Tensor or none): masks to embedReturns:torch.Tensor: sparse embeddings for the points and boxes, with shapeBxNx(embed_dim), where N is determined by the number of input pointsand boxes.torch.Tensor: dense embeddings for the masks, in the shapeBx(embed_dim)x(embed_H)x(embed_W)"""bs = self._get_batch_size(points, boxes, masks)# sparse_embeddings只是为了保证函数返回形式统一,即点和框都为NONE的时候返回一个空的tensorsparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())if points is not None:coords, labels = pointspoint_embeddings = self._embed_points(coords, labels, pad=(boxes is None))sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)if boxes is not None:box_embeddings = self._embed_boxes(boxes)sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)# mask,来源有几种# 1. 用户指定的low-resolution mask# 2. 上一次预测的maskif masks is not None:# 4倍下采样dense_embeddings = self._embed_masks(masks)else:dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(bs, -1, self.image_embedding_size[0], self.image_embedding_size[1])# dense_embeddings   Bx256x64x64# sparse_embeddings  Bx4x256return sparse_embeddings, dense_embeddings

MaskDecoder

MaskDecoder

class MaskDecoder(nn.Module):def __init__(self,*,transformer_dim: int,transformer: nn.Module,num_multimask_outputs: int = 3,activation: Type[nn.Module] = nn.GELU,iou_head_depth: int = 3,iou_head_hidden_dim: int = 256,) -> None:"""Predicts masks given an image and prompt embeddings, using atransformer architecture.Arguments:transformer_dim (int): the channel dimension of the transformertransformer (nn.Module): the transformer used to predict masksnum_multimask_outputs (int): the number of masks to predictwhen disambiguating masksactivation (nn.Module): the type of activation to use whenupscaling masksiou_head_depth (int): the depth of the MLP used to predictmask qualityiou_head_hidden_dim (int): the hidden dimension of the MLPused to predict mask quality"""super().__init__()# transformer的编码维度self.transformer_dim = transformer_dim# mask预测,twoway-transformerself.transformer = transformerself.num_multimask_outputs = num_multimask_outputs# iou预测tokenself.iou_token = nn.Embedding(1, transformer_dim)# 从代码看,+1是为了匹配非multi_mask的情况self.num_mask_tokens = num_multimask_outputs + 1# mask预测tokenself.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)# 对mask预测值上采样self.output_upscaling = nn.Sequential(nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),LayerNorm2d(transformer_dim // 4),activation(),nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),activation(),)# mask-MLPself.output_hypernetworks_mlps = nn.ModuleList([MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)for i in range(self.num_mask_tokens)])# iou预测头self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)def forward(self,image_embeddings: torch.Tensor,image_pe: torch.Tensor,sparse_prompt_embeddings: torch.Tensor,dense_prompt_embeddings: torch.Tensor,multimask_output: bool,) -> Tuple[torch.Tensor, torch.Tensor]:"""Predict masks given image and prompt embeddings.Arguments:image_embeddings (torch.Tensor): the embeddings from the image encoderimage_pe (torch.Tensor): positional encoding with the shape of image_embeddingssparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxesdense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputsmultimask_output (bool): Whether to return multiple masks or a singlemask.Returns:torch.Tensor: batched predicted maskstorch.Tensor: batched predictions of mask quality"""masks, iou_pred = self.predict_masks(image_embeddings=image_embeddings,image_pe=image_pe,sparse_prompt_embeddings=sparse_prompt_embeddings,dense_prompt_embeddings=dense_prompt_embeddings,)# Select the correct mask or masks for outputif multimask_output:mask_slice = slice(1, None)else:mask_slice = slice(0, 1)masks = masks[:, mask_slice, :, :]iou_pred = iou_pred[:, mask_slice]# Prepare outputreturn masks, iou_preddef predict_masks(self,image_embeddings: torch.Tensor, #[B, 256, 64, 64]image_pe: torch.Tensor,sparse_prompt_embeddings: torch.Tensor,dense_prompt_embeddings: torch.Tensor,) -> Tuple[torch.Tensor, torch.Tensor]:"""Predicts masks. See 'forward' for more details."""# Concatenate output tokenspdb.set_trace()# 拼接iou_token和mask_tokens,分别预测iou和mask# [5, 256]output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)# 扩展到batch维度# [B, 5, 256]output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)# 拼接sparse_prompt_embeding# [B, 8, 256]tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)# Expand per-image data in batch direction to be per-mask# [B, 256, 64, 64]src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)# image_embeding 和 dense_prompt_embedings进行element-wise add# [B, 256, 64, 64]src = src + dense_prompt_embeddings# 扩展image_pe# [B, 256, 64, 64]pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)b, c, h, w = src.shape# Run the transformer# 下面的部分对应论文Figure.14。# hs, 对应论文中的output_token [B, 8, 256]# src, attn后的image_embeding [B, 4096, 256] (PS: 4096=64x64)hs, src = self.transformer(src, pos_src, tokens)# 取出iou_token [B, 256]iou_token_out = hs[:, 0, :]# 取出mask_tokens [B, 4, 256]mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]# Upscale mask embeddings and predict masks using the mask tokens# reshape回去 [B, 256, 64, 64]src = src.transpose(1, 2).view(b, c, h, w)# 上采样 [B, 32, 256, 256]upscaled_embedding = self.output_upscaling(src)hyper_in_list: List[torch.Tensor] = []# 把每个mask_token送入各自的MLPfor i in range(self.num_mask_tokens):hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))# 拼接 [B, 4, 32]hyper_in = torch.stack(hyper_in_list, dim=1)b, c, h, w = upscaled_embedding.shape# image_embeding和mask_tokens进行矩阵乘得到最终的masks [B, 4, 256, 256]masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)# Generate mask quality predictions# [B, 4]iou_pred = self.iou_prediction_head(iou_token_out)# mask经过上采样X4后,就和image一致了return masks, iou_pred

SAM-HQ

对比sam,主要区别有这几个:

  1. global-loca fusion。高频-低频特征融合,类似于FPN,提升微小mask的精度;
  2. 添加了HQ-OUTPUT TOKEN。保持原始结构不变,只微调该分支,类似于lora,可以保持原始sam部分能力。
    sam-hq

比较mask_decoder_hq.py和mask_decoder.py,构造函数里面主要添加了几个OP,如下:

        # HQ-SAM parametersself.hf_token = nn.Embedding(1, transformer_dim) # HQ-Ouptput-Tokenself.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) # corresponding new MLP layer for HQ-Ouptput-Token# three conv fusion layers for obtaining HQ-Featureself.compress_vit_feat = nn.Sequential(nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2),LayerNorm2d(transformer_dim),nn.GELU(), nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2))self.embedding_encoder = nn.Sequential(nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),LayerNorm2d(transformer_dim // 4),nn.GELU(),nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),)self.embedding_maskfeature = nn.Sequential(nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1), LayerNorm2d(transformer_dim // 4),nn.GELU(),nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1))

具体的运算代码:

def forward(self,image_embeddings: torch.Tensor,image_pe: torch.Tensor,sparse_prompt_embeddings: torch.Tensor,dense_prompt_embeddings: torch.Tensor,multimask_output: bool,hq_token_only: bool,interm_embeddings: torch.Tensor,) -> Tuple[torch.Tensor, torch.Tensor]:"""Predict masks given image and prompt embeddings.Arguments:image_embeddings (torch.Tensor): the embeddings from the ViT image encoderimage_pe (torch.Tensor): positional encoding with the shape of image_embeddingssparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxesdense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputsmultimask_output (bool): Whether to return multiple masks or a singlemask.Returns:torch.Tensor: batched predicted maskstorch.Tensor: batched predictions of mask quality"""# 首先,取出浅层的image_embeding,这个时候的特征感受野比较小,高频特征vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT# 特征融合。类似于检测模型里面的FPN,至此,高频+低频特征融合完成hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features)# mask预测函数入口masks, iou_pred = self.predict_masks(image_embeddings=image_embeddings,image_pe=image_pe,sparse_prompt_embeddings=sparse_prompt_embeddings,dense_prompt_embeddings=dense_prompt_embeddings,hq_features=hq_features,)# Select the correct mask or masks for outputif multimask_output:# mask with highest scoremask_slice = slice(1,self.num_mask_tokens-1)iou_pred = iou_pred[:, mask_slice]iou_pred, max_iou_idx = torch.max(iou_pred,dim=1)iou_pred = iou_pred.unsqueeze(1)masks_multi = masks[:, mask_slice, :, :]masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1)else:# singale mask output, defaultmask_slice = slice(0, 1)iou_pred = iou_pred[:,mask_slice]masks_sam = masks[:,mask_slice]masks_hq = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens)]if hq_token_only:masks = masks_hqelse:masks = masks_sam + masks_hq# Prepare outputreturn masks, iou_pred

predict_masks函数注释:

def predict_masks(self,image_embeddings: torch.Tensor,image_pe: torch.Tensor,sparse_prompt_embeddings: torch.Tensor,dense_prompt_embeddings: torch.Tensor,hq_features: torch.Tensor,) -> Tuple[torch.Tensor, torch.Tensor]:"""Predicts masks. See 'forward' for more details."""# Concatenate output tokensoutput_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0)output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)# Expand per-image data in batch direction to be per-masksrc = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)src = src + dense_prompt_embeddingspos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)b, c, h, w = src.shape# Run the transformerhs, src = self.transformer(src, pos_src, tokens)iou_token_out = hs[:, 0, :]mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]# Upscale mask embeddings and predict masks using the mask tokenssrc = src.transpose(1, 2).view(b, c, h, w)upscaled_embedding_sam = self.output_upscaling(src)# 上面部分代码和sam几乎一样,下面部分是关于HQ# 融合sam和hq特征,对应论文中的global-local fusionupscaled_embedding_hq = self.embedding_maskfeature(upscaled_embedding_sam) + hq_features.repeat(b,1,1,1)hyper_in_list: List[torch.Tensor] = []# 对mask_token_out进行MLP,对应论文的updated hq-output tokenfor i in range(self.num_mask_tokens):if i < self.num_mask_tokens - 1:hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))else:hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :]))hyper_in = torch.stack(hyper_in_list, dim=1)b, c, h, w = upscaled_embedding_sam.shape# mask_out和global-local fusion进行矩阵乘法masks_sam = (hyper_in[:,:self.num_mask_tokens-1] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w)masks_sam_hq = (hyper_in[:,self.num_mask_tokens-1:] @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w)masks = torch.cat([masks_sam,masks_sam_hq],dim=1)# Generate mask quality predictionsiou_pred = self.iou_prediction_head(iou_token_out)return masks, iou_pred

FastSAM

两个独立部分:

  1. 全图实例分割。输出图片中所有目标mask(只有一个类别)。
  2. prompt匹配。对prompt进行编码,然后对上一步的输出的mask进行匹配。简单的如box、point直接通过点的位置、IOU等方式匹配,如果是text则用clip进行embeding,然后计算相似度。

相比于sam,fastsam在seg-anything模式下因为不需要进行稠密prompt采样,因此输出mask会更快。这也是mobilesamV2改进的方向。
在这里插入图片描述

MobileSAM

这个更简单,整体的逻辑都是沿用sam,对sam的image-encoder(ViT)进行蒸馏到轻量级网络(Tiny-ViT),减少网络尺寸和耗时。
看下tinyvit的方法就差不多了:
TinyViT: Fast Pretraining Distillation for Small Vision Transformers
在这里插入图片描述

还有其他几个:

  • Ground-SAM。text-detection-segment,侧重文字交互式的进行检测和分割。
  • Semantic-SAM。着重优化目标局部和整体之间的关系、分割。
  • sam2。 引入了track思想,不需要逐帧prompt的end2end连续帧的分割。

应用

sam应用还是挺广的,主要负责抠图,然后对这些区域进行擦除、替换、修复等。

  • https://github.com/geekyutao/Inpaint-Anything
  • https://github.com/advimman/lama

引用

https://github.com/facebookresearch/segment-anything
https://github.com/SysCV/SAM-HQ
https://github.com/CASIA-IVA-Lab/FastSAM
https://github.com/ChaoningZhang/MobileSAM
https://github.com/IDEA-Research/Grounded-Segment-Anything
https://github.com/UX-Decoder/Semantic-SAM


http://www.mrgr.cn/news/35690.html

相关文章:

  • AJAX(简介以及一些用法)
  • 微信小程序开发
  • 老板回来,我不知道——观察者模式
  • SX_c程序的编译_24
  • Docker搭建 RabbitMQ 最新版
  • 修牛蹄视频哪里找?修牛蹄的解压视频素材网站分享
  • API代理是什么?解读其原理与作用
  • Unity场景内画车道线(根据五阶曲线系数)
  • 第Y1周:调用官方权重进行检测
  • Web3技术解析:区块链与智能合约的角色
  • JAVA开源项目 体育馆管理系统 计算机毕业设计
  • 深入解析 helpTransfer 方法:多线程协作中的哈希表扩容
  • java启动参数JAVA OPT不生效问题
  • Ollama在Windows安装,使用,简单调用API
  • 鸿蒙之setTimeout问题
  • 高级算法LLM大语言模型算法特训 带你转型AI大语言模型算法工程师
  • 大模型分布式训练并行技术(二)-数据并行
  • 最大似然估计,存在即合理
  • Vue+Tui-image-editor实现图片编辑(涂鸦,裁剪,标注,旋转,滤镜)
  • pdf.js滚动翻页的例子