本节课继续主要介绍CIFAR10数据集的读取
cifar_train = DataLoader(cifar_train, batch_size=batchsz, )
# 按照其要求,这里的参数需要有batch_size,
# 在该部分代码前面定义batch_size
def main():
batchsz=32
# 这个batch_size数值不宜太大也不宜过小
cifar_train = datasets.CIFAR10('cifar', train=True, transform=transforms.Compose([
transforms.Resize((32, 32)),
# .Compose相当于一个数据转换的集合
# 进行数据转换,首先将图片统一为32*32
transforms.ToTensor()
# 将数据转化到Tensor中
]), download=True)
# 直接在datasets中导入CIFAR10数据集,放在"cifar"文件夹中
cifar_train = DataLoader(cifar_train, batch_size=batchsz, )
# 按照其要求,这里的参数需要有batch_size,
# 在该部分代码前面定义batch_size
这里设置了batch_size=32,对于一般硬件配置来说32是个较合理的数值,若硬件性能够强可设更高。
后面再加shuffle=True,使数据加载的随机化
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
# 再使数据加载的随机化
其他参数这里暂时不进行讲解
下面将这部分代码复制粘贴一下,将里面的train更改为test、train=False等。
这两部分书写好后,代码为
import torch
from torchvision import datasets
# 引入pytorch、datasets工具包
from torchvision import transforms
# 引入数据变换工具包
from torch.utils.data import DataLoader
# 多线程数据读取
def main():
batchsz=32
# 这个batch_size数值不宜太大也不宜过小
cifar_train = datasets.CIFAR10('cifar', train=True, transform=transforms.Compose([
transforms.Resize((32, 32)),
# .Compose相当于一个数据转换的集合
# 进行数据转换,首先将图片统一为32*32
transforms.ToTensor()
# 将数据转化到Tensor中
]), download=True)
# 直接在datasets中导入CIFAR10数据集,放在"cifar"文件夹中
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
# 按照其要求,这里的参数需要有batch_size,
# 在该部分代码前面定义batch_size
# 再使数据加载的随机化
cifar_test = datasets.CIFAR10('cifar', train=False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
if __name__ == '__main__':
main()
进行到这里可以先输出部分信息进行查看
import torch
from torchvision import datasets
# 引入pytorch、datasets工具包
from torchvision import transforms
# 引入数据变换工具包
from torch.utils.data import DataLoader
# 多线程数据读取
def main():
batchsz=32
# 这个batch_size数值不宜太大也不宜过小
cifar_train = datasets.CIFAR10('cifar', train=True, transform=transforms.Compose([
transforms.Resize((32, 32)),
# .Compose相当于一个数据转换的集合
# 进行数据转换,首先将图片统一为32*32
transforms.ToTensor()
# 将数据转化到Tensor中
]), download=True)
# 直接在datasets中导入CIFAR10数据集,放在"cifar"文件夹中
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
# 按照其要求,这里的参数需要有batch_size,
# 在该部分代码前面定义batch_size
# 再使数据加载的随机化
cifar_test = datasets.CIFAR10('cifar', train=False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
x, label = iter(cifar_train).next()
# 通过.iter方法输出一个数据进行查看
print('s.shape:', x.shape, 'label.shape:', label.shape)
# 输出shape进行查看
if __name__ == '__main__':
main()
保存,先跑一下代码
运行后输出出现下面字段说明开始下载数据
待下载完成后输出为
s.shape: torch.Size([32, 3, 32, 32]) label.shape: torch.Size([32])
数据恰好为3通道、32*32的size
本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!