首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >TensorFlow决策森林构建GBDT(Python)

TensorFlow决策森林构建GBDT(Python)

作者头像
算法进阶
发布于 2022-06-02 04:32:32
发布于 2022-06-02 04:32:32
87000
代码可运行
举报
文章被收录于专栏:算法进阶算法进阶
运行总次数:0
代码可运行

一、Deep Learning is Not All You Need

尽管神经网络在图像识别、自然语言等很多领域大放异彩,但回到表格数据的数据挖掘任务中,树模型才是低调王者,如论文《Tabular Data: Deep Learning is Not All You Need》提及的:

深度学习可能不是解决所有机器学习问题的灵丹妙药,通过树模型在处理表格数据时性能与神经网络相当(甚至优于神经网络),而且树模型易于训练使用,有较好的可解释性。

二、树模型的使用

对于决策树等模型的使用,通常是要到scikit-learn、xgboost、lightgbm等机器学习库调用, 这和深度学习库是独立割裂的,不太方便树模型与神经网络的模型融合。

一个好消息是,Google 开源了 TensorFlow 决策森林(TF-DF),为基于树的模型和神经网络提供统一的接口,可以直接用TensorFlow调用树模型。决策森林(TF-DF)简单来说就是用TensorFlow封装了常用的随机森林(RF)、梯度提升(GBDT)等算法,其底层算法是基于C++的 Yggdrasil 决策森林 (YDF)实现的。

三、TensorFlow构建GBDT实践

TF-DF安装很简单pip install -U tensorflow_decision_forests,有个遗憾是目前只支持Linux环境,如果本地用不了将代码复制到 Google Colab 试试~

  • 本例的数据集用的癌细胞分类的数据集,首先加载下常用的模块及数据集:
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import numpy as np  
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
tf.random.set_seed(123)

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score,roc_curve

dataset_cancer = datasets.load_breast_cancer()    # 加载癌细胞数据集

#print(dataset_cancer['DESCR'])

df = pd.DataFrame(dataset_cancer.data, columns=dataset_cancer.feature_names)  

df['label'] = dataset_cancer.target

print(df.shape)

df.head()
  • 划分数据集,并简单做下数据EDA分析:
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# holdout验证法: 按37划分测试集 训练集
x_train, x_test= train_test_split(df, test_size=0.3)

# EDA分析:数据统计指标
x_train.describe(include='all')
  • 构建TensorFlow的GBDT模型:TD-DF 一个非常方便的地方是它不需要对数据进行任何预处理。它会自动处理数字和分类特征,以及缺失值,我们只需要将df转换为 TensorFlow 数据集,如下一些超参数设定:

模型方面的树的一些常规超参数,类似于scikit-learn的GBDT

此外,还有带有正则化(dropout、earlystop)、损失函数(focal-loss)、效率方面(goss基于梯度采样)等优化方法:

构建模型、编译及训练,一步到位:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# 模型参数
model_tf = tfdf.keras.GradientBoostedTreesModel(loss="BINARY_FOCAL_LOSS")

# 模型训练
model_tf.compile()
model_tf.fit(x=train_ds,validation_freq=0.1)
  • 评估模型效果
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
## 模型评估
可以看到test的准确率已经都接近1,可以再那个困难的数据任务试试~
evaluation = model_tf.evaluate(test_ds,return_dict=True)
probs = model_tf.predict(test_ds)
fpr, tpr, _ = roc_curve(x_test.label, probs)
plt.plot(fpr, tpr)
plt.title('ROC curve')
plt.xlabel('false positive rate')
plt.ylabel('true positive rate')
plt.xlim(0,)
plt.ylim(0,)
plt.show()
print(evaluation)
  • 模型解释性 GBDT等树模型还有另外一个很大的优势是解释性,这里TF-DF也有实现。模型情况及特征重要性可以通过print(model_tf.summary())打印出来,

特征重要性支持了几种不同的方法评估:

