根据LSS代码中explore.py 中来显示出来图像
https://github.com/nv-tlabs/lift-splat-shoot/blob/d74598cb51101e2143097ab270726a561f81f8fd/src/explore.py#L249
安装
pytorch
pip install nuscenes-devkit tensorboardX efficientnet_pytorch==0.7.0
下载LSS的代码
git clone https://github.com/nv-tlabs/lift-splat-shoot.git
同时下载mini数据集
https://www.nuscenes.org/.
模型
模型下载
Visualize Predictions
python main.py viz_model_preds mini/trainval --modelf=MODEL_LOCATION --dataroot=NUSCENES_ROOT --map_folder=NUSCENES_MAP_ROOT
实现:
首先安装必要的库
import torch
import numpy as np
import pickle
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from PIL import Image
import matplotlib.patches as mpatches
from src.tool import gen_dx_bx, denormalize_img, add_ego
from src.models import compile_model
对于 nuscenes dataset 的数据设定
H=900
W=1600
resize_lim=(0.193, 0.225)
final_dim=(128, 352)
bot_pct_lim=(0.0, 0.22)
rot_lim=(-5.4, 5.4)
rand_flip=True
xbound=[-50.0, 50.0, 0.5]
ybound=[-50.0, 50.0, 0.5]
zbound=[-10.0, 10.0, 20.0]
dbound=[4.0, 45.0, 1.0]
grid_conf = {
'xbound': xbound,
'ybound': ybound,
'zbound': zbound,
'dbound': dbound,
}
cams = ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT',
'CAM_BACK_LEFT', 'CAM_BACK', 'CAM_BACK_RIGHT']
data_aug_conf = {
'resize_lim': resize_lim,
'final_dim': final_dim,
'rot_lim': rot_lim,
'H': H, 'W': W,
'rand_flip': rand_flip,
'bot_pct_lim': bot_pct_lim,
'cams': cams,
'Ncams': 5,
}
初始化模型
# initialize lss model
model = compile_model(grid_conf, data_aug_conf, outC=1)
然后打印出模型下载
print('loading', modelf)
model.load_state_dict(torch.load(modelf))
model.to(device)
打印模型
model.eval()
LiftSplatShoot(
(camencode): CamEncode(
(trunk): EfficientNet(
(_conv_stem): Conv2dStaticSamePadding(
3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
(static_padding): ZeroPad2d((1, 1, 1, 1))
)
(_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
(_blocks): ModuleList(
(0): MBConvBlock(
(_depthwise_conv): Conv2dStaticSamePadding(
32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
(static_padding): ZeroPad2d((1, 1, 1, 1))
)
(_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
(_se_reduce): Conv2dStaticSamePadding(
32, 8, kernel_size=(1, 1), stride=(1, 1)
(static_padding): Identity()
)
(_se_expand): Conv2dStaticSamePadding(
8, 32, kernel_size=(1, 1), stride=(1, 1)
(static_padding): Identity()
)
(_project_conv): Conv2dStaticSamePadding(
32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False
...
(3): ReLU(inplace=True)
(4): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1))
)
)
)
画图的代码
# 禁用梯度计算,节省显存
with torch.no_grad():
# 遍历数据加载器中的每一个批次
for batchi, (imgs, rots, trans, intrins, post_rots, post_trans, binimgs) in enumerate(loader):
# 将数据移到设备上,并传入模型,获取输出
out = model(imgs.to(device),
rots.to(device),
trans.to(device),
intrins.to(device),
post_rots.to(device),
post_trans.to(device),
)
# 将输出进行sigmoid激活,并移动到CPU
out = out.sigmoid().cpu()
# 遍历批次中的每一张图片
for si in range(imgs.shape[0]):
# 清空当前的绘图
plt.clf()
# 遍历每张图片中的不同视角图像
for imgi, img in enumerate(imgs[si]):
# 创建子图
ax = plt.subplot(gs[1 + imgi // 3, imgi % 3])
# 反归一化图像
showimg = denormalize_img(img)
# 翻转底部的图像
if imgi > 2:
showimg = showimg.transpose(Image.FLIP_LEFT_RIGHT)
# 显示图像
plt.imshow(showimg)
plt.axis('off')
# 在图像上添加注释
plt.annotate(cams[imgi].replace('_', ' '), (0.01, 0.92), xycoords='axes fraction')
# 创建顶层的子图
ax = plt.subplot(gs[0, :])
# 隐藏坐标轴
ax.get_xaxis().set_ticks([])
ax.get_yaxis().set_ticks([])
# 设置子图边框颜色和宽度
plt.setp(ax.spines.values(), color='b', linewidth=2)
# 添加图例
plt.legend(handles=[
mpatches.Patch(color=(0.0, 0.0, 1.0, 1.0), label='Output Vehicle Segmentation'),
mpatches.Patch(color='#76b900', label='Ego Vehicle'),
mpatches.Patch(color=(1.00, 0.50, 0.31, 0.8), label='Map (for visualization purposes only)')
], loc=(0.01, 0.86))
# 显示输出图像
plt.imshow(out[si].squeeze(0), vmin=0, vmax=1, cmap='Blues')
# 绘制静态地图(提高可视化效果)
rec = loader.dataset.ixes[counter]
plot_nusc_map(rec, nusc_maps, loader.dataset.nusc, scene2map, dx, bx)
# 设置显示范围
plt.xlim((out.shape[3], 0))
plt.ylim((0, out.shape[3]))
# 添加自车位置
add_ego(bx, dx)
# 保存图像
imname = f'eval{batchi:06}_{si:03}.jpg'
print('saving', imname)
plt.savefig(imname)
counter += 1
得到原来图片
# 遍历批次中的每一张图片
for si in range(imgs.shape[0]):
# 清空当前的绘图
plt.clf()
# 遍历每张图片中的不同视角图像
for imgi, img in enumerate(imgs[si]):
# 创建子图
ax = plt.subplot(gs[1 + imgi // 3, imgi % 3])
# 反归一化图像
showimg = denormalize_img(img)
# 翻转底部的图像
if imgi > 2:
showimg = showimg.transpose(Image.FLIP_LEFT_RIGHT)
# 显示图像
plt.imshow(showimg)
plt.axis('off')
# 在图像上添加注释
plt.annotate(cams[imgi].replace('_', ' '), (0.01, 0.92), xycoords='axes fraction')
车辆分割可视化结构
# 创建顶层的子图
ax = plt.subplot(gs[0, :])
# 隐藏坐标轴
ax.get_xaxis().set_ticks([])
ax.get_yaxis().set_ticks([])
# 设置子图边框颜色和宽度
plt.setp(ax.spines.values(), color='b', linewidth=2)
# 添加图例
plt.legend(handles=[
mpatches.Patch(color=(0.0, 0.0, 1.0, 1.0), label='Output Vehicle Segmentation'),
mpatches.Patch(color='#76b900', label='Ego Vehicle'),
mpatches.Patch(color=(1.00, 0.50, 0.31, 0.8), label='Map (for visualization purposes only)')
], loc=(0.01, 0.86))
# 显示输出图像
plt.imshow(out[si].squeeze(0), vmin=0, vmax=1, cmap='Blues')
静态地图的效果
# 绘制静态地图(提高可视化效果)
rec = loader.dataset.ixes[counter]
plot_nusc_map(rec, nusc_maps, loader.dataset.nusc, scene2map, dx, bx)
# 设置显示范围
plt.xlim((out.shape[3], 0))
plt.ylim((0, out.shape[3]))
plt.imshow(out_mean[0], cmap='viridis')
plt.colorbar()
plt.axis('off')
图像坐标系向ego坐标系进行坐标转化
# 遍历每张图片中的不同视角图像
for imgi, img in enumerate(imgs[si]):
# 将点从自车坐标系转换到摄像头坐标系
ego_pts = ego_to_cam(pts[si], rots[si, imgi], trans[si, imgi], intrins[si, imgi])
# 获取仅在图像内的点的掩码
mask = get_only_in_img_mask(ego_pts, H, W)
# 对点进行后旋转和平移变换
plot_pts = post_rots[si, imgi].matmul(ego_pts) + post_trans[si, imgi].unsqueeze(1)
# 创建子图
ax = plt.subplot(gs[imgi // 3, imgi % 3])
# 反归一化图像
showimg = denormalize_img(img)
# 显示图像
plt.imshow(showimg)
# 如果显示激光雷达点云
if show_lidar:
plt.scatter(plot_pts[0, mask], plot_pts[1, mask], c=ego_pts[2, mask],
s=5, alpha=0.1, cmap='jet')
# 关闭坐标轴显示
plt.axis('off')
# 切换到最终的绘图轴
plt.sca(final_ax)
# 绘制图像中的点
plt.plot(img_pts[si, imgi, :, :, :, 0].view(-1), img_pts[si, imgi, :, :, :, 1].view(-1))