OccFormer代码loading_nusc_panoptic_occ.py分析

博客给出代码来源为https://github.com/zhangyp15/OccFormer ,并对其中关键代码增加了一些注释,涉及深度学习、目标检测等信息技术领域内容。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

代码来自:https://github.com/zhangyp15/OccFormer 

对关键代码增加了一些注释

import numpy as np
import yaml, os
import torch
import numba as nb

from mmdet.datasets.builder import PIPELINES
from .loading_nusc_occ import voxel_transform

import pdb

'''
Instructions from PanopticNuScenes:

A ground truth label file named {token}_panoptic.npz is provided for each sample in the Panoptic nuScenes dataset. 
A .npz file contains the panoptic label array (uint16 format) of the corresponding points in a pointcloud. 
The panoptic label of each point is: (general class index * 1000 + instance index).
Note here general class index (32 classes in total) rather than the challenge class index (16 classes in total) is used. 
For example, a ground truth instance from car class (general class index = 17), and with assigned car instance index 1, 
    will have a ground truth panoptic label of 1000 * 17 + 1 = 17001 in the .npz file. 
Since these ground truth panoptic labels are generated from annotated bounding boxes, 
    points that are included in more than 1 bounding box will be ignored, and assigned with panoptic label 0: 
    class index 0 and instance index 0. For points from stuff, their panoptic labels will be general class index 1000. 

To align with thing classes, you may think the stuff classes as sharing an instance index of 0 by all points. 
To load a ground truth file, you can use:
    from nuscenes.utils.data_io import load_bin_file
    label_file_path = /data/sets/nuscenes/panoptic/v1.0-mini/{token}_panoptic.npz
    panoptic_label_arr = load_bin_file(label_file_path, 'panoptic')

'''

