LiftSplatShoot中用于处理多视角图像数据并生成鸟瞰图(调试)

根据LSS代码中explore.py 中来显示出来图像
https://github.com/nv-tlabs/lift-splat-shoot/blob/d74598cb51101e2143097ab270726a561f81f8fd/src/explore.py#L249

安装
pytorch

pip install nuscenes-devkit tensorboardX efficientnet_pytorch==0.7.0

下载LSS的代码

git clone https://github.com/nv-tlabs/lift-splat-shoot.git

同时下载mini数据集

https://www.nuscenes.org/.

模型
模型下载

Visualize Predictions

python main.py viz_model_preds mini/trainval --modelf=MODEL_LOCATION --dataroot=NUSCENES_ROOT --map_folder=NUSCENES_MAP_ROOT

实现:

首先安装必要的库

import torch
import numpy as np
import pickle
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from PIL import Image
import matplotlib.patches as mpatches

from src.tool import gen_dx_bx, denormalize_img, add_ego
from src.models import compile_model

对于 nuscenes dataset 的数据设定

H=900
W=1600
resize_lim=(0.193, 0.225)
final_dim=(128, 352)
bot_pct_lim=(0.0, 0.22)
rot_lim=(-5.4, 5.4)
rand_flip=True

xbound=[-50.0, 50.0, 0.5]
ybound=[-50.0, 50.0, 0.5]
zbound=[-10.0, 10.0, 20.0]
dbound=[4.0, 45.0, 1.0]

grid_conf = {
    'xbound': xbound,
    'ybound': ybound,
    'zbound': zbound,
    'dbound': dbound,
}
cams = ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT',
    'CAM_BACK_LEFT', 'CAM_BACK', 'CAM_BACK_RIGHT']

data_aug_conf = {
                'resize_lim': resize_lim,
                'final_dim': final_dim,
                'rot_lim': rot_lim,
                'H': H, 'W': W,
                'rand_flip': rand_flip,
                'bot_pct_lim': bot_pct_lim,
                'cams': cams,
                'Ncams': 5,
            }

初始化模型

# initialize lss model
model = compile_model(grid_conf, data_aug_conf, outC=1)

然后打印出模型下载

    print('loading', modelf)
    model.load_state_dict(torch.load(modelf))
    model.to(device)

打印模型

