Loading [MathJax]/jax/output/CommonHTML/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >PyTorch 中自定义数据集的读取方法

PyTorch 中自定义数据集的读取方法

原创
作者头像
陶陶name
发布于 2022-05-12 00:51:02
发布于 2022-05-12 00:51:02
1.1K0
举报
文章被收录于专栏:陶陶计算机陶陶计算机

显然我们在学习深度学习时,不能只局限于通过使用官方提供的MNSIT、CIFAR-10、CIFAR-100这样的数据集,很多时候我们还是需要根据自己遇到的实际问题自己去搜集数据,然后制作数据集(收集数据集的方法有很多,这里就不过多的展开了)。这里只介绍数据集的读取。 1. 自定义数据集的方法: 首先创建一个Dataset类

在这里插入图片描述
在这里插入图片描述
在代码中: def init() 一些初始化的过程写在这个函数下 def len() 返回所有数据的数量,比如我们这里将数据划分好之后,这里仅仅返回的是被处理后的关系 def getitem() 回数据和标签补充代码 上述已经将框架打出来了,接下来就是将框架填充完整就行了,下面是完整的代码,代码的解释说明我也已经写在其中了# -*- coding: utf-8 -*- # @Author : 胡子旋 # @Email :1017190168@qq.com import torch import os,glob import visdom import time import torchvision import random,csv from torch.utils.data import Dataset,DataLoader from torchvision import transforms from PIL import Image class pokemom(Dataset): def __init__(self,root,resize,mode,): super(pokemom,self).__init__() # 保存参数 self.root=root self.resize=resize # 给每一个类做映射 self.name2label={} # "squirtle":0 ,"pikachu":1…… for name in sorted(os.listdir(os.path.join(root))): # 过滤掉文件夹 if not os.path.isdir(os.path.join(root,name)): continue # 保存在表中;将最长的映射作为最新的元素的label的值 self.name2label[name]=len(self.name2label.keys()) print(self.name2label) # 加载文件 self.images,self.labels=self.load_csv('images.csv') # 裁剪数据 if mode=='train': self.images=self.images[:int(0.6*len(self.images))] # 将数据集的60%设置为训练数据集合 self.labels=self.labels[:int(0.6*len(self.labels))] # label的60%分配给训练数据集合 elif mode=='val': self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))] # 从60%-80%的地方 self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))] else: self.images = self.images[int(0.8 * len(self.images)):] # 从80%的地方到最末尾 self.labels = self.labels[int(0.8 * len(self.labels)):] # image+label 的路径 def load_csv(self,filename): # 将所有的图片加载进来 # 如果不存在的话才进行创建 if not os.path.exists(os.path.join(self.root,filename)): images=[] for name in self.name2label.keys(): images+=glob.glob(os.path.join(self.root,name,'*.png')) images+=glob.glob(os.path.join(self.root, name, '*.jpg')) images += glob.glob(os.path.join(self.root, name, '*.jpeg')) print(len(images),images) # 1167 'pokeman\\bulbasaur\\00000000.png' # 将文件以上述的格式保存在csv文件内 random.shuffle(images) with open(os.path.join(self.root,filename),mode='w',newline='') as f: writer=csv.writer(f) for img in images: # 'pokeman\\bulbasaur\\00000000.png' name=img.split(os.sep)[-2] label=self.name2label[name] writer.writerow([img,label]) print("write into csv into :",filename) # 如果存在的话就直接的跳到这个地方 images,labels=[],[] with open(os.path.join(self.root, filename)) as f: reader=csv.reader(f) for row in reader: # 接下来就会得到 'pokeman\\bulbasaur\\00000000.png' 0 的对象 img,label=row # 将label转码为int类型 label=int(label) images.append(img) labels.append(label) # 保证images和labels的长度是一致的 assert len(images)==len(labels) return images,labels # 返回数据的数量 def __len__(self): return len(self.images) # 返回的是被裁剪之后的关系 def denormalize(self, x_hat): mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1) std = torch.tensor(std).unsqueeze(1).unsqueeze(1) # print(mean.shape, std.shape) x = x_hat * std + mean return x # 返回idx的数据和当前图片的label def __getitem__(self,idx): # idex-[0-总长度] # retrun images,labels # 将图片,label的路径取出来 # 得到的img是这样的一个类型:'pokeman\\bulbasaur\\00000000.png' # 然而label得到的则是 0,1,2 这样的整形的格式 img,label=self.images[idx],self.labels[idx] tf=transforms.Compose([ lambda x:Image.open(x).convert('RGB'), # 将t图片的路径转换可以处理图片数据 # 进行数据加强 transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))), # 随机旋转 transforms.RandomRotation(15), # 设置旋转的度数小一些,否则的话会增加网络的学习难度 # 中心裁剪 transforms.CenterCrop(self.resize), # 此时:既旋转了又不至于导致图片变得比较的复杂 transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]) img=tf(img) label=torch.tensor(label) return img,label def main(): # 验证工作 viz=visdom.Visdom() db=pokemom('pokeman',64,'train') # 这里可以改变大小 224->64,可以通过visdom进行查看 # 可视化样本 x,y=next(iter(db)) print('sample:',x.shape,y.shape,y) viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x')) # 加载batch_size的数据 loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8) for x,y in loader: viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch')) viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y')) # 每一次加载后,休息10s time.sleep(10) if __name__ == '__main__': main()

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
05-PyTorch自定义数据集Datasets、Loader和tranform
对于机器学习中的许多不同问题,我们采取的步骤都是相似的。PyTorch 有许多内置数据集,用于大量机器学习基准测试。除此之外也可以自定义数据集,本问将使用我们自己的披萨、牛排和寿司图像数据集,而不是使用内置的 PyTorch 数据集。具体来说,我们将使用 torchvision.datasets 以及我们自己的自定义 Dataset 类来加载食物图像,然后我们将构建一个 PyTorch 计算机视觉模型,希望对三种物体进行分类。
renhai
2023/11/24
1.2K0
05-PyTorch自定义数据集Datasets、Loader和tranform
深度学习实战之手写签名识别(100%准确率、语音播报)
在完成了上述的环境搭建后,即可进入到准备阶段了。这里准备的有数据集的准备、以及相关代码的主备。
陶陶name
2022/05/13
1.7K0
深度学习实战之垃圾分类
垃圾分类,指按一定规定或标准将垃圾分类储存、分类投放和分类搬运,从而转变成公共资源的一系列活动的总称。分类的目的是提高垃圾的资源价值和经济价值,力争物尽其用;然而我们在日常生活中认为对垃圾分类还是有些不知所措的,对干垃圾、湿垃圾……分的不是很清楚,由此我们就想到了使用深度学习的方法进行分类。简介 本篇博文主要会带领大家进行数据的预处理、网络搭建、模型训练、模型测试 1. 获取数据集 这里笔者已经为大家提供了一个比较完整的数据集,所以大家不必再自己去收集数据了 数据集链接:https://pan.baidu
陶陶name
2022/05/13
6620
pytorch学习笔记(六):自定义Datasets
本文介绍了如何自定义PyTorch Datasets,通过实例化CustomDataset类并继承自torch.utils.data.Dataset类,并重写了__init__、__getitem__和__len__方法,来实现自定义的数据集。通过这种方法,可以更好地控制数据集的准备和加载过程,并可以根据具体的应用场景进行定制。同时,还介绍了MNIST数据集的例子,通过继承自torch.utils.data.Dataset类,实现了对该数据集的准备和加载,并演示了自定义数据集的方法和技巧。
ke1th
2018/01/02
1.7K0
Transfer Learning
通过网络上收集宝可梦的图片,制作图像分类数据集。我收集了5种宝可梦,分别是皮卡丘,超梦,杰尼龟,小火龙,妙蛙种子
mathor
2020/02/17
4610
【时空序列预测实战】详解时空序列常用数据集之MovingMnist数据集(demo代码)
这篇文章我们主要介绍MovingMnist数据集,做这个方向的research是逃不过这个数据集的使用的
石晓文
2020/11/09
2.3K0
【时空序列预测实战】详解时空序列常用数据集之MovingMnist数据集(demo代码)
Pytorch打怪路(三)Pytorch创建自己的数据集2
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Teeyohuang/article/details/82108203
TeeyoHuang
2019/05/25
1K0
我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!
我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!
红色石头
2022/01/10
1.6K0
我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!
收藏 | PyTorch Cookbook:常用代码段集锦
链接 | https://zhuanlan.zhihu.com/p/59205847
AI算法修炼营
2020/06/03
7540
收藏 | PyTorch Cookbook:常用代码段集锦
手撕 CNN 经典网络之 VGGNet(PyTorch实战篇)
详细介绍了 VGGNet 的网络结构,今天我们将使用 PyTorch 来复现VGGNet网络,并用VGGNet模型来解决一个经典的Kaggle图像识别比赛问题。
红色石头
2022/04/14
9140
手撕 CNN 经典网络之 VGGNet(PyTorch实战篇)
轻松学Pytorch-迁移学习实现表面缺陷检查
大家好,我又又好久没有更新这个系列了,主要原因归根结底只有一个懒,所谓 一勤天下无难事,百思心中有良谋。以后还争取每周更新,这次隔了一周没有更新,对不起大家了。今天给大家更新的是如何基于torchvision自带的模型完成图像分类任务的迁移学习,前面我们已经完成了对对象检测任务的迁移学习,这里补上针对图像分类任务的迁移学习,官方的文档比较啰嗦,看了之后其实可操作性很低,特别是对于初学者,估计看了之后就发懵的那种。本人重新改写了一波,代码简洁易懂,然后把训练结果导出ONNX,使用OpenCV DNN调用部署,非常实用!废话不多说了,少吹水!
OpenCV学堂
2020/09/22
1.6K0
轻松学Pytorch-迁移学习实现表面缺陷检查
使用关键点进行小目标检测
【GiantPandaCV导语】本文是笔者出于兴趣搞了一个小的库,主要是用于定位红外小目标。由于其具有尺度很小的特点,所以可以尝试用点的方式代表其位置。本文主要采用了回归和heatmap两种方式来回归关键点,是一个很简单基础的项目,代码量很小,可供新手学习。
BBuf
2020/09/10
9840
[pytorch] 一种加速dataloder的方法
一位不错的小伙给的代码 (前同事)。 这里实现主要是使用:nvidia.dali 代码如下: from __future__ import division import torch import types import joblib import collections import numpy as np import pandas as pd from random import shuffle from nvidia.dali.pipeline import Pipeline import nvi
MachineLP
2020/02/25
1.5K0
resnet34 pytorch_pytorch环境搭建
导师的课题需要用到图片分类;入门萌新啥也不会,只需要实现这个功能,给出初步效果,不需要花太多时间了解内部逻辑。经过一周的摸索,建好环境、pytorch,终于找到整套的代码和数据集,实现了一个小小的分类。记录一下使用方法,避免后续使用时遗忘。感谢各位大佬的开源代码和注释!
全栈程序员站长
2022/09/27
8630
resnet34 pytorch_pytorch环境搭建
基于交通灯数据集的端到端分类
抓住11月的尾巴,这里写上昨天做的一个DL的作业吧,作业很简单,基于交通灯的图像分类,但这确是让你从0构建深度学习系统的好例子,很多已有的数据集都封装好了,直接调用,这篇文章将以pytorch这个深度学习框架一步步搭建分类系统。
努力努力再努力F
2018/12/01
1.7K0
轻松学pytorch – 使用多标签损失函数训练卷积网络
大家好,我还在坚持继续写,如果我没有记错的话,这个是系列文章的第十五篇,pytorch中有很多非常方便使用的损失函数,本文就演示了如何通过多标签损失函数训练验证码识别网络,实现验证码识别。
OpenCV学堂
2020/07/16
1.2K0
轻松学Pytorch-实现自定义对象检测器
大家好,今天来继续更新轻松学Pytorch专栏,这个是系列文章我会一直坚持写下去的,希望大家转发、点赞、留言支持!上一篇文章使用了torchvision中提供的预训练对象检测网络Faster-RCNN实现了常见的对象检测,基于COCO数据集,支持90个类型对象检测,非常的实用。本文将介绍如何使用自定义数据集,使用Faster-RCNN预训练模型实现迁移学习,完成自定义对象检测。
OpenCV学堂
2020/07/30
8850
轻松学Pytorch-实现自定义对象检测器
VGG16 训练猫狗数据集
准备数据应该是一件比较麻烦的过程,所以一般都去找那种公开的数据集。在网上找到的可以用于猫狗分类的数据集有 Kaggle 的 “Dogs vs. Cats”数据集,还有牛津大学提供的 Oxford-IIIT Pet 数据集,包含猫和狗的图片,都是非常适合做猫狗分类任务的公开数据集。
繁依Fanyi
2025/03/24
1970
PyTorch数据Pipeline标准化代码模板
PyTorch作为一款流行深度学习框架其热度大有超越TensorFlow的感觉。根据此前的统计,目前TensorFlow虽然仍然占据着工业界,但PyTorch在视觉和NLP领域的顶级会议上已呈一统之势。
机器视觉CV
2019/12/17
1.7K0
PyTorch数据Pipeline标准化代码模板
PyTorch 自定义数据集
准备 COCO128[1] 数据集,其是 COCO[2] train2017 前 128 个数据。按 YOLOv5 组织的目录:
GoCoding
2021/05/06
8720
PyTorch 自定义数据集
推荐阅读
相关推荐
05-PyTorch自定义数据集Datasets、Loader和tranform
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档