首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >机器学习入门 8-6 验证数据集与交叉验证

机器学习入门 8-6 验证数据集与交叉验证

作者头像
触摸壹缕阳光
发布2019-12-30 16:14:58
发布2019-12-30 16:14:58
1.8K0
举报
本系列是《玩转机器学习教程》一个整理的视频笔记。本小节探讨将数据集划分训练集和测试集的局限性,进而引出验证集,为了解决验证集随机性的问题,引入了交叉验证和留一法,并进一步探讨网格搜索背后的意义,最后通过编程实现调参选择模型的整个过程。

01

划分训练集和测试集的局限性

前几个小节通过引入过拟合和欠拟合的概念,让大家理解使用train_test_split方法划分出测试集的意义。

将所有样本数据都当成训练集,如果此时模型发生了过拟合,我们并不会发觉。因为在这种情况下模型在训练集上误差会非常非常的小,会觉得训练得到的模型很好,但是实际上模型的泛化能力很差,我们真正关注的就是模型的泛化能力。所以将数据集划分成训练集和测试集两部分:

  1. 训练数据集用于训练模型;
  2. 测试数据集用来评估训练得到模型的性能,简单来说就是通过测试集来判断训练得到模型的好坏;

如果此时模型发生了过拟合,也就是在训练集上表现的很好,可是在测试集上表现的不好。发生过拟合需要执行下面步骤。

  1. 重新调整模型的参数(一般是超参数),重新得到不同复杂度的新模型;
  2. 在训练集上训练新模型;
  3. 在测试集上对训练得到的新模型评估性能;

当然往往不可能一次就能找到性能最好的模型,上面三个步骤通常需要循环往复。

这里调整模型的参数通常指的就是超参数。在kNN算法中超参数有k值和p值,在多项式回归中degree阶数也是超参数。我们需要调整这些超参数得到一个在测试集上表现比较好的模型,这样的模型泛化能力比较强,这样我们就更有信心当把这个模型放入真实生产环境中面对未知数据的时候能够有更好的表现。

但是将数据集划分为训练集和测试集这种方式真的靠谱吗???

02

验证集的引入

将数据集划分为训练集和测试集当然要比只使用训练集来得到模型靠谱合理的多。

严格来说还是存在一些有问题不靠谱的地方。这个问题就在于,通过测试集来评估模型的好坏,如果模型在测试集上表现的不好,此时就需要重新调节模型参数,调节参数的方向变成了拟合测试集,很有可能最终得到的模型过拟合了这个测试集,可以想象一些虽然是使用训练集训练得到模型,但是每次都是通过测试集来看这个模型的好坏,一旦发现模型不好就需要换一份参数重新进行训练,其实这个过程模型在一定程度上围绕着测试集打转。也就是说我们想办法找到一组参数,这组参数是我们用训练集训练得到的模型在测试集上效果最好,但是由于这个测试集是已知的,此时就相当于在针对这组测试集进行调参,那么模型很有可能在测试集过拟合。

如果解决这个问题呢?解决方法就是将整个数据集划分为三个部分,这三个部分分别是训练集、验证集以及测试集。

  1. 训练集用于训练模型;
  2. 验证集是之前测试集所做的事情,当我们训练好模型后将验证集输入到这个模型中,看看相应的效果怎么样,如果效果不好的话,就需要换参数重新训练模型,直到我们找到一组参数,这组参数使得模型针对验证集来说已经达到最优;
  3. 测试集用于测试模型最终的性能;

上面的这个过程,测试集不再参与模型的创建了,而训练集和测试集都参与了模型的创建。

  1. 训练集用于训练模型,验证集用于评判,如果评判的结果不好的话,就需要重新选择参数训练模型,训练集和验证集都参与了模型的创建;
  2. 测试集对于模型来说完全不可知的,相当于是我们在模拟真正的真实环境中模拟的一个完全未知的数据。测试集不参与模型的创建,当我们最终确定好了模型之后再把测试集输入到这个最终模型中得到模型最终的性能;

验证集是作为调整超参数使用的数据集,而测试集是作为衡量最终模型性能的数据集。我们使用一组参数的模型在训练集上进行训练,训练得到的模型一旦在验证集上表现的不好,我们就需要重新换参数(调参数),所以验证集相当于调整超参数使用,对于选出来最终的模型,在测试集上进行测试得到的结果更加准确,可以评判模型的性能。

验证集通常被称为Validation Dataset,将数据集划分为训练集、验证集以及测试集的代码和将数据集划分为训练集和测试集的实现相似。

03

交叉验证

