LeRobot源码剖析——对机器人各个动作策略的统一封装:包含ALOHA ACT、Diffusion Policy、VLA模型π0
前言
过去2年多的深入超过此前7年,全靠夜以继日的勤奋,一天当两天用,抠论文 抠代码 和大模型及具身同事讨论,是目前日常
而具身库里,idp3、π0、lerobot值得反复研究,故,近期我一直在抠π0及lerobot的源码
- 本文一开始是此文《LeRobot——Hugging Face打造的机器人开源库:包含对顶层script、与底层基础层dataset的源码分析》的第四部分 策略,考虑到为避免该文的篇幅过长,故把该文的第四部分 策略独立出来,成本文
- 且如此文所述
LeRobot 项目采用了模块化设计,整个项目大概分为以下三层
最顶层:scripts (命令行工具),属于发号施令层
中间层:common/policies (控制策略),属于中间执行层
底层支撑层:common/datasets, common/envs, common/robot_devices
基础配置层: configs, common/utils
属于底层基础层
对于lerobot中的common/policies:策略实现,其包含以下策略
- act:Action Chunking Transformer 策略
- diffusion:扩散策略
- tdmpc:时序差分模型预测控制
- vqbet:向量量化行为变换器
- pi0:基础策略实现
第一部分 封装的ALOHA ACT策略
如本博客中的此文《一文通透动作分块算法ACT:斯坦福ALOHA团队推出的动作序列预测算法(Action Chunking with Transformers)》所述
下图左侧是CVAE编码器——包含一个transformer encoder,右侧是CVAE解码器——包含一个transformer encoder和一个transformer decoder)
- 上图左侧的CVAE 编码器(采用类似BERT的transformer编码器实现),其预测样式变量
的分布的均值和方差,该分布被参数化为对角高斯分布
其输入是来自当前关节位置,和来自示范数据集的长度为的目标动作序列,前面再加上一个习得的类似于BERT中的“[CLS]”token,从而形成了一个
长度的输入
通过编码器之后,使用“[CLS]”对应的特征用于预测“风格变量”的均值和方差,这相当于CVAE 编码器的输出(当然,其同时也是CVAE解码器的输入之一)
- 上图右侧的CVAE解码器(即策略),通过
和当前观测(当前观测包括图像cam1~cam4、机器人关节位置joints)的条件来预测动作序列(即接下来的k个动作)
且他们使用ResNet图像编码器、transformer encoder,和transformer decoder来实现CVAE解码器
ACT模型的核心思想是同时预测一系列未来动作(称为"动作块"),而不是传统方法中单步预测动作。这种设计使机器人能够表现出更连贯、更具前瞻性的行为模式,特别适合需要精确协调的复杂任务
1.1 policies/act/modeling_act.py
`ACTPolicy`类继承自`PreTrainedPolicy`,作为用户接口层,负责输入/输出归一化、动作选择和训练过程管理
它包含两种关键的动作选择机制:
- 一种是简单地维护预测动作的队列
- 另一种是使用`ACTTemporalEnsembler`进行时序集成,通过加权平均多次预测结果来提高稳定性
时序集成器使用指数权重函数(`w_i = exp(-temporal_ensemble_coeff * i)`),可以调整对新旧预测的重视程度
底层神经网络`ACT`类采用多模态Transformer架构,包括:
- 可选的变分自编码器(VAE)编码器,用于在训练时捕获动作空间的潜在分布
- 基于ResNet的视觉骨干网络,用于提取图像特征
- Transformer编码器,处理来自不同输入模态(潜变量、机器人状态、环境状态、图像特征)的标记
- Transformer解码器,通过交叉注意力机制整合编码器信息并生成动作序列
- 动作回归头,将解码器输出转换为具体的控制信号
- 位置编码在整个架构中起着关键作用,包括一维和二维的正弦位置编码,使模型能够处理序列和空间信息
模型支持两种训练方式:使用变分目标(带KL散度正则化)或直接使用L1损失
1.1.1 ACTPolicy类
1.1.2 ACTTemporalEnsembler类
1.1.3 ACT类
如代码中的ASCII图所示
TransformerUsed alone for inference(acts as VAE decoderduring training)┌───────────────────────┐│ Outputs ││ ▲ ││ ┌─────►┌───────┐ │┌──────┐ │ │ │Transf.│ ││ │ │ ├─────►│decoder│ │┌────┴────┐ │ │ │ │ │ ││ │ │ │ ┌───┴───┬─►│ │ ││ VAE │ │ │ │ │ └───────┘ ││ encoder │ │ │ │Transf.│ ││ │ │ │ │encoder│ │└───▲─────┘ │ │ │ │ ││ │ │ └▲──▲─▲─┘ ││ │ │ │ │ │ │inputs └─────┼──┘ │ image emb. ││ state emb. │└───────────────────────┘
整体结构包含三个主要组件:
- 用于捕获动作分布的VAE编码器(训练时使用)
- 处理多模态观察的Transformer编码器
- 以及生成动作序列的Transformer解码器
1.3.1.1 __init__方法的实现
初始化方法构建了一个由多个精心设计的组件组成的网络:
- 首先是可选的变分自编码器(VAE)部分,它采用BERT风格的设计,以CLS标记、机器人状态和动作序列作为输入,通过编码过程捕获动作分布的潜在表示
VAE编码器使用固定的正弦位置编码和多层投影来处理不同类型的输入def __init__(self, config: ACTConfig):# 初始化父类nn.Modulesuper().__init__() # 存储配置参数 self.config = config # 如果启用VAE模式if self.config.use_vae: # 创建VAE编码器self.vae_encoder = ACTEncoder(config, is_vae_encoder=True) # 创建分类标记嵌入层,只有1个标记self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) # 为机器人关节状态创建投影层,将其映射到隐藏维度# 如果提供了机器人状态特征if self.config.robot_state_feature: # 从原始维度映射到模型维度self.vae_encoder_robot_state_input_proj = nn.Linear(self.config.robot_state_feature.shape[0], config.dim_model )# 为动作(关节空间目标)创建投影层,将其映射到隐藏维度self.vae_encoder_action_input_proj = nn.Linear(# 动作特征的原始维度self.config.action_feature.shape[0], # 映射到模型维度config.dim_model, )
当`use_vae`设为False时,这部分会被完全跳过,模型将使用全零向量作为潜变量# 从VAE编码器的输出创建到潜在分布参数空间的投影层(输出均值和方差)# *2是因为需要输出均值和方差self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2) # 为VAE编码器的输入创建固定的正弦位置嵌入,为批次维度添加一个维度# *2是因为需要输出均值和方差num_input_token_encoder = 1 + config.chunk_size Ç# 如果有机器人状态,则增加一个标记if self.config.robot_state_feature: num_input_token_encoder += 1# 注册一个不需要梯度的缓冲区self.register_buffer( # 缓冲区名称"vae_encoder_pos_enc", # 创建正弦位置编码并扩展批次维度create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0), )
- 视觉处理采用配置化的预训练骨干网络(通常是ResNet),通过`IntermediateLayerGetter`提取深层特征
这种设计使模型能够处理原始相机输入,而不需要手工设计的特征提取器# 用于图像特征提取的骨干网络# 如果使用图像特征if self.config.image_features: # 从torchvision.models获取指定的骨干网络backbone_model = getattr(torchvision.models, config.vision_backbone)( # 控制是否使用空洞卷积replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], # 使用预训练权重weights=config.pretrained_backbone_weights, # 使用冻结的批量归一化层(不更新统计信息)norm_layer=FrozenBatchNorm2d, )# 注意:这里假设我们使用的是ResNet模型(因此layer4是最终特征图)# 注意:这个forward方法返回一个字典:{"feature_map": output}self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) # 创建一个获取中间层输出的包装器
- 核心Transformer结构包含编码器和解码器
# Transformer(在使用变分目标训练时充当VAE解码器)self.encoder = ACTEncoder(config) # 创建Transformer编码器self.decoder = ACTDecoder(config) # 创建Transformer解码器
前者处理包括潜变量、机器人状态(即机器人的关节角度joints等状态信息)、环境状态和图像特征在内的多模态输入
后者通过交叉注意力机制生成动作序列# Transformer编码器输入投影。标记将被结构化为# [latent, (robot_state), (env_state), (image_feature_map_pixels)]# 从骨干网络最后一层特征数到模型维度if self.config.robot_state_feature: # 为机器人状态创建投影层self.encoder_robot_state_input_proj = nn.Linear( # 从原始维度映射到模型维度self.config.robot_state_feature.shape[0], config.dim_model )# 如果使用环境状态特征if self.config.env_state_feature: # 为环境状态创建投影层self.encoder_env_state_input_proj = nn.Linear( # 从原始维度映射到模型维度self.config.env_state_feature.shape[0], config.dim_model )# 为潜在向量创建投影层self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model) # 如果使用图像特征if self.config.image_features: # 为图像特征创建1x1卷积投影层self.encoder_img_feat_input_proj = nn.Conv2d( # 从骨干网络最后一层特征数到模型维度backbone_model.fc.in_features, config.dim_model, kernel_size=1 )
特别值得注意的是位置编码的处理:一维特征使用简单的嵌入层,而图像特征使用复杂的二维正弦位置编码(通过`ACTSinusoidalPositionEmbedding2d`实现),确保模型能够理解空间关系# Transformer解码器# 为transformer的解码器创建可学习的位置嵌入(类似于DETR的对象查询)# 为每个动作块位置创建嵌入self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model) # 在transformer解码器输出上的最终动作回归头# 从模型维度映射到动作维度self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0]) # 重置模型参数self._reset_parameters()
# Transformer编码器位置嵌入# 为潜在向量预留1个标记n_1d_tokens = 1 # 如果有机器人状态,则增加一个标记if self.config.robot_state_feature: n_1d_tokens += 1# 如果有环境状态,则增加一个标记if self.config.env_state_feature: n_1d_tokens += 1# 为一维特征创建位置嵌入self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model) # 如果使用图像特征if self.config.image_features: # 创建二维正弦位置嵌入self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
该架构的模块化设计使其能够适应不同的任务需求:它可以处理多摄像头输入、不同的状态表示,并且可以通过配置参数调整如块大小、层数、头数等性能关键因素。最终,通过动作回归头,模型将解码器的输出映射为具体的控制信号,形成一个完整的感知-决策-控制流程,使机器人能够执行连贯、前瞻性的动作序列
1.3.1.2 _reset_parameters的实现
对于视觉处理,模型使用预训练的ResNet骨干网络(可配置)提取特征,并支持多摄像头输入
def _reset_parameters(self):"""Xavier-uniform initialization of the transformer parameters as in the original code."""# 遍历编码器和解码器的所有参数for p in chain(self.encoder.parameters(), self.decoder.parameters()): # 如果参数维度大于1(通常是权重矩阵if p.dim() > 1: )# 使用Xavier均匀初始化nn.init.xavier_uniform_(p)
1.3.1.3 forward方法的实现
前向传播流程清晰分明:可选的VAE编码阶段(仅用于训练)、输入准备阶段、Transformer编码-解码阶段和输出阶段
- Transformer部分的设计特别注重处理多模态输入和位置编码
编码器处理包括潜在向量、机器人状态、环境状态和图像特征的标记序列,每种输入都有相应的投影层将其映射到共同的嵌入维度
位置编码同样精心设计,包括一维序列的正弦位置编码和图像特征的二维正弦位置编码 - 解码器则使用可学习的位置嵌入(类似DETR的对象查询)和交叉注意力机制从编码器输出生成动作序列
具体而言,方法首先处理批次大小确定,并根据配置和运行模式决定如何准备潜在向量
- 当启用VAE且处于训练模式时,它构建一个BERT风格的输入序列——如下图左下角所示,以CLS标记开始,后跟机器人状态(如果配置),最后是动作序列
- 这些输入经过嵌入层投影到统一维度空间,并添加正弦位置编码以保留序列顺序信息
经过VAE编码器处理后,如上图右上角所示,CLS标记的输出被用来生成潜在空间分布参数(均值和对数方差),最后通过重参数化技巧(mu + exp(log_sigma/2) * 随机噪声)采样得到潜在向量
这是VAE训练的关键步骤,确保梯度可以通过随机采样过程反向传播# 将cls标记输出投影为潜在分布参数latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) # 前半部分为均值参数mu = latent_pdf_params[:, : self.config.latent_dim] # 后半部分为对数方差参数,这是2*log(sigma),这样做是为了匹配原始实现log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :] # 使用重参数化技巧采样潜在变量,mu + exp(log_sigma/2)*噪声latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
若不使用VAE,则简单地使用全零向量作为潜在表示
接下来的多模态融合阶段展示了处理异构数据的精妙设计
- 方法首先准备Transformer编码器的输入「接收包含多模态输入(机器人状态、环境状态和/或摄像头图像)的批次数据」:
从投影后的潜在向量开始
根据配置添加机器人状态和环境状态标记# 准备transformer编码器的输入,首先添加投影后的潜在变量encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)] # 准备一维特征的位置嵌入encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
对于图像处理,它遍历每个摄像头视角,通过ResNet骨干网络提取特征# 机器人状态标记,如果配置包含机器人状态特征if self.config.robot_state_feature: # 添加投影后的机器人状态 encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"])) # 环境状态标记,如果配置包含环境状态特征if self.config.env_state_feature: # 添加投影后的环境状态encoder_in_tokens.append( self.encoder_env_state_input_proj(batch["observation.environment_state"]))
添加二维位置编码,然后将所有特征拼接并重排为序列形式。这种设计允许模型无缝地整合来自不同来源的信息# 相机观察特征和位置嵌入,如果配置包含图像特征if self.config.image_features: # 用于存储所有相机的特征all_cam_features = [] # 用于存储所有相机特征的位置嵌入all_cam_pos_embeds = [] # 遍历每个相机for cam_index in range(batch["observation.images"].shape[-4]): # 通过骨干网络提取特征cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"] # 生成2D位置嵌入并转换为与特征相同的数据类型,(B, C, h, w) , 将特征投影到模型维度cam_features = self.encoder_img_feat_input_proj(cam_features) # 添加到所有相机特征列表 all_cam_features.append(cam_features) # 添加到所有相机位置嵌入列表all_cam_pos_embeds.append(cam_pos_embed)
- 最后的Transformer处理阶段将所有准备好的标记和位置编码输入编码器,产生上下文化的表示
解码器以全零输入开始,通过交叉注意力机制关注编码器输出的相关部分,生成动作表示序列
最终通过线性层将这些表示映射为具体的动作向量
1.1.4 ACTEncoder类和ACTEncoderLayer类
1.1.5 ACTDecoder类和ACTDecoderLayer类
1.1.6 ACTSinusoidalPositionEmbedding2d类
1.2 policies/act/configuration_act.py
// 待更
第二部分 封装的Diffusion Policy
// 待更
第三部分 封装的pi0:涉及配置、模型训练/推理、attention优化等
该模块主要包含以下组件
- 转换工具 (conversion_scripts/)
包含将 pi0 模型转换为 HuggingFace 格式的脚本
提供了与 JAX 实现进行对比的工具
包含性能基准测试脚本 - 配置系统 (configuration_pi0.py)
定义了 `PI0Config` 类,继承自 `PreTrainedConfig`
配置了模型的输入/输出结构、归一化映射、图像预处理参数
支持特定的机器人配置,例如针对 Aloha 机器人的适配
包含训练相关的参数设置,如学习率、权重衰减等 - 注意力机制优化 (flex_attention.py)
提供了基于 PyTorch 的灵活注意力机制实现
针对 PyTorch 2.5.0 及以上版本的优化
支持分组查询注意力(GQA)以提高效率 - 核心模型实现 (modeling_pi0.py)
实现了 `PI0Policy` 类,封装了训练和推理功能
实现了 `PI0FlowMatching` 类,这是核心的流匹配模型
包含对机器人电机角度的转换处理,尤其是针对 Aloha 机器人的特殊处理 - paligemma_with_expert.py
可能马上就有同学疑问了,那这个模块和π0的官方实现库——π0官方库的实现详见此文《π0源码剖析——从π0模型架构的实现(如何基于PaLI-Gemma和扩散策略去噪生成动作),到基于C/S架构下的模型训练与部署》的分析,有何区别或不同呢?
- 实现语言和框架差异
openpi: 使用 JAX 框架实现,这是一个为高性能数值计算设计的库
lerobot/pi0: 使用 PyTorch 框架实现,是 JAX 版本的移植版本
包括从代码注释中也可以明确看到:"Designed by Physical Intelligence. Ported from Jax by Hugging Face",表明 lerobot 中的实现是由 Hugging Face 团队将原始 JAX 代码移植到 PyTorch - 集成与生态系统
openpi: 作为独立库存在,专注于 π0 模型本身
lerobot/pi0: 集成到更大的 LeRobot 框架中,遵循 LeRobot 的设计模式和接口标准
例如,lerobot/pi0 实现中的 `PI0Policy` 类继承自 LeRobot 的 `PreTrainedPolicy` 接口,这使它能够与整个 LeRobot 框架的数据处理、训练和评估流程无缝集成
当然了,π0官方库本身也提供了类似「将Libero数据集转换为LeRobot数据集」的脚本 - 多模态模型整合与加速模型推理
openpi: 可能需要手动配置与外部模型的交互
lerobot/pi0 中实现了一个特殊的 `PaliGemmaWithExpertModel` 类,用于整合 PaliGemma 多模态模型与 Gemma 专家模型
且lerobot 实现包含了针对 PyTorch 的优化,如灵活注意力机制 (`flex_attention.py`),用于加速模型推理——实现了KV cache
支持不同的注意力实现方式 (eager、fa2、flex),可以根据硬件和性能需求进行选择 - 权重转换机制
lerobot/pi0 包含专门的转换脚本 (`conversion_scripts/convert_pi0_to_hf_lerobot.py`),用于将原始 JAX 模型权重转换为 PyTorch 格式
这显示 lerobot 的实现是基于原始模型的移植,而不是独立实现 - 特有的适配性扩展
lerobot/pi0 添加了一些针对特定机器人硬件的适配功能,这些在原始 openpi 实现中可能不存在或实现方式不同:
Aloha 机器人适配: 通过 `adapt_to_pi_aloha` 参数配置,提供了专门处理 Aloha 机器人关节角度和夹爪位置的转换函数
空相机支持: 通过 `empty_cameras` 参数支持额外的空相机输入,用于模拟缺失的摄像头视角 - 接口更简洁、使用更简单
lerobot 版本提供了更简洁的接口,例如:# 使用预训练模型 policy = Pi0Policy.from_pretrained("lerobot/pi0")# 微调模型 python lerobot/scripts/train.py \ --policy.path=lerobot/pi0 \ --dataset.repo_id=danaaubakirova/koch_test
总之,lerobot/common/policies/pi0 本质上是 openpi 官方 JAX 实现的 PyTorch 移植版本,由 Hugging Face 团队开发,专门适配 LeRobot 框架。这个移植版本保持了原始算法的核心功能,同时添加了适配性扩展和针对pytorch的优化,使其能够更好地适应 LeRobot 生态系统和更广泛的机器人硬件
两者最根本的区别在于实现语言(JAX vs. PyTorch),和集成框架(独立库 vs. LeRobot 框架组件)
3.1 转换conversion_scripts:把JAX 实现的 π0 转换为 PyTorch 格式
在conversion_scripts目录中,主要有以下4个文件:
- benchmark.py
- compare_with_jax.py
- conversion_utils.py
- convert_pi0_to_hf_lerobot.py
conversion_scripts 模块的主要目的是将 Physical Intelligence 公司开发的原始 JAX 实现的 π0 模型转换为 PyTorch 格式,以便在 LeRobot 框架中使用
从代码中可以确认
- 脚本支持将三种不同的模型变体转换为 PyTorch 格式:
`pi0_base`: 基础模型
`pi0_aloha_sim`: 适用于 ALOHA 仿真环境的模型,包含空相机支持
`pi0_aloha_towel`: 适用于 ALOHA 真实机器人的模型,支持特殊的关节角度转换 - 原始 JAX π0 模型和转换后的 PyTorch 实现都使用了 Gemma 模型作为动作专家,而不是简单的 MLP 结构。这一点在 conversion_utils.py 中的 `get_gemma_config()` 函数中得到了体现,该函数配置了一个 18 层、1024 隐藏单元的 Gemma 模型
3.1.1 核心实现convert_pi0_to_hf_lerobot.py:将JAX格式的π0模型权重转换为PyTorch格式
这是核心转换脚本,负责将原始 JAX/Orbax 格式的 π0 模型权重转换为 PyTorch/HuggingFace 格式。主要功能包括:
转换流程
- 从 Orbax 检查点加载 JAX 格式的模型权重
- 提取 PaliGemma 视觉编码器和语言模型的权重
- 提取 Gemma 动作专家模型的权重
- 提取线性投影层的权重
- 重新映射权重以匹配 PyTorch 模型的结构
- 根据目标模型类型(pi0_base、pi0_aloha_sim、pi0_aloha_towel)应用不同的配置
- 保存为 HuggingFace 兼容格式
核心转换工作在`slice_paligemma_state_dict`和`slice_gemma_state_dict`函数中完成。这些函数执行精细的参数映射,处理各种Transformer组件(注意力层、MLP、层归一化等)的权重和偏置。每个函数都需要处理大量的张量重塑、转置和重组操作,以保持模型架构的语义等价性。例如,注意力层的查询、键和值投影矩阵需要特别注意,因为JAX和PyTorch的张量排列约定不同
3.1.1.1 slice_initial_orbax_checkpoint
脚本首先通过Orbax检查点管理器从OCDBT(Orbax CheckPoint Directory-Based Tree)格式加载原始模型参数。它使用`slice_initial_orbax_checkpoint`函数将嵌套的参数树结构扁平化,并分离出PaliGemma参数和投影参数
3.1.1.2 slice_paligemma_state_dict
`slice_paligemma_state_dict`函数处理视觉编码器(基于SigLIP)、多模态投影器和语言模型(Gemma)的前半部分,同时将专家模型的参数分离出来
- 函数首先处理参数命名约定的变体,通过检查状态字典中是否存在`"/value"`后缀来确定参数存储格式
def slice_paligemma_state_dict(state_dict, config): # 定义函数,用于将JAX格式的PaliGemma参数转换为PyTorch格式suffix = "/value" if "img/embedding/kernel/value" in state_dict else "" # 确定参数键值的后缀,根据参数存储格式不同而变化# fmt: off # 关闭代码格式化,保持原格式# patch embeddings # 处理图像补丁嵌入层参数state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose( # 提取并转换补丁嵌入权重,调整维度顺序3, 2, 0, 1)state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}") # 提取补丁嵌入偏置# 处理位置嵌入参数state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape( # 提取位置嵌入权重并重塑形状-1, config.vision_config.hidden_size)
- 随后进行三个主要阶段的处理
第一阶段处理视觉编码器部分。它首先转换补丁嵌入(patch embeddings)和位置嵌入(positional embeddings),调整张量形状和维度顺序以匹配PyTorch模型的期望格式
然后,函数提取全部27层视觉Transformer的参数,包括层归一化(layernorm)、多层感知机(MLP)和多头注意力机制(attention)的权重和偏置。对于每个注意力子层,它都需要进行精确的形状转换和转置操作,确保查询(query)、键(key)、值(value)和输出投影(output projection)矩阵都被正确映射# 提取视觉层参数,基础模型中有27层# 提取第一个层归一化的缩放参数encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}") # 提取第一个层归一化的偏置参数encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}") # 提取第二个层归一化的缩放参数encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}") # 提取第二个层归一化的偏置参数encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}") # 提取MLP第一个全连接层的权重encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}") # 提取MLP第一个全连接层的偏置encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}") # 提取MLP第二个全连接层的权重encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")# 提取MLP第二个全连接层的偏置 encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}") # 提取注意力机制中键投影的权重encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}") # 提取注意力机制中键投影的偏置encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}") # 提取注意力机制中值投影的权重encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}") # 提取注意力机制中值投影的偏置encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}") # 提取注意力机制中查询投影的权重encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}") # 提取注意力机制中查询投影的偏置encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}") # 提取注意力机制中输出投影的权重encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}") # 提取注意力机制中输出投影的偏置encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}")
# 遍历所有视觉层(共27层)for i in range(config.vision_config.num_hidden_layers): # 设置第i层的第一个层归一化权重state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose() # 设置第i层的第一个层归一化偏置state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i] # 设置第i层的第二个层归一化权重state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose() # 设置第i层的第二个层归一化偏置state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i] # 设置第i层MLP的第一个全连接层权重state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose() # 设置第i层MLP的第一个全连接层偏置state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i] # 设置第i层MLP的第二个全连接层权重state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose() # 设置第i层MLP的第二个全连接层偏置state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i] # 设置第i层注意力的键投影权重state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() # 设置第i层注意力的键投影偏置state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) # 设置第i层注意力的值投影权重state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() # 设置第i层注意力的值投影偏置state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) # 设置第i层注意力的查询投影权重state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() # 设置第i层注意力的查询投影偏置state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)# 设置第i层注意力的输出投影权重 state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() # 设置第i层注意力的输出投影偏置state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) # 设置视觉模型最终层归一化的权重state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose() # 设置视觉模型最终层归一化的偏置state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}")
- 第二阶段处理多模态投影器和词嵌入,这是连接视觉和语言模型的关键桥梁。投影器参数需要转置以适应框架间的张量排列差异
# multimodal projector # 处理多模态投影器参数# 设置多模态投影器线性层的权重state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose() # 设置多模态投影器线性层的偏置state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}")
- 第三阶段转换语言模型(Gemma)部分,处理18层Transformer结构。这一部分特别复杂,因为JAX中的einsum表示和PyTorch的线性层表示有很大不同。代码通过复杂的转置和重塑操作将注意力计算的矩阵调整为正确的形状和排列
特别是对查询投影的处理需要进行三次转置和一次重塑,将(8, 2048, 256)的原始形状转换为PyTorch模型中期望的(2048, 2048)形状# 处理文本解码器(Gemma)部分# 提取词嵌入向量embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}") # 设置语言模型词嵌入层的权重state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector # 提取einsum注意力和MLP表示,Gemma-2B中有18层# 提取注意力向量einsum参数llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}") # 提取键值einsum参数llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}") # 提取查询einsum参数llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}") # 提取MLP门控einsum参数llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}") # 提取MLP线性层参数llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}") # TODO verify correctness of layer norm loading # 待办:验证层归一化加载的正确性# 提取注意力前的层归一化参数llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}") # 提取前馈网络前的层归一化参数llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
MLP中的门控投影(gate_proj)、上投影(up_proj)和下投影(down_proj)权重也需要类似的处理# 遍历文本模型的所有层(共18层)for i in range(config.text_config.num_hidden_layers): # 查询einsum参数形状为(8, 2048, 256)# 重塑查询投影权重q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) # 设置第i层查询投影权重state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped # 重塑键投影权重k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() # 设置第i层键投影权重state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped # 重塑值投影权重v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()# 设置第i层值投影权重 state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped # 输出投影处理# 重塑输出投影权重o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) # 设置第i层输出投影权重state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped # mlp layers # 处理MLP层参数# 获取门控投影权重gate_proj_weight = llm_mlp_gating_einsum[i, 0] # 设置第i层MLP门控投影权重state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() # 获取上投影权重up_proj_weight = llm_mlp_gating_einsum[i, 1] # 设置第i层MLP上投影权重state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() # 设置第i层MLP下投影权重state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() # 设置第i层输入层归一化权重state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] # 设置第i层注意力后层归一化权重state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] # 设置语言模型最终归一化层权重state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}") # 设置语言模型输出头权重(与词嵌入共享权重)state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied.
- 最后,函数将参数分为主模型参数和专家模型参数,返回两个分离的状态字典。这种分离允许后续代码分别处理PaliGemma主体和Gemma专家组件,支持PI0模型的混合架构设计
# 恢复代码格式化# 初始化专家模型参数字典expert_dict = {} # 初始化最终状态字典final_state_dict = {} # 遍历状态字典中的所有键值对for key, value in state_dict.items(): # 如果键不在以下列表中(不是专家模型参数)if key not in [ f"llm/final_norm_1/scale{suffix}",f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",f"llm/layers/attn/kv_einsum_1/w{suffix}",f"llm/layers/attn/q_einsum_1/w{suffix}",f"llm/layers/mlp_1/gating_einsum{suffix}",f"llm/layers/mlp_1/linear{suffix}",f"llm/layers/pre_attention_norm_1/scale{suffix}",f"llm/layers/pre_ffw_norm_1/scale{suffix}",]:# 将值转换为PyTorch张量并添加到最终状态字典final_state_dict[key] = torch.from_numpy(value) else:# 将专家模型参数添加到专家字典expert_dict[key] = value # 返回最终状态字典和专家字典return final_state_dict, expert_dict
3.1.1.3 slice_gemma_state_dict
而`slice_gemma_state_dict`函数专门处理Gemma专家模型部分。对于27层视觉编码器和18层语言模型,脚本中的循环分别为每层精确地重映射参数
3.1.1.4 convert_pi0_checkpoint
最后,`convert_pi0_checkpoint`函数整合了所有过程:加载参数、处理投影层权重、处理PaliGemma和Gemma权重、创建适当的模型配置、实例化PI0Policy模型、加载状态字典、转换为指定精度,并保存模型使其与Hugging Face的`from_pretrained`方法兼容
脚本根据检查点路径自动检测是基础模型还是特定于Aloha机器人的变体,并相应地调整配置参数。此外,它支持不同的精度格式(float32、bfloat16、float16),以适应各种硬件和部署场景
3.1.2 conversion_utils.py:为转换提供关键的配置函数
这是一个辅助工具模块,为转换过程提供了关键的配置函数。具体功能包括:
- `get_paligemma_config()`: 创建标准的 PaliGemma 配置对象,设置了图像尺寸、补丁大小以及各种模型参数,如隐藏层大小、注意力头数量等,确保 PyTorch 版本的配置与原始 JAX 模型匹配
- `get_gemma_config()`: 创建 Gemma 动作专家模型的配置对象,指定了隐藏层大小(1024)、层数(18)、注意力头数量(8)等参数
具体而言
- `get_paligemma_config`函数创建了PaliGemma多模态模型的完整配置,它同时包含视觉和文本处理能力
函数首先设置基本的标记配置(如填充标记、开始标记和结束标记的ID),然后定义视觉处理相关参数
视觉部分使用224×224像素的图像输入和14×14像素的补丁大小,产生256个图像标记# 定义函数获取PaliGemma配置,参数precision指定模型精度 def get_paligemma_config(precision: str): # 初始化基本配置字典config = { "image_token_index": None, # 图像标记索引,初始设为None"pad_token_id": 0, # 填充标记ID为0"bos_token_id": 2, # 序列开始标记ID为2"eos_token_id": 1, # 序列结束标记ID为1}
函数为文本处理部分配置了一个18层的Transformer架构,每层有8个注意力头但只有1个键值头(表示使用了分组查询注意力机制,这是Gemma模型的特点),隐藏层维度为2048image_size = 224 # 设置图像大小为224像素(边长)patch_size = 14 # 设置图像patch大小为14像素(边长)# 计算图像patch数量:总像素除以每个patch的像素num_image_tokens = (image_size**2) // (patch_size**2)
视觉编码器被配置为27层,具有16个注意力头,隐藏层维度为1152# 设置图像token索引值config["image_token_index"] = 257152 text_config = { # 定义文本处理部分(语言模型)的配置"vocab_size": 257152, # 词汇表大小"num_hidden_layers": 18, # 隐藏层数量"num_key_value_heads": 1, # 键值头数量(用于分组查询注意力)"head_dim": 256, # 每个注意力头的维度"torch_dtype": precision, # 使用传入的精度参数"hidden_size": 2048, # 隐藏层大小"hidden_activation": "gelu_pytorch_tanh", # 隐藏层激活函数"num_attention_heads": 8, # 注意力头数量"intermediate_size": 16384, # 前馈网络中间层大小"is_encoder_decoder": False, # 不是编码器-解码器架构}
这些精心选择的参数确保了模型能够有效处理图像信息并与文本进行融合# 定义视觉处理部分的配置vision_config = { "torch_dtype": precision, # 使用传入的精度参数"image_size": image_size, # 图像大小"patch_size": patch_size, # patch大小"num_image_tokens": num_image_tokens, # 图像token数量"hidden_size": 1152, # 视觉模型隐藏层大小"intermediate_size": 4304, # 视觉模型中间层大小"num_hidden_layers": 27, # 视觉模型隐藏层数量"num_attention_heads": 16, # 视觉模型注意力头数量"projector_hidden_act": "gelu_fast", # 投影器隐藏层激活函数"vision_use_head": False, # 不使用视觉头}
# 创建最终PaliGemma配置对象final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config) return final_config # 返回配置对象
- 相比之下,`get_gemma_config`函数创建了Gemma专家模型的配置,它共享许多与PaliGemma文本部分相同的结构特征,但隐藏层大小减半至1024,中间层大小也从16384减少到4096
这种设计使Gemma专家模型更加轻量,同时保持足够的表达能力来补充PaliGemma的处理能力# 定义函数获取Gemma配置,参数precision指定模型精度 def get_gemma_config(precision: str): # 初始化基本配置字典config = { "image_token_index": None, # 图像标记索引,初始设为None"pad_token_id": 0, # 填充标记ID为0 "bos_token_id": 2, # 序列开始标记ID为2"eos_token_id": 1, # 序列结束标记ID为1}# 设置图像标记索引值config["image_token_index"] = 257152 # 定义文本处理模型的配置text_config = { "vocab_size": 257152, # 词汇表大小"num_hidden_layers": 18, # 隐藏层数量"num_key_value_heads": 1, # 键值头数量(用于分组查询注意力)"head_dim": 256, # 每个注意力头的维度"torch_dtype": precision, # 使用传入的精度参数"hidden_size": 1024, # 隐藏层大小(注意比PaliGemma的文本部分小一半)"hidden_activation": "gelu_pytorch_tanh", # 隐藏层激活函数"num_attention_heads": 8, # 注意力头数量"intermediate_size": 4096, # 前馈网络中间层大小(比PaliGemma小很多)"is_encoder_decoder": False, # 不是编码器-解码器架构}
final_config = GemmaConfig() # 创建空的Gemma配置对象final_config.update(text_config) # 使用text_config更新配置对象return final_config # 返回配置对象
两个配置函数都接受精度参数(如float32、bfloat16或float16),使模型能够适应不同的硬件和内存需求
3.2 配置configuration_pi0.py:配置VLM PaliGemma和动作专家Gemma 300M
`PI0Config`类是LeRobot项目中π0(PI0)策略模型的核心配置组件。作为一个使用Python的`dataclass`装饰器实现的配置类,它提供了一套全面的参数集,用于定义模型的输入/输出结构、预处理步骤、微调选项以及训练设置
这个类通过`@PreTrainedConfig.register_subclass("pi0")`装饰器注册为可序列化的预训练配置,使其能与LeRobot的模型加载和保存机制无缝集成。
配置类定义了三个主要参数组
- 首先是输入/输出结构参数,包括观察步数(`n_obs_steps`)、处理块大小(`chunk_size`)和动作步数(`n_action_steps`)
# 定义PI0配置类,继承自PreTrainedConfig class PI0Config(PreTrainedConfig): # Input / output structure. # 输入/输出结构配置n_obs_steps: int = 1 # 观察步数,默认为1步chunk_size: int = 50 # 处理块的大小,默认为50n_action_steps: int = 50 # 动作步数,默认为50
- 它还指定了不同输入类型的归一化方式,视觉输入使用恒等映射,而状态和动作数据则进行均值-标准差归一化
# 定义归一化映射字典normalization_mapping: dict[str, NormalizationMode] = field( # 使用lambda函数作为默认值工厂default_factory=lambda: { "VISUAL": NormalizationMode.IDENTITY, # 视觉数据使用恒等映射(不归一化)"STATE": NormalizationMode.MEAN_STD, # 状态数据使用均值-标准差归一化"ACTION": NormalizationMode.MEAN_STD, # 动作数据使用均值-标准差归一化})
- 图像预处理部分配置将所有输入图像调整为224×224像素大小,并支持添加空摄像机视图,这在Aloha仿真环境中用于补充顶部摄像头的视角
# 图像预处理配置resize_imgs_with_padding: tuple[int, int] = (224, 224) # 调整图像大小并填充至224x224像素# 添加空白图像# 用于pi0_aloha_sim,它添加了除顶部相机外的左右手腕空白相机empty_cameras: int = 0 # 空白相机数量,默认为0
此配置还包含了特定于机器人控制的参数。`adapt_to_pi_aloha`参数启用从标准Aloha空间到PI内部运行时使用的空间的转换,而`use_delta_joint_actions_aloha`则控制是否使用相对于当前状态的关节差值,这对于精确的机器人控制至关重要
# 将关节和夹持器值从标准Aloha空间转换为# pi内部运行时使用的空间,该空间用于训练基础模型adapt_to_pi_aloha: bool = False # 是否适应PI Aloha格式,默认为False# 在传递给模型之前,将关节维度转换为相对于当前状态的增量# 夹持器维度将保持绝对值,# 是否使用Aloha的关节动作增量,默认为Falseuse_delta_joint_actions_aloha: bool = False # 分词器配置tokenizer_max_length: int = 48 # 分词器最大长度,默认为48# 投影器配置proj_width: int = 1024 # 投影宽度,默认为1024# 解码配置num_steps: int = 10 # 解码步数,默认为10
模型的注意力机制、微调和训练设置也有详细配置
`attention_implementation`参数支持多种注意力计算实现("eager"、"fa2"或"flex"),而`freeze_vision_encoder`和`train_expert_only`参数允许选择性地冻结模型组件以进行高效的微调
# 注意力机制工具配置# 是否使用缓存,默认为Trueuse_cache: bool = True # 注意力实现方式,默认为"eager",也可以是"fa2"或"flex"attention_implementation: str = "eager" # 微调设置freeze_vision_encoder: bool = True # 是否冻结视觉编码器,默认为Truetrain_expert_only: bool = False # 是否仅训练专家部分,默认为Falsetrain_state_proj: bool = True # 是否训练状态投影,默认为True
训练优化器和学习率调度器的默认设置基于AdamW优化器,并使用余弦衰减与预热的学习率调度策略,这是现代大型预训练模型的常见选择
此外,该类的`__post_init__`方法执行重要的输入验证,确保配置的一致性,例如检查动作步数不超过处理块大小,并验证当前只支持单个观察步骤。它还通过显式的`NotImplementedError`标记
3.3 paligemma_with_expert.py:将PaliGemma与Gemma集成在一起
paligemma_with_expert.py是PI0架构的核心模型类,它巧妙地将PaliGemma视觉-语言模型与Gemma专家语言模型集成在一起,形成了一个强大的多模态推理系统。该类继承自Hugging Face的`PreTrainedModel`,使其能够与Transformers生态系统无缝集成
3.3.1 对旋转位置编码RoPE的简单实现
这个文件首先定义了一个`apply_rope`函数,用于应用旋转位置编码RoPE到输入张量,这是一种在注意力计算中直接编码位置信息的技术
与传统的绝对位置编码不同,RoPE通过在复数域中进行旋转变换,在保持向量内积不变的同时编码相对位置信息
原理讲解详见此文《一文通透位置编码:从标准位置编码、旋转位置编码RoPE到ALiBi、LLaMA 2 Long(含NTK-aware简介)》
- 该函数首先计算输入张量`x`最后维度的一半(`d_half`),因为RoPE基于二维旋转,对嵌入向量的每对元素进行操作
# 定义旋转位置编码(RoPE)函数,接收输入张量、位置张量和最大波长参数 def apply_rope(x, positions, max_wavelength=10_000): """Applies RoPE positions [B, L] to x [B, L, H, D].""" # 将RoPE位置编码应用于输入张量,B是批次大小,L是序列长度,H是头数,D是头维度# 计算头维度的一半,因为RoPE处理时会将每个向量分成两半d_half = x.shape[-1] // 2
- 然后,它获取设备和数据类型信息,并将输入转换为float32以确保计算精度
dtype = x.dtype # 获取输入张量的数据类型x = x.to(torch.float32) # 将输入张量转换为float32类型以确保计算精度
- 接下来,函数计算频率指数`freq_exponents`,它是通过将`2.0/D`(其中D是嵌入维度)乘以一个从0到`d_half-1`的序列得到的。这些指数用于创建时间尺度`timescale`,形成一个几何级数`max_wavelength**freq_exponents`
核心计算步骤是通过将位置值除以相应的时间尺度来获得弧度值`radians`。这种方式使得不同维度的嵌入以不同的频率旋转,低维度旋转缓慢,高维度旋转迅速,从而在不同尺度上捕获位置信息# 计算频率指数,不同维度使用不同频率的旋转freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device) # 计算时间尺度,形成几何级数,低维度旋转慢,高维度旋转快timescale = max_wavelength**freq_exponents # 计算旋转弧度,位置值除以相应的时间尺度radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32) # 扩展弧度张量维度以便于后续计算radians = radians[..., None, :]
- 然后,函数计算这些弧度的正弦和余弦值
sin = torch.sin(radians) # 计算弧度的正弦值cos = torch.cos(radians) # 计算弧度的余弦值
- 最后,函数将嵌入向量沿最后一个维度分为两半,并分别应用旋转变换:
- 前半部分:`x1 * cos - x2 * sin`# 将输入张量沿最后一个维度分成两半x1, x2 = x.split(d_half, dim=-1) # 创建与输入张量相同形状的空张量来存储结果res = torch.empty_like(x) # 应用旋转变换的第一部分:前半部分 = x1*cos - x2*sinres[..., :d_half] = x1 * cos - x2 * sin # 应用旋转变换的第二部分:后半部分 = x2*cos + x1*sinres[..., d_half:] = x2 * cos + x1 * sin
- 后半部分:`x2 * cos + x1 * sin`
这个过程实际上是在二维空间中对向量对执行旋转,旋转角度与位置成正比。这种方法的巧妙之处在于,它使得注意力机制能够自然地感知相对位置(即两个token之间的距离),而不仅仅是绝对位置,这对模型理解序列中的长距离依赖关系和结构关系至关重要
然后定义了两个主要类:`PaliGemmaWithExpertConfig`和`PaliGemmaWithExpertModel`,接下来,分别介绍这两个类的实现
3.3.2 PaliGemmaWithExpertConfig:管理和配置PaliGemmaWithExpertModel
`PaliGemmaWithExpertConfig`类是为`PaliGemmaWithExpertModel`定义配置的类,它继承自Hugging Face的`PretrainedConfig`
该类的作用是管理和配置一个复合模型,该模型由PaliGemma(一个视觉-语言模型)和Gemma专家模型组合而成
这个配置类声明了`model_type`为"PaliGemmaWithExpertModel",并通过`sub_configs`字典定义了两个子配置类型:
- paligemma_config
- gemma_expert_config
它们都使用`AutoConfig`作为基类。这种结构使模型能够独立配置两个组件,同时保持它们在一个统一的框架内
# 定义PaliGemma与专家模型的组合配置类,继承自预训练配置基类
class PaliGemmaWithExpertConfig(PretrainedConfig): # 设置模型类型标识符model_type = "PaliGemmaWithExpertModel" # 定义子配置映射,指定使用AutoConfig处理两个子模型sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
构造函数接受多个参数,其中三个关键控制参数决定了模型的行为方式:
- `freeze_vision_encoder`,默认为True,控制是否冻结视觉编码器参数
- `train_expert_only`,默认为True,决定是否只训练专家模型部分
- `attention_implementation`,默认为"eager",指定使用哪种注意力机制实现(可选值为"eager"、"fa2"或"flex")
def __init__(self,paligemma_config: dict | None = None, # PaliGemma模型的配置字典,可选gemma_expert_config: dict | None = None, # Gemma专家模型的配置字典,可选freeze_vision_encoder: bool = True, # 是否冻结视觉编码器,默认为Truetrain_expert_only: bool = True, # 是否仅训练专家模型部分,默认为Trueattention_implementation: str = "eager", # 注意力机制的实现方式,默认为"eager"**kwargs, # 额外的关键字参数):# 保存是否冻结视觉编码器的设置self.freeze_vision_encoder = freeze_vision_encoder # 保存是否仅训练专家模型的设置self.train_expert_only = train_expert_only # 保存注意力实现方式的设置self.attention_implementation = attention_implementation
此外,对于该构造函数
- 如果没有提供`paligemma_config`,构造函数会创建一个默认配置,这个配置指定了PaliGemma模型的详细参数,包括
词汇表大小(257152)、隐藏层维度(2048)
文本配置(如注意力头数量、隐藏层数)if paligemma_config is None: # 如果没有提供PaliGemma配置# Default config from Pi0 # 使用PI0的默认配置# 从映射中获取PaliGemma配置类并实例化self.paligemma_config = CONFIG_MAPPING["paligemma"]( transformers_version="4.48.1", # Transformers库版本_vocab_size=257152, # 词汇表大小bos_token_id=2, # 开始标记IDeos_token_id=1, # 结束标记IDhidden_size=2048, # 隐藏层大小image_token_index=257152, # 图像标记索引model_type="paligemma", # 模型类型pad_token_id=0, # 填充标记IDprojection_dim=2048, # 投影维度
和视觉配置(如SigLIP视觉模型的参数)# 文本配置text_config={ # 隐藏层激活函数"hidden_activation": "gelu_pytorch_tanh", "hidden_size": 2048, # 隐藏层大小"intermediate_size": 16384, # 中间层大小"model_type": "gemma", # 文本模型类型为gemma"num_attention_heads": 8, # 注意力头数量"num_hidden_layers": 18, # 隐藏层数量"num_image_tokens": 256, # 图像token数量"num_key_value_heads": 1, # 键值头数量(分组注意力)"torch_dtype": "float32", # PyTorch数据类型"vocab_size": 257152, # 词汇表大小},
# 视觉配置vision_config={ "hidden_size": 1152, # 隐藏层大小"intermediate_size": 4304, # 中间层大小"model_type": "siglip_vision_model", # 视觉模型类型为SigLIP"num_attention_heads": 16, # 注意力头数量"num_hidden_layers": 27, # 隐藏层数量"num_image_tokens": 256, # 图像标记数量"patch_size": 14, # 图像块大小"projection_dim": 2048, # 投影维度"projector_hidden_act": "gelu_fast", # 投影器隐藏层激活函数"torch_dtype": "float32", # PyTorch数据类型"vision_use_head": False, # 是否使用视觉头},
- 同样,如果没有提供`gemma_expert_config`,也会创建一个默认的Gemma专家模型配置,配置中包含注意力头参数、隐藏层参数、激活函数选择等关键设置
if gemma_expert_config is None: # 如果没有提供Gemma专家配置# Default config from Pi0 # 使用PI0的默认配置self.gemma_expert_config = CONFIG_MAPPING["gemma"]( # 从映射中获取Gemma配置类并实例化attention_bias=False, # 是否使用注意力偏置attention_dropout=0.0, # 注意力dropout率bos_token_id=2, # 开始tokenIDeos_token_id=1, # 结束token IDhead_dim=256, # 注意力头维度hidden_act="gelu_pytorch_tanh", # 隐藏层激活函数hidden_activation="gelu_pytorch_tanh", # 隐藏层激活函数(冗余)hidden_size=1024, # 隐藏层大小initializer_range=0.02, # 初始化范围intermediate_size=4096, # 中间层大小max_position_embeddings=8192, # 最大位置嵌入数model_type="gemma", # 模型类型num_attention_heads=8, # 注意力头数量num_hidden_layers=18, # 隐藏层数量num_key_value_heads=1, # 键值头数量(分组注意力)pad_token_id=0, # 填充标记IDrms_norm_eps=1e-06, # RMS归一化的epsilon值rope_theta=10000.0, # RoPE位置编码的theta参数torch_dtype="float32", # PyTorch数据类型transformers_version="4.48.1", # Transformers库版本use_cache=True, # 是否使用缓存vocab_size=257152, # 词汇表大小)
最后,在`__post_init__`方法中,配置类执行两项重要的验证:
- 首先检查`train_expert_only`和`freeze_vision_encoder`的设置是否兼容(如果只训练专家模型,则视觉编码器必须被冻结)
- 其次验证`attention_implementation`参数值是否有效。这些验证确保模型配置的一致性,防止训练过程中可能出现的问题
通过这种详细的配置机制,`PaliGemmaWithExpertModel`能够灵活地适应不同的训练和推理需求,同时保持设置的一致性和有效性
3.3.3 PaliGemmaWithExpertModel:分别初始化VLM PaliGemma、Gemma 300M
`PaliGemmaWithExpertModel`是一个结合了PaliGemma视觉-语言模型和Gemma专家语言模型的架构
- 在初始化阶段,模型实例化了PaliGemma和Gemma两个子模型`PaliGemmaForConditionalGeneration`处理视觉和初始语言理解
以及`GemmaForCausalLM`作为专家模型处理后续的推理和生成任务
并移除了不需要的Gemma词嵌入层(因为输入嵌入已由PaliGemma处理)# 定义PaliGemma与专家模型的组合类,继承自PreTrainedModel class PaliGemmaWithExpertModel(PreTrainedModel): config_class = PaliGemmaWithExpertConfig # 指定配置类# 初始化方法,接收配置参数def __init__(self, config: PaliGemmaWithExpertConfig): super().__init__(config=config) # 调用父类初始化方法self.config = config # 保存配置# 实例化PaliGemma模型self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config) # 实例化Gemma专家模型self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
通过`to_bfloat16_like_physical_intelligence`方法,模型将关键组件转换为bfloat16格式,提高计算效率并减少内存占用,同时与原始Physical Intelligence实现保持一致# 移除未使用的词嵌入层,设置为None,因为输入嵌入已由PaliGemma处理self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_like_physical_intelligence() # 将模型转换为bfloat16格式self.set_requires_grad() # 设置各部分是否参与梯度更新
- 该模型实现了灵活的训练控制机制
`set_requires_grad`
和重写的`train`方法def set_requires_grad(self): # 设置模型各部分是否需要梯度if self.config.freeze_vision_encoder: # 如果配置为冻结视觉编码器self.paligemma.vision_tower.eval() # 将视觉塔设置为评估模式for params in self.paligemma.vision_tower.parameters(): # 遍历视觉塔的所有参数params.requires_grad = False # 设置不需要梯度if self.config.train_expert_only: # 如果配置为只训练专家模型self.paligemma.eval() # 将整个PaliGemma设置为评估模式for params in self.paligemma.parameters(): # 遍历PaliGemma的所有参数params.requires_grad = False # 设置不需要梯度
确保即使在训练模式下,冻结的组件(如视觉编码器或整个PaliGemma模型)也保持在评估状态def train(self, mode: bool = True): # 重写train方法,控制训练模式super().train(mode) # 调用父类的train方法if self.config.freeze_vision_encoder: # 如果配置为冻结视觉编码器self.paligemma.vision_tower.eval() # 即使在训练模式下,也将视觉塔设为评估模式if self.config.train_expert_only: # 如果配置为只训练专家模型self.paligemma.eval() # 即使在训练模式下,也将PaliGemma设为评估模式
这种设计使得用户可以根据任务需求和计算资源灵活地选择微调策略,比如仅训练Gemma专家部分而保持视觉-语言基础模型不变 - 模型提供了两个关键的嵌入辅助方法:
`embed_image`将图像转换为特征表示
`embed_language_tokens`将语言token转换为嵌入表示def embed_image(self, image: torch.Tensor): # 图像嵌入方法return self.paligemma.get_image_features(image) # 使用PaliGemma获取图像特征
这些方法为下一节「3.4.3 PI0FlowMatching类的实现:嵌入处理、训练、推理(迭代去噪生成最终动作)」中的`PI0FlowMatching`类的`embed_prefix`功能提供了底层支持# 语言token嵌入方法def embed_language_tokens(self, tokens: torch.Tensor): # 使用PaliGemma语言模型的嵌入层处理tokenreturn self.paligemma.language_model.model.embed_tokens(tokens)
- `forward`方法是一个精心设计的复杂函数,它实现了PaliGemma和Gemma Expert两个模型的联合前向计算过程。正如代码中的TODO注释所示,这确实是一个"巨大的前向函数",但其复杂性是有必要的,因为它实现了两个独立模型在层级上的深度集成
该函数首先准备两个模型列表`models`,并从输入嵌入中获取批次大小
随后,它执行了一个关键的层循环,遍历PaliGemma文本配置中指定的层数。在每一层,函数对两个模型的输入应用相同的处理步骤:层归一化(input_layernorm)、计算查询/键/值投影# 待办:将这个巨大的前向传播方法拆分为模块或函数def forward(self,attention_mask: Optional[torch.Tensor] = None, # 注意力掩码position_ids: Optional[torch.LongTensor] = None, # 位置ID# 过去的键值对缓存past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, # 输入嵌入列表[前缀嵌入, 后缀嵌入]inputs_embeds: List[torch.FloatTensor] = None, use_cache: Optional[bool] = None, # 是否使用缓存fill_kv_cache: Optional[bool] = None, # 是否填充键值缓存):# 定义模型列表,包含PaliGemma语言模型和Gemma专家模型models = [self.paligemma.language_model.model, self.gemma_expert.model]
然后连接并应用旋转位置编码(RoPE)# RMSNorm # RMS归一化处理num_layers = self.paligemma.config.text_config.num_hidden_layers # 获取层数head_dim = self.paligemma.config.text_config.head_dim # 获取注意力头维度for layer_idx in range(num_layers): # 遍历每一层query_states = [] # 初始化查询状态列表key_states = [] # 初始化键状态列表value_states = [] # 初始化值状态列表# 遍历输入嵌入for i, hidden_states in enumerate(inputs_embeds): if hidden_states is None: # 如果隐藏状态为Nonecontinue # 继续下一次循环# 获取当前模型的当前层layer = models[i].layers[layer_idx] # 应用输入层归一化hidden_states = layer.input_layernorm(hidden_states) # 获取输入形状(除去最后一维)input_shape = hidden_states.shape[:-1] # 构建隐藏形状,适合多头注意力hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) # 转换为bfloat16类型hidden_states = hidden_states.to(dtype=torch.bfloat16) query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) # 计算查询状态并重塑key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) # 计算键状态并重塑value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape) # 计算值状态并重塑query_states.append(query_state) # 添加到查询状态列表key_states.append(key_state) # 添加到键状态列表value_states.append(value_state) # 添加到值状态列表# B:批次大小,L:序列长度,H:头数,D:头维度# concatenate on the number of embeddings/tokens # 在嵌入/标记数量维度上连接query_states = torch.cat(query_states, dim=1) # 连接所有查询状态key_states = torch.cat(key_states, dim=1) # 连接所有键状态value_states = torch.cat(value_states, dim=1) # 连接所有值状态
代码中包含了高效的键值缓存机制,这对推理性能至关重要query_states = apply_rope(query_states, position_ids) # 应用RoPE位置编码到查询状态key_states = apply_rope(key_states, position_ids) # 应用RoPE位置编码到键状态
当设置`use_cache=True`时,函数会根据`fill_kv_cache`参数决定是填充新的缓存还是追加到现有缓存。这允许模型在自回归生成过程中重复使用之前计算的键值对,大大减少了计算量
经过RoPE处理后通过选择的注意力实现(由`get_attention_interface`方法确定,以确定"eager"、"fa2"或"flex")计算注意力输出if use_cache and past_key_values is None: # 如果使用缓存且过去的键值对为Nonepast_key_values = {} # 初始化为空字典if use_cache: # 如果使用缓存if fill_kv_cache: # 如果需要填充键值缓存past_key_values[layer_idx] = { # 存储当前层的键值对"key_states": key_states, # 存储键状态"value_states": value_states, # 存储值状态}else: # 如果不填充缓存,则使用已有缓存# # 待办:这里可以进行一些优化# 连接过去和当前的键状态key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1) # 连接过去和当前的值状态value_states = torch.cat( [past_key_values[layer_idx]["value_states"], value_states], dim=1)
attention_interface = self.get_attention_interface() # 获取注意力接口att_output = attention_interface( # 计算注意力输出attention_mask, batch_size, head_dim, query_states, key_states, value_states)att_output = att_output.to(dtype=torch.bfloat16) # 转换为bfloat16类型
插入解释一下这个get_attention_interface方法
其中的flex_attention_forward下下文的3.5节,至于eager_attention_forward下面马上要介绍def get_attention_interface(self):if self.config.attention_implementation == "fa2":// fa2对应flash_attention_forwardattention_interface = self.flash_attention_forwardelif self.config.attention_implementation == "flex":// flex对应于pi0/paligemma_with_expert.py的开头的引入:from lerobot.common.policies.pi0.flex_attention import flex_attention_forwardattention_interface = flex_attention_forwardelse:// 对应下面马上要介绍的eager_attention_forwardattention_interface = self.eager_attention_forwardreturn attention_interface
计算得到的注意力输出被分割并通过输出投影、残差连接和前馈网络(MLP)处理
最后应用最终的层归一化# att_output的第一部分是前缀(直到序列长度)outputs_embeds = [] # 初始化输出嵌入列表start = 0 # 初始化起始索引for i, hidden_states in enumerate(inputs_embeds): # 遍历输入嵌入layer = models[i].layers[layer_idx] # 获取当前模型的当前层if hidden_states is not None: # 如果隐藏状态不为Noneend = start + hidden_states.shape[1] # 计算结束索引# 如果数据类型不匹配if att_output.dtype != layer.self_attn.o_proj.weight.dtype: # 转换数据类型att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) # 应用输出投影out_emb = layer.self_attn.o_proj(att_output[:, start:end]) # 待办:第一个dropout(默认为0.0)# 第一个残差连接out_emb += hidden_states # 克隆第一个残差后的结果 after_first_residual = out_emb.clone() # 应用注意力后的层归一化out_emb = layer.post_attention_layernorm(out_emb) # 应用多层感知机out_emb = layer.mlp(out_emb) # 待办:第二个dropout(默认为0.0)# 添加第二个残差连接out_emb += after_first_residual # 添加到输出嵌入列表outputs_embeds.append(out_emb) start = end # 更新起始索引else: # 如果隐藏状态为Noneoutputs_embeds.append(None) # 添加None到输出嵌入列表inputs_embeds = outputs_embeds # 更新输入嵌入为输出嵌入,准备下一层处理
# 最终归一化outputs_embeds = [] # 初始化最终输出嵌入列表# 遍历输入嵌入for i, hidden_states in enumerate(inputs_embeds): # 如果隐藏状态不为Noneif hidden_states is not None: out_emb = models[i].norm(hidden_states) # 应用最终层归一化outputs_embeds.append(out_emb) # 添加到输出嵌入列表else: outputs_embeds.append(None) # 添加None到输出嵌入列表# 返回输出嵌入和过去的键值对return outputs_embeds, past_key_values
- `eager_attention_forward`方法实现了标准的多头注意力机制,支持分组查询注意力(Grouped Query Attention,允许多个查询头共享相同的键值头,这是Gemma架构的特点)优化
它将查询、键和值向量进行矩阵乘法操作,应用注意力掩码,执行softmax归一化,并计算最终的注意力输出def eager_attention_forward(self, attention_mask, batch_size, head_dim, query_states, key_states, value_states):num_att_heads = self.config.paligemma_config.text_config.num_attention_headsnum_key_value_heads = self.config.paligemma_config.text_config.num_key_value_headsnum_key_value_groups = num_att_heads // num_key_value_heads
3.4 modeling_pi0.py:含模型训练、模型推理(迭代去噪生成动作)
根据本博客的此文《π0——用于通用机器人控制的VLA模型:一套框架控制7种机械臂(基于PaliGemma和流匹配的3B模型)》
可知pi0 模型采用了一个复杂的架构,主要由以下部分组成:
┌──────────────────────────────┐
│ actions │
│ ▲ │
│ ┌┴─────┐ │
│ kv cache │Gemma │ │
│ ┌──────────►│Expert│ │
│ │ │ │ │
│ ┌┴────────┐ │x 10 │ │
│ │ │ └▲──▲──┘ │
│ │PaliGemma│ │ │ │
│ │ │ │ robot state │
│ │ │ noise │
│ └▲──▲─────┘ │
│ │ │ │
│ │ image(s) │
│ language tokens │
└──────────────────────────────┘
该模块依赖于:
- PyTorch 作为基础深度学习框架
- Transformers 库中的 PaliGemma 和 Gemma 模型
- LeRobot 框架中的数据处理和规范化工具
3.4.1 库的导入与几个辅助函数的实现
具体而言,该代码首先导入了必要的库,包括PyTorch和其自定义的模块,如`PaliGemmaWithExpertModel`。文件顶部的文档字符串提供了模型的概述、论文链接、安装说明以及使用示例
代码中定义了几个辅助函数:`create_sinusoidal_pos_embedding`用于生成正弦余弦位置编码向量;`sample_beta`用于生成Beta分布样本;`make_att_2d_masks`用于创建二维注意力掩码;`resize_with_pad`用于调整图像大小并进行填充;`pad_vector`用于向量填充;`normalize`和`unnormalize`用于值的标准化与还原;以及一系列用于机器人抓取器转换的函数
3.4.2 PI0Policy类的实现:将「PI0FlowMatching模型」集成到LeRobot框架中进行训练和推理
`PI0Policy`是一个包装类,用于将下一节的「PI0FlowMatching模型」集成到LeRobot框架中进行训练和推理
相当于PI0Policy类侧重高层抽象与环境的交互,而PI0FlowMatching侧重底层算法底线,当使用模型时,用户主要通过PI0Policy与系统交互,而不需要直接接触PI0FlowMatching的复杂实现细节
该类继承自`PreTrainedPolicy`,提供了一个统一的接口来处理多模态输入(图像、机器人状态、语言指令)并生成机器人动作序列
- 在初始化阶段,`PI0Policy`接收一个配置对象和可选的数据集统计信息,设置了输入输出的归一化处理器,初始化了PaliGemma语言分词器和PI0FlowMatching模型核心。它还创建了一个动作队列,用于高效地管理预测的动作序列
- 该类的`select_action`方法是其核心推理接口,它实现了一个智能的队列机制:当动作队列为空时,它会处理完整的输入批次(包括准备图像、状态和语言指令),然后使用模型一次性生成多步动作序列并填充队列;在每次调用时,它只返回队列中的下一个动作,从而提高执行效率。这种设计特别适合于需要连续动作控制的机器人环境
- 在训练过程中,`forward`方法负责计算损失函数
它首先对输入进行归一化处理,准备好所有模态数据def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:# 执行完整的训练前向传播并计算损失if self.config.adapt_to_pi_aloha: # 如果配置为适配PI-Aloha模型batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) # 对机器人状态观测进行PI-Aloha解码转换batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) # 对动作进行PI-Aloha编码的逆变换
然后调用模型的前向传播函数计算每个步骤和每个电机的损失batch = self.normalize_inputs(batch) # 对输入数据进行归一化处理batch = self.normalize_targets(batch) # 对目标数据进行归一化处理images, img_masks = self.prepare_images(batch) # 准备并处理图像输入及其掩码state = self.prepare_state(batch) # 准备机器人状态数据lang_tokens, lang_masks = self.prepare_language(batch) # 准备语言指令的标记和掩码actions = self.prepare_action(batch) # 准备动作数据actions_is_pad = batch.get("actions_id_pad") # 获取动作填充标识(如果存在)
该方法还实现了智能的损失处理,包括对填充区域的剔除和统计跟踪。loss_dict = {} # 初始化损失追踪字典,用于记录损失计算过程losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) # 调用核心PI0FlowMatching模型计算损失loss_dict["losses_after_forward"] = losses.clone() # 记录模型前向传播后的原始损失
- 该类还包含几个专门的预处理方法:`prepare_images`方法对图像进行调整大小、填充和归一化,以适应SigLIP视觉模型的要求;`prepare_language`方法对文本指令进行分词处理;`prepare_state`和`prepare_action`方法对状态和动作向量进行填充
- 特别值得注意的是适配Aloha系统的方法(`_pi_aloha_decode_state`、`_pi_aloha_encode_actions`等),这些方法通过翻转特定关节和转换抓取器位置,实现了与Aloha系统的兼容,展示了该模型在不同机器人平台间的适应性
3.4.3 PI0FlowMatching类的实现:嵌入处理、训练、推理(迭代去噪生成最终动作)
`PI0FlowMatching`类是π0模型的核心实现,这是一个先进的视觉-语言-动作流模型,专为通用机器人控制而设计。该模型通过融合视觉输入、语言指令和机器人状态来生成精确的机器人动作序列
该类采用了流匹配(Flow Matching)技术,这是一种类似于扩散模型的方法,但具有更高效的训练和采样特性
在初始化阶段,它创建了一个PaliGemmaWithExpertModel实例(将PaliGemma视觉-语言模型与Gemma专家模型结合),并设置了处理状态、动作和时间信息的投影层
类的核心功能分为嵌入处理、训练流程和推理流程三个主要方面
首先是嵌入处理,分为embed_prefix和embed_suffix
`embed_prefix`方法处理模型的前缀输入:图像和语言输入,使用PaliGemma模型将图像嵌入到特征空间,并对语言token进行嵌入,同时创建适当的注意力掩码以允许图像和语言token之间的全面注意力交互
- 首先,该方法通过迭代输入的图像列表,将每个图像传递给`paligemma_with_expert.embed_image`函数,生成图像嵌入。这些嵌入随后被转换为bfloat16数据类型,以优化内存使用和计算效率
接着,方法应用了一个重要的归一化步骤,将图像嵌入乘以嵌入维度的平方根,这是Transformer架构中常用的缩放技术,有助于稳定训练过程和梯度流动
对于每个图像,方法还创建了相应的掩码,来标记哪些位置包含有效的图像内容,这些掩码将在后续的注意力计算中使用 - 对于语言输入,该方法使用`paligemma_with_expert.embed_language_tokens`函数将文本标记转换为嵌入表示,并同样应用了归一化,乘以嵌入维度的平方根。语言嵌入和相应的掩码也被添加到累积列表中
- 在处理完所有输入后,方法创建了注意力掩码(`att_masks`)来控制不同输入组件之间的交互。值得注意的是,图像标记之间以及图像和语言标记之间被设置为完全可以相互关注(值为0),这允许模型充分融合视觉和语言信息。最后,方法将所有嵌入和掩码沿着序列维度(dim=1)连接起来,并对注意力掩码进行适当的扩展,以匹配批次大小
- 返回的三元组包含连接后的嵌入、填充掩码和注意力掩码,这些将作为PaliGemma模型的输入,使其能够处理多模态信息并生成上下文丰富的表示,进而用于后续的机器人动作生成。代码中的TODO注释也表明了未来可能的优化方向,如预分配内存和移除循环以提高性能
`embed_suffix`方法负责处理模型的"后缀"输入——即机器人状态、带噪声的动作和时间步信息,将时间步使用正弦-余弦位置编码表示,并通过一个两层MLP网络融合动作和时间信息
与`embed_prefix`方法处理视觉和语言输入不同,这个方法专注于为Gemma专家模型准备必要的状态和动作表示
- 首先,方法通过线性投影层`state_proj`对机器人状态进行编码,将其转换为bfloat16数据类型以保持计算效率,并添加一个额外的维度使其成为一个单独的标记。对应的掩码被设置为全1,表示这是有效数据。注意力掩码值被设为1,这意味着前缀元素(图像和语言标记)不应关注这个状态标记,从而创建了信息流的单向边界
- 接下来,方法处理时间步信息,使用正弦-余弦位置编码进行嵌入。这种编码技术特别适合表示连续的时间值,通过在不同时间尺度上(从4e-3到4.0的周期范围)使用正弦和余弦函数,创建了一个能够有效区分不同时间点的表示
- 方法还对带噪声的动作应用了线性投影`action_in_proj`
然后,它巧妙地将时间嵌入扩展为与动作嵌入相同的形状,并在特征维度上连接它们
这个组合后的表示经过一个小型的多层感知机(MLP)处理:首先通过`action_time_mlp_in`线性层,然后应用SiLU激活函数(也称为Swish),最后通过`action_time_mlp_out`线性层。这一过程有效地融合了动作和时间信息,创建了上下文感知的表示 - 在注意力掩码的设置上,方法采用了精心设计的模式:第一个动作标记被设为1,表示前缀元素不应关注它;而剩余的动作标记被设为0,允许完全的交叉注意力
这种设计确保了模型中信息的适当流动——状态和初始动作标记作为上下文独立的起始点,而后续的动作标记则能够关注和利用所有可用信息 - 最后,方法将所有嵌入和掩码连接起来,并对注意力掩码进行适当的扩展和格式化,以便在后续的Transformer处理中使用
这种结构化的表示方式是流匹配算法成功运行的关键,使模型能够从噪声动作平滑过渡到目标动作
其次是训练过程
`forward`方法是PI0流匹配模型的核心训练流程,它实现了从多模态输入生成机器人动作的完整前向传播路径,并计算训练损失。这个方法基于流匹配(Flow Matching)技术,这是一种类似于扩散模型但更适合连续动作空间的生成方法
- 首先,该方法确保有可用的噪声和时间参数
如果未提供,它会分别调用`sample_noise`生成标准正态分布噪声和`sample_time`从Beta分布采样时间步(范围在0.001到0.999之间)def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None) -> Tensor:"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
# 如果没有提供噪声,则生成与动作形状相同的标准正态分布噪声if noise is None:noise = self.sample_noise(actions.shape, actions.device)# 如果没有提供时间步,则从Beta分布采样时间(范围在0.001到0.999之间)if time is None:time = self.sample_time(actions.shape[0], actions.device)
- 然后,它执行一个关键的线性插值操作:`x_t = time_expanded * noise + (1 - time_expanded) * actions`,这创建了目标动作的噪声版本,其中时间接近1时更接近纯噪声,接近0时更接近真实动作
同时计算`u_t = noise - actions`,表示从真实动作到噪声的向量场方向# 扩展时间维度以便与动作形状匹配,用于后续广播操作time_expanded = time[:, None, None]# 创建噪声化的动作:时间接近1时更接近噪声,接近0时更接近真实动作x_t = time_expanded * noise + (1 - time_expanded) * actions
如此文《π0源码剖析——从π0模型架构的实现(如何基于PaLI-Gemma和扩散策略去噪生成动作),到基于C/S架构下的模型训练与部署》中「1.2.4.3 损失函数compute_loss:训练模型去噪的准确率」一节所说的:# 计算从真实动作到噪声的向量场方向,这是模型需要学习预测的目标u_t = noise - actions
创建带噪动作序列 x_t,相当于x_t是噪声化的动作,随着时间从0到1,原始动作
逐渐添加真实噪声
,变为纯噪声
而
代表所加的真实噪声,便是咱们所要预测的所添加的噪声
的ground truth
故所添加的噪声即 = 加满噪声的动作
- 原始动作
- 接下来,方法分别调用`embed_prefix`和`embed_suffix`处理输入组件:
前者处理图像和语言token
后者处理机器人状态和噪声化的动作# 处理图像和语言输入,生成前缀嵌入表示和对应的掩码prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
这两个函数返回的嵌入和掩码被连接起来# 处理机器人状态和噪声化动作,生成后缀嵌入表示和对应的掩码suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
并使用`make_att_2d_masks`函数创建二维注意力掩码,控制不同输入元素之间的信息流动# 沿序列维度连接前缀和后缀的填充掩码pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)# 沿序列维度连接前缀和后缀的注意力掩码att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
位置ID通过累积求和填充掩码并减1来生成# 创建二维注意力掩码,控制不同输入元素之间的信息流att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
# 通过累积求和填充掩码并减1来计算位置ID,用于位置编码position_ids = torch.cumsum(pad_masks, dim=1) - 1
- 随后,方法将准备好的输入传递给`paligemma_with_expert`模型进行处理,获取后缀输出(主要是动作表示)
这个输出被裁剪为仅保留对应于动作步骤的部分# 将准备好的输入传递给PaliGemma和Gemma专家模型,获取输出表示(_, suffix_out), _ = self.paligemma_with_expert.forward(attention_mask=att_2d_masks,position_ids=position_ids,past_key_values=None,inputs_embeds=[prefix_embs, suffix_embs],use_cache=False,fill_kv_cache=False,)
转换为float32数据类型# 从输出中提取最后n_action_steps个标记,对应于动作表示suffix_out = suffix_out[:, -self.config.n_action_steps :]
并通过`action_out_proj`投影到动作空间,得到预测的向量场`v_t`# 将输出转换为float32数据类型,保持精度一致性suffix_out = suffix_out.to(dtype=torch.float32)
# 通过线性投影将后缀输出转换为动作向量场预测v_t = self.action_out_proj(suffix_out)
- 最后,方法计算预测向量场`v_t`与真实向量场`u_t`之间的均方误差作为损失函数
这种训练方式使模型学习从任意噪声状态到目标动作的向量场,在推理时可以通过从随机噪声开始,沿着这个向量场逐步前进来生成平滑、精确的动作序列# 计算预测向量场v_t与真实向量场u_t之间的均方误差损失losses = F.mse_loss(u_t, v_t, reduction="none")# 返回逐元素损失张量,供调用者进一步处理return losses
最后是推理:依次sample_actions、denoise_step
首先,`sample_actions`方法是PI0流匹配模型的核心推理函数,负责根据视觉、语言指令和机器人状态生成一系列动作
与训练时的`forward`方法不同,这个方法实现了从随机噪声到有意义的动作序列的生成过程,采用了类似于扩散模型的逐步降噪技术
- 首先,该方法获取批次大小和设备信息
如果未提供噪声,则生成形状为(批次大小, 动作步数, 最大动作维度)的标准正态分布噪声作为起始点def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:# 执行完整的推理前向传播并计算动作(批次大小 x 步骤数 x 电机数)bsize = state.shape[0] # 获取批次大小(从状态tensor的第一维)device = state.device # 获取当前设备(CPU或GPU)
接着,它调用`embed_prefix`处理图像和语言输入,创建嵌入表示和对应的掩码if noise is None: # 如果没有提供噪声actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim) # 创建噪声形状:(批次大小, 动作步数, 最大动作维度)noise = self.sample_noise(actions_shape, device) # 采样标准正态分布噪声
并通过`make_att_2d_masks`函数将其转换为二维注意力掩码,同时计算位置ID# 处理图像和语言输入,生成前缀嵌入及相关掩码prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( images, img_masks, lang_tokens, lang_masks)
# 为前缀创建二维注意力掩码prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) # 计算前缀位置ID(累积和减1)prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
- 一个关键的优化是计算并缓存前缀(图像和语言)输入的键值对
这是通过调用`paligemma_with_expert.forward`并设置`use_cache=True`和`fill_kv_cache=True`实现的
由于前缀输入在整个推理过程中保持不变,这种缓存机制避免了重复计算,显著提高了效率# 计算图像和语言的键值缓存,提高推理效率_, past_key_values = self.paligemma_with_expert.forward(attention_mask=prefix_att_2d_masks, # 设置注意力掩码position_ids=prefix_position_ids, # 设置位置IDpast_key_values=None, # 初始没有过去的键值对inputs_embeds=[prefix_embs, None], # 只传入前缀嵌入(图像和语言)use_cache=self.config.use_cache, # 使用缓存机制fill_kv_cache=True, # 填充键值缓存)
- 然后,方法设置欧拉法数值积分的时间步长`dt`(负值,因为时间从1倒数到0),初始化噪声状态`x_t`,并将时间设置为1.0(表示起始的纯噪声状态)
接下来进入主要的降噪循环,直到时间接近或达到0:# 计算欧拉积分的时间步长(负值,因为从1倒数到0)dt = -1.0 / self.config.num_steps # 转换为tensordt = torch.tensor(dt, dtype=torch.float32, device=device)
1. 将当前时间扩展为与批次大小匹配的张量
2. 调用`denoise_step`方法预测当前状态和时间下的向量场`v_t`——即预测噪声x_t = noise # 初始化噪声状态为纯噪声time = torch.tensor(1.0, dtype=torch.float32, device=device) # 设置初始时间为1.0(表示纯噪声状态)while time >= -dt / 2: # 降噪循环,直到时间接近或达到0expanded_time = time.expand(bsize) # 扩展时间为批次大小匹配的tensor
3. 执行欧拉步骤更新`x_t`(通过公式`x_t += dt * v_t`)v_t = self.denoise_step( # 执行一步降噪,预测向量场state, # 机器人状态prefix_pad_masks, # 前缀填充掩码past_key_values, # 键值缓存x_t, # 当前噪声状态expanded_time, # 当前时间步)
注意,本质就是对去噪,而
便是预测的噪声,
是时间步长——如上面说过的「时间步长`dt`为负值(因为是从t=1向t=0方向演化),生成初始随机噪声作为起点,且时间上约定:"t=1是噪声,t=0是目标分布"」
这种欧拉积分实际上是在求解概率流ODE——Ordinary Differential Equation,从噪声分布逐步转换到目标动作分布。通过迭代调用`denoise_step`,模型能够逐渐去除噪声,显现出与输入条件(图像、语言和状态)相符的有意义动作序列# 欧拉步骤,更新噪声状态(沿向量场方向移动)x_t += dt * v_t
4. 更新时间(`time += dt`)
最后返回去噪后的动作序列time += dt # 更新时间(向0移动)
return x_t # 返回最终去噪后的动作序列
其次,`denoise_step`方法是PI0流匹配模型中的核心推理组件,负责在流匹配过程中执行单个降噪步骤。该方法接收机器人状态、前缀填充掩码、键值缓存、当前噪声状态和时间步作为输入,并返回向量场预测——return v_t,指导噪声朝着目标动作转变
- 首先,方法调用`embed_suffix`函数处理机器人状态、噪声动作和时间步信息,生成相应的嵌入表示和掩码。这些表示包含了状态和噪声动作在当前时间点的完整上下文
def denoise_step(self,state,prefix_pad_masks,past_key_values,x_t,timestep,):# 在给定的时间步对噪声`x_t`应用一个降噪步骤# 处理状态、噪声动作和时间步,生成后缀嵌入及相关掩码suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
- 接下来,方法构建复杂的注意力掩码系统,以实现前缀(已缓存的图像和语言表示)和后缀(状态和动作)之间的适当交互。它计算后缀序列长度、批次大小和前缀长度,然后扩展前缀掩码维度以匹配所需的注意力掩码形状
同时,它使用`make_att_2d_masks`函数为后缀创建二维注意力掩码,并将两个掩码沿第三维连接,形成完整的注意力掩码suffix_len = suffix_pad_masks.shape[1] # 获取后缀序列的长度batch_size = prefix_pad_masks.shape[0] # 获取批次大小prefix_len = prefix_pad_masks.shape[1] # 获取前缀序列的长度prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) # 将前缀掩码扩展为三维形状,适合注意力计算
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) # 为后缀创建二维注意力掩码full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) # 沿第三维连接前缀和后缀掩码,形成完整注意力掩码
- 一个关键的处理步骤是位置ID的计算,它先计算前缀偏移量(通过对前缀掩码求和),然后加上后缀填充掩码的累积和并减1
这确保了位置编码的连续性,使模型能够正确处理序列位置信息# 计算前缀偏移量(每个样本有效前缀标记的数量)prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] # 计算位置ID,确保前缀和后缀的位置编码连续position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
- 然后,方法调用`paligemma_with_expert.forward`
但与训练阶段不同的是,这里只传入后缀嵌入(前缀部分已通过`past_key_values`缓存),这大大提高了推理效率
方法设置`fill_kv_cache=False`,表示使用现有缓存而非创建新缓存# 调用PaliGemma和Gemma专家模型的前向传播outputs_embeds, _ = self.paligemma_with_expert.forward( attention_mask=full_att_2d_masks, # 传入完整注意力掩码position_ids=position_ids, # 传入位置IDpast_key_values=past_key_values, # 传入缓存的键值对(来自前缀处理)inputs_embeds=[None, suffix_embs], # 只传入后缀嵌入(前缀已缓存)use_cache=self.config.use_cache, # 是否使用缓存机制fill_kv_cache=False, # 不填充新的键值缓存(使用现有缓存))
- 最后,方法提取后缀输出,特别是与动作步骤对应的部分,将其转换为float32数据类型(保持计算精度)
并通过`action_out_proj`投影到动作空间,得到向量场预测`v_t`# 提取后缀输出(对应于Gemma专家模型输出)suffix_out = outputs_embeds[1] # 只保留最后n_action_steps个标记的输出(对应动作部分)suffix_out = suffix_out[:, -self.config.n_action_steps :] # 转换为float32数据类型以保持计算精度suffix_out = suffix_out.to(dtype=torch.float32)
# 通过线性投影将输出转换为动作空间中的向量场预测v_t = self.action_out_proj(suffix_out) # 返回预测的向量场(指导噪声如何移动到目标点)return v_t
这个方法体现了流匹配算法的精髓——它不是直接预测动作,而是预测动作空间中的向量场,指导噪声状态如何逐步转变为有意义的动作。在`sample_actions`方法的循环中,这个函数被反复调用,通过欧拉积分逐步将随机噪声转化为精确、平滑且符合条件的机器人动作序列
3.5 flex_attention.py:实现了分组查询注意力
3.5.1 对分组查询注意力(GQA)的回顾
`flex_attention_forward`函数实现了PyTorch 2.5之后引入的FlexAttention机制,这是一种高效的注意力计算方案,专为大型语言模型设计,特别是使用分组查询注意力(GQA)的模型
关于GQA的介绍,详见此文《https://blog.csdn.net/v_JULY_v/article/details/134228287》
在PI0架构中,这是三种可选的注意力实现之一(其他两种为"eager"和"fa2"),提供了优化的内存使用和计算效率
3.5.2 每个键值KV头服务于8个查询Q头——相当于value头数/key头数是query头数的1/8
函数开始时记录输入张量的原始数据类型,然后设置分组查询注意力的参数:8个注意力头但只有1个键值头,每个键值KV头服务于8个查询Q头——相当于value头数/key头数是query头数的1/8,这种配置是Gemma模型的特点,能在保持表达能力的同时显著减少内存占用和计算量
original_dtype = query_states.dtype # 保存查询状态的原始数据类型num_att_heads = 8 # 设置注意力头数量为8num_key_value_heads = 1 # 设置键值头数量为1(分组查询注意力的特点)num_key_value_groups = num_att_heads // num_key_value_heads # 计算每个键值头对应的查询头组数
接下来,函数对键状态和值状态执行精心设计的扩展操作,使单个键值头能够被多个查询头共享。这通过添加维度、扩展和重塑键值张量来实现,确保它们与查询头的数量匹配
- 比如先对K做添加、扩展、重塑
# 在键状态张量中添加一个维度,用于后续展开key_states = key_states[:, :, :, None, :] # 扩展键状态张量以匹配所有查询头key_states = key_states.expand( # 扩展为[批次大小, 序列长度, 键值头数, 每组查询头数, 头维度]batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim )# 重塑键状态张量以便于计算key_states = key_states.reshape( # 重塑为[批次大小, 序列长度, 总注意力头数, 头维度])batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
- 然后再对V做添加、扩展、重塑
# 在值状态张量中添加一个维度,用于后续展开value_states = value_states[:, :, :, None, :] # 扩展值状态张量以匹配所有查询头value_states = value_states.expand( # 扩展为[批次大小, 序列长度, 键值头数, 每组查询头数, 头维度]batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim )# 重塑值状态张量以便于计算value_states = value_states.reshape( # 重塑为[批次大小, 序列长度, 总注意力头数, 头维度]batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim )
- 最后做转置
# 转置查询状态张量,将头维度移到前面 [批次大小, 注意力头数, 序列长度, 头维度]query_states = query_states.transpose(1, 2) # 转置键状态张量,将头维度移到前面 [批次大小, 注意力头数, 序列长度, 头维度]key_states = key_states.transpose(1, 2) # 转置值状态张量,将头维度移到前面 [批次大小, 注意力头数, 序列长度, 头维度]value_states = value_states.transpose(1, 2)
为了保证计算精度,函数将所有状态转换为float32类型
# 将查询状态转换为float32类型以提高计算精度query_states = query_states.to(torch.float32) # 将键状态转换为float32类型以提高计算精度key_states = key_states.to(torch.float32) # 将值状态转换为float32类型以提高计算精度value_states = value_states.to(torch.float32)
然后处理因果掩码(causal mask)。掩码确保每个位置只能关注当前及之前的位置,这对自回归生成至关重要
# 将输入的注意力掩码赋值给因果掩码变量causal_mask = attention_mask # 如果因果掩码不为空if causal_mask is not None: # 调整掩码形状以匹配注意力头和序列长度causal_mask = causal_mask[:, None, :, : key_states.shape[2]] # 如果掩码的注意力头维度为1,但查询状态有多个头if causal_mask.shape[1] == 1 and query_states.shape[1] > 1: # 扩展掩码以匹配查询状态的注意力头数causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
3.5.2 针对FlexAttention的优化,函数实现的一个巧妙的块处理系统
针对FlexAttention的优化,函数实现了一个巧妙的块处理系统:
- 通过`precomputed_mask_factory`创建掩码访问函数,将序列长度向上取整为128(块大小)的倍数,并添加适当的填充
# 定义预计算掩码工厂函数def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature: # 内部定义掩码修改函数,接收批次、头、查询索引和键值索引def mask_mod(b, h, q_idx, kv_idx): # 危险区域:如果索引超出形状,会在设备端触发断言# 返回指定位置的掩码值return precomputed_mask[b][h][q_idx][kv_idx] return mask_mod # 返回掩码修改函数# 获取因果掩码的形状参数b_mask, h_mask, q_len, kv_len = causal_mask.shape # 设置块大小为128,用于优化计算block_size = 128 # 将查询长度向上取整到块大小的倍数q_len_rounded = _round_up_to_multiple(q_len, block_size) # 将键值长度向上取整到块大小的倍数kv_len_rounded = _round_up_to_multiple(kv_len, block_size) # 关键:我们需要在这里扩展,否则会得到CUDA索引错误# 计算查询维度需要的填充量pad_q = q_len_rounded - q_len # 计算键值维度需要的填充量pad_k = kv_len_rounded - kv_len # 对因果掩码进行填充,使其大小符合块大小要求padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0) # 创建填充掩码的修改函数mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
- 代码中最关键的部分是对掩码的处理和块掩码的创建
首先通过`create_mask`生成完整的4D掩码
然后通过`create_block_mask`将其转换为更高效的块式表示# 创建4D掩码mask_4d = create_mask( # 使用原始掩码修改函数mod_fn=mask_mod_fn_orig, B=b_mask, # 批次大小H=h_mask, # 头数量Q_LEN=q_len_rounded, # 查询长度(已取整)KV_LEN=kv_len_rounded, # 键值长度(已取整)device=causal_mask.device, # 设备与因果掩码相同_compile=False, # 不使用编译)
这些块构造函数接受`mask_mod`函数作为输入,该函数提供了安全访问掩码值的方法,特别注意了越界访问可能导致的设备端断言错误# 为4D掩码创建掩码修改函数mask_mod_fn_padded = precomputed_mask_factory(mask_4d) block_mask = create_block_mask( # 创建块掩码mask_mod=mask_mod_fn_padded, # 使用填充后的掩码修改函数B=b_mask, # 批次大小H=h_mask, # 头数Q_LEN=q_len_rounded, # 向上取整后的查询长度KV_LEN=kv_len_rounded, # 向上取整后的键值长度BLOCK_SIZE=block_size, # 块大小device=causal_mask.device, # 使用与因果掩码相同的设备_compile=False, # 不编译)
- 最后,函数调用`flex_attention`内核,该内核在底层实现了高效的注意力计算
# 掩码在内核中应用,理想情况下比score_mod更高效# 调用FlexAttention函数计算注意力输出和权重attn_output, attention_weights = flex_attention( query_states, # 查询状态key_states, # 键状态value_states, # 值状态block_mask=block_mask, # 块掩码# 启用分组查询注意力(GQA),因为我们已经对查询/键状态进行了相应的形状调整enable_gqa=True, # 设置缩放因子,默认为head_dim的平方根的倒数scale=head_dim**-0.5 if scaling is None else scaling, # 返回对数和指数值return_lse=True, )
- 结果被转换回原始数据类型,转置并重塑为期望的输出格式[批次大小, 序列长度, 嵌入维度]
# 将注意力输出转换回原始数据类型attn_output = attn_output.to(dtype=original_dtype) # [B, Q_LEN, H, head_dim],转置注意力输出并确保内存连续attn_output = attn_output.transpose(1, 2).contiguous() # 重塑注意力输出的形状attn_output = attn_output.reshape( batch_size, # 批次大小-1, # 自动计算第二维大小# 合并头数和头维度attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim] )# 返回注意力输出return attn_output
// 待更