首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

如何使用导入的MNIST数据集?

MNIST数据集是一个广泛使用的手写数字识别数据集,包含了60000个训练样本和10000个测试样本,每个样本是一个28x28像素的灰度图像,代表一个手写数字(0到9)。以下是如何导入和使用MNIST数据集的基本步骤:

1. 导入MNIST数据集

使用Python和TensorFlow/Keras

代码语言:txt
复制
import tensorflow as tf
from tensorflow.keras.datasets import mnist

# 加载数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 查看数据集的基本信息
print("训练样本数量:", len(x_train))
print("测试样本数量:", len(x_test))

使用Python和PyTorch

代码语言:txt
复制
import torch
from torchvision import datasets, transforms

# 定义数据预处理
transform = transforms.Compose([transforms.ToTensor()])

# 加载数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 创建数据加载器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

2. 数据预处理

归一化

为了提高模型的训练效果,通常需要对图像数据进行归一化处理。

代码语言:txt
复制
# TensorFlow/Keras
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# PyTorch
# transform中已经包含了ToTensor(),它会自动将像素值归一化到[0, 1]

调整数据维度

确保数据维度符合模型的输入要求。

代码语言:txt
复制
# TensorFlow/Keras
x_train = x_train.reshape((x_train.shape[0], 28, 28, 1))
x_test = x_test.reshape((x_test.shape[0], 28, 28, 1))

# PyTorch
# transform中已经包含了ToTensor(),它会自动调整维度

3. 构建模型

使用TensorFlow/Keras

代码语言:txt
复制
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D

model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    MaxPooling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu'),
    Flatten(),
    Dense(64, activation='relu'),
    Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

使用PyTorch

代码语言:txt
复制
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64 * 5 * 5, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 64 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = Net()

4. 训练模型

TensorFlow/Keras

代码语言:txt
复制
model.fit(x_train, y_train, epochs=5, batch_size=64, validation_split=0.1)

PyTorch

代码语言:txt
复制
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(5):
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

5. 评估模型

TensorFlow/Keras

代码语言:txt
复制
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'Test accuracy: {test_acc}')

PyTorch

代码语言:txt
复制
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, target in test_loader:
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

print(f'Test accuracy: {100 * correct / total}%')

应用场景

MNIST数据集常用于以下几个方面:

  • 教学和入门:适合初学者学习和实践深度学习基础概念。
  • 模型基准测试:用于评估新算法或模型的性能。
  • 研究:在计算机视觉和机器学习研究中作为标准数据集。

可能遇到的问题及解决方法

  1. 内存不足:如果数据集太大,可能会导致内存不足。可以通过分批次加载数据来解决。
  2. 数据不平衡:某些数字的样本数量可能比其他数字少,可以通过数据增强或重新采样来解决。
  3. 模型过拟合:可以通过增加正则化、使用Dropout层或增加训练数据来解决。

通过以上步骤,你可以成功导入和使用MNIST数据集进行手写数字识别任务。

页面内容是否对你有帮助?
有帮助
没帮助

相关·内容

MNIST数据集的导入与预处理

MNIST数据集 MNIST数据集简介 MNIST数据集,是一组由美国高中生和人口调查局员工手写的70000个数字的图片。每张图像都用其代表的数字标记。...MNIST数据集的获取 MNIST数据集网上流传的大体上有两类,不过两者有些不同,第一种是每幅图片大小是2828的,第二种是每幅图片大小是3232的,官网下载的是哪种不作细究,因为可以通过更简单的数据获取方法.../ 在本实验中可以这样进行MNIST数据集的导入 from sklearn.datasets import fetch_openml mnist = fetch_openml("mnist_784")...28*28的尺寸,其它数据集也可以使用类似导入方式,但要去官网搜该数据集的命名方式。...老版本导入数据集叫fetch_data,在sklearn2.0版本之后已无法使用。 数据截取 为什么要数据的截取? 对于KNN来说,将MNIST的6-7万数据全扔进去会导致运行极其缓慢。

