前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【演化计算】Evolutionary Forest——基于演化算法的自动特征工程框架

【演化计算】Evolutionary Forest——基于演化算法的自动特征工程框架

作者头像
演化计算与人工智能
发布2021-06-10 10:30:46
7800
发布2021-06-10 10:30:46
举报

本文转载自知乎号 “震灵”

在传统的机器学习领域,构建鲁棒且有意义的特征可以显著改善最终模型的性能。尤其是随着深度学习的发展,特征自动构建已经不再是一件新鲜事。但是,在传统机器学习领域,尤其是数据量不足的时候,基于深度学习的特征构建算法往往难以取得满意的效果。此外,深度学习的黑盒特性也影响了深度学习算法在金融和医疗领域的应用。因此,本文旨在探索一种新的基于演化算法的自动特征构建算法(Evolutionary Forest)在特征工程方面的效果。为了简单起见,我选择了scikit-learn包中的一个问题作为案例研究问题。这项任务被称为“diabetes”,其目标是预测一年后该疾病的进展情况。

首先我们需要安装自动特征构建框架Evolutionary Forest,目前该框架可以直接从PIP进行安装,也可以从GitHub上手动下载源码安装。

代码语言:javascript
复制
pip install -U evolutionary_forest

代码语言:javascript
复制
git clone https://github.com/zhenlingcn/EvolutionaryForest.git
cd EvolutionaryForest
pip install -e .

安装完成后,就可以开始模型训练了,我们将数据分成训练集和测试集,分别训练随机森林和Evolutionary Forest,并在测试集上进行测试。

代码语言:javascript
复制
import random

import numpy as np

from lightgbm import LGBMRegressor
from sklearn.datasets import load_diabetes
from sklearn.ensemble import ExtraTreesRegressor, AdaBoostRegressor, GradientBoostingRegressor, RandomForestRegressor
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from xgboost import XGBRegressor
from catboost import CatBoostRegressor

from evolutionary_forest.utils import get_feature_importance, plot_feature_importance, feature_append
from evolutionary_forest.forest import cross_val_score, EvolutionaryForestRegressor

random.seed(0)
np.random.seed(0)

X, y = load_diabetes(return_X_y=True)
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
r = RandomForestRegressor()
r.fit(x_train, y_train)
print(r2_score(y_test, r.predict(x_test)))
r = EvolutionaryForestRegressor(max_height=8, normalize=True, select='AutomaticLexicase',
                                mutation_scheme='weight-plus-cross-global',
                                gene_num=10, boost_size=100, n_gen=100, base_learner='DT',
                                verbose=True)
r.fit(x_train, y_train)
print(r2_score(y_test, r.predict(x_test)))

输出结果:

代码语言:javascript
复制
随机森林R2: 0.26491330931789137
                                            fitness                                                           size                     