MEAN_MIN_DEPTH指标。平均最小深度越小,较低的值意味着大量样本是基于此特征进行分类的,变量越重要。

NUM_NODES指标。它显示了给定特征被用作分割的次数,类似split。此外还有其他指标就不一一列举了。

我们还可以打印出模型的具体决策的树结构,通过运行tfdf.model_plotter.plot_model_in_colab(model_tf, tree_idx=0, max_depth=10),整个过程还是比较清晰的。

小结

基于TensorFlow的TF-DF的树模型方法,我们可以方便训练树模型(特别对于熟练TensorFlow框架的同学),更进一步,也可以与TensorFlow的神经网络模型做效果对比、树模型与神经网络模型融合、利用异构模型先特征表示学习再输入模型(如GBDT+DNN、DNN embedding+GBDT),进一步了解可见如下参考文献。

参考文献:https://www.tensorflow.org/decision_forests/ https://keras.io/examples/structured_data/classification_with_tfdf/

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-05-05,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 算法进阶 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
最喜欢随机森林?周志华团队DF21后,TensorFlow开源决策森林库TF-DF
在人工智能发展史上,各类算法可谓层出不穷。近十几年来,深层神经网络的发展在机器学习领域取得了显著进展。通过构建分层或「深层」结构,模型能够在有监督或无监督的环境下从原始数据中学习良好的表征,这被认为是其成功的关键因素。
机器之心
2021/06/08
9490
机器学习(12)——随机森林集成学习随机森林
前言:前面已经介绍了的几种算法,应该对算法有了一个基本的认识了,本章主要是在前面已经学到的基础上,对前面的算法模型进行整合操作,训练出效果更好的分类器模型。 集成学习 集成学习的思想是将若干个学习器(分类器&回归器)组合之后产生一个新学习器。弱分类器( weak learner)指那些分类准确率只稍微好于随机猜测的分类器( errorrate<0.5);集成算法的成功在于保证弱分类器的多样性( Diversity)。而且集成不稳定的算法也能够得到一个比较明显的性能提升。 常见的集成学习思想有: (1)投
DC童生
2018/04/27
2.6K1
机器学习(12)——随机森林集成学习随机森林
原理+代码,总结了 11 种回归模型
本文所用数据说明:所有模型使用数据为股市数据,与线性回归模型中的数据一样,可以做参考,此处将不重复给出。
数据STUDIO
2021/06/24
4.8K0
基于随机森林模型的心脏病人预测分类
今天给大家分享一个新的kaggle案例:基于随机森林模型(RandomForest)的心脏病人预测分类。本文涉及到的知识点主要包含:
皮大大
2022/02/22
2.1K0
基于随机森林模型的心脏病人预测分类
异常检测算法速览(Python代码)
异常检测是通过数据挖掘方法发现与数据集分布不一致的异常数据,也被称为离群点、异常值检测等等。
算法进阶
2022/06/01
1K0
异常检测算法速览(Python代码)
3大树模型实战乳腺癌预测分类
本文从特征的探索分析出发,经过特征工程和样本均衡性处理,使用决策树、随机森林、梯度提升树对一份女性乳腺癌的数据集进行分析和预测建模。
皮大大
2023/08/25
5780
面试腾讯,基础考察太细致。。。
在不平衡数据集中,某些类别的样本数量远多于其他类别,这会导致模型更倾向于预测多数类,而忽略少数类。
Python编程爱好者
2024/06/04
2210
面试腾讯,基础考察太细致。。。
医学图像 | 使用深度学习实现乳腺癌分类(附python演练)
乳腺癌是全球第二常见的女性癌症。2012年,它占所有新癌症病例的12%,占所有女性癌症病例的25%。
磐创AI
2019/09/24
2.7K0
医学图像 | 使用深度学习实现乳腺癌分类(附python演练)
从深度学习到深度森林方法(Python)
目前深度神经网络(DNN)做得好的几乎都是涉及图像视频(CV)、自然语言处理(NLP)等的任务,都是典型的数值建模任务(在表格数据tabular data的表现也是稍弱的),而在其他涉及符号建模、离散建 模、混合建模的任务上,深度神经网络的性能并没有那么好。
算法进阶
2022/06/02
6580
从深度学习到深度森林方法(Python)
数学建模~~~预测方法--决策树模型
第一维度:介绍基本的概念,以及这个决策树的分类和相关的这个算法和基尼系数的计算方法,通过给定这个用户的数据预测这个用户是否会离职;
阑梦清川
2025/02/24
1620
数学建模~~~预测方法--决策树模型
字节一面,差点跪在 GBDT !!
这些天有一个同学在字节一面的时候,在 GBDT 交流的时候,感觉差点点挂掉。好在后面的面试中表现还算可以。
Python编程爱好者
2024/07/12
1770
字节一面,差点跪在 GBDT !!
每个Kaggle冠军的获胜法门:揭秘Python中的模型集成
选自Dataquest 作者:Sebastian Flennerhag 机器之心编译 集成方法可将多种机器学习模型的预测结果结合在一起,获得单个模型无法匹敌的精确结果,它已成为几乎所有 Kaggle 竞赛冠军的必选方案。那么,我们该如何使用 Python 集成各类模型呢?本文作者,曼彻斯特大学计算机科学与社会统计学院的在读博士 Sebastian Flennerhag 对此进行了一番简述。 在 Python 中高效堆叠模型 集成(ensemble)正在迅速成为应用机器学习最热门和流行的方法。目前,几乎每一
机器之心
2018/05/10
3.3K0
常用机器学习代码汇总
皮大大
2023/08/25
5010
TensorFlow Eager 教程
大家好! 在本教程中,我们将使用 TensorFlow 的命令模式构建一个简单的前馈神经网络。 希望你会发现它很有用! 如果你对如何改进代码有任何建议,请告诉我。
ApacheCN_飞龙
2022/12/02
1K0
TensorFlow Eager 教程
6大监督学习方法:实现毒蘑菇分类
本文是kaggle案例分享的第3篇,赛题的名称是:Mushroom Classification,Safe to eat or deadly poison? 数据来自UCI:https://archi
皮大大
2021/12/15
2.3K0
6大监督学习方法:实现毒蘑菇分类
机器学习——决策树模型:Python实现
决策树模型既可以做分类分析(即预测分类变量值),也可以做回归分析(即预测连续变量值),分别对应的模型为分类决策树模型(DecisionTreeClassifier)及回归决策树模型(DecisionTreeRegressor)。
全栈程序员站长
2022/11/03
1.4K0
【python】在【机器学习】与【数据挖掘】中的应用:从基础到【AI大模型】
在大数据时代,数据挖掘与机器学习成为了各行各业的核心技术。Python作为一种高效、简洁且功能强大的编程语言,得到了广泛的应用。
小李很执着
2024/06/15
4110
python_sklearn库的使用
train_test_split()可以将数据按比例随机分为训练集和测试集;参数如下:
全栈程序员站长
2022/11/02
5850
数据分析入门系列教程-决策树实战
在学习了上一节决策树的原理之后,你有没有想动手实践下的冲动呢,今天我们就来用决策树进行项目实战。
周萝卜
2020/10/30
9530
数据分析入门系列教程-决策树实战
【机器学习实战】kaggle 欺诈检测---如何解决欺诈数据中正负样本极度不平衡问题
使用机器学习模型识别欺诈性信用卡交易,这样可以确保客户不会为未曾购买的商品承担费用。
机器学习司猫白
2025/01/21
2240
【机器学习实战】kaggle 欺诈检测---如何解决欺诈数据中正负样本极度不平衡问题
推荐阅读
相关推荐
最喜欢随机森林?周志华团队DF21后,TensorFlow开源决策森林库TF-DF
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验