笔记:BLIP源码之(2)模型是如何定义的
模型是怎么定义的:model之前的继承方式是怎么样的,用了什么api,论文里面的一个公式就调用了很多function
调用 blip_retrieval
这个函数,得到本论文用到的model,接下来需要一层一层剖析
model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'], queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'])
blip_retrieval
函数:
def blip_retrieval(pretrained='',**kwargs):# BLIP_Retrieval类传入参数后实例化的对象就是model# 创建 BLIP 模型实例model = BLIP_Retrieval(**kwargs)if pretrained:# 如果指定了预训练模型的路径,则调用 load_checkpoint 函数加载预训练模型model,msg = load_checkpoint(model,pretrained)print("missing keys:")# missing_keys 属性:在加载预训练模型时,模型中存在但在预训练模型文件中缺失的参数键print(msg.missing_keys)return model
当加载预训练模型时,模型的参数通常以键值对的形式存储。每个键表示一个参数变量,对应的值表示该参数的具体数值。
如果缺失了键信息,可能出现了以下情况:模型结构的改变,部分参数未被保存,预训练模型文件损坏等。
打印出缺失的键信息可以帮助我们了解哪些参数键在加载过程中无法获取到对应的数值。这样的信息可能对模型的进一步使用、调试或修复是有帮助的。
接下来看 BLIP_Retrieval
类的代码实现:
首先给出retrieval_coco.yaml
中和BLIP_Retrieval 类有关的配置参数:
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
vit: 'base'
#
vit_grad_ckpt: True
vit_ckpt_layer: 4
image_size: 384
queue_size: 57600
negative_all_rank: True
论文原文:
根据原文可知,self.visual_encoder
就是 image encoder,作用是:把一张输入图像划分成很多patches,并且编码他们成为一个embeddings的序列,还要加入一个 cls token
来表示全局图像特征
。
class BLIP_Retrieval(nn.Module):def __init__(self, med_config = 'configs/med_config.json', image_size = 384,vit = 'base',vit_grad_ckpt = False,vit_ckpt_layer = 0, embed_dim = 256, queue_size = 57600,momentum = 0.995,negative_all_rank = False,):# 在构造函数中,通过调用父类 nn.Module 的构造函数 super().__init__() # 来确保父类的初始化操作被正确执行 super().__init__()self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)self.tokenizer = init_tokenizer()
1. visual encoder
create_vit
函数的定义如下:
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):assert vit in ['base', 'large'], "vit parameter must be base or large"if vit=='base':vision_width = 768visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,drop_path_rate=0 or drop_path_rate) ''' 因为配置的vit对应的值是base,所以省略了large的代码 '''# 返回视觉编码器实例,以及 vision_widthreturn visual_encoder, vision_width
接下来还要找 VisionTransformer
类的代码实现:
class VisionTransformer(nn.Module):def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, use_grad_checkpointing=False, ckpt_layer=0):super().__init__()# 模型的特征数量 和 嵌入维度 是一样的self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models# 创建了一个偏函数,该偏函数将 nn.LayerNorm 类作为函数,同时固定了 eps 参数的值为 1e-6norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)# PatchEmbed 类是用于将输入图像切分成多个大小相等的图像块并进行嵌入的操作# 每个patches都会被转换为一个嵌入向量,该嵌入向量表示了该图像块的特征# 即 img_size=384,patch_size = 16,总共均匀分成 384/16 = 24,这24个图像被转成embed_size = 768(16x16x3)self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)# 384/16 = 24num_patches = self.patch_embed.num_patches# "CLS" 标记,代表 "classification"。它的目的是为了让模型能够在处理图像时同时学习到全局信息# self.cls_token 是一个可学习的参数,它的值将在训练过程中通过反向传播进行更新,以适应特定的任务和数据。# 通过这种方式,模型可以在学习过程中适应不同的图像分类、检测、生成等任务,并捕捉到全局信息的重要性self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))# 表示位置编码(Positional Encoding),用于为每个图像块(包括 "CLS" 标记)提供位置信息。self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))# 在位置编码后对其进行随机丢弃操作,以减少过拟合# drop_rate 是丢弃概率,控制要丢弃的元素比例self.pos_drop = nn.Dropout(p=drop_rate)# drop_path_rate = 0,depth = 12# torch.linspace(0, drop_path_rate, depth) 创建了一个张量,其中包含了从 0 到 drop_path_rate 之间的# depth 个均匀间隔的值,将张量中的每个值转换为 Python 数值,并将这些数值存储在列表 dpr 中# Drop Path 应用于 transformer encoder的attention 和 多层感知机(MLP)中# 在下文的block类中可以看到# 通过随机丢弃一部分连接来增加模型的泛化能力dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule# block就是 Transformer Encoderself.blocks = nn.ModuleList([Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,# 当前自注意力层是最后几层之一,需要应用梯度检查点技术,i>= 12-4use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer))for i in range(depth)])self.norm = norm_layer(embed_dim)# 截断正态分布是一种正态分布的变体,它将生成的值限制在一定的范围内,以避免生成过大或过小的值。# 这里的 std=.02 表示生成的值的标准差为 0.02trunc_normal_(self.pos_embed, std=.02) # 使用截断正态分布初始化 self.pos_embed 参数trunc_normal_(self.cls_token, std=.02) # 截断正态分布初始化 self.cls_token 参数# 对模型进行初始化self.apply(self._init_weights)# 是一个模型的初始化方法。它通过遍历模型的所有模块,对线性层和层归一化(LayerNorm)层进行特定的参数初始化def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:# 使用常数初始化nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0) # 将偏置 m.bias 的值设为 0nn.init.constant_(m.weight, 1.0)def no_weight_decay(self):# 指定不需要进行权重衰减(weight decay)的参数return {'pos_embed', 'cls_token'}def forward(self, x, register_blk=-1):B = x.shape[0] # x:(B,N,D)x = self.patch_embed(x)# position of cls: (B, 0, D)# self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))# expand 是用来扩展维度的,在第一个维度上复制 B 次,保持第二维度和第三维度不变# 因此,cls_token的shape从(1,1,768)变成(B,1,D)cls_tokens = self.cls_token.expand(B, -1, -1)# 在第一个维度进行拼接,x变成 (B,N+1,D)x = torch.cat((cls_tokens, x), dim=1)# 每个输入张量中的补丁位置添加位置信息x = x + self.pos_embed[:,:x.size(1),:]x = self.pos_drop(x) # 至此,得到了 Transformer Encoder的输入# x会进入 Transformer Encoderfor i,blk in enumerate(self.blocks):x = blk(x, register_blk==i)# Transformer Encoder的输出 会再经过 normx = self.norm(x)# x.shape = (B, N+1, D),N:num of patches D:dimension of a patchreturn x# 加载预训练模型的参数def load_pretrained(self, checkpoint_path, prefix=''):""" Load weights from .npz checkpoints for official Google Brain Flax implementation"""_load_weights(self, checkpoint_path, prefix)
因为调用了Block
类,所以附上论文的图,注意,这个Block定义的就是框出来的部分:先是layer Norm,再是Attention,再是residual,之后又接上Norm,再有residual,即:residual在每一块结束,Norm在每一块开始前:
贴上代码:
class Block(nn.Module):def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):super().__init__()self.norm1 = norm_layer(dim)# Attention模块中 multi-head后还跟了project,因此,既有attn_drop,还有proj_dropself.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here# 如果drop_path <= 0.则不做操作self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()# 第2个Normself.norm2 = norm_layer(dim)# mlp_ratio = 4, MLP 比例参数,表示隐藏层维度相对于输入维度的比例# mlp_hidden_dim :768 x 4mlp_hidden_dim = int(dim * mlp_ratio)# mlp层self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)# use_grad_checkpointing = true,则会对自注意力模块和 MLP 模块进行梯度检查点封装if use_grad_checkpointing: # 是否使用梯度检查点技术# 对attention和mlp进行梯度检查点封装self.attn = checkpoint_wrapper(self.attn)self.mlp = checkpoint_wrapper(self.mlp)def forward(self, x, register_hook=False):# 经过attentin和mlp后都进行一下 drop path 再进行residual addx = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))x = x + self.drop_path(self.mlp(self.norm2(x)))return x
梯度检查点技术的目的是为了减少计算和内存消耗,特别是在模型中存在大量的计算图时。通过使用梯度检查点技术,可以将计算图中的一部分操作在前向传播时计算并保存,而在反向传播时只需计算梯度,从而减少内存占用和计算时间。
通过对自注意力模块和 MLP 模块应用梯度检查点封装,可以在一定程度上优化模型的计算和内存消耗
Block中又调用了Attention
类,注意:多头注意力后面还跟了一个线性层:
class Attention(nn.Module):def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):super().__init__()self.num_heads = num_headshead_dim = dim // num_heads # 768//12 = 64# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weightsself.scale = qk_scale or head_dim ** -0.5 # 64 ** -0.5 = 1/8self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.attn_gradients = Noneself.attention_map = None
'''省略部分代码'''def forward(self, x, register_hook=False):B, N, C = x.shape # batch_size N(patches) Dimension(768)# 经过self.qkv(x)得到 (B N 768*3) ,reshape 之后成 (B,N,3,12,64),再permute成(3,B,12,N,64)qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)# q,k,v的shape:(B,12,N,64)q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)# q:(B,12,N,64) k.transpose(-2, -1):(B,12,64,N) 得到attn的shape:(B,12,N,N)attn = (q @ k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)if register_hook:# 保存注意力图和注意力梯度self.save_attention_map(attn)# 通过 register_hook 方法将 self.save_attn_gradients 函数# 注册为注意力权重张量 attn 的梯度钩子函数。这样,在计算注意力# 权重的梯度时,钩子函数将被调用并执行自定义的操作,例如保存梯度值或进行其他处理。attn.register_hook(self.save_attn_gradients) # (B,12,N,N) * (B,12,N,64) = (B,12,N,64) transpose后:(B,N,12,64) ,reshape 后:(B,N,768)x = (attn @ v).transpose(1, 2).reshape(B, N, C)# 接了一个线性层x = self.proj(x)x = self.proj_drop(x)return x
register_hook
是 PyTorch 中的一个方法,用于注册一个钩子函数(hook function)到张量上。钩子函数可以在张量的梯度计算过程中执行自定义操作,例如记录梯度、修改梯度、分析梯度等。通过注册钩子函数,可以在模型的前向传播和反向传播过程中,对张量的值或梯度进行监控、记录和分析,以实现一些特定的需求,如可视化、调试、梯度修正等。
Block中还调用了Mlp
类:
class Mlp(nn.Module):""" MLP as used in Vision Transformer, MLP-Mixer and related networks原论文说的:The MLP contains two layers with a GELU non-linearity."""def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()# in_features = 768 ,hidden_features = 768 x 4out_features = out_features or in_featureshidden_features = hidden_features or in_features# 非线性层1self.fc1 = nn.Linear(in_features, hidden_features)# 激活函数,用的是 nn.GELUself.act = act_layer()# # 非线性层2self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return x
2. multimodal mixture of encoder-decoder (MED)
class BLIP_Retrieval(nn.Module):'''忽略部分代码'''self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)'''之前第一点把visual_encoder 相关代码做记录了,现在往后看代码'''# Tokenizer(分词器)用于将输入的文本分割成单词、子词或字符等更小的单位,# 以便进行后续的处理和编码self.tokenizer = init_tokenizer() # 导入med的配置文件med_config = BertConfig.from_json_file(med_config)# 更改配置文件:令 MED的encoder_width 等于 vision_width(768,因为选的是base)med_config.encoder_width = vision_width# Text Encoder(文本编码器)是指将分词后的文本转换为向量表示的模型或组件# 文本编码器可以是基于预训练模型的深度神经网络(如 BERT、GPT 等),# 也可以是其他常用的编码模型(如 Word2Vec、GloVe 等),本论文是自定义了 BertModel类:self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
调用了自定义的init_tokenizer
函数,但是主要还是调用了BertTokenizer
,只是额外加入了special_tokens
:
# BERT tokenizer
def init_tokenizer():tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')# 表示解码器的开始位置tokenizer.add_special_tokens({'bos_token':'[DEC]'})# 表示编码器的开始位置tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] return tokenizer
因为还调用了自定义的BertModel
类,附上原论文内容和相关代码:
Image-grounded text encoder
, which injects visual information by inserting one additional cross-attention (CA) layer between the self-attention (SA) layer and the feed forward network (FFN) for each transformer block of the text encoder.(这和原始transformer论文中的decoder一样)。
A task-specific [Encode] token is appended to the text, and the output embedding of [Encode] is used as the multimodal representation of the image-text pair.换句话说,[Encode]" token 的嵌入包含了图像和文本的融合信息,可以作为图像-文本对的表示
Image-grounded text decoder
, which replaces the bidirectional self-attention layers in the image-grounded text encoder with causal self-attention layers. (这意味着解码器在生成序列时只能依赖当前位置之前的信息,不会引入未来信息的依赖)。
A [Decode] token is used to signal the beginning of a sequence, and an end-of-sequence token is used to signal its end.(为了指示序列的开始,使用了一个特殊的 “[Decode]” token,而使用一个特殊的序列结束标记(end-of-sequence token)来标识序列的结束)
class BertModel(BertPreTrainedModel):"""The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer ofcross-attention is added between the self-attention layers, following the architecture described in `Attention is all you need argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an input to the forward pass."""def __init__(self, config, add_pooling_layer=True):super().__init__(config)self.config = configself.embeddings = BertEmbeddings(config)self.encoder = BertEncoder(config)self.pooler = BertPooler(config) if add_pooling_layer else Noneself.init_weights()
BertModel继承自BertPreTrainedModel
,附上代码:
class BertPreTrainedModel(PreTrainedModel):"""An abstract class to handle weights initialization and a simple interface for downloading and loading pretrainedmodels."""config_class = BertConfigbase_model_prefix = "bert"_keys_to_ignore_on_load_missing = [r"position_ids"]def _init_weights(self, module):""" Initialize the weights """if isinstance(module, (nn.Linear, nn.Embedding)):# Slightly different from the TF version which uses truncated_normal for initialization# cf https://github.com/pytorch/pytorch/pull/5617module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)elif isinstance(module, nn.LayerNorm):module.bias.data.zero_()module.weight.data.fill_(1.0)if isinstance(module, nn.Linear) and module.bias is not None:module.bias.data.zero_()
还调用了很多其他的类来一起组成这个BertModel
,首先是BertEmbeddings
:
class BertEmbeddings(nn.Module):"""Construct the embeddings from word and position embeddings."""def __init__(self, config):super().__init__()# pad_token_id": 0self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load# any TensorFlow checkpoint file# config.layer_norm_eps:归一化操作的小数精度self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)self.dropout = nn.Dropout(config.hidden_dropout_prob)# position_ids (1, len position emb) is contiguous in memory and exported when serialized# 注册一个缓冲区(buffer)的张量,并命名为 "position_ids"# 这个缓冲区被注册后,在模型被序列化时,其数据将被导出并保存下来,以便在加载模型时重新使用self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))# 绝对位置编码self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")self.config = configdef forward(self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0):if input_ids is not None:input_shape = input_ids.size()else:# [:-1]:保留除最后一个维度外的所有维度input_shape = inputs_embeds.size()[:-1]seq_length = input_shape[1]if position_ids is None:# position_ids (1, len position emb)# 从 past_key_values_length 开始,到 seq_length + past_key_values_length - 1 结束的位置 做嵌入position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]if inputs_embeds is None:inputs_embeds = self.word_embeddings(input_ids)embeddings = inputs_embedsif self.position_embedding_type == "absolute":position_embeddings = self.position_embeddings(position_ids)embeddings += position_embeddings# word embedding + position embedding + LayerNorm+dropoutembeddings = self.LayerNorm(embeddings)embeddings = self.dropout(embeddings)return embeddings
还有BertEncoder
,由12个BertLayer组成:
class BertEncoder(nn.Module):def __init__(self, config):super().__init__()self.config = config# 由12个BertLayer组成self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])self.gradient_checkpointing = Falsedef forward(self,hidden_states,attention_mask=None,head_mask=None,encoder_hidden_states=None,encoder_attention_mask=None,past_key_values=None,use_cache=None,output_attentions=False,output_hidden_states=False,return_dict=True,mode='multimodal',):all_hidden_states = () if output_hidden_states else Noneall_self_attentions = () if output_attentions else Noneall_cross_attentions = () if output_attentions and self.config.add_cross_attention else Nonenext_decoder_cache = () if use_cache else Nonefor i in range(self.config.num_hidden_layers):layer_module = self.layer[i]if output_hidden_states:all_hidden_states = all_hidden_states + (hidden_states,)layer_head_mask = head_mask[i] if head_mask is not None else Nonepast_key_value = past_key_values[i] if past_key_values is not None else Noneif self.gradient_checkpointing and self.training:if use_cache:logger.warn("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")use_cache = Falsedef create_custom_forward(module):def custom_forward(*inputs):return module(*inputs, past_key_value, output_attentions)return custom_forwardlayer_outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(layer_module),hidden_states,attention_mask,layer_head_mask,encoder_hidden_states,encoder_attention_mask,mode=mode,)else:layer_outputs = layer_module(hidden_states,attention_mask,layer_head_mask,encoder_hidden_states,encoder_attention_mask,past_key_value,output_attentions,mode=mode,)hidden_states = layer_outputs[0]if use_cache:next_decoder_cache += (layer_outputs[-1],)if output_attentions:all_self_attentions = all_self_attentions + (layer_outputs[1],)if output_hidden_states:all_hidden_states = all_hidden_states + (hidden_states,)if not return_dict:return tuple(vfor v in [hidden_states,next_decoder_cache,all_hidden_states,all_self_attentions,all_cross_attentions,]if v is not None)return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states,past_key_values=next_decoder_cache,hidden_states=all_hidden_states,attentions=all_self_attentions,cross_attentions=all_cross_attentions,)
BertEncoder
又调用了 BertLayer
:
class BertLayer(nn.Module):# 1、先经过self-attention、线性层、LayerNorm# 2、再经过cross-attention(可选)、 线性层、LayerNorm# 3、先经过线性层,把 hidden_size 映射到 intermediate_size ,再从intermediate_size 映射到 hidden_size# 4、再经过 intermediate_size 到 hidden_size的映射,再经过LayerNormdef __init__(self, config, layer_num):super().__init__()self.config = configself.chunk_size_feed_forward = config.chunk_size_feed_forwardself.seq_len_dim = 1self.attention = BertAttention(config) self.layer_num = layer_num if self.config.add_cross_attention:self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)self.intermediate = BertIntermediate(config)self.output = BertOutput(config)def forward(self,hidden_states,attention_mask=None,head_mask=None,encoder_hidden_states=None,encoder_attention_mask=None,past_key_value=None,output_attentions=False,mode=None,):# decoder uni-directional self-attention cached key/values tuple is at positions 1,2self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else Noneself_attention_outputs = self.attention(hidden_states,attention_mask,head_mask,output_attentions=output_attentions,past_key_value=self_attn_past_key_value,)attention_output = self_attention_outputs[0]outputs = self_attention_outputs[1:-1]present_key_value = self_attention_outputs[-1]# 多模态模式 需要cross attentionif mode=='multimodal':assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"cross_attention_outputs = self.crossattention(attention_output,attention_mask,head_mask,encoder_hidden_states,encoder_attention_mask,output_attentions=output_attentions,)attention_output = cross_attention_outputs[0]outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward(self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output)outputs = (layer_output,) + outputsoutputs = outputs + (present_key_value,)return outputsdef feed_forward_chunk(self, attention_output):intermediate_output = self.intermediate(attention_output)layer_output = self.output(intermediate_output, attention_output)return layer_output
BertLayer
又调用了BertAttention
和BertIntermediate
、Bertoutput
,先贴上BertAttention
的代码:
class BertAttention(nn.Module):def __init__(self, config, is_cross_attention=False):super().__init__()self.self = BertSelfAttention(config, is_cross_attention)self.output = BertSelfOutput(config)self.pruned_heads = set()# 剪枝注意力头# 剪枝是一种模型压缩技术,用于减少模型的大小和计算开销# 该方法接受一个heads参数,表示要剪枝的注意力头的索引列表def prune_heads(self, heads):if len(heads) == 0:returnheads, index = find_pruneable_heads_and_indices(heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads)# Prune linear layersself.self.query = prune_linear_layer(self.self.query, index)self.self.key = prune_linear_layer(self.self.key, index)self.self.value = prune_linear_layer(self.self.value, index)self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)# Update hyper params and store pruned headsself.self.num_attention_heads = self.self.num_attention_heads - len(heads)self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_headsself.pruned_heads = self.pruned_heads.union(heads)def forward(self,hidden_states,attention_mask=None,head_mask=None,encoder_hidden_states=None,encoder_attention_mask=None,past_key_value=None,output_attentions=False,):# self是自注意力层self_outputs = self.self(hidden_states,attention_mask,head_mask,encoder_hidden_states,encoder_attention_mask,past_key_value,output_attentions,)# 再经过output层attention_output = self.output(self_outputs[0], hidden_states)# self_outputs[1:]:attention_probs、past_key_valueoutputs = (attention_output,) + self_outputs[1:] # add attentions if we output themreturn outputs # 返回:attention_output、attention_probs、past_key_value
BertAttention
又分别调用了 BertSelfAttention
和BertSelfOutput
,
先贴上 BertSelfAttention
代码,通过 is_cross_attention
来决定 是原始transformer论文中的encoder还是decoder,如果为true,则输入是 encoder_hidden_states
,也就相当于Transformer中的decoder,如果是false,输入是hidden_states
,则相当于Transformer中的encoder。
class BertSelfAttention(nn.Module):def __init__(self, config, is_cross_attention):super().__init__()self.config = configif config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):raise ValueError("The hidden size (%d) is not a multiple of the number of attention ""heads (%d)" % (config.hidden_size, config.num_attention_heads))# config.num_attention_heads = 12self.num_attention_heads = config.num_attention_heads# 768/12 = 64self.attention_head_size = int(config.hidden_size / config.num_attention_heads)# self.all_head_size = 12 * 64 = 768self.all_head_size = self.num_attention_heads * self.attention_head_size# 线性层:768 -> 768self.query = nn.Linear(config.hidden_size, self.all_head_size)# 如果设置了 需要 cross_attentionif is_cross_attention:# config.encoder_width:768self.key = nn.Linear(config.encoder_width, self.all_head_size)self.value = nn.Linear(config.encoder_width, self.all_head_size)else:# config.hidden_size = 768self.key = nn.Linear(config.hidden_size, self.all_head_size)self.value = nn.Linear(config.hidden_size, self.all_head_size)self.dropout = nn.Dropout(config.attention_probs_dropout_prob)self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":self.max_position_embeddings = config.max_position_embeddingsself.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)self.save_attention = False '''省略部分代码'''# 改变形状def transpose_for_scores(self, x):new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)x = x.view(*new_x_shape)return x.permute(0, 2, 1, 3)def forward(self,hidden_states,attention_mask=None,head_mask=None,encoder_hidden_states=None,encoder_attention_mask=None,past_key_value=None,output_attentions=False,):mixed_query_layer = self.query(hidden_states)# If this is instantiated as a cross-attention module, the keys# and values come from an encoder; the attention mask needs to be# such that the encoder's padding tokens are not attended to.is_cross_attention = encoder_hidden_states is not Noneif is_cross_attention:key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))attention_mask = encoder_attention_maskelif past_key_value is not None:key_layer = self.transpose_for_scores(self.key(hidden_states))value_layer = self.transpose_for_scores(self.value(hidden_states))key_layer = torch.cat([past_key_value[0], key_layer], dim=2)value_layer = torch.cat([past_key_value[1], value_layer], dim=2)else:key_layer = self.transpose_for_scores(self.key(hidden_states))value_layer = self.transpose_for_scores(self.value(hidden_states))query_layer = self.transpose_for_scores(mixed_query_layer)past_key_value = (key_layer, value_layer)# Take the dot product between "query" and "key" to get the raw attention scores.attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":seq_length = hidden_states.size()[1]position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)distance = position_ids_l - position_ids_rpositional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibilityif self.position_embedding_type == "relative_key":relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)attention_scores = attention_scores + relative_position_scoreselif self.position_embedding_type == "relative_key_query":relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_keyattention_scores = attention_scores / math.sqrt(self.attention_head_size)if attention_mask is not None:# Apply the attention mask is (precomputed for all layers in BertModel forward() function)attention_scores = attention_scores + attention_mask# Normalize the attention scores to probabilities.attention_probs = nn.Softmax(dim=-1)(attention_scores)if is_cross_attention and self.save_attention:self.save_attention_map(attention_probs)attention_probs.register_hook(self.save_attn_gradients) # This is actually dropping out entire tokens to attend to, which might# seem a bit unusual, but is taken from the original Transformer paper.attention_probs_dropped = self.dropout(attention_probs)# Mask heads if we want toif head_mask is not None:attention_probs_dropped = attention_probs_dropped * head_maskcontext_layer = torch.matmul(attention_probs_dropped, value_layer)context_layer = context_layer.permute(0, 2, 1, 3).contiguous()new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)context_layer = context_layer.view(*new_context_layer_shape)outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)outputs = outputs + (past_key_value,)return outputs
再贴上BertSelfOutput
的代码:
class BertSelfOutput(nn.Module):def __init__(self, config):super().__init__()self.dense = nn.Linear(config.hidden_size, config.hidden_size)self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)self.dropout = nn.Dropout(config.hidden_dropout_prob)def forward(self, hidden_states, input_tensor):# 全联接层 + dropout +LayerNorm+残差链接hidden_states = self.dense(hidden_states)hidden_states = self.dropout(hidden_states)# 残差链接hidden_states = self.LayerNorm(hidden_states + input_tensor)return hidden_states
附上BertIntermediate
的代码:
class BertIntermediate(nn.Module):def __init__(self, config):super().__init__()# 先经过线性层,把 hidden_size 映射到 intermediate_sizeself.dense = nn.Linear(config.hidden_size, config.intermediate_size)# 再经过激活层if isinstance(config.hidden_act, str):self.intermediate_act_fn = ACT2FN[config.hidden_act]else:self.intermediate_act_fn = config.hidden_actdef forward(self, hidden_states):hidden_states = self.dense(hidden_states)hidden_states = self.intermediate_act_fn(hidden_states)# 得到中间层的隐藏状态return hidden_states
再贴上BertOutput
的代码:
class BertOutput(nn.Module):def __init__(self, config):super().__init__()self.dense = nn.Linear(config.intermediate_size, config.hidden_size)self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)self.dropout = nn.Dropout(config.hidden_dropout_prob)def forward(self, hidden_states, input_tensor):# 全连接层 + dropou+LayerNorm+残差链接hidden_states = self.dense(hidden_states)hidden_states = self.dropout(hidden_states)hidden_states = self.LayerNorm(hidden_states + input_tensor)return hidden_states
总结:
BertEncoder 调用了 BertLayer, BertLayer调用了 BertAttention、BertIntermediate、BertOutput,其中BertAttention 又调用了 BertSelfAttention、BertOutput
再回到BertModel
这个类:
BertPooler
的代码:
class BertPooler(nn.Module):def __init__(self, config):super().__init__()self.dense = nn.Linear(config.hidden_size, config.hidden_size)self.activation = nn.Tanh()# 只对第一个token做poolingdef forward(self, hidden_states):# We "pool" the model by simply taking the hidden state corresponding# to the first token.first_token_tensor = hidden_states[:, 0]pooled_output = self.dense(first_token_tensor)pooled_output = self.activation(pooled_output)return pooled_output