-------     ------------------------------------------------------------------------    ----------------------------------------
gen nevals  avg         gen max         min         nevals  std         avg     gen max min nevals  std    
0   50      0.123885    0   0.537842    -0.126433   50      0.135692    42.68   0   58  32  50      5.27045
1   50      0.0770607   1   0.341941    -0.179468   50      0.117896    42.48   1   54  34  50      4.57051
2   50      0.0367042   2   0.263943    -0.140117   50      0.0960126   43.8    2   54  32  50      4.5299 
3   50      0.0288158   3   0.235371    -0.21743    50      0.0954819   42.56   3   54  34  50      5.20062
4   50      0.00906959  4   0.174993    -0.152483   50      0.0801536   41.84   4   52  34  50      4.81398
5   50      -0.0177332  5   0.149236    -0.171106   50      0.0815517   43.92   5   54  34  50      4.59495
6   50      -0.0358982  6   0.152073    -0.198653   50      0.0769967   44.8    6   52  36  50      3.6    
7   50      -0.0267816  7   0.133931    -0.189818   50      0.0758869   44.84   7   56  36  50      4.27252
8   50      -0.0260382  8   0.137885    -0.18096    50      0.0772178   45.36   8   56  36  50      4.37154
9   50      -0.0227268  9   0.175584    -0.1303     50      0.0559889   45.76   9   64  38  50      5.42055
10  50      -0.0401644  10  0.112817    -0.162578   50      0.0714563   45.64   10  64  34  50      6.52   
11  50      -0.0362288  11  0.123167    -0.190264   50      0.067495    47      11  64  40  50      5.47357
12  50      -0.0351611  12  0.18934     -0.171688   50      0.0756482   47.76   12  58  38  50      5.47927
13  50      -0.0298361  13  0.213579    -0.216762   50      0.0908174   48.64   13  62  40  50      5.35821
14  50      -0.0488558  14  0.10724     -0.237322   50      0.0715246   49.88   14  66  38  50      5.69083
15  50      -0.0615554  15  0.104763    -0.196072   50      0.0692981   49.48   15  60  40  50      5.15069
16  50      -0.0683978  16  0.120128    -0.235416   50      0.0828015   52      16  66  40  50      5.76888
17  50      -0.0660739  17  0.132585    -0.220159   50      0.0796116   52.32   17  64  38  50      6.98982
18  50      -0.0793876  18  0.148632    -0.238571   50      0.0725794   52.64   18  64  40  50      5.99253
19  50      -0.103851   19  0.0779257   -0.241246   50      0.0656209   53.36   19  68  38  50      6.84035
20  50      -0.117092   20  0.0741851   -0.248124   50      0.0712469   52.08   20  66  38  50      7.09884
21  50      -0.126118   21  0.0892651   -0.278597   50      0.0721668   53.16   21  70  38  50      7.53753
22  50      -0.110964   22  0.0687603   -0.229036   50      0.0627026   53.16   22  68  36  50      7.40908
23  50      -0.105468   23  0.0780661   -0.220073   50      0.0706625   53.32   23  70  38  50      8.07574
24  50      -0.0953438  24  0.069703    -0.256387   50      0.0653205   54.48   24  76  40  50      7.29449
25  50      -0.0842262  25  0.066852    -0.259633   50      0.0742788   58.08   25  74  46  50      7.3752 
26  50      -0.09351    26  0.0749092   -0.190871   50      0.0626105   58.12   26  82  44  50      7.21565
27  50      -0.0841089  27  0.0975908   -0.182683   50      0.0627661   57.64   27  72  40  50      6.77277
28  50      -0.0992423  28  0.0337034   -0.267054   50      0.0608403   58      28  74  38  50      7.87909
29  50      -0.0995033  29  0.0944528   -0.232369   50      0.0611414   58.52   29  76  38  50      9.59842
30  50      -0.0846225  30  0.0923268   -0.237576   50      0.074057    60.68   30  78  50  50      6.59224
31  50      -0.108175   31  0.0540676   -0.294169   50      0.0831479   61.08   31  84  42  50      8.27729
32  50      -0.120433   32  0.0263439   -0.251897   50      0.0586961   64.08   32  90  46  50      9.43152
33  50      -0.0940024  33  0.0235497   -0.232917   50      0.0678665   63.28   33  90  46  50      9.13683
34  50      -0.0980132  34  0.0486761   -0.258975   50      0.0667732   65.68   34  94  46  50      9.96281
35  50      -0.0942408  35  0.0841642   -0.215611   50      0.0556214   64.4    35  80  42  50      9.64158
36  50      -0.102879   36  0.0529161   -0.247432   50      0.0661325   63.32   36  88  40  50      11.1758
37  50      -0.109848   37  0.0735743   -0.268532   50      0.0680665   64.68   37  94  40  50      11.0172
38  50      -0.102651   38  0.0456731   -0.248441   50      0.0679962   68.2    38  92  40  50      11.9415
39  50      -0.11965    39  0.00341094  -0.217797   50      0.055611    68.96   39  90  52  50      10.0099
40  50      -0.123702   40  0.0013399   -0.237731   50      0.049941    71.96   40  98  52  50      10.677 
41  50      -0.132916   41  0.0778124   -0.282513   50      0.0647553   71.16   41  92  52  50      8.46725
42  50      -0.109516   42  0.0560338   -0.251807   50      0.0537345   72.68   42  98  54  50      10.195 
43  50      -0.10434    43  0.0221852   -0.230932   50      0.0578682   73.96   43  98  56  50      10.2019
44  50      -0.106631   44  0.0775565   -0.241648   50      0.0701587   75      44  98  54  50      10.927 
45  50      -0.106065   45  0.00969612  -0.244165   50      0.0620504   76.4    45  96  54  50      10.7926
46  50      -0.107319   46  0.0452202   -0.260366   50      0.063689    81.6    46  106 58  50      11.7712
47  50      -0.10029    47  0.0533131   -0.261053   50      0.0677185   82.56   47  108 54  50      13.3868
48  50      -0.104824   48  0.0186688   -0.278224   50      0.0708134   84.92   48  108 60  50      13.4772
49  50      -0.115506   49  0.0708258   -0.241785   50      0.0710336   84.08   49  114 60  50      12.7057
50  50      -0.110368   50  0.0332325   -0.217109   50      0.0626142   91.08   50  122 64  50      14.5008
51  50      -0.109537   51  0.0444241   -0.243925   50      0.0705436   92.6    51  118 64  50      13.6953
52  50      -0.120179   52  0.030236    -0.243537   50      0.064675    94.36   52  122 58  50      12.8293
53  50      -0.10772    53  0.102777    -0.296489   50      0.0776605   94.88   53  118 70  50      11.2865
54  50      -0.133547   54  0.00743194  -0.264296   50      0.0625067   96.6    54  134 74  50      11.3719
55  50      -0.133541   55  -0.016464   -0.26914    50      0.0617844   94.84   55  128 64  50      13.6299
56  50      -0.11475    56  0.10123     -0.256939   50      0.0705579   97.76   56  130 70  50      12.8165
57  50      -0.127505   57  0.0443451   -0.259618   50      0.0705481   97.72   57  132 68  50      14.0628
58  50      -0.113637   58  0.0865859   -0.23016    50      0.0727976   100.16  58  136 66  50      14.0618
59  50      -0.122624   59  0.0480375   -0.248926   50      0.0670382   97.76   59  136 64  50      14.1853
60  50      -0.13619    60  0.125065    -0.266734   50      0.0757874   100.04  60  134 72  50      13.1848
61  50      -0.142652   61  0.0382405   -0.310567   50      0.0705014   104.72  61  150 68  50      17.9722
62  50      -0.115978   62  0.034415    -0.265388   50      0.0742441   105.2   62  138 76  50      14.3722
63  50      -0.138148   63  -0.0265558  -0.264836   50      0.0585654   107.36  63  148 66  50      19.4214
64  50      -0.138618   64  0.0304411   -0.293099   50      0.0698058   108.2   64  150 62  50      19.4864
65  50      -0.143101   65  -0.00123757 -0.265788   50      0.0569489   107.12  65  152 56  50      18.3876
66  50      -0.1426     66  0.00688022  -0.283889   50      0.0644535   109.16  66  160 72  50      20.6033
67  50      -0.144704   67  -0.0344605  -0.289808   50      0.0541972   114.96  67  154 72  50      21.6638
68  50      -0.149151   68  -0.0335756  -0.286015   50      0.0589805   116.76  68  154 76  50      19.4849
69  50      -0.152714   69  0.0106457   -0.294819   50      0.0608671   117.16  69  168 82  50      20.0164
70  50      -0.15481    70  -0.0516891  -0.254148   50      0.0601247   117.56  70  180 82  50      21.6861
71  50      -0.154969   71  -0.0160051  -0.278613   50      0.0579621   128.48  71  210 84  50      27.7865
72  50      -0.164418   72  -0.0599303  -0.245451   50      0.0458175   128.24  72  196 90  50      22.7952
73  50      -0.168461   73  -0.0226711  -0.271976   50      0.0516079   130.36  73  190 78  50      25.7952
74  50      -0.1426     74  -0.00680237 -0.22915    50      0.0540989   135.28  74  180 82  50      25.6648
75  50      -0.156227   75  -0.0503103  -0.250355   50      0.0489163   136.2   75  198 82  50      26.4098
76  50      -0.145399   76  0.0343498   -0.271271   50      0.0655932   146.08  76  190 82  50      26.551 
77  50      -0.136322   77  0.030991    -0.270662   50      0.0640328   146.88  77  208 102 50      23.0136
78  50      -0.14898    78  -0.000534747    -0.235962   50      0.0467585   146.16  78  220 100 50      24.3896
79  50      -0.155908   79  -0.0287941      -0.25179    50      0.0523467   152.08  79  216 100 50      26.8833
80  50      -0.147683   80  0.00358838      -0.265754   50      0.0604739   158.8   80  228 116 50      26.9132
81  50      -0.154123   81  -0.0109347      -0.290986   50      0.0666492   160.68  81  230 106 50      29.0506
82  50      -0.166035   82  -0.0102758      -0.297705   50      0.0566674   164.28  82  234 118 50      29.3449
83  50      -0.166368   83  -0.0290866      -0.28154    50      0.0616977   168.64  83  246 114 50      27.304 
84  50      -0.180013   84  0.0411345       -0.28144    50      0.064723    170.6   84  248 122 50      27.6586
85  50      -0.160058   85  -0.00114839     -0.267153   50      0.0669397   167.36  85  228 130 50      21.4865
86  50      -0.191222   86  -0.0235472      -0.298564   50      0.0630311   171.12  86  222 126 50      24.9869
87  50      -0.177559   87  -0.0133237      -0.291066   50      0.0590346   178.08  87  226 132 50      19.1372
88  50      -0.173883   88  -0.0757583      -0.290256   50      0.0560975   180.72  88  228 132 50      21.5175
89  50      -0.182142   89  -0.0800455      -0.297681   50      0.0509136   185.52  89  234 148 50      18.2584
90  50      -0.169139   90  -0.0183723      -0.289145   50      0.0614699   188.36  90  228 150 50      19.6965
91  50      -0.184399   91  -0.0389598      -0.297164   50      0.0602288   188.64  91  234 142 50      21.2619
92  50      -0.189729   92  -0.0711759      -0.286582   50      0.0485411   194.52  92  240 140 50      22.1623
93  50      -0.18623    93  -0.0565673      -0.296043   50      0.0491493   190.52  93  240 142 50      23.596 
94  50      -0.189683   94  -0.0829889      -0.302007   50      0.0557292   193.64  94  278 156 50      24.8699
95  50      -0.171148   95  -0.0487378      -0.284869   50      0.0533105   190.08  95  250 160 50      21.7769
96  50      -0.181613   96  -0.0807617      -0.318067   50      0.0521236   191.2   96  252 134 50      22.7086
97  50      -0.173216   97  -0.0744686      -0.292126   50      0.0528409   192.64  97  236 126 50      20.2995
98  50      -0.173024   98  -0.0440228      -0.286258   50      0.052287    192.44  98  238 130 50      22.4117
99  50      -0.178127   99  -0.0364317      -0.310203   50      0.055504    192.72  99  252 146 50      23.3127
100 50      -0.157361   100 0.0178264       -0.359337   50      0.0690032   187.68  100 260 126 50      28.1804
EvolutionaryForest R2: 0.34862896686986733

