常用可视化命令
from PIL import Image
import numpy as np
img = Image.fromarray(np.array(gt_masks.cpu()*255).astype(‘uint8’)).convert(‘RGB’)
img.save(‘picture.jpg’)
About数据
数据集官网 :https://captain-whu.github.io/DOAI2019/dataset.html
数据集工具包 :https://github.com/CAPTAIN-WHU/DOTA_devkit
DOTA-v1.5数据集一共有16个类别,包含40万个带注释的对象实例。
训练集:1141张
验证集:458张
16个类别分别是:飞机,轮船,储罐,棒球场,网球场,篮球场,地面跑道,港口,桥梁,小型车辆,大型车辆,直升机,环形交叉路口,足球场,游泳池和集装箱起重机。
plane, ship, storage tank, baseball diamond, tennis court, basketball court, ground track field, harbor, bridge, small vehicle, large vehicle, helicopter, roundabout, soccer ball field, swimming pool and container crane.
数据集工具包
原图像像素非常大,要训练就得切割成小的patch,数据集工具包有提供相关处理的代码。
图片可视化
也就是官方代码提供的DOTA.py
, 该脚本可以对你想要的类的图片进行可视化。
#The code is used for visulization, inspired from cocoapi
# Licensed under the Simplified BSD License [see bsd.txt]
import os
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon, Circle
import numpy as np
import dota_utils as util
from collections import defaultdict
import cv2
def _isArrayLike(obj):
if type(obj) == str:
return False
return hasattr(obj, '__iter__') and hasattr(obj, '__len__')
class DOTA:
def __init__(self, basepath):
self.basepath = basepath
self.labelpath = os.path.join(basepath, 'labelTxt')
self.imagepath = os.path.join(basepath, 'images')
self.imgpaths = util.GetFileFromThisRootDir(self.labelpath) # 每个图片txt文件的绝对路径
self.imglist = [util.custombasename(x) for x in self.imgpaths] # 每个图片的前缀名字 比如P1506
self.catToImgs = defaultdict(list)
self.ImgToAnns = defaultdict(list) # 存放每个类别下 图片名字
self.createIndex()
def createIndex(self):
for filename in self.imgpaths: # 对于每个文件txt处理他的标注角点 存储为字典 name为类别 poly为坐标 area为区域类似形状
objects = util.parse_dota_poly(filename)
imgid = util.custombasename(filename)
self.ImgToAnns[imgid] = objects
for obj in objects:
cat = obj['name']
self.catToImgs[cat].append(imgid)
def getImgIds(self, catNms=[]):
"""
:param catNms: category names
:return: all the image ids contain the categories
"""
catNms = catNms if _isArrayLike(catNms) else [catNms]
if len(catNms) == 0:
return self.imglist
else:
imgids = []
for i, cat in enumerate(catNms):
if i == 0:
imgids = set(self.catToImgs[cat])
else:
imgids &= set(self.catToImgs[cat])
return list(imgids)
def loadAnns(self, catNms=[], imgId = None, difficult=None):
"""
:param catNms: category names
:param imgId: the img to load anns
:return: objects
"""
catNms = catNms if _isArrayLike(catNms) else [catNms]
objects = self.ImgToAnns[imgId]
if len(catNms) == 0:
return objects
outobjects = [obj for obj in objects if (obj['name'] in catNms)]
return outobjects
def showAnns(self, objects, imgId, range):
"""
:param catNms: category names
:param objects: objects to show
:param imgId: img to show
:param range: display range in the img
:return:
"""
img = self.loadImgs(imgId)[0]
plt.imshow(img)
plt.axis('off')
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
circles = []
r = 5
for obj in objects:
c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
poly = obj['poly']
polygons.append(Polygon(poly))
color.append(c)
point = poly[0]
circle = Circle((point[0], point[1]), r)
circles.append(circle)
p = PatchCollection(polygons, facecolors=color, linewidths=0, alpha=0.4)
ax.add_collection(p)
p = PatchCollection(polygons, facecolors='none', edgecolors=color, linewidths=2)
ax.add_collection(p)
p = PatchCollection(circles, facecolors='red')
ax.add_collection(p)
#plt.savefig("{}.jpg".format(imgId))
def loadImgs(self, imgids