在pytorch中,提供了一种十分方便的数据读取机制,即使用torch.utils.data.Dataset与torch.utils.data.DataLoader组合得到数据迭代器。在每次训练时,利用这个迭代器输出每一个batch数据,并能在输出时对数据进行相应的预处理或数据增强等操作。
1、torch.utils.data.Dataset
torch.utils.data.Dataset是代表自定义数据集方法的类,用户可以通过继承该类来自定义自己的数据集类,在继承时要求用户重载__len__()和__getitem__()这两个魔法方法。
- __len__():返回的是数据集的大小。我们构建的数据集是一个对象,而数据集不像序列类型(列表、元组、字符串)那样可以直接用len()来获取序列的长度,魔法方法__len__()的目的就是方便像序列那样直接获取对象的长度。如果A是一个类,a是类A的实例化对象,当A中定义了魔法方法__len__(),len(a)则返回对象的大小。
- __getitem__():实现索引数据集中的某一个数据。我们知道,序列可以通过索引的方法获取序列中的任意元素,__getitem__()则实现了能够通过索引的方法获取对象中的任意元素。此外,我们可以在__getitem__()中实现数据预处理。
示例1
import torch
from torch.utils.data import Dataset
class TensorDataset(Dataset):
"""
TensorDataset继承Dataset, 重载了__init__(), __getitem__(), __len__()
实现将一组Tensor数据对封装成Tensor数据集
能够通过index得到数据集的数据,能够通过len,得到数据集大小
"""
def __init__(self, data_tensor, target_tensor):
self.data_tensor = data_tensor
self.target_tensor = target_tensor
def __getitem__(self, index):
return self.data_tensor[index], self.target_tensor[index]
def __len__(self):
return self.data_tensor.size(0)
# 生成数据
data_tensor = torch.randn(4, 3)
target_tensor = torch.rand(4)
# 将数据封装成Dataset
tensor_dataset = TensorDataset(data_tensor, target_tensor)
# 可使用索引调用数据
print(tensor_dataset[1])
# 输出:(tensor([-1.0351, -0.1004, 0.9168]), tensor(0.4977))
# 获取数据集大小
print(len(tensor_dataset))
# 输出:4
示例2
import os
from PIL import Image
from torch.utils.data import Dataset
class PatchDataset(Dataset):
def __init__(self, data_dir, transforms=None):
"""
:param data_dir: 数据集所在路径
:param transform: 数据预处理
"""
self.data_info = self.get_img_info(data_dir)
self.transforms = transforms
def __getitem__(self, item):
path_img, label = self.data_info[item]
image = Image.open(path_img).convert('RGB')
if self.transforms is not None:
image = self.transforms(image)
return image, label
def __len__(self):
return len(self.data_info)
@staticmethod
def get_img_info(data_dir):
path_dir = os.path.join(data_dir, 'train_dataset.txt')
data_info = []
with open(path_dir) as file:
lines = file.readlines()
for line in lines:
data_info.append(line.strip('\n').split(' '))
return data_info
其中, train_dataset.txt中的内容为:
2. torch.utils.data.DataLoader
作用:DataLoader将Dataset对象或自定义数据类的对象封装成一个迭代器,这个迭代器可以迭代输出Dataset的内容。同时可以实现多进程、shuffle、不同采样策略,数据校对等等处理过程。
__init__()中的几个重要的输入:
- dataset:这个就是pytorch已有的数据读取接口(比如torchvision.datasets.ImageFolder)或者自定义的数据接口的输出,该输出要么是torch.utils.data.Dataset类的对象,要么是继承自torch.utils.data.Dataset类的自定义类的对象。
- batch_size:根据具体情况设置即可。
- shuffle:随机打乱顺序,一般在训练数据中会采用。
- collate_fn:是用来处理不同情况下的输入dataset的封装,一般采用默认即可,除非你自定义的数据读取输出非常少见。
- batch_sampler:从注释可以看出,其和batch_size、shuffle等参数是互斥的,一般采用默认。sampler:从代码可以看出,其和shuffle是互斥的,一般默认即可。
- num_workers:从注释可以看出这个参数必须大于等于0,0的话表示数据导入在主进程中进行,其他大于0的数表示通过多个进程来导入数据,可以加快数据导入速度。
- pin_memory:注释写得很清楚了: pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them. 也就是一个数据拷贝的问题。
- timeout:是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。
代码示例(接示例1)
tensor_dataloader = DataLoader(tensor_dataset, # 封装的对象
batch_size=2, # 输出的batch size
shuffle=True, # 随机输出
num_workers=0) # 只有1个进程
# 以for循环形式输出
for data, target in tensor_dataloader:
print(data, target)
输出结果:
tensor([[ 0.7745, 0.2186, 0.1231],
[-0.1307, 1.5778, -1.2906]]) tensor([0.3749, 0.4659])
tensor([[-0.1605, 0.9359, 0.1314],
[-1.1694, 1.0986, -0.9927]]) tensor([0.8071, 0.8997])