将数据集划分为训练集、验证集和测试集还是有一个问题。这个问题在于随机,由于每一次验证的数据集都是随机的从原来的数据集中切分出来的,通过之前的分析也可以看出来,和将数据集划分为训练集和测试集类似,训练的模型有可能过拟合验证集,但是我们只有一份验证集,一旦这一份验证集里有比较极端的样本就可能会导致模型最终的结果不准确。

为了解决这个问题,有了交叉验证(Cross Validation)这样的方式。在我们调整模型参数的时候,交叉验证相对来讲是看模型性能比较正规比较标准的方式。

那什么叫做交叉验证呢???

对于训练的数据来说,通常将训练的数据分成k份。比如现在把它分成三份,将这三份叫做A,B和C,接下来要做的事情就是让A,B和C分别做验证集。

  1. A做验证集,把BC合起来当做训练集;
  2. B做验证集,把AC合起来当做训练集;
  3. C做验证集,把AB合起来当做训练集;

每一个训练数据集和验证集的搭配就会产生一个模型。比如将训练的数据集分成三份的话,通过这样的方式就可以得到三个模型,这三个模型每一个模型在验证集上都会求出一个性能指标,把这些性能指标的平均值作为最终衡量当前算法得到的模型的性能标准。

当然了这里将训练的数据分成三份,实际上可以把训练的数据分成k份,通常k为3,5和10。将训练的数据分成k份,相应的会得到k个模型,将这k个模型的均值作为最终的结果。如果得到的最终结果不够好的话,还需要调整一下参数,然后继续将训练的数据分成k份,每一组都会得到相应模型的性能,将这k个模型的性能指标平均作为最终的结果。

由于交叉验证方式中有一个求平均的过程,所以不会由于某一份验证集中有什么极端样本导致最终训练出来的模型有过大的偏差,所以这样做比将数据集划分训练集和测试集以及将数据集划分为训练集、验证集和训练集(只设立一个验证集)两种方式要靠谱的多。也正是因为如此,通常在调参的时候要使用交叉验证的方式。

接下来通过具体的编程实现来看一下如何使用这种交叉验证的方式来进行调参。

04

编程实现交叉验证

  • Steps1:本小节使用digits手写识别数据集。
  • Steps2:首先使用之前一直使用的train_test_split方法。

使用kNN算法来对手写数字进行识别,在这个过程中训练模型相应的进行调参。对于kNN算法而言,超参数一共有两个:

  1. k值,也就是k近邻算法中几个相邻的元素当成邻居进行投票,选择范围设定为2到10之间;
  2. p值,也就是k近邻算法计算的距离,选定的范围为1到5之间;

kNN算法中的9个k值和5个p值一共有45种不同的参数组合,对应的每一组都创建一个kNN对象,这里主要调整k值和p值,因此将weights的值固定为"distance",每一次都在训练集上进行训练,最后在测试集上评估模型的性能,通过比对每组参数得到的性能指标,选出在测试集上性能最好的性能指标,并输出相应的k值和p值。当然由于前面使用train_test_split的时候指定了random_state随机种子为666,因此如果将种子设置为666的话,得到的最好的超参数和上面得到的结果是一致的。

上面就是我们之前一直使用的train_test_split方式进行超参数的调整。

  • Steps3:接下来使用交叉验证的方式进行超参数的调整。

对于交叉验证的过程不从底层进行实现了,其实如果底层实现也很简单,利用numpy数组的切片就能够轻松的实现交叉验证。这里直接调用sklearn的model_selection下的cross_val_score方法即可,只需要传入相应的算法以及训练的数据(将来会被划分为训练集和验证集)就会自动进行交叉验证的过程,返回k个模型中每个模型的准确率,这里使用cross_val_score默认k折为3,因此默认返回拥有三个数的数组,当然在sklearn中cross_val_score在如何分组这件事情上使用了一些比较复杂的技巧。

Steps2和Steps3分别使用了train_test_split以及交叉验证的方式进行调参选择模型,两次最合适的模型参数分别为:

  1. 两种方法得到的Best k和Best p是不一样的,通常在这种情况下,更相信通过交叉验证得到的这组参数,因为在train_test_split中得到的这组参数很有可能过拟合了在train_test_split中分离出来的测试集;
  2. 在交叉验证中得到的最佳分数0.982是低于在train_test_split中得到的最佳分数0.986,这是因为在交叉验证的过程中,通常不会过拟合某一组的验证数据,所以平均来讲计算得到的分数会稍微低一些;

现在得到了使用交叉验证计算最终得到最好的k和p,那最终的准确率就是交叉验证得到的0.982吗?当然不是了,这里交叉验证过程就是为了拿到最好的k和p值而已,当我们拿到了这组参数之后,就可以用这组参数创建适用当前数据最佳的kNN。