model.eval()
LiftSplatShoot(
  (camencode): CamEncode(
    (trunk): EfficientNet(
      (_conv_stem): Conv2dStaticSamePadding(
        3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
        (static_padding): ZeroPad2d((1, 1, 1, 1))
      )
      (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_blocks): ModuleList(
        (0): MBConvBlock(
          (_depthwise_conv): Conv2dStaticSamePadding(
            32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
            (static_padding): ZeroPad2d((1, 1, 1, 1))
          )
          (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
          (_se_reduce): Conv2dStaticSamePadding(
            32, 8, kernel_size=(1, 1), stride=(1, 1)
            (static_padding): Identity()
          )
          (_se_expand): Conv2dStaticSamePadding(
            8, 32, kernel_size=(1, 1), stride=(1, 1)
            (static_padding): Identity()
          )
          (_project_conv): Conv2dStaticSamePadding(
            32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False
...
      (3): ReLU(inplace=True)
      (4): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)

画图的代码

# 禁用梯度计算,节省显存
with torch.no_grad():
    # 遍历数据加载器中的每一个批次
    for batchi, (imgs, rots, trans, intrins, post_rots, post_trans, binimgs) in enumerate(loader):
        # 将数据移到设备上,并传入模型,获取输出
        out = model(imgs.to(device),
                    rots.to(device),
                    trans.to(device),
                    intrins.to(device),
                    post_rots.to(device),
                    post_trans.to(device),
                    )
        # 将输出进行sigmoid激活,并移动到CPU
        out = out.sigmoid().cpu()

        # 遍历批次中的每一张图片
        for si in range(imgs.shape[0]):
            # 清空当前的绘图
            plt.clf()
            # 遍历每张图片中的不同视角图像
            for imgi, img in enumerate(imgs[si]):
                # 创建子图
                ax = plt.subplot(gs[1 + imgi // 3, imgi % 3])
                # 反归一化图像
                showimg = denormalize_img(img)
                # 翻转底部的图像
                if imgi > 2:
                    showimg = showimg.transpose(Image.FLIP_LEFT_RIGHT)
                # 显示图像
                plt.imshow(showimg)
                plt.axis('off')
                # 在图像上添加注释
                plt.annotate(cams[imgi].replace('_', ' '), (0.01, 0.92), xycoords='axes fraction')

            # 创建顶层的子图
            ax = plt.subplot(gs[0, :])
            # 隐藏坐标轴
            ax.get_xaxis().set_ticks([])
            ax.get_yaxis().set_ticks([])
            # 设置子图边框颜色和宽度
            plt.setp(ax.spines.values(), color='b', linewidth=2)
            # 添加图例
            plt.legend(handles=[
                mpatches.Patch(color=(0.0, 0.0, 1.0, 1.0), label='Output Vehicle Segmentation'),
                mpatches.Patch(color='#76b900', label='Ego Vehicle'),
                mpatches.Patch(color=(1.00, 0.50, 0.31, 0.8), label='Map (for visualization purposes only)')
            ], loc=(0.01, 0.86))
            # 显示输出图像
            plt.imshow(out[si].squeeze(0), vmin=0, vmax=1, cmap='Blues')

            # 绘制静态地图(提高可视化效果)
            rec = loader.dataset.ixes[counter]
            plot_nusc_map(rec, nusc_maps, loader.dataset.nusc, scene2map, dx, bx)
            # 设置显示范围
            plt.xlim((out.shape[3], 0))
            plt.ylim((0, out.shape[3]))
            # 添加自车位置
            add_ego(bx, dx)

            # 保存图像
            imname = f'eval{batchi:06}_{si:03}.jpg'
            print('saving', imname)
            plt.savefig(imname)
            counter += 1

得到原来图片

        # 遍历批次中的每一张图片
        for si in range(imgs.shape[0]):
            # 清空当前的绘图
            plt.clf()
            # 遍历每张图片中的不同视角图像
            for imgi, img in enumerate(imgs[si]):
                # 创建子图
                ax = plt.subplot(gs[1 + imgi // 3, imgi % 3])
                # 反归一化图像
                showimg = denormalize_img(img)
                # 翻转底部的图像
                if imgi > 2:
                    showimg = showimg.transpose(Image.FLIP_LEFT_RIGHT)
                # 显示图像
                plt.imshow(showimg)
                plt.axis('off')
                # 在图像上添加注释
                plt.annotate(cams[imgi].replace('_', ' '), (0.01, 0.92), xycoords='axes fraction')

车辆分割可视化结构

            # 创建顶层的子图
            ax = plt.subplot(gs[0, :])
            # 隐藏坐标轴
            ax.get_xaxis().set_ticks([])
            ax.get_yaxis().set_ticks([])
            # 设置子图边框颜色和宽度
            plt.setp(ax.spines.values(), color='b', linewidth=2)
            # 添加图例
            plt.legend(handles=[
                mpatches.Patch(color=(0.0, 0.0, 1.0, 1.0), label='Output Vehicle Segmentation'),
                mpatches.Patch(color='#76b900', label='Ego Vehicle'),
                mpatches.Patch(color=(1.00, 0.50, 0.31, 0.8), label='Map (for visualization purposes only)')
            ], loc=(0.01, 0.86))
            # 显示输出图像
            plt.imshow(out[si].squeeze(0), vmin=0, vmax=1, cmap='Blues')


静态地图的效果

            # 绘制静态地图(提高可视化效果)
            rec = loader.dataset.ixes[counter]
            plot_nusc_map(rec, nusc_maps, loader.dataset.nusc, scene2map, dx, bx)
            # 设置显示范围
            plt.xlim((out.shape[3], 0))
            plt.ylim((0, out.shape[3]))
            plt.imshow(out_mean[0], cmap='viridis')
            plt.colorbar()
            plt.axis('off')


图像坐标系向ego坐标系进行坐标转化

                # 遍历每张图片中的不同视角图像
                for imgi, img in enumerate(imgs[si]):
                    # 将点从自车坐标系转换到摄像头坐标系
                    ego_pts = ego_to_cam(pts[si], rots[si, imgi], trans[si, imgi], intrins[si, imgi])
                    # 获取仅在图像内的点的掩码
                    mask = get_only_in_img_mask(ego_pts, H, W)
                    # 对点进行后旋转和平移变换
                    plot_pts = post_rots[si, imgi].matmul(ego_pts) + post_trans[si, imgi].unsqueeze(1)

                    # 创建子图
                    ax = plt.subplot(gs[imgi // 3, imgi % 3])
                    # 反归一化图像
                    showimg = denormalize_img(img)
                    # 显示图像
                    plt.imshow(showimg)
                    # 如果显示激光雷达点云
                    if show_lidar:
                        plt.scatter(plot_pts[0, mask], plot_pts[1, mask], c=ego_pts[2, mask],
                                s=5, alpha=0.1, cmap='jet')
                    # 关闭坐标轴显示
                    plt.axis('off')

                    # 切换到最终的绘图轴
                    plt.sca(final_ax)
                    # 绘制图像中的点
                    plt.plot(img_pts[si, imgi, :, :, :, 0].view(-1), img_pts[si, imgi, :, :, :, 1].view(-1))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值