首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >pytorch下载CIFAR10数据集[通俗易懂]

pytorch下载CIFAR10数据集[通俗易懂]

作者头像
全栈程序员站长
发布2022-06-25 15:32:29
发布2022-06-25 15:32:29
1.2K0
举报

大家好,又见面了,我是你们的朋友全栈君。

代码语言:javascript
复制
import torch 
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader


def main():
    batchsz = 32

    cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor
    ]), download=True)
    cifar_train = DataLoader(cifar_train,batch_size=batchse,shuffle=True)

    cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor
    ]), download=True)
    cifar_teat = DataLoader(cifar_train,batch_size=batchse,shuffle=True)

    x, label = iter(cifar_train).next()
    print('x:', x.shape, 'label:', label.shape)


if __name__ == "__main__":
    main()

发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/152097.html原文链接:https://javaforall.cn

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档