在深度学习中,卷积神经网络用于处理图像。为了构建训练神经网络,我们需要处理大量图像。有几种方法可以在 PyTorch 中加载计算机视觉数据集,具体取决于数据集的格式和项目的具体要求。
一种流行的方法是使用内置的 PyTorch 数据集类,例如 torchvision.datasets
。它提供了一种方便的方法来加载和预处理常见的计算机视觉数据集,例如CIFAR-10和ImageNet。比如,要加载 CIFAR-10 数据集,你可以使用以下代码:
import torchvision.datasets as datasets
# Download the cifar Dataset
cifar10_train = datasets.CIFAR10(root="./data", train=True, download=True)
cifar10_test = datasets.CIFAR10(root="./data", train=False, download=True)
上面的代码将下载 CIFAR-10 数据集并将其保存在' ./data '目录中。
另一种方法是使用torch.utils.data.DataLoader
类来加载数据。当数据在本地机器中并且你希望拥有数据增强功能和数据混洗能力以及指定批处理大小的能力时,这种方法更有用。它具有自定义数据加载顺序、批处理、单进程或多进程数据加载等优点。
这里我们可以使用torchvision的transform.Compose
函数对图像进行旋转、翻转、规范化并将其转换为张量形式。
from torchvision import transforms
from torch.utils.data import DataLoader
# Image Transformation
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize([0.35, 0.35, 0.406], [0.30, 0.34, 0.35])
])
# Load the dataset with transformation
cifar10_train = datasets.CIFAR10(root="./data", train=True, download=False, transform=transform)
cifar10_test = datasets.CIFAR10(root="./data", train=False, download=False, transform=transform)
# Make the batch of size 16
train_loader = DataLoader(cifar10_train, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(cifar10_test, batch_size=32, shuffle=False, num_workers=2)
查看训练和测试数据
#Train Dataset
print(train_loader.dataset)
#Test Dataset
print(test_loader.dataset)
输出为:
Dataset CIFAR10
Number of datapoints: 50000
Root location: ./data
Split: Train
StandardTransform
Transform: Compose(
RandomHorizontalFlip(p=0.5)
RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0)
ToTensor()
Normalize(mean=[0.35, 0.35, 0.406], std=[0.3, 0.34, 0.35])
)
Dataset CIFAR10
Number of datapoints: 10000
Root location: ./data
Split: Test
StandardTransform
Transform: Compose(
RandomHorizontalFlip(p=0.5)
RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0)
ToTensor()
Normalize(mean=[0.35, 0.35, 0.406], std=[0.3, 0.34, 0.35])
)
绘制图像:
# Iteration
inputs, Class = next(iter(train_loader))
#Define the class names
class_name ={0:'airplane',
1:'automobile',
2:'bird',
3:'cat',
4:'deer',
5:'dog',
6:'frog',
7:'horse',
8:'ship',
9:'truck'
}
#Plot the figure
plt.figure(figsize=(30,16), dpi=1000)
for i in range(32):
plt.subplot(4,8,i+1)
plt.imshow(inputs[i].numpy().transpose((1, 2, 0)))
plt.axis('off')
plt.title(class_name[int(Class[i])])
plt.show()
输出为: