Drivestduio 代码笔记与理解
Rigid Node: 表示 car
或者trucks
Deformable Node : 表示一些 分布之外的 non-rigid 的运动物体, 比如远处的行人等和Cyclist。
在 load_objects
会读取每一个 dynamic objects 的 'bounding box’的信息,具体如下:
frame_instances
记录了每一帧都有哪些 instance, 以及对应 每一帧 其 位姿信息等;
instances_info
包含每一帧 对于哪些 instance 是可见的。
1. 读取 Bounding Box 的基本信息
逻辑是先遍历场景的 instance, 然后 再每一个 instance 的信息。
## 存放每一帧 的instance 的pose
instances_pose = np.zeros((num_full_frames, num_instances, 4, 4))for k, v in instances_info.items():instances_model_types[int(k)] = OBJECT_CLASS_NODE_MAPPING[v["class_name"]]for frame_idx, obj_to_world, box_size in zip(v["frame_annotations"]["frame_idx"], v["frame_annotations"]["obj_to_world"], v["frame_annotations"]["box_size"]):# the first ego pose as the origin of the world coordinate system.obj_to_world = np.array(obj_to_world).reshape(4, 4)obj_to_world = np.linalg.inv(ego_to_world_start) @ obj_to_worldinstances_pose[frame_idx, int(k)] = np.array(obj_to_world)instances_size[frame_idx, int(k)] = np.array(box_size)
根据 per_frame_instance_mask
来得到 每一帧对于哪些instance 是可见的。
per_frame_instance_mask = np.zeros((num_full_frames, num_instances))for frame_idx, valid_instances in frame_instances.items():per_frame_instance_mask[int(frame_idx), valid_instances] = 1
2. 使用Bounding Box 的信息初始化高斯
这里需要介绍一个 dynamic vehicle 非常重要的坐标系,物体坐标系(Object系),
其通常位于汽车的车辆中心
。所以任何一帧的 Lidar 通过w2o
矩阵可以将Lidar 点转换到 canonical space, 完成对于多帧 Lidar 的聚集
将 Lidar 点 (世界坐标系下面的) 通过 转化矩阵 w2o
转换到 Object 坐标系下面 , 然后 根据 Bounding Box 的 Size 去保留 在 BBX 内部的点云,准备进行初始化。
o2w = self.pixel_source.instances_pose[fi, ins_id]o_size = self.pixel_source.instances_size[ins_id]# convert the lidar points to the instance's coordinate systemw2o = torch.inverse(o2w)o_pts = transform_points(lidar_pts, w2o)# 将BBX 之外的点通过 Mask 滤除,这一步是在局部 Object 坐标系下面进行的mask = ((o_pts[:, 0] > -o_size[0] / 2)& (o_pts[:, 0] < o_size[0] / 2)& (o_pts[:, 1] > -o_size[1] / 2)& (o_pts[:, 1] < o_size[1] / 2)& (o_pts[:, 2] > -o_size[2] / 2)& (o_pts[:, 2] < o_size[2] / 2))valid_pts = o_pts[mask]valid_colors = self.lidar_source.colors[lidar_dict["lidar_mask"]][mask]
通过 比较 在 instances_pose
的pose (O2W系) 移动,仅仅对于 动态的 instance 进行保留 。
因为车辆的移动其实可以看成是
O2W
坐标系的移动。 相当于车辆是静止的,但是环境是运动的
if only_moving:# consider only the instances with non-zero flowslogger.info(f"Filtering out the instances with non-moving trajectories")new_instance_dict = {}for k, v in instance_dict.items():if v["num_pts"] > 0: ## 仅仅考虑有点的 instance# flows = v["flows"]# if flows.norm(dim=-1).mean() > moving_thres:# v.pop("flows")# new_instance_dict[k] = v# logger.info(f"Instance {k} has {v['num_pts']} lidar sample points")frame_info = self.pixel_source.per_frame_instance_mask[:, k]instances_pose = self.pixel_source.instances_pose[:, k]instances_trans = instances_pose[:, :3, 3]valid_trans = instances_trans[frame_info]traj_length = valid_trans[1:] - valid_trans[:-1]traj_length = torch.norm(traj_length, dim=-1).sum()if traj_length > traj_length_thres:new_instance_dict[k] = vlogger.info(f"Instance {k} has {v['num_pts']} lidar sample points")instance_dict = new_instance_dict
将所有帧的 Lidar Aggregated 到 Canonical Space 下面,如图所示:
静态高斯的初始化
静态的 高斯初始化 = Lidar_samples + 半球内的随机采样点。 随机采样点是 PVG
这篇文章所介绍的, 在 球内部 和 球外面进行均匀采样。
Rigid 高斯的初始化
从 Canonical Space
累计的 点云进行 高斯的各项属性的初始化, 读取 点云的 坐标和颜色,然后进行初始化。 并记录了 每个bbx 的大小以及 每个instance 在每一帧的可见性,分别用 self.instances_size
和 self.instances_fv
表示。
## (num_instances, 3) BBX 的大小self.instances_size = torch.stack(instances_size).to(self.device) # # (num_frame, num_instances) instance 在每一帧的可见性
self.instances_fv = torch.cat(instances_fv, dim=1).to(self.device)
值得注意的是, Drivestudio 将每一帧的每一个 instance 的 BBX 的 的 pose 也作为参数去考虑优化:
# (num_frame, num_instances, 4) 四元数self.instances_quats = Parameter(self.quat_act(instances_quats))# (num_frame, num_instances, 3) 平移
self.instances_trans = Parameter(instances_trans)
高斯参数的优化器设置:
所有的 Rigid Nodes 会把放进一个 优化字典当中,然后一起优化,并不是每个 instance 去独立的优化。
Rigid 的每一个GS 都是像原始的 3DGS 一样,配置 每一个属性的 学习率去进行优化的。
groups.append({'params': params,'name': params_name,'lr': optim_cfg.lr,'eps': optim_cfg.eps,'weight_decay': optim_cfg.weight_decay})
groups 构建好之后,全部一起当作字典丢进 Adam 优化器去进行优化
self.optimizer = torch.optim.Adam(groups, lr=0.0, eps=1e-15)
Sky Model
Drivestudio 使用场景的 Environment map
来对于 天空的颜色进行建模. Sky 被建模成一个 长方体 cube, 然后使用基于光线方向(Opengl系)来在 environment cube 上进行纹理查询。这个 environment map 虽然没有任何网络,但是其本身的参数也是需要被优化的。 对应的 Code 如下
class EnvLight(torch.nn.Module):def __init__(self,class_name: str,resolution=1024,device: torch.device = torch.device("cuda"),**kwargs):super().__init__()self.class_prefix = class_name + "#"self.device = deviceself.to_opengl = torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32, device="cuda")## 需要被优化的 environment mapself.base = torch.nn.Parameter(0.5 * torch.ones(6, resolution, resolution, 3, requires_grad=True),)def forward(self, image_infos):l = image_infos["viewdirs"]l = (l.reshape(-1, 3) @ self.to_opengl.T).reshape(*l.shape)l = l.contiguous()prefix = l.shape[:-1]if len(prefix) != 3: # reshape to [B, H, W, -1]l = l.reshape(1, 1, -1, l.shape[-1])light = dr.texture(self.base[None, ...], l, filter_mode='linear', boundary_mode='cube')light = light.view(*prefix, -1)return light
开始训练:
针对每一个 Node 提取出场景的N 个动态对象高斯。 如果是 Rigid 物体的高斯,前面的代码是采用 Object
系存储的,需要转换到 World
系,然后提取出来。
以 平移变化来分析:
首先我们有 frame_id
标记 我们训练的是哪一帧,取出这一帧的所有 instance 对应的 旋转和 rot_cur_frame
平移trans_cur_frame
. 假设我们有M个动态点,将这M个动态点 和 应用在M个旋转和平移向量上,同时得到了这个所有动态类别在场景frame_id
对应的位置和坐标。
def transform_means(self, means: torch.Tensor) -> torch.Tensor:"""transform the means of instances to world spaceaccording to the pose at the current frame"""assert means.shape[0] == self.point_ids.shape[0], \"its a bug here, we need to pass the mask for points_ids"quats_cur_frame = self.instances_quats[self.cur_frame] # (num_instances, 4)rot_cur_frame = quat_to_rotmat(self.quat_act(quats_cur_frame)) # (num_instances, 3, 3)## 求出每个点的旋转rot_per_pts = rot_cur_frame[self.point_ids[..., 0]] # (num_points, 3, 3)trans_cur_frame = self.instances_trans[self.cur_frame] # (num_instances, 3)## 求出每个点的平移trans_per_pts = trans_cur_frame[self.point_ids[..., 0]]# transform the means to world spacemeans = torch.bmm(rot_per_pts, means.unsqueeze(-1)).squeeze(-1) + trans_per_ptsreturn means
之后使用 gsplat
作为渲染的框架,执行渲染, 这里的动态和静态实际上都是转换到 世界系的 高斯 然后一起渲染的。 为了渲染 动态物体,将场景高斯的 动态物体的 Opacity
设置为0, 其他的属性不用改变。