基于上述结果,我们可以看到Evolutionary Forest优于传统的随机森林。然而,我们不应该仅仅满足于有一个更好的模型。事实上,该框架的一个更重要的目标是获得更多优质的可解释特征,从而提高主流机器学习模型的性能。因此,我们可以基于impurity reduction计算特征的重要性,然后根据这些重要性分数对所有特征进行排序。为了清晰起见,目前只显示前15个最重要的特征。

代码语言:javascript
复制
feature_importance_dict = get_feature_importance(r)
plot_feature_importance(feature_importance_dict)

在创建特征重要性图之后,我们可以尝试利用这些有用的特征,并探究这些特征是否能够真正地改进现有模型的性能。为了简单起见,我们放弃使用原来的特征,只保留构造好的特征。

代码语言:javascript
复制
code_importance_dict = get_feature_importance(r, simple_version=False)
new_X = feature_append(r, X, list(code_importance_dict.keys())[:20], only_new_features=True)
new_train = feature_append(r, x_train, list(code_importance_dict.keys())[:20], only_new_features=True)
new_test = feature_append(r, x_test, list(code_importance_dict.keys())[:20], only_new_features=True)
new_r = RandomForestRegressor()
new_r.fit(new_train, y_train)
print(r2_score(y_test, new_r.predict(new_test)))