@PIPELINES.register_module()
class LoadNuscPanopticOccupancyAnnotations(object):
    def __init__(
            self,
            data_root='/root/autodl-tmp',
            is_train=False,
            is_test_submit=False,
            grid_size=None, 
            point_cloud_range=None,
            bda_aug_conf=None,
            unoccupied_id=17,
            cls_metas='nuscenes.yaml',
        ):
        
        self.is_train = is_train
        self.is_test_submit = is_test_submit
        self.cls_metas = cls_metas
        with open(cls_metas, 'r') as stream:
            nusc_cls_metas = yaml.safe_load(stream)
            self.learning_map = nusc_cls_metas['learning_map']
        
        self.data_root = data_root
        self.bda_aug_conf = bda_aug_conf
        
        # voxel settings
        self.grid_size = np.array(grid_size)
        print("self.grid_size",self.grid_size)
        self.point_cloud_range = np.array(point_cloud_range)
        # for semantickitti, the transformer center is not (0, 0, 0) and makes the transformation a bit more complex
        self.transform_center = (self.point_cloud_range[:3] + self.point_cloud_range[3:]) / 2
        self.unoccupied_id = unoccupied_id

        # create full-resolution occupancy labels
        self.voxel_size = (self.point_cloud_range[3:] - self.point_cloud_range[:3]) / self.grid_size
        print("self.voxel_size",self.voxel_size)
    
    def sample_3d_augmentation(self):
        """Generate 3d augmentation values based on bda_config."""
        rotate_bda = np.random.uniform(*self.bda_aug_conf['rot_lim'])
        scale_bda = np.random.uniform(*self.bda_aug_conf['scale_lim'])
        flip_dx = np.random.uniform() < self.bda_aug_conf['flip_dx_ratio']
        flip_dy = np.random.uniform() < self.bda_aug_conf['flip_dy_ratio']
        flip_dz = np.random.uniform() < self.bda_aug_conf.get('flip_dz_ratio', 0.0)
        
        return rotate_bda, scale_bda, flip_dx, flip_dy, flip_dz

    def __call__(self, results):
        # for test-submission of nuScenes Panoptic Segmentation 
        if self.is_test_submit:
            imgs, rots, trans, intrins, post_rots, post_trans, gt_depths, sensor2sensors = results['img_inputs']
            bda_rot = torch.eye(3).float()
            results['img_inputs'] = (imgs, rots, trans, intrins, post_rots, post_trans, bda_rot, gt_depths, sensor2sensors)
            
            pts_filename = results['pts_filename']
            points = np.fromfile(pts_filename, dtype=np.float32, count=-1).reshape(-1, 5)[..., :3]
            points_label = np.zeros((points.shape[0], 1)) # placeholder
            lidarseg = np.concatenate([points, points_label], axis=-1)
            results['points_occ'] = torch.from_numpy(lidarseg).float()
            
            return results
       
        # 打上断点运行每一行
        pdb.set_trace()
        panopticseg_file = results['lidarseg'].replace('lidarseg', 'panoptic')
        panopticseg_file = panopticseg_file.replace('.bin', '.npz')
        panopticseg_file = os.path.join(self.data_root, panopticseg_file)
        panoptic_labels = np.load(panopticseg_file)['data']
        semantic_labels = panoptic_labels // 1e3
        
        # perform learning mapping
        # np.unique 按类处理
        for semantic_label in np.unique(semantic_labels):
            # 筛选同一类的所有点,将这类点的mask设为true
            semantic_mask = (semantic_labels == semantic_label)
            # 找到这类点的真实id
            mapped_id = self.learning_map[int(semantic_label)]
            # 将这类点的id更新
            semantic_labels[semantic_mask] = mapped_id
            
            cls_panoptic_labels = panoptic_labels[semantic_mask] - semantic_label * 1e3
            # 新的panoptic_labels=class id *1000 + cls_panoptic_labels
            panoptic_labels[semantic_mask] = mapped_id * 1e3 + cls_panoptic_labels
        
        #转换成一列
        panoptic_labels = panoptic_labels.reshape(-1, 1)
        points = np.fromfile(results['pts_filename'], dtype=np.float32, count=-1).reshape(-1, 5)[..., :3]
        # 把点和标签拼接起来,一共为4列
        lidarseg = np.concatenate([points, panoptic_labels], axis=-1)
        
        if self.is_train:
            # 是为了计算旋转矩阵
            rotate_bda, scale_bda, flip_dx, flip_dy, flip_dz = self.sample_3d_augmentation()
            # 最后得到的bda_rot矩阵为
            # [[-1.,  0.,  0.],
            # [ 0., -1.,  0.],
            # [ 0.,  0.,  1.]] 不是很懂为什么
            _, bda_rot = voxel_transform(None, rotate_bda, scale_bda, flip_dx, flip_dy, flip_dz)
        else:
            bda_rot = torch.eye(3).float()
        
        # transform points
        points = points @ bda_rot.t().numpy()
        # 更新点
        lidarseg[:, :3] = points
        
        ''' create voxel labels from lidarseg '''
        eps = 1e-5
        # 得到网格中的体素索引
        points_grid_ind = np.floor((np.clip(lidarseg[:, :3], self.point_cloud_range[:3],
                self.point_cloud_range[3:] - eps) - self.point_cloud_range[:3]) / self.voxel_size).astype(np.int)
        # 将每个点的体素索引和每个点的标签拼接起来
        label_voxel_pair = np.concatenate([points_grid_ind, lidarseg[:, -1:]], axis=1)
        # 按照x,y,z进行多级排序,最后以z为主
        label_voxel_pair = label_voxel_pair[np.lexsort((points_grid_ind[:, 0], points_grid_ind[:, 1], points_grid_ind[:, 2])), :]
        label_voxel_pair = label_voxel_pair.astype(np.uint16)
        
        # 0: noise, 1-16 normal classes, 17 unoccupied
        # empty_id=17000
        empty_id = int(self.unoccupied_id * 1e3)
        processed_label = np.ones(self.grid_size, dtype=np.uint16) * empty_id
        processed_label = nb_process_label(processed_label, label_voxel_pair)
        # print("processed_label", processed_label)
        # convert label_0 to label_65535 (ignored)
        processed_label[processed_label == 0] = 65535
        # convert empty to label id 0
        processed_label[processed_label == empty_id] = 0

        ''' save results'''
        results['gt_occ'] = torch.from_numpy(processed_label.astype(np.float32)).long()
        results['points_occ'] = torch.from_numpy(lidarseg).float()
        imgs, rots, trans, intrins, post_rots, post_trans, gt_depths, sensor2sensors = results['img_inputs']
        results['img_inputs'] = (imgs, rots, trans, intrins, post_rots, post_trans, bda_rot, gt_depths, sensor2sensors)
        return results

