Loading [MathJax]/jax/output/CommonHTML/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >pytorch-Train-Val-Test划分(上)

pytorch-Train-Val-Test划分(上)

作者头像
用户6719124
发布于 2019-11-17 13:54:06
发布于 2019-11-17 13:54:06
3.9K00
代码可运行
举报
运行总次数:0
代码可运行

本节介绍的是Train/Val/Test部分的划分,合理的划分会有效地减少under-fitting和over-fitting现象。

我们以数字识别为例,正常一个数据集我们要划分出来训练部分和测设部分,如下图所示

如上图,左侧橘色部分作为训练部分,神经网络在该区域内不停地学习,将特征转入到函数中,学习好后得到一个函数模型。随后将上图右面白色区域的测试部分导入到该模型中,进行accuracy和loss的验证。

通过不断地测试可以查看模型是否调整到一个最佳的参数,及结果是否发生over-fitting现象。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# 训练-测试代码写法
train_loader = torch.utils.data.Dataloader(
# 一般使用DataLoader函数来让机器学习或测试
    datasets.MNIST('../data', train=True, download=True,
# 使用 train=True 或 train=False来进行数据集的划分
#  train=True时为训练集,相反不是训练集(即为测试集)
                   transform=transform.Compose([
                       transforms.ToTensor(),
                       transforms.Normlaize((0.1307,),(0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.Dataloader(
    datasets.MNIST('../data', train=False, download=True,
                   transform=transform.Compose([
                       transforms.ToTensor(),
                       transforms.Normlaize((0.1307,),(0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

这里注意,正常情况下数据集是要有validation(验证集)的,若没有设置,即将test和val集合并为一个。

前面讲解了如何对数据集进行划分,那么如何进行循环学习验证测试呢?

代码如下

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
for epoch in range(epochs):
for batch_idx, (data, target) in enumerate(train_loader):

# 这里的data和target一般作为backward用
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# 每次循环都查看一次是否发生over-fitting现象
# 如果发生了over-fitting现象,我们便将最后一次
# 模型的状态函数作为最终的模型版本

    test_loss = 0
    correct = 0
for data, target in test_loader:
data = data.view(-1, 28*28)
        pred = logits.data.max(1)[1]
        correct += pred.eq(target.data).sum()

以一个实际例子的train error和test error来举例作图

由图看出在train进行到第5次后,test error便达到一个较低的位置。而后随着训练次数的增加,test error会逐渐增加,发生over-fitting现象。

我们将训练次数在5次的点叫做check-point,神经网络会记住该点的参数值。再拿该点所对应的参数做一个实际的预测。

但正常下除了提供神经网络学习的train set和挑选最佳参数的test set外,一般还要有validation set。但val set数据要代替test set的功能,而test数据则要交给客户,进行实际验证,正常情况下test set数据是不加入到神经网络学习测试中的。

若将val set 和 test set 数据都加入到学习或测试部分,则会欺骗客户,使得客户无法拿到最佳的模型。所以正常情况下客户会抽走一部分数据作为test set不让神经网络得到,以此来验证模型的效果。

在kaggle比赛中也会发生这种情况。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-11-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
Train / Val / Test划分
合理的Train/Test集划分会有效地减少under-fitting和over-fitting现象
mathor
2020/01/15
2.3K0
pytorch-Train-Val-Test划分(下)
上半节讲到一个数据集正常需要划分为train、validation和test三个数据集,那么具体到代码中是如何实现的?
用户6719124
2019/11/17
3.5K0
轻松学Pytorch–Visdom可视化
Visdom是Facebook专为PyTorch开发的实时可视化工具包,其作用相当于TensorFlow中的Tensorboard,灵活高效且界面美观,下面就一起来学习下如何使用吧!如果想更多了解关于Visdom的使用可以参考官方
OpenCV学堂
2020/06/04
1.9K0
pytorch:实践MNIST手写数字识别
在datasets.MNIST的中可以设置download=True,这样设置,系统会自动在root里面检测MNIST数据文件,如果存在则不下载,如果不存在则自动联网下载。我尝试自动联网下载,结果十几分钟之后,下载一半之后报错,网络出现问题。于是翻阅其它资源,将其手动下载下来添加到minst文件夹中自动创建的raw文件夹中。 (如果你也需要这个数据集,可以在微信公众号“我有一计”内回复“数据集”,即可获取下载链接)
zstar
2022/06/14
4920
MNIST手写数字识别
掌握利用卷积神经网络CNN实现对MNIST手写数字的识别。一个简单的神经网络实验
不去幼儿园
2024/12/03
4840
MNIST手写数字识别
PyTorch入门笔记-手写数字实战01
对 MNIST 手写数字识别进行分类大致分为四个步骤,这四个步骤也是训练大多数深度学习模型的基本步骤:
触摸壹缕阳光
2020/11/05
1.1K0
PyTorch入门笔记-手写数字实战01
Pytorch实现简单的数字识别(上)
使用深度学习神经网络对数字识别,大体需要4个步骤:①读取数据。②建立模型。③训练。④测试、验证。
用户6719124
2019/11/18
1.6K0
16,8和4位浮点数是如何工作的
50年前Kernighan、Ritchie和他们的C语言书的第一版开始,人们就知道单精度“float”类型有32位大小,双精度类型有64位大小。还有一种具有扩展精度的80位“长双精度”类型,这些类型几乎涵盖了浮点数据处理的所有需求。但是在最近几年,尤其是今年LLM的兴起,为了减小模型的存储和内存占用,开发人员开始尽可能地缩小浮点类型。
deephub
2023/10/06
2.8K0
16,8和4位浮点数是如何工作的
pytorch中的数据索引
在PyTorch中,数据索引是指在处理张量(Tensor)时访问或操作特定元素的过程。索引在数据处理和深度学习中是非常常见且重要的操作,它允许我们以各种方式访问数据集中的元素,执行数据的切片、提取、过滤等操作。
GeekLiHua
2025/01/21
1940
pytorch中的数据索引
用PyTorch实现MNIST手写数字识别(非常详细)
MNIST可以说是机器学习入门的hello word了!导师一般第一个就让你研究MNIST,研究透了,也算基本入门了。好的,今天就来扯一扯学一学。
小锋学长生活大爆炸
2020/08/13
2.1K0
用PyTorch实现MNIST手写数字识别(非常详细)
PyTorch中的LeNet-5入门
LeNet-5是一个经典的卷积神经网络(CNN)模型,由Yann LeCun等人在1998年提出。它在手写数字识别任务上取得了很好的性能,并被广泛应用于图像分类问题。本文将介绍如何使用PyTorch实现LeNet-5模型,并在MNIST手写数字数据集上进行训练和测试。
大盘鸡拌面
2023/10/18
6050
Pytorch实现简单的数字识别(下)
但要注意loss的降低程度不能代表神经网络结构模型的好坏,应该将最终的正确率结果作为验证模型优劣的工具。
用户6719124
2019/11/18
7080
手写数字识别基本思路
问题 什么是MNIST?如何使用Pytorch实现手写数字识别?如何进行手写数字对模型进行检验? 方法 mnist数据集 MNIST数据集是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含
算法与编程之美
2023/08/22
3210
手写数字识别基本思路
【项目实战】MNIST 手写数字识别(上)
本文将介绍如何在 PyTorch 中构建一个简单的卷积神经网络,并训练它使用 MNIST 数据集识别手写数字,这将可以被看做是图像识别的 “Hello, World!”;
sidiot
2023/08/31
6280
【项目实战】MNIST 手写数字识别(上)
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
之前我们讨论的问题都是二分类居多,对于二分类问题,我们若求得p(0),南无p(1)=1-p(0),还是比较容易的,但是本节我们将引入多分类,那么我们所求得就转化为p(i)(i=1,2,3,4…),同时我们需要满足以上概率中每一个都大于0;且总和为1。
小馒头学Python
2024/04/24
3.2K0
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
Pytorch打怪路(二)pytorch进行mnist训练和测试
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Teeyohuang/article/details/79242946
TeeyoHuang
2019/05/25
1.9K0
Microsoft NNI入门
【GiantPandaCV导语】Neural Network Intelligence 是一个工具包,可以有效帮助用户设计并调优汲取学习模型的神经网络架构,以及超参数。具有易于使用、可扩展、灵活、高效的特点。本文主要讲NNI基础的概念以及一个训练MNIST的入门教程。本文首发于GiantPandaCV,未经允许,不得转载。
BBuf
2021/03/09
1.7K0
Microsoft NNI入门
多分类问题线性层和训练部分代码的构建
如下图网络是一个十个输出(十分类问题) 首先建立三个线性层 import torch import torch.nn.functional as F # 先建立三个线性层 784=>200=>20
mathor
2020/01/14
7180
多分类问题线性层和训练部分代码的构建
PyTorch的简单实现
PyTorch 的关键数据结构是张量,即多维数组。其功能与 NumPy 的 ndarray 对象类似,如下我们可以使用 torch.Tensor() 创建张量。如果你需要一个兼容 NumPy 的表征,或者你想从现有的 NumPy 对象中创建一个 PyTorch 张量,那么就很简单了。
代码的路
2022/06/18
2K0
PyTorch的简单实现
Pytorch-多分类问题神经层和训练部分代码的构建
这里完成了tensor的建立和forward过程,下面介绍train(训练)部分。
用户6719124
2019/11/17
8240
推荐阅读
相关推荐
Train / Val / Test划分
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验