输出结果:

代码语言:javascript
复制
基于新特征的随机森林 R2: 0.2865161546726096

从结果中可以看出,所构建的特征确实能够带来性能的提高,说明了所构建的特征的有效性。然而,一个更有趣的问题是,这些特征是否只能用于这个随机森林模型呢,或者说它们是否也可以应用于其他机器学习模型呢?因此,在下一节中,我们将尝试看看这些特征是否可以用来改善现有的最先进的机器学习算法的性能。

代码语言:javascript
复制
regressor_list = ['RF', 'ET', 'AdaBoost', 'GBDT', 'DART', 'XGBoost', 'LightGBM', 'CatBoost']

scores_base = []
scores = []

for regr in regressor_list:
    regressor = {
        'RF': RandomForestRegressor(n_jobs=1, n_estimators=100),
        'ET': ExtraTreesRegressor(n_estimators=100),
        'AdaBoost': AdaBoostRegressor(n_estimators=100),
        'GBDT': GradientBoostingRegressor(n_estimators=100),
        'DART': LGBMRegressor(n_jobs=1, n_estimators=100, boosting_type='dart'),
        'XGBoost': XGBRegressor(n_jobs=1, n_estimators=100),
        'LightGBM': LGBMRegressor(n_jobs=1, n_estimators=100),
        'CatBoost': CatBoostRegressor(n_estimators=100, thread_count=1,
                                      verbose=False, allow_writing_files=False),
    }[regr]
    score = cross_val_score(regressor, X, y)
    print(regr, score, np.mean(score))
    scores_base.append(np.mean(score))
    score = cross_val_score(regressor, new_X, y)
    print(regr, score, np.mean(score))
    scores.append(np.mean(score))
