pytorch写网络时的检查方法以及tensor转图片

本文介绍了在PyTorch中编写网络时的验证方法,通过直接调用Dataset的__getitem__方法进行测试,并展示了如何将Tensor转换为图片显示。在过程中,作者探讨了ToPILImage函数的使用以及对代码结构的疑惑。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

先说第一个,刚刚接触pytorch,自己写网络感觉没人教感觉不容易,所以哪怕写完一个小的函数也要自己设置输入,去验证一下。
前面读取数据的小函数写的确实没啥毛病,但是当我在pytorch里要用到Dataset的时候一写就牵扯到很多函数,好像没办法验证我这个写的有什么问题,或者说不明白我一个理想输入进去,输出是个啥,如果输出不对应那不是白忙活了。
然后我在引用我重写的Dataset类的时候发现pycharm后面竟然能引导到__init__、getitem、__len__等函数,这就好办多了啊!
所以我在主函数里准备好数据以后就调用了一下__getitem__这个函数,然后给它定义一个索引,果然可以正常输出(老哥我是真滴强,一次通过,虽然写的很简单吧哈哈)。
我把主题代码段贴一下:

#这是我重写的Dataset
class DataFormFolder(data.Dataset):
    def __init__(self, args, input_data_process=None, target_data_process=None):
        super(DataFormFolder, self).__init__()
        self.input_filenames_list = get_hr_train_name(args)
        self.target_filenames_list = get_lr_train_name(args)
        self.input_data_process = input_data_process
        self.target_data_process = target_data_process

    def __getitem__(self, item):
        input_img = load_img(self.input_filenames_list[item])
        target_img = load_img(self.target_filenames_list[item])
        if self.input_data_process:
            input_data = self.input_data_process(input_img)
        if self.target_data_process:
            target_data = self.target_data_process(target_img)
        return input_data, target_data

    def __len__(self):
        return len(self.input_filenames_list)
    # 下面是输入这个重写类的接口
    def getting_train_data(args):
    crop_size = calculate_valid_crop_size(256, 4)
    return DataFormFolder(args,
                          input_data_process=input_data_process(crop_size, 4),
                          target_data_process=target_data_process(crop_size))
    # 最后我用主函数调用就可以了
     def main():
    parser = get_parser()
    args = parser.parse_args()
    train_data = getting_train_data(args)
    img = ToPILImage()(train_data.__getitem__(1)[0]).convert('RGB')
    print(train_data.__getitem__(1)[0].shape)

类似这样,可以看到我最后两句都直接用了重写类的方法。
这里要注意,Dataset输出的类型是tensor,tensor的尺寸查看要用

tensor.shape()

下面是想看到你索引的数据的输出的方法:
这里又印出来一个问题:对于我随便索引的一个图片输入,这个重写类输出的是一个tensor类型的参数,那么我要看到这个图是个啥样子,总不能直接看矩阵啊!所以要用到Teosor转PIL.Image的方法。
这里有一篇文章介绍的很好,包括numpy pil和tensor的相互转换,可以直接看这个。
然后我百度了一下,发现torchversion工具箱有这个功能啊,虽然之前的文章diss了他一下,但是还是挺香的。可以看到我最后用的ToPILImage就是啦,可以加上image.show()看到你索引的图片。
但是到这里我又有一点不明白。

img = ToPILImage()(train_data.__getitem__(1)[0]).convert('RGB')

就是这一句,为什么ToPILImage()的这个括号不扩在train_data.getitem(1)[0]里面,而是单独再用一个括号括出来呢?我一开始以为这句话有问题,经过验证确实没毛病,但是我还是不能理解,所以我进行了下面的尝试:

img = ToPILImage(train_data.__getitem__(1)[0])
img = img.convert('RGB)

以为自己很机智可以用笨方法来重写,但是发现也是不行。
应该是我的python理解不够,有哪位大哥可以指出来就好了!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值