# 将每个体素里面出现最多次数的标签作为当前体素的标签
@nb.jit('u2[:,:,:](u2[:,:,:], u2[:,:])', nopython=True, cache=True, parallel=False)
def nb_process_label(processed_label, sorted_label_voxel_pair):
    ignore_index = 0
    label_size = 256 * 256
    counter = np.zeros((label_size,), dtype=np.uint16)
    # counter[0]=1 
    counter[sorted_label_voxel_pair[0, 3]] = 1
    cur_sear_ind = sorted_label_voxel_pair[0, :3]
    for i in range(1, sorted_label_voxel_pair.shape[0]):
        cur_ind = sorted_label_voxel_pair[i, :3]
        if not np.all(np.equal(cur_ind, cur_sear_ind)):
            counter[ignore_index] = 0
            processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter)
            counter = np.zeros((label_size,), dtype=np.uint16)
            cur_sear_ind = cur_ind
        
        counter[sorted_label_voxel_pair[i, 3]] += 1
    
    processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter)
    
    return processed_label
<think>嗯,用户现在问的是关于“nuscenes_infos_temporal_train_new.pkl”这个文件。首先,我需要确定这个文件的背景和用途。根据之前的对话,用户之前询问过如何解压tar.gz文件,可能是在处理数据集相关的内容。而NuScenes是一个知名的自动驾驶数据集,所以这个.pkl文件很可能与NuScenes数据集有关。 接下来,我需要分析文件名中的各个部分。文件名中的“nuscenes”指的是数据集名称;“infos”可能表示信息或元数据;“temporal”可能涉及时间序列或时间相关的数据;“train”说明这是训练集用的;“new”可能表示这是新版本或修改过的文件;而“.pkl”扩展名说明这是一个通过Python的pickle模块序列化的文件,通常用于存储数据结构。 用户可能需要了解这个文件的具体内容、结构以及如何正确加载和使用它。考虑到用户之前处理过解压问题,他们可能在处理数据集时遇到了这个文件,需要进一步的信息来处理或分析其中的数据。 需要确认这个文件的作用,可能包含哪些数据,比如训练样本的元数据、时间序列信息、传感器数据路径、标注信息等。此外,用户可能想知道如何加载这个.pkl文件,使用Python的pickle模块时的注意事项,比如兼容性和安全性问题,以及可能的数据结构示例。 还需要考虑用户可能的使用场景,比如在进行自动驾驶相关的模型训练时,如何利用这个文件中的数据。可能需要提供代码示例,展示如何加载并查看文件内容,同时提醒用户注意版本兼容性和数据路径的正确性。 另外,用户可能没有说出的是,他们可能在加载这个文件时遇到了问题,比如pickle版本不匹配,或者数据结构不清晰,导致无法正确读取数据。这时候需要给出解决方案,比如使用特定的pickle协议版本,或者检查数据结构的层次。 最后,总结这个文件的作用,强调其在训练时间序列模型中的重要性,并确保用户了解如何处理和验证其中的数据,以避免在后续步骤中出现错误。</think>--- ### 一、文件背景与作用 **`nuscenes_infos_temporal_train_new.pkl`** 是自动驾驶领域 **NuScenes 数据集** 的衍生文件,主要用于存储 **时间序列(temporal)** 的训练数据元信息。其核心作用是为模型提供带有时间连续性的传感器数据(如连续帧的激光雷达、摄像头数据等),以支持时序感知任务(如目标运动预测、多帧融合检测等)。 --- ### 二、文件名解析 | 字段 | 含义 | |----------------------|----------------------------------------------------------------------| | `nuscenes` | 数据集名称(公开自动驾驶数据集,含多模态传感器数据) | | `infos` | 信息文件(包含数据路径、标注、传感器参数等元数据) | | `temporal` | 时间序列扩展(区别于单帧数据,包含连续时间戳的关联数据) | | `train` | 训练集划分(可能还有`val`/`test`对应验证集和测试集) | | `new` | 可能是版本标识(如修复了旧版数据问题或添加新字段) | | `.pkl` | Python Pickle序列化格式(二进制存储Python对象) | --- ### 三、文件内容结构 通过Python代码加载后可查看数据结构(需安装`pickle`库): ```python import pickle with open("nuscenes_infos_temporal_train_new.pkl", "rb") as f: data = pickle.load(f) # 典型字段示例: # { # "sample_sequence_0": { # "lidar_path": ["timestamp_0.bin", "timestamp_1.bin", ...], # 连续帧点云路径 # "camera_paths": {"CAM_FRONT": ["img_0.jpg", "img_1.jpg", ...]}, # "annotations": [ # 时序标注(如车辆轨迹) # {"track_id": 1, "bboxes": [bbox_0, bbox_1, ...], ...}, # ... # ], # "ego_pose": [...] # 自车运动轨迹 # }, # "sample_sequence_1": { ... } # } ``` --- ### 四、核心应用场景 1. **时序目标检测** 利用连续帧数据提升检测鲁棒性(如遮挡恢复、运动模糊补偿) *示例模型*:[CenterPoint](https://arxiv.org/abs/2006.11275) 的时序扩展版本 2. **运动轨迹预测** 通过历史轨迹预测行人/车辆的未来路径 *示例方法*:LSTM + Social Pooling 3. **多模态融合** 对齐激光雷达、摄像头在不同时间戳的数据(需`calibrated_sensor`参数) --- ### 五、使用注意事项 #### 1. 加载依赖 ```bash pip install nuscenes-devkit # 需NuScenes官方工具包解析数据路径 ``` #### 2. 版本兼容性 - 检查NuScenes数据集版本(如v1.0-mini、v1.0-trainval)是否与`.pkl`文件匹配 - 若出现`UnpicklingError`,可能是Python版本或pickle协议不兼容(尝试指定协议): ```python data = pickle.load(f, encoding="latin1") # 或 protocol=4 ``` #### 3. 路径验证 确保`.pkl`中的文件路径(如`lidar_path`)与本地NuScenes数据集实际存储路径一致,否则需手动修正路径前缀: ```python # 示例修正代码 for seq in data.values(): seq["lidar_path"] = [f"/your/local/path/{p}" for p in seq["lidar_path"]] ``` --- ### 六、扩展:与标准NuScenes数据的区别 | 特性 | 标准NuScenes `infos_train.pkl` | `temporal_train_new.pkl` | |--------------------------|--------------------------------------|----------------------------------------| | **数据单元** | 单帧(isolated frame) | 连续帧序列(如10帧为一组) | | **标注重点** | 单帧3D框 | 跨帧目标ID追踪、运动状态 | | **典型任务** | 静态检测 | 动态预测、轨迹生成 | | **传感器对齐** | 单时刻对齐 | 多时刻时空校准 | --- ### 七、操作建议 1. **数据抽样检查** ```python print(list(data.keys())[:3]) # 查看前3个序列ID print(data["sample_sequence_0"]["annotations"][0].keys()) # 检查标注字段 ``` 2. **可视化工具** 使用NuScenes官方提供的[可视化工具](https://github.com/nutonomy/nuscenes-devkit)展示时序数据: ```python from nuscenes.nuscenes import NuScenes nusc = NuScenes(version='v1.0-trainval', dataroot='/path/to/data') ``` 3. **内存管理** 对于大型`.pkl`文件,建议分块加载: ```python import pickle5 # 针对大文件的高效读取 with open("file.pkl", "rb") as f: while True: try: chunk = pickle5.load(f) except EOFError: break ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值