scores_base = np.array(scores_base)
scores = np.array(scores)

输出结果:

代码语言:javascript
复制
RF [0.40687788 0.48232282 0.44269645 0.35267621 0.44181526] 0.4252777225424557
RF [0.50793898 0.53050996 0.52949131 0.42782289 0.48283776] 0.4957201806972883
ET [0.36793279 0.5021531  0.4315543  0.40709356 0.45861526] 0.43346980380606287
ET [0.46677294 0.51676178 0.51322419 0.40318946 0.46436162] 0.4728619983346178
AdaBoost [0.37194555 0.45511817 0.41798425 0.41877328 0.42923058] 0.4186103652544749
AdaBoost [0.3823143  0.51884104 0.45531521 0.41891795 0.45295098] 0.44566789858668293
GBDT [0.3369759  0.51456239 0.42577374 0.33095893 0.4255554 ] 0.40676527219349135
GBDT [0.48299994 0.47270435 0.53152066 0.43823941 0.45024527] 0.4751419254707957
DART [0.35379204 0.4339262  0.40526565 0.29617651 0.40656592] 0.37914526592589093
DART [0.49346642 0.45434591 0.51810906 0.38240628 0.45425024] 0.4605155806458937
XGBoost [0.19069273 0.31696014 0.38186465 0.15942315 0.30005706] 0.2697995450013659
XGBoost [0.45217748 0.45463161 0.41229883 0.26416623 0.38149529] 0.3929538866861097
LightGBM [0.35463506 0.44812537 0.36198867 0.27123459 0.4335444 ] 0.3739056175995374
LightGBM [0.5458251  0.49609922 0.48624857 0.3708606  0.45993293] 0.4717932842432391
CatBoost [0.32511248 0.50369762 0.42298773 0.33447697 0.50308684] 0.41787232702003096
CatBoost [0.49084212 0.55488901 0.52494819 0.4255866  0.46956737] 0.49316665855778447

基于上述结果,我们可以得出结论,自动构建的特征提高了所有模型的性能。尤其值得注意的是,自动构建的特征大幅度改进了XGBoost和随机森林的性能。基于这个实验的结果,我们可以得出结论,Evolutionary Forest不仅是一种有效的回归模型构建方法,可以构建一个强大的回归模型,也作为一个自动特征生成方法,可以用于生成可解释的特征以及提高现有机器学习系统的性能。自动构建的特征引起的改善如下图所示。

代码语言:javascript
复制
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(style="white", font_scale=1.5)
width = 0.4
fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(regressor_list, scores_base, width, label='Original Features')
difference = scores - scores_base
print(np.where(difference > 0, 'g', 'y'))
ax.bar(regressor_list, difference, width, bottom=scores_base,
       label='Constructed Features',
       color=np.where(difference > 0, 'r', 'y'))
ax.set_ylabel('Score ($R^2$)')
ax.set_title('Effect of Feature Construction')
ax.legend()
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

最后,基于这个简单的例子,我们验证了Evolutionary Forest能够发现有用的特征,并可以用来改进现有的机器学习系统。然而,需要注意的是,即使发现的特征提高了验证得分,也存在过拟合的风险。因此,在实际应用中,我们应该对获得的模型进行检验,以确保新构建模型的有效性。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-06-08,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 DrawSky 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档