pytorch系列教程(二)-数据预处理Dataset、DataLoader、Transform

本文详细介绍了PyTorch中数据加载和预处理的关键组件:Dataset、DataLoader和Transform。通过实例演示了如何使用这些组件从文件读取数据,应用图像变换,并创建迭代器以批量加载数据。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

pytorch对于怎么样把数据放进神经网络训练有一套非常成熟的机制,我们只需要按照流程即可,这个流程只要是涉及了Dataset、DataLoader和Transform
这篇博客参考了:
(第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform
(第二篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform
(第三篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform
  

步骤

还是拿前一篇文章的例子pytorch系列教程(一)-训练和测试模型流程来讲述如何把数据放到神经网络中训练的

#前一篇文章的代码
train_datasets = MyDataset()           # 第一步:构造Dataset对象
train_dataloader = DataLoader(train_datasets)# 第二步:通过DataLoader来构造迭代对象

以yolov1网络为例子,data/train.txt中的内容
在这里插入图片描述
下面来看看 MyDataset()中要做什么

 
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import cv2
import numpy as np
 
from torchvision.transforms import ToTensor
from torchvision import datasets, transforms
 
import random


class MyDataset()(Dataset):
    def __init__(self,transform):
        self.transform = transform
        self.img_list = []
        self.labels=np.zeros((7,7,5*NUM_BBOX+len(CLASSES)))
        
        #data/train.txt存放的是训练数据的路径
        with open("data/train.txt", 'r') as f: 
             #将训练数据的路径放到img_list中
             self.img_list = [x.strip() for x in f]
 
    def __getitem__(self, idx):
        ###################################
        #idx最大值为len(self.img_list)-1
        img = cv2.imread(self.img_list[idx])
        if self.transform:
          #将数据转成tensor
          img = self.transform(img)
          label= self.transform(label)
        return img,label

    def __len__(self):
        #返回训练数据的长度
        return len(self.img_list)

def main(): 
     transform=transforms.Compose([
                transforms.ToTensor()
            ])
     train_datasets = MyDataset(transform)           # 第一步:构造Dataset对象
     train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)# 第二步:通过DataLoader来构造迭代对象
     
if __name__ == '__main__':
     main()

  

总结

看完代码之后总结一下
1、构造一个类继承Dataset,例如 MyDataset(Dataset)。在类中
首先新建一个变量用来存放训练数据或者标签的路径,例如self.img_list
然后在类中重写 getitem(self, index)和__len__(self)

  • getitem(self, index)中主要做的是返回的img和label。这个img和label都已经是tensor形式
  • len(self)中主要做的是返回训练数据的总长度

  
Dataset详解
Dataset类是Pytorch中图像数据集中最为重要的一个类,也是Pytorch中所有数据集加载类中应该继承的父类。其中Dataset类中的两个私有成员函数必须被重载,否则将会触发错误提示:

  • def getitem(self, index):
  • def len(self):
  • def init(self):

构造函数一般情况下我们也是要自己定义的,但是不是强制性的。
其中__len__应该返回数据集的大小,而__getitem__应该编写支持数据集索引的函数,例如通过dataset[i]可以得到数据集中的第i+1个数据。这个Dataset抽象父类的定义如下:

class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError
 
    def __len__(self):
        raise NotImplementedError
 
    def __add__(self, other):
        return ConcatDataset([self, other])

  
2、利用transforms.Compose添加图像变换
transforms中的图像变换操作大全

__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
           "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
           "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
           "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]

  
3、通过DataLoader来构造迭代对象
看一下DataLoader的定义

class DataLoader(object):
    __initialized = False
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, 
    def __setattr__(self, attr, val):
    def __iter__(self):
    def __len__(self):
    
Arguments:
        dataset (Dataset): 是一个DataSet对象,表示需要加载的数据集.
        batch_size (int, optional): 每一个batch加载多少组样本,即指定batch_size,默认是 1 
        shuffle (bool, optional): 布尔值True或者是False ,表示每一个epoch之后是否对样本进行随机打乱,默认是False

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值