用k=2,p=2这组通过交叉验证找到的kNN分类器,对X_train和y_train整体进行拟合训练,然后使用X_test和y_test验证最终模型准确率的结果为0.98。此时就可以说,我们用交叉验证的方式或者更准确的说使用三交叉验证的方式(因为交叉验证的过程中每次将训练的数据分成三份),用三交叉验证的方式找到了kNN算法最佳的参数组合k = 2,p = 2,此时我们模型分类的准确度是98%。

在这里得到的最终的分类准确度是根据X_test和y_test计算出来的,X_test和y_test在真正寻找最佳模型的过程中是完全没有的,对于我们整个模型来说是完全陌生的,也就是说我们使用一组模型完全没有见过的数据来测量他最终准确率是怎么样的,此时的准确率是值得相信的。

不过说了这么多,其实之前进行网格搜索的时候已经使用了交叉验证,只不过交叉验证的过程被sklearn封装在网格搜索中。很有可能当时并没有意识到,接下来实现网格搜索的过程,使用GridSearchCV实现网格搜索,这里的CV就是Cross Validation交叉验证。

网格搜索的参数和之前搜索的范围一样,运行会有下面的输出信息。

代码语言:javascript
复制
Fitting 3 folds for each of 45 candidates, totalling 135 fits

这里的3 folds就是指网格搜索中每一次使用交叉验证的方式进行搜索,都会将训练集分成三份。而此时的参数组合k值9种 * p值5种 = 45种组合,因此网格搜索需要对45组参数进行搜索,每组参数又要生成三个模型来计算它们性能的平均值,加在一起总共需要135次训练。

下面就可以输出网格搜索的结果:

可以看出通过网格搜索输出的最佳模型分数和前面使用交叉验证得到的最佳模型的分数是一致的,都是0.9823,并且网格搜索和前面交叉验证得到的最佳参数以及最终在测试集上得到的分数都是一致的。

通过上面可以看出来,sklearn中封装的GridSearchCV本身就实现了用交叉验证的方式来进行参数搜索的方法。

整个过程对于cross_val_score函数默认将训练的数据分成三份,如果想要分成其他数值的份数,只需要传入cv参数并指定即可。

代码语言:javascript
复制
cross_val_score(knn_clf, X_train, y_train, cv = 5) # 分成5份

最终返回的数组中就有五个数值,对应的就是训练了五个模型,每个模型对应的分数是多少。

当然对于GridSearchCV来说也可以传入cv参数,如果指定cv参数为5的话,进行网格搜索的时候,每一次交叉验证都会将我们的数据集分成五份。

代码语言:javascript
复制
grid_search = GridSearchCV(knn_clf, param_grid, verbose=1, cv=5)

最后总结一下:

我们说了使用交叉验证的方式在调参的过程中评价模型的准确度更加靠谱。通常把训练集分成k份,具体来说又称为k-folds cross validation,其中每一个fold就是把训练集分成k份,每一份都可以叫做一个fold,通过之前程序的运行可以体会到这种k-folds cross validation方法的缺点就是由于每一次都要训练k个模型,所以整体调参的性能慢了k倍,这个k值越大整体调参的过程就会越耗时,但是通常最后找到的参数可以更加的信赖。

04

留一法 LOO-CV

在极端情况下,k-folds cross validation可以变成留一法(LOO-CV)这种交叉验证方式。

如果训练数据集一共有m个样本,留一法就是将和m个样本分成m份。换句话说,每一次都将(m - 1)份样本用于训练,然后去看剩下的一个样本当成验证集进行测试,将m个模型的结果综合起来进行平均,作为衡量当前参数下这个模型对应预测的准确度。这样做的优点显而易见,完全不受随机的影响,因为使用k-folds cross validation,即使我们把样本分成k份,其实这k份无论如何划分都是有若干种可能的,也会有随机带来的影响,但是留一法完全不受随机的影响,也是最接近模型真正的性能指标,显然留一法最大的缺点就是计算量巨大。比如对于手写数字识别的数据集而言,有上千个样本,每训练一个参数就需要训练上千个模型,显而易见时间开销实在是太大了,所以如果你的计算资源没有那么富裕的话,不要使用留一法。

虽然留一法计算量巨大,但是在很多学术研究论文中为了最终结果的严谨性有可能会使用这种留一法。在下一小节,会继续进行总结,同时介绍如果发现自己训练的模型有过拟合倾向的时候,到底应该如何去做。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-12-27,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档