先来看看说说主要步骤。
1、引入数据
2、训练模型
3、预测
1、引入数据,采用经典的iris数据
Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。
iris以鸢尾花的特征作为数据来源,常用在分类操作中。该数据集由3种不同类型的鸢尾花的50个样本数据构成。其中的一个种类与另外两个种类是线性可分离的,后两个种类是非线性可分离的。
该数据集包含了5个属性:
& Sepal.Length(花萼长度),单位是cm;
& Sepal.Width(花萼宽度),单位是cm;
& Petal.Length(花瓣长度),单位是cm;
& Petal.Width(花瓣宽度),单位是cm;
& 种类:Iris Setosa(山鸢尾)、Iris Versicolour(杂色鸢尾),以及Iris Virginica(维吉尼亚鸢尾)。
首先引入Iris到scikit-learn里
这个数据集是别人已经给我们整理好的,我们只需要导入就可以,现在读者一定对这个数据还是感到很好奇,那么我们一步一步来查看数据集是怎样的。
如图所示,首先我们把iris的原始数据赋值给iris,接着查看了iris数据集的变量分别是花瓣长度,花瓣宽度,萼片长度和萼片宽度,在接着查看了iris的种类,分别为那三种英文名,就不翻译了。接着查看了第一个数据集的值,他们四个数这是第一个样本的花瓣长度,花瓣宽度,萼片长度和萼片宽度。接着是看了第一个样本的种类是“setosa”
为了更直观的看到数据,可以用一下命令把数据打印出来。看看数据庐山真面目,下图没显示全部。我还把他保存到了电脑本地,方便以后用。
这数据,UCI上面也有,打开网站后第一个就是。
到此我们已经有了数据,但是有一个点,就是有关训练集和测试集。我们需要把数据分一些数来。
简单的理解,我们把三种类型的花,分别抽出一个样本,等我们模型建立好以后,用这三个数据去测试我们的模型,看他是否能准确预测出我们花的种类。下面是分别提取三个数的代码,分别提取了第1个,第51个,第101个数据,这三个数据不在参与我们模型的训练。而是后面用来检验模型的。
2、训练模型
训练模型就比较简单,上次课也有说过。
我们导入模型,然后用训练集的数据进行训练模型,这就得到了模型了。
3、预测
预测同样很简单,采用predict命令,然后把我提出来的三个样本数据,带入模型,并查看模型的预测效果
上面一条命令是查看这三个样本数据原本的花的类型,分别是0,1,2,即是“Iris Setosa(山鸢尾)用0表示的、Iris Versicolour(杂色鸢尾)用1表示的,以及Iris Virginica(维吉尼亚鸢尾)用2表示”
下面一条命令这是把这三个样本数据导入模型,让模型分别这三种花的类型,结果显示和原本的种类是一样的,这就说明模型都预测对了。
决策树可视化
前面的步骤正式了模型还不错,但是有一个疑问,这个模型是怎么建立的,为什么能准确判断出花的种类?
利用可视化的决策树,能帮助我们理解模型建立的奥秘。
由于可视化要借助与 graphviz 这个包,另外graphviz 是一个画图的软件,可以自行查找相关资料了解。
引入包的前提是需要安装这个包。可以使用 pip install graphviz 进行安装,如果你使用的也是 anaconda 这是要在anaconda Prompt 命令窗口中使用此命令。
运行以上代码后可以得到下图:
这就是决策树,其实不难理解,当一个样本数据进入时,我们进入这课树的顶端,首先判断其X3 ,这个X3就是Petal.Length(花瓣长度) 当你的样本数据进去时他会首先判断第一个条件,Sepal.Length(花萼长度)是否小于0.8,如果是True 这判断为Iris Setosa(山鸢尾),如果False ,则看Sepal.Length(花萼长度)是否小于1.75.....
这就是决策树的原来,可以带如三个测试数据判断,看看结果一样吗?
请关注公众号:糖馅味。一起学习数据分析,成为数据科学家!
领取专属 10元无门槛券
私享最新 技术干货