先说第一个,刚刚接触pytorch,自己写网络感觉没人教感觉不容易,所以哪怕写完一个小的函数也要自己设置输入,去验证一下。
前面读取数据的小函数写的确实没啥毛病,但是当我在pytorch里要用到Dataset的时候一写就牵扯到很多函数,好像没办法验证我这个写的有什么问题,或者说不明白我一个理想输入进去,输出是个啥,如果输出不对应那不是白忙活了。
然后我在引用我重写的Dataset类的时候发现pycharm后面竟然能引导到__init__、getitem、__len__等函数,这就好办多了啊!
所以我在主函数里准备好数据以后就调用了一下__getitem__这个函数,然后给它定义一个索引,果然可以正常输出(老哥我是真滴强,一次通过,虽然写的很简单吧哈哈)。
我把主题代码段贴一下:
#这是我重写的Dataset
class DataFormFolder(data.Dataset):
def __init__(self, args, input_data_process=None, target_data_process=None):
super(DataFormFolder, self).__init__()
self.input_filenames_list = get_hr_train_name(args)
self.target_filenames_list = get_lr_train_name(args)
self.input_data_process = input_data_process
self.target_data_process = target_data_process
def __getitem__(self, item):
input_img = load_img(self.input_filenames_list[item])
target_img = load_img(self.target_filenames_list[item])
if self.input_data_process:
input_data = self.input_data_process(input_img)
if self.target_data_process:
target_data = self.target_data_process(target_img)
return input_data, target_data
def __len__(self):
return len(self.input_filenames_list)
# 下面是输入这个重写类的接口
def getting_train_data(args):
crop_size = calculate_valid_crop_size(256, 4)
return DataFormFolder(args,
input_data_process=input_data_process(crop_size, 4),
target_data_process=target_data_process(crop_size))
# 最后我用主函数调用就可以了
def main():
parser = get_parser()
args = parser.parse_args()
train_data = getting_train_data(args)
img = ToPILImage()(train_data.__getitem__(1)[0]).convert('RGB')
print(train_data.__getitem__(1)[0].shape)
类似这样,可以看到我最后两句都直接用了重写类的方法。
这里要注意,Dataset输出的类型是tensor,tensor的尺寸查看要用
tensor.shape()
下面是想看到你索引的数据的输出的方法:
这里又印出来一个问题:对于我随便索引的一个图片输入,这个重写类输出的是一个tensor类型的参数,那么我要看到这个图是个啥样子,总不能直接看矩阵啊!所以要用到Teosor转PIL.Image的方法。
这里有一篇文章介绍的很好,包括numpy pil和tensor的相互转换,可以直接看这个。
然后我百度了一下,发现torchversion工具箱有这个功能啊,虽然之前的文章diss了他一下,但是还是挺香的。可以看到我最后用的ToPILImage就是啦,可以加上image.show()看到你索引的图片。
但是到这里我又有一点不明白。
img = ToPILImage()(train_data.__getitem__(1)[0]).convert('RGB')
就是这一句,为什么ToPILImage()的这个括号不扩在train_data.getitem(1)[0]里面,而是单独再用一个括号括出来呢?我一开始以为这句话有问题,经过验证确实没毛病,但是我还是不能理解,所以我进行了下面的尝试:
img = ToPILImage(train_data.__getitem__(1)[0])
img = img.convert('RGB)
以为自己很机智可以用笨方法来重写,但是发现也是不行。
应该是我的python理解不够,有哪位大哥可以指出来就好了!