在 TensorFlow 中,你可以使用 tf.data.Dataset
API 来加载、分割和切片 MNIST 数据集。MNIST 数据集是一个经典的手写数字识别数据集,包含 60,000 个训练样本和 10,000 个测试样本。TensorFlow 提供了方便的工具来加载和处理这个数据集。
以下是一个完整的示例,展示了如何使用 TensorFlow 数据集加载方法对 MNIST 数据集进行分割和切片。
首先,确保你已经安装了 TensorFlow。如果还没有安装,可以使用以下命令进行安装:
pip install tensorflow
使用 tf.keras.datasets.mnist
模块来加载 MNIST 数据集。
import tensorflow as tf
# 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
# 打印数据集的形状
print(f'Train images shape: {train_images.shape}')
print(f'Train labels shape: {train_labels.shape}')
print(f'Test images shape: {test_images.shape}')
print(f'Test labels shape: {test_labels.shape}')
将 NumPy 数组转换为 tf.data.Dataset
对象。
# 创建训练数据集
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
# 创建测试数据集
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
假设你想将训练数据集分割为训练集和验证集。你可以使用 Dataset.take
和 Dataset.skip
方法来实现。
# 定义分割比例
validation_split = 0.1
num_train_samples = int((1 - validation_split) * len(train_images))
num_validation_samples = len(train_images) - num_train_samples
# 分割训练集和验证集
train_dataset = train_dataset.take(num_train_samples)
validation_dataset = train_dataset.skip(num_train_samples)
你可以使用 Dataset.batch
方法来切片数据集,以便在训练过程中使用小批量数据。
# 定义批量大小
batch_size = 32
# 切片数据集
train_dataset = train_dataset.batch(batch_size)
validation_dataset = validation_dataset.batch(batch_size)
test_dataset = test_dataset.batch(batch_size)
在训练之前,你可能需要对数据进行预处理,例如归一化。
def preprocess(image, label):
image = tf.cast(image, tf.float32) / 255.0
return image, label
# 应用预处理函数
train_dataset = train_dataset.map(preprocess)
validation_dataset = validation_dataset.map(preprocess)
test_dataset = test_dataset.map(preprocess)
现在你可以使用这些数据集来训练一个简单的模型。
# 创建一个简单的模型
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
# 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 训练模型
model.fit(train_dataset, epochs=5, validation_data=validation_dataset)
# 评估模型
test_loss, test_acc = model.evaluate(test_dataset)
print(f'Test accuracy: {test_acc}')
领取专属 10元无门槛券
手把手带您无忧上云