street gaussion 耗时分析
目录
render 函数耗时
耗时函数 300ms
render 函数耗时
def render(self, viewpoint_camera: Camera,pc: StreetGaussianModel,convert_SHs_python = None, compute_cov3D_python = None, scaling_modifier = None, override_color = None,exclude_list = [],): include_list = list(set(pc.model_name_id.keys()) - set(exclude_list))start_time=time.time()# Step1: render foregroundpc.set_visibility(include_list)pc.parse_camera(viewpoint_camera)print('time11-------',time.time()-start_time) result = self.render_kernel(viewpoint_camera, pc, convert_SHs_python, compute_cov3D_python, scaling_modifier, override_color)print('time12-------',time.time()-start_time) # Step2: render skyif pc.include_sky:sky_color = pc.sky_cubemap(viewpoint_camera, result['acc'].detach())result['rgb'] = result['rgb'] + sky_color * (1 - result['acc'])print('time13-------',time.time()-start_time) if pc.use_color_correction:result['rgb'] = pc.color_correction(viewpoint_camera, result['rgb'])print('time14-------',time.time()-start_time) if cfg.mode != 'train':result['rgb'] = torch.clamp(result['rgb'], 0., 1.)print('time15-------',time.time()-start_time) return result
time11------- 0.37116456031799316
time12------- 0.4089925289154053
time13------- 0.4396076202392578
time14------- 0.4396934509277344
time15------- 0.4397237300872803
耗时函数 300ms
parse_camera
def parse_camera(self, camera: Camera):# set cameraself.viewpoint_camera = camera# set background maskself.background.set_background_mask(camera)self.frame = camera.meta['frame']self.frame_idx = camera.meta['frame_idx']self.frame_is_val = camera.meta['is_val']self.num_gaussians = 0# background if self.get_visibility('background'):num_gaussians_bkgd = self.background.get_xyz.shape[0]self.num_gaussians += num_gaussians_bkgd# object (build scene graph)self.graph_obj_list = []if self.include_obj:timestamp = camera.meta['timestamp']for i, obj_name in enumerate(self.obj_list):obj_model: GaussianModelActor = getattr(self, obj_name)start_timestamp, end_timestamp = obj_model.start_timestamp, obj_model.end_timestampif timestamp >= start_timestamp and timestamp <= end_timestamp and self.get_visibility(obj_name):self.graph_obj_list.append(obj_name)num_gaussians_obj = getattr(self, obj_name).get_xyz.shape[0]self.num_gaussians += num_gaussians_obj# set index rangeself.graph_gaussian_range = dict()idx = 0if self.get_visibility('background'):num_gaussians_bkgd = self.background.get_xyz.shape[0]self.graph_gaussian_range['background'] = [idx, idx+num_gaussians_bkgd-1]idx += num_gaussians_bkgdfor obj_name in self.graph_obj_list:num_gaussians_obj = getattr(self, obj_name).get_xyz.shape[0]self.graph_gaussian_range[obj_name] = [idx, idx+num_gaussians_obj-1]idx += num_gaussians_objif len(self.graph_obj_list) > 0:self.obj_rots = []self.obj_trans = []for i, obj_name in enumerate(self.graph_obj_list):obj_model: GaussianModelActor = getattr(self, obj_name)track_id = obj_model.track_idobj_rot = self.actor_pose.get_tracking_rotation(track_id, self.viewpoint_camera)obj_trans = self.actor_pose.get_tracking_translation(track_id, self.viewpoint_camera) ego_pose = self.viewpoint_camera.ego_poseego_pose_rot = matrix_to_quaternion(ego_pose[:3, :3].unsqueeze(0)).squeeze(0)obj_rot = quaternion_raw_multiply(ego_pose_rot.unsqueeze(0), obj_rot.unsqueeze(0)).squeeze(0)obj_trans = ego_pose[:3, :3] @ obj_trans + ego_pose[:3, 3]obj_rot = obj_rot.expand(obj_model.get_xyz.shape[0], -1)obj_trans = obj_trans.unsqueeze(0).expand(obj_model.get_xyz.shape[0], -1)self.obj_rots.append(obj_rot)self.obj_trans.append(obj_trans)self.obj_rots = torch.cat(self.obj_rots, dim=0)self.obj_trans = torch.cat(self.obj_trans, dim=0) self.flip_mask = []for obj_name in self.graph_obj_list:obj_model: GaussianModelActor = getattr(self, obj_name)if obj_model.deformable or self.flip_prob == 0:flip_mask = torch.zeros_like(obj_model.get_xyz[:, 0]).bool()else:flip_mask = torch.rand_like(obj_model.get_xyz[:, 0]) < self.flip_probself.flip_mask.append(flip_mask)self.flip_mask = torch.cat(self.flip_mask, dim=0)