1.7K20
  • 使用Python解析MNIST数据集

    前言 最近在学习Keras,要使用到LeCun大神的MNIST手写数字数据集,直接从官网上下载了4个压缩包: ?...MNIST数据集 解压后发现里面每个压缩包里有一个idx-ubyte文件,没有图片文件在里面。回去仔细看了一下官网后发现原来这是IDX文件格式,是一种用来存储向量与多维度矩阵的文件格式。...解析脚本 根据以上解析规则,我使用了Python里的struct模块对文件进行读写(如果不熟悉struct模块的可以看我的另一篇博客文章《Python中对字节流/二进制流的操作:struct模块简易使用教程...12:param idx3_ubyte_file: idx3文件路径 13:return: np.array类型对象 14""" 15return data 针对MNIST数据集的解析脚本如下:...11数据集下载地址为http://yann.lecun.com/exdb/mnist。 12相关格式转换见官网以及代码注释。

    1.3K40

    详解 MNIST 数据集

    大家好,又见面了,我是你们的朋友全栈君。 MNIST 数据集已经是一个被”嚼烂”了的数据集, 很多教程都会对它”下手”, 几乎成为一个 “典范”....不妨新建一个文件夹 – mnist, 将数据集下载到 mnist 以后, 解压即可: 图片是以字节的形式进行存储, 我们需要把它们读取到 NumPy array 中, 以便训练和测试算法. import...训练数据集包含 60,000 个样本, 测试数据集包含 10,000 样本. 在 MNIST 数据集中的每张图片由 28 x 28 个像素点构成, 每个像素点用一个灰度值表示....通过执行下面的代码, 我们将会从刚刚解压 MNIST 数据集后的 mnist 目录下加载 60,000 个训练样本和 10,000 个测试样本....25 个不同形态: 另外, 我们也可以选择将 MNIST 图片数据和标签保存为 CSV 文件, 这样就可以在不支持特殊的字节格式的程序中打开数据集.

    2.4K10

    MNIST手写数据集

    除了图像数据,MNIST数据集还提供了对应的标签数据,标签是0到9之间的数字,表示图像上的手写数字。下载和导入数据在Python中,可以使用一些机器学习库来下载和导入MNIST数据集。...mnist# 下载和导入MNIST数据集(train_images, train_labels), (test_images, test_labels) = mnist.load_data()数据可视化为了更好地理解...实际应用场景MNIST手写数据集在实际应用中有很多用途。一些常见的应用场景包括:数字识别:使用MNIST数据集训练机器学习模型,以实现对手写数字的识别。...mnist.load_data()​​函数从Keras中下载并导入MNIST数据集。...缺点虽然MNIST数据集在机器学习社区中被广泛使用,但也存在一些缺点:简单性:MNIST数据集相对简单,并且面临的挑战较小。

    79100

    Imagenet数据集_mnist数据集介绍

    Imagenet数据集是目前深度学习图像领域应用得非常多的一个领域,关于图像分类、定位、检测等研究工作大多基于此数据集展开。...Imagenet数据集文档详细,有专门的团队维护,使用非常方便,在计算机视觉领域研究论文中应用非常广,几乎成为了目前深度学习图像领域算法性能检验的“标准”数据集。...Imagenet数据集有1400多万幅图片,涵盖2万多个类别; 其中有超过百万的图片有明确的类别标注和图像中物体位置的标注。...Number of synsets with SIFT features: 1000 Number of images with SIFT features: 1.2million Imagenet数据集是一个非常优秀的数据集...,但是标注难免会有错误,几乎每年都会对错误的数据进行修正或是删除,建议下载最新数据集并关注数据集更新。

    99120

    详解 MNIST 数据集

    MNIST 数据集已经是一个被"嚼烂"了的数据集, 很多教程都会对它"下手", 几乎成为一个 "典范". 不过有些人可能对它还不是很了解, 下面来介绍一下....测试集(test set) 也是同样比例的手写数字数据. 不妨新建一个文件夹 -- mnist, 将数据集下载到 mnist 以后, 解压即可: ?...训练数据集包含 60,000 个样本, 测试数据集包含 10,000 样本. 在 MNIST 数据集中的每张图片由 28 x 28 个像素点构成, 每个像素点用一个灰度值表示....通过执行下面的代码, 我们将会从刚刚解压 MNIST 数据集后的 mnist 目录下加载 60,000 个训练样本和 10,000 个测试样本....7 另外, 我们也可以选择将 MNIST 图片数据和标签保存为 CSV 文件, 这样就可以在不支持特殊的字节格式的程序中打开数据集.

    2.3K20

    使用Google的Quickdraw创建MNIST样式数据集!

    对于那些运行深度学习模型的人来说,MNIST是无处不在的。手写数字的数据集有许多用途,从基准测试的算法(在数千篇论文中引用)到可视化,比拿破仑的1812年进军更为普遍。...图纸如下所示: 构建您自己的QuickDraw数据集 我想了解您如何使用这些图纸并创建自己的MNIST数据集。...这是一个简短的python gist ,我用来阅读.npy文件并将它们组合起来创建一个可以用来替代MNIST的含有80,000个图像的数据集。...它们以hdf5格式保存,这种格式是跨平台的,经常用于深度学习。 用QuickDraw代替MNIST 我使用这个数据集代替MNIST。...在Keras 教程中,使用Python中的自动编码器进行一些工作。下图显示了顶部的原始图像,并使用自动编码器在底部显示重建的图像。 接下来我使用了一个R语言的变分自编码器的数据集。

    1.7K80

    MNIST数据集 & CIFAR10数据集

    大家好,又见面了,我是你们的朋友全栈君。 MNIST数据集 MNIST数据集是分类任务中最简单、最常用的数据集。...人为的手写了0-9数字的图片 MNIST大概有7w张 MNIST数据值都是灰度图,所以图像的通道数只有一个 因为MNIST数据集是专门为深度学习来的,所以其数据集格式和我们常见的很不一样...,但是在Pytorch/Tensorflow中有函数可以很容易的读取,如果用普通Python来读取则不是那么容易 CIFAR10数据集 http://www.cs.toronto.edu/~...kriz/cifar.html CIFAR10数据集比MNIST要复杂一些....CIFAR10是真实数据集,MNIST是人为构建的 CIFAR10是32*32的 有CIFAR-10和CIFAR-100 CIFAR-10图片的10种类别,每一类大概有6000张 一共6w

    66910

    完整教程:使用caffe测试mnist数据集

    这篇原创笔记来自铁粉zhupc,感谢为大家提供的这份caffe测试mnist数据集的精彩总结。...gpu版本,如何编译安装的百度上教程基本可用,笔者在windows跟ubuntu都编译成功了。...首先,我们需要下mnist数据集,在进入到data文件夹下,有个获取数据的脚本 caffe/data/mnist/get_mnist.sh,执行完成后会得到下面几个文件,通过名字判断可知道分别是测试集与训练集的样本与标签...Lmdb是一种数据库,查询和插入非常高效,caffe使用lmdb作为数据源,同时caffe也支持hdf5文件。 Caffe搭建网络是基于prototxt文件,超参数也在里面配置。...以及最大迭代次数,文件末尾也可以自由的定义使用GPU或者CPU,snapshot_prefix指的是快照生成的路径,这里要配置好。

    1.2K60

    MNIST数据集的格式转换

    以前直接用的是sklearn或者TensorFlow提供的mnist数据集,已经转换为矩阵形式的数据格式。...但是sklearn体用的数据集合并不全,一共只有3000+图,每个图是8*8的大小,但是原始数据并不是这样的。...MNIST数据集合的原始网址为:http://yann.lecun.com/exdb/mnist/ 进入官网,发现有4个文件,分别对应训练集、测试集的图像和标签: ?...官网给的数据集合并不是原始的图像数据格式,而是编码后的二进制格式: 图像的编码为: ?...典型的head+data模式:前16个字节分为4个整型数据,每个4字节,分别代表:数据信息des、图像数量(img_num),图像行数(row)、图像列数(col),之后的数据全部为像素,每row*col

    2.3K50

    使用KNN识别MNIST手写数据集(手写,不使用KNeighborsClassifier)

    大家好,又见面了,我是你们的朋友全栈君。 数据集 提取码:mrfr 浏览本文前请先搞懂K近邻的基本原理:最简单的分类算法之一:KNN(原理解析+代码实现) 算法实现步骤: 数据处理。...每一个数字都是一个32X32维的数据,如下所示: knn中邻居一词指的就是距离相近。我们要想计算两个样本之间的距离,就必须将每一个数字变成一个向量。...具体做法就是将32X32的数据每一行接在一起,形成一个1X1024的数据,这样我们就可以计算欧式距离。...计算测试数据到所有训练数据的距离,并按照从小到大排序,选出前K个 根据距离计算前K个样本的权重 将相同的训练样本的权重加起来,返回权重最大样本的标签 代码实现: import os def load_data...manifold/digits/trainingDigits') distance = [] #存储测试数据到所有训练数据的距离 for i in range(len(

    28010

    MNIST数据集手写数字分类

    目录0.编程环境1、下载并解压数据集2、完整代码3、数据准备4、数据观察4.1 查看变量mnist的方法和属性4.2 对比三个集合4.3 mnist.train.images观察4.4 查看手写数字图5...MNIST数据集下载链接: https://pan.baidu.com/s/1fPbgMqsEvk2WyM9hy5Em6w 密码: wa9p 下载压缩文件MNIST_data.rar完成后,选择解压到当前文件夹...行代码导入tensorflow库,取别名tf; 第4行代码人从tensorflow.examples.tutorials.mnist库中导入input_data文件; 本文作者使用anaconda集成开发环境...4、数据观察本章内容主要是了解变量mnist中的数据内容,并掌握变量mnist中的方法使用。...4.3 mnist.train.images观察查看mnist.train.images的数据类型和矩阵形状。

    2.8K20

    PyTorch 揭秘 :构建MNIST数据集

    火种二:动态计算图的强大 PyTorch使用动态计算图(Dynamic Computation Graph),也就是说,图的构建是在代码运行时动态进行的,这允许你进行更为直观的模型构建和调试。...这让PyTorch在处理可变长度的输入,如不同长度的文本序列或时间序列数据时,显得游刃有余。动态图的特性也使得在网络中嵌入复杂的控制流成为可能,比如循环和条件语句,这些都是静态图难以做到的。...火种四:实践举例 看一个实际的例子,如何用PyTorch来构建一个卷积神经网络(CNN)来识别手写数字,也就是著名的MNIST数据集: python import torch.optim as optim...running_loss = 0.0 print('Finished Training') # 保存模型参数 torch.save(net.state_dict(), 'mnist_cnn.pth...小结 PyTorch 以其简洁性、强大的动态计算图和活跃的社区支持让学习和研发都变得轻松。我们还通过构建一个CNN模型来识别MNIST数据集中的手写数字,讲述了整个模型的设计、训练和评估过程。

    24510

    手写KNN识别MNIST数据集

    数据集[1] 提取码:mrfr 浏览本文前请先搞懂K近邻的基本原理:深入浅出KNN算法 算法实现步骤: 1.数据处理。...每一个数字都是一个32X32维的数据,如下所示: KNN中邻居一词指的就是距离相近。我们要想计算两个样本之间的距离,就必须将每一个数字变成一个向量。...具体做法就是将32X32的数据每一行接在一起,形成一个1X1024的数据,这样我们就可以计算欧式距离。...2.计算测试数据到所有训练数据的距离,并按照从小到大排序,选出前K个 3.根据距离计算前K个样本的权重4.将相同的训练样本的权重加起来,返回权重最大样本的标签 代码实现: import os def...(K, test_data[i][j])) if __name__ == '__main__': test() References [1] 数据集: https://pan.baidu.com

    39710
    领券