torch/utils/data/_utils/dataloader.py
通常在使用pytorch训练神经网络时,DataLoader模块是整个网络训练过程中的基础前提且尤为重要,其主要作用是根据传入接口的参数将训练集分为若干个大小为batch size的batch以及其他一些细节上的操作。一个典型的数据加载以及batch训练过程如下:(其中的args后面会详细解释)
loader = torch.utils.data.DataLoader(args)
for data, label in loader:
training
这次主要解读DataLoader模块的源码(Pytorch版本为1.8.0),在解读源码之前首先需要明确好几个概念。
iterable和iterator的区别
两者从字面翻译层面来看分别为可迭代和迭代器的意思。这两者概念很相近,然而在底层实现上面有些区别。
iterable: 表示某个对象是可迭代的,底层只实现了__iter__方法
iterator: 表示某个对象是迭代器,底层不仅实现了__iter__,同时还实现了__next__方法
举个例子:python语言中的list、dict、str都是可迭代的,即可以使用for循环。而对于迭代器,不仅可以使用for循环来实现遍历,还可以通过next()函数来获取下一个元素。可以通过iter()函数来将一个可迭代对象转换为一个迭代器。DataLoader是可迭代的,不是迭代器。
#先获取iterator对象
it = iter([1,2,3,4,5])
while True:
try:
#获取下一个值
x