前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >数据挖掘实战:基于机器学习的肺癌患者建模预测分类

数据挖掘实战:基于机器学习的肺癌患者建模预测分类

原创
作者头像
皮大大
发布2024-04-09 09:57:22
3563
发布2024-04-09 09:57:22
举报

公众号:尤而小屋 作者:Peter 编辑:Peter

大家好,我是Peter~

本文介绍一个完整的数据挖掘实战项目,主要内容包含:

  • 数据探索性分析EDA
  • 数据编码及因子化
  • 基于重要性的特征选择
  • 数据标准化
  • 交叉验证
  • 网格搜索
  • 分类模型评估
  • 基于eli5和shap的模型可解释性

引言

肺癌是全球范围内最常见的癌症之一,也是导致癌症相关死亡的主要原因。早期发现和诊断对于提高患者的生存率和治疗效果至关重要。

随着电子健康记录的普及,大量的医疗数据被数字化存储,包括患者的临床信息、影像学资料和生物标志物等,为机器学习模型的训练提供了丰富的数据资源。

通过机器学习模型对肺癌进行自动识别和分类,可以帮助医生更准确地诊断肺癌,尤其是在早期阶段,从而提高治疗效果。

1 导入库

项目地址:https://www.kaggle.com/code/michaelbryantds/lung-cancer-classification/notebook

导入建模所需要的各种库,包含数据处理、可视化、scikit-learn建模、模型可解释性

In 1:

代码语言:python
代码运行次数:0
复制
import pandas as pd
import numpy as np
from numpy import mean, std

from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()

import eli5  # pip install eli5,shap
from eli5.sklearn import PermutationImportance
import shap  

from sklearn.model_selection import GridSearchCV 
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
# 分类模型评估指标
from sklearn import metrics
from sklearn.metrics import accuracy_score,precision_score, matthews_corrcoef, confusion_matrix, classification_report
import scikitplot as skplt  # pip install scikit-plot
from sklearn.preprocessing import MinMaxScaler
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC

# 忽略警告
import warnings
warnings.filterwarnings("ignore")

2 读取数据

In 2:

代码语言:python
代码运行次数:0
复制
df = pd.read_csv("survey lung cancer.csv")
df.head()

Out2:

GENDER

AGE

SMOKING

YELLOW_FINGERS

ANXIETY

PEER_PRESSURE

CHRONIC DISEASE

FATIGUE

ALLERGY

WHEEZING

ALCOHOL CONSUMING

COUGHING

SHORTNESS OF BREATH

SWALLOWING DIFFICULTY

CHEST PAIN

LUNG_CANCER

0

M

69

1

2

2

1

1

2

1

2

2

2

2

2

2

YES

1

M

74

2

1

1

1

2

2

2

1

1

1

2

2

2

YES

2

F

59

1

1

1

2

1

2

1

2

1

2

2

1

2

NO

3

M

63

2

2

2

1

1

1

1

1

2

1

1

2

2

NO

4

F

63

1

2

1

1

1

1

1

2

1

2

2

1

1

NO

查看数据的基本信息:

1、整体的数据量

In 3:

代码语言:python
代码运行次数:0
复制
df.shape  # 1、整体的数据量

Out3:

代码语言:python
代码运行次数:0
复制
(309, 16)

2、数据字段信息:

In 4:

代码语言:python
代码运行次数:0
复制
df.columns  # 字段名称

Out4:

代码语言:python
代码运行次数:0
复制
Index(['GENDER', 'AGE', 'SMOKING', 'YELLOW_FINGERS', 'ANXIETY',
       'PEER_PRESSURE', 'CHRONIC DISEASE', 'FATIGUE ', 'ALLERGY ', 'WHEEZING',
       'ALCOHOL CONSUMING', 'COUGHING', 'SHORTNESS OF BREATH',
       'SWALLOWING DIFFICULTY', 'CHEST PAIN', 'LUNG_CANCER'],
      dtype='object')

In 5:

代码语言:python
代码运行次数:0
复制
df.dtypes  # 字段的不同数据类型

Out5:

代码语言:python
代码运行次数:0
复制
GENDER                   object
AGE                       int64
SMOKING                   int64
YELLOW_FINGERS            int64
ANXIETY                   int64
PEER_PRESSURE             int64
CHRONIC DISEASE           int64
FATIGUE                   int64
ALLERGY                   int64
WHEEZING                  int64
ALCOHOL CONSUMING         int64
COUGHING                  int64
SHORTNESS OF BREATH       int64
SWALLOWING DIFFICULTY     int64
CHEST PAIN                int64
LUNG_CANCER              object
dtype: object

可以看到本次数据中主要是字符型object和数值型int64的类型。

3、数据基本信息

In 6:

代码语言:python
代码运行次数:0
复制
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 309 entries, 0 to 308
Data columns (total 16 columns):
 #   Column                 Non-Null Count  Dtype 
---  ------                 --------------  ----- 
 0   GENDER                 309 non-null    object
 1   AGE                    309 non-null    int64 
 2   SMOKING                309 non-null    int64 
 3   YELLOW_FINGERS         309 non-null    int64 
 4   ANXIETY                309 non-null    int64 
 5   PEER_PRESSURE          309 non-null    int64 
 6   CHRONIC DISEASE        309 non-null    int64 
 7   FATIGUE                309 non-null    int64 
 8   ALLERGY                309 non-null    int64 
 9   WHEEZING               309 non-null    int64 
 10  ALCOHOL CONSUMING      309 non-null    int64 
 11  COUGHING               309 non-null    int64 
 12  SHORTNESS OF BREATH    309 non-null    int64 
 13  SWALLOWING DIFFICULTY  309 non-null    int64 
 14  CHEST PAIN             309 non-null    int64 
 15  LUNG_CANCER            309 non-null    object
dtypes: int64(14), object(2)
memory usage: 38.8+ KB

4、数据描述统计信息

In 7:

代码语言:python
代码运行次数:0
复制
df.describe()

Out7:

AGE

SMOKING

YELLOW_FINGERS

ANXIETY

PEER_PRESSURE

CHRONIC DISEASE

FATIGUE

ALLERGY

WHEEZING

ALCOHOL CONSUMING

COUGHING

SHORTNESS OF BREATH

SWALLOWING DIFFICULTY

CHEST PAIN

count

309.000000

309.000000

309.000000

309.000000

309.000000

309.000000

309.000000

309.000000

309.000000

309.000000

309.000000

309.000000

309.000000

309.000000

mean

62.673139

1.563107

1.569579

1.498382

1.501618

1.504854

1.673139

1.556634

1.556634

1.556634

1.579288

1.640777

1.469256

1.556634

std

8.210301

0.496806

0.495938

0.500808

0.500808

0.500787

0.469827

0.497588

0.497588

0.497588

0.494474

0.480551

0.499863

0.497588

min

21.000000

1.000000

1.000000

1.000000

1.000000

1.000000

1.000000

1.000000

1.000000

1.000000

1.000000

1.000000

1.000000

1.000000

25%

57.000000

1.000000

1.000000

1.000000

1.000000

1.000000

1.000000

1.000000

1.000000

1.000000

1.000000

1.000000

1.000000

1.000000

50%

62.000000

2.000000

2.000000

1.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

1.000000

2.000000

75%

69.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

max

87.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

2.000000

3 数据探索性分析EDA

In 8:

代码语言:python
代码运行次数:0
复制
# 数值型和分类型
numerical = ["AGE"]

categorical = ['GENDER','SMOKING','YELLOW_FINGERS','ANXIETY','PEER_PRESSURE',
               'CHRONIC DISEASE','FATIGUE ','ALLERGY ','WHEEZING','ALCOHOL CONSUMING',
               'COUGHING','SHORTNESS OF BREATH','SWALLOWING DIFFICULTY',
               'CHEST PAIN','LUNG_CANCER']

3.1 直方图hist

绘制不同数值型字段的直方图

In 9:

代码语言:python
代码运行次数:0
复制
# numerical

for i in df[numerical].columns:
    plt.hist(df[numerical][i])
    plt.xticks()
    plt.xlabel(i)
    plt.ylabel('Number of People')
    plt.show()

3.2 柱状图barplot

In 10:

代码语言:python
代码运行次数:0
复制
df["GENDER"].value_counts().index

Out10:

代码语言:python
代码运行次数:0
复制
Index(['M', 'F'], dtype='object', name='GENDER')

In 11:

代码语言:python
代码运行次数:0
复制
df["GENDER"].value_counts()

Out11:

代码语言:python
代码运行次数:0
复制
GENDER
M    162
F    147
Name: count, dtype: int64

In 12:

代码语言:python
代码运行次数:0
复制
# # categorical

# for i in df[categorical].columns:
#     sns.barplot(x=df[categorical][i].value_counts().index,  # 索引index即名称
#                 y=df[categorical][i].value_counts()  # 数量统计
#                )
    
#     plt.xlabel(i)
#     plt.ylabel("Number of People")
#     plt.show()

In 13:

代码语言:python
代码运行次数:0
复制
# 使用seaborn绘制柱状图
fig, axs = plt.subplots(5, 3, figsize=(15, 20))
axs = axs.flatten()

for i, col in enumerate(categorical):
    sns.barplot(x=df[col].value_counts().index, 
                y=df[col].value_counts(),  
                ax=axs[i]
               )
    axs[i].set_title(col)

# 调整子图间距
plt.tight_layout()
plt.show()

3.3 成对图pairplot

pairplot显示多个变量之间的成对关系

In 14:

代码语言:python
代码运行次数:0
复制
sns.pairplot(df, hue="LUNG_CANCER")
plt.legend()
plt.show()

4 数据预处理

为了方便后续的建模,对数据进行预处理:

In 15:

代码语言:python
代码运行次数:0
复制
categorical.remove("LUNG_CANCER")  # 目标字段

In 16:

代码语言:python
代码运行次数:0
复制
df[categorical] = df[categorical].astype("object")  # 强制转成字符型

4.1 特征&目标

In 17:

代码语言:python
代码运行次数:0
复制
X = df.copy()
y = X.pop("LUNG_CANCER")  # 提取目标字段
y

Out17:

代码语言:python
代码运行次数:0
复制
0      YES
1      YES
2       NO
3       NO
4       NO
      ... 
304    YES
305    YES
306    YES
307    YES
308    YES
Name: LUNG_CANCER, Length: 309, dtype: object

In 18:

代码语言:python
代码运行次数:0
复制
X.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 309 entries, 0 to 308
Data columns (total 15 columns):
 #   Column                 Non-Null Count  Dtype 
---  ------                 --------------  ----- 
 0   GENDER                 309 non-null    object
 1   AGE                    309 non-null    int64 
 2   SMOKING                309 non-null    object
 3   YELLOW_FINGERS         309 non-null    object
 4   ANXIETY                309 non-null    object
 5   PEER_PRESSURE          309 non-null    object
 6   CHRONIC DISEASE        309 non-null    object
 7   FATIGUE                309 non-null    object
 8   ALLERGY                309 non-null    object
 9   WHEEZING               309 non-null    object
 10  ALCOHOL CONSUMING      309 non-null    object
 11  COUGHING               309 non-null    object
 12  SHORTNESS OF BREATH    309 non-null    object
 13  SWALLOWING DIFFICULTY  309 non-null    object
 14  CHEST PAIN             309 non-null    object
dtypes: int64(1), object(14)
memory usage: 36.3+ KB

In 19:

代码语言:python
代码运行次数:0
复制
y.info()
<class 'pandas.core.series.Series'>
RangeIndex: 309 entries, 0 to 308
Series name: LUNG_CANCER
Non-Null Count  Dtype 
--------------  ----- 
309 non-null    object
dtypes: object(1)
memory usage: 2.5+ KB

4.2 数据编码:因子化factorize

In 20:

代码语言:python
代码运行次数:0
复制
X_model = X.copy()

In 21:

代码语言:python
代码运行次数:0
复制
for col in X_model.select_dtypes("object"):
    X_model[col], _ = X_model[col].factorize()   # 因子化过程
    
X_model["AGE"] = X_model["AGE"].astype("float64")  # 转成float64

In 22:

代码语言:python
代码运行次数:0
复制
X_model.dtypes

Out22:

代码语言:python
代码运行次数:0
复制
GENDER                     int64
AGE                      float64
SMOKING                    int64
YELLOW_FINGERS             int64
ANXIETY                    int64
PEER_PRESSURE              int64
CHRONIC DISEASE            int64
FATIGUE                    int64
ALLERGY                    int64
WHEEZING                   int64
ALCOHOL CONSUMING          int64
COUGHING                   int64
SHORTNESS OF BREATH        int64
SWALLOWING DIFFICULTY      int64
CHEST PAIN                 int64
dtype: object

In 23:

代码语言:python
代码运行次数:0
复制
# 判断X_model中的字段类型是否为int

discrete_features = X_model.dtypes == int

In 24:

代码语言:python
代码运行次数:0
复制
discrete_features

Out24:

代码语言:python
代码运行次数:0
复制
GENDER                   False
AGE                      False
SMOKING                  False
YELLOW_FINGERS           False
ANXIETY                  False
PEER_PRESSURE            False
CHRONIC DISEASE          False
FATIGUE                  False
ALLERGY                  False
WHEEZING                 False
ALCOHOL CONSUMING        False
COUGHING                 False
SHORTNESS OF BREATH      False
SWALLOWING DIFFICULTY    False
CHEST PAIN               False
dtype: bool

4.3 特征选择

In 25:

代码语言:python
代码运行次数:0
复制
from sklearn.feature_selection import mutual_info_classif

In 26:

代码语言:python
代码运行次数:0
复制
def calculate_mic_scores(X_model, y, discrete_features):
    mic = mutual_info_classif(X_model, y, discrete_features=discrete_features)  # 计算特征X_model和木匾变量y的mic
    mic = pd.Series(mic, name="MIC SCORES", index=X_model.columns)  # 转成Series数据并排序
    mic = mic.sort_values(ascending=False)
    return mic

In 27:

代码语言:python
代码运行次数:0
复制
mic = calculate_mic_scores(X_model,y,discrete_features)
mic

Out27:

代码语言:python
代码运行次数:0
复制
ALLERGY                  0.059814
CHEST PAIN               0.058419
ALCOHOL CONSUMING        0.042346
COUGHING                 0.040891
SMOKING                  0.033635
SWALLOWING DIFFICULTY    0.027580
AGE                      0.025290
ANXIETY                  0.011860
WHEEZING                 0.007738
PEER_PRESSURE            0.005157
CHRONIC DISEASE          0.003491
FATIGUE                  0.003473
GENDER                   0.000000
YELLOW_FINGERS           0.000000
SHORTNESS OF BREATH      0.000000
Name: MIC SCORES, dtype: float64

4.4 特征重要性绘图

In 28:

代码语言:python
代码运行次数:0
复制
def plot_mic(scores):
    scores = scores.sort_values(ascending=True)
    width = np.arange(len(scores))
    
    ticks = list(scores.index)
    plt.barh(width,scores)
    plt.yticks(width,ticks)
    plt.title("Mutual Information Scores")

In 29:

代码语言:python
代码运行次数:0
复制
plt.figure(dpi=100, figsize=(8,5))
plot_mic(mic)   # 上面计算的mic得分代入

5 建模

5.1 切分数据

In 30:

代码语言:python
代码运行次数:0
复制
X = pd.concat([X[numerical], pd.get_dummies(X[categorical])],axis=1)
X.head()

Out30:

AGE

GENDER_F

GENDER_M

SMOKING_1

SMOKING_2

YELLOW_FINGERS_1

YELLOW_FINGERS_2

ANXIETY_1

ANXIETY_2

PEER_PRESSURE_1

...

ALCOHOL CONSUMING_1

ALCOHOL CONSUMING_2

COUGHING_1

COUGHING_2

SHORTNESS OF BREATH_1

SHORTNESS OF BREATH_2

SWALLOWING DIFFICULTY_1

SWALLOWING DIFFICULTY_2

CHEST PAIN_1

CHEST PAIN_2

0

69

False

True

True

False

False

True

False

True

True

...

False

True

False

True

False

True

False

True

False

True

1

74

False

True

False

True

True

False

True

False

True

...

True

False

True

False

False

True

False

True

False

True

2

59

True

False

True

False

True

False

True

False

False

...

True

False

False

True

False

True

True

False

False

True

3

63

False

True

False

True

False

True

False

True

True

...

False

True

True

False

True

False

False

True

False

True

4

63

True

False

True

False

False

True

True

False

True

...

True

False

False

True

False

True

True

False

True

False

5 rows × 29 columns

In 31:

代码语言:python
代码运行次数:0
复制
feature_names = X.columns

In 32:

代码语言:python
代码运行次数:0
复制
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, train_size=0.75, random_state=1)  # 根据y中的比例进行划分

5.2 数据标准化

In 33:

代码语言:python
代码运行次数:0
复制
mm = MinMaxScaler()

In 34:

代码语言:python
代码运行次数:0
复制
X_train[numerical] = mm.fit_transform(X_train[numerical])
X_test[numerical] = mm.transform(X_test[numerical])

5.3 交叉验证cross_val_score

5.3.1 LogisticRegression

In 35:

代码语言:python
代码运行次数:0
复制
lr = LogisticRegression(max_iter=2000)
cv = cross_val_score(lr, X_train,y_train,cv=5)

# 逻辑回归模型在5折交叉验证下的均值和标准差
print(mean(cv), "±", np.std(cv))
0.9308048103607771 ± 0.0158848177577188
5.3.2 RandomForestClassifier

In 36:

代码语言:python
代码运行次数:0
复制
rf = RandomForestClassifier(random_state = 1)
cv = cross_val_score(rf,X_train,y_train,cv=5)

print(mean(cv), "±", std(cv))
0.9350601295097132 ± 0.013760232121339587
5.3.3 Support Vector Classifier

In 37:

代码语言:python
代码运行次数:0
复制
svc = SVC(probability = True)
cv = cross_val_score(svc,X_train,y_train,cv=5)

print(mean(cv), "±", std(cv))
0.9351526364477335 ± 0.023485469326540147

5.4 网格搜索GridSearchCV

网格搜索(Grid Search)是一种在机器学习中用于模型超参数优化的方法。它通过遍历所有的超参数组合来找到最佳的参数设置,从而使得模型在给定的任务上达到最优的性能。

网格搜索的基本思想是为每个超参数设定一个范围或者列表,然后尝试所有可能的组合。具体来说,算法会为每个超参数生成一个候选值列表,然后将这些列表进行笛卡尔积运算,生成所有可能的参数组合。

之后,算法会使用这些组合来训练模型,并通过交叉验证等方式评估每个模型的性能。最终,算法会选择表现最好的参数组合作为最优解。

In 38:

代码语言:python
代码运行次数:0
复制
def clf_performance(model, model_name):
    print("模型名称:", model_name)
    
    # 双引号里面只能使用单引号
    print(f"Best Score: {str(model.best_score_)} ± {str(model.cv_results_['std_test_score'][model.best_index_])}")
    print(f"Best Parameters: {str(model.best_params_)}")
5.4.1 LogisticRegression

In 39:

代码语言:python
代码运行次数:0
复制
lr = LogisticRegression()

# 参数设置
param_grid = {"max_iter":[15000],
              "C":np.arange(0.1,0.6,0.1)
             }


# 对逻辑回归模型的网格搜索调参
clf_lr = GridSearchCV(lr, param_grid=param_grid, cv=5, n_jobs=-1)
best_clf_lr = clf_lr.fit(X_train, y_train)
clf_performance(best_clf_lr, "Logistic Regression")
模型名称: Logistic Regression
Best Score: 0.9264569842738206 ± 0.01712752904500742
Best Parameters: {'C': 0.30000000000000004, 'max_iter': 15000}
5.4.2 RandomForestClassifier

In 40:

代码语言:python
代码运行次数:0
复制
rf = RandomForestClassifier(random_state=1)

# 待搜索的参数组合
param_grid = {
    "n_estimators":np.arange(8,20,2),
    "bootstrap":[True, False],
    "max_depth":[10],
    "max_features":["atuo","sqrt"],
    "min_samples_leaf":np.arange(2,6,1),
    "min_samples_split":np.arange(2,6,1)
}

clf_rf = GridSearchCV(rf, param_grid=param_grid,cv=5,n_jobs=-1)
best_clf_rf = clf_rf.fit(X_train, y_train)
clf_performance(best_clf_rf, "Random Forest")
模型名称: Random Forest
Best Score: 0.9481961147086032 ± 0.021593980755705847
Best Parameters: {'bootstrap': False, 'max_depth': 10, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 8}
5.4.3 Support Vector Classifier

In 41:

代码语言:python
代码运行次数:0
复制
svc = SVC(probability=True, random_state=1)

param_grid = {'kernel': ['linear', 'poly', 'sigmoid','rbf'], # 核函数选择
              'gamma': [1, 1e-1, 1e-2, 1e-3, 1e-4],  #  核函数参数
              'C': np.arange(40,70,5)  # 正则化参数
             }

clf_svc = GridSearchCV(svc, param_grid = param_grid, cv = 5, n_jobs = -1)
best_clf_svc = clf_svc.fit(X_train,y_train)
clf_performance(best_clf_svc,'Support Vector Classifier')
模型名称: Support Vector Classifier
Best Score: 0.9438482886216466 ± 0.016747588503435138
Best Parameters: {'C': 50, 'gamma': 1, 'kernel': 'linear'}

5.5 模型评估

使用基于网格搜索找到的最佳参数组合进行建模,然后对模型进行评估:

5.5.1 LogisticRegression

In 42:

代码语言:python
代码运行次数:0
复制
lr = LogisticRegression(C=0.3, max_iter=15000)  # 最佳参数组合

lr.fit(X_train, y_train) # 模型训练
y_pred_lr = lr.predict(X_test)  # 模型预测

# 准确率
print('LogisticRegression test accuracy: {}'.format(accuracy_score(y_test, y_pred_lr)))
LogisticRegression test accuracy: 0.8461538461538461

1、计算特征的重要性程度

In 43:

代码语言:python
代码运行次数:0
复制
lr_coef = pd.DataFrame([lr.coef_[0]], columns=feature_names)  # X中的特征名称
sorted_lr = lr_coef.iloc[:, np.argsort(lr_coef.loc[0])]
sorted_lr

Out43:

SWALLOWING DIFFICULTY_1

ALLERGY _1

ALCOHOL CONSUMING_1

COUGHING_1

PEER_PRESSURE_1

FATIGUE _1

YELLOW_FINGERS_1

WHEEZING_1

ANXIETY_1

CHEST PAIN_1

...

CHEST PAIN_2

ANXIETY_2

WHEEZING_2

YELLOW_FINGERS_2

FATIGUE _2

PEER_PRESSURE_2

COUGHING_2

ALCOHOL CONSUMING_2

ALLERGY _2

SWALLOWING DIFFICULTY_2

0

-0.641878

-0.613902

-0.568027

-0.530684

-0.491953

-0.474459

-0.469793

-0.409697

-0.404706

-0.378352

...

0.378346

0.404699

0.409691

0.469787

0.474453

0.491947

0.530678

0.568021

0.613895

0.641872

1 rows × 29 columns

In 44:

代码语言:python
代码运行次数:0
复制
plt.figure(figsize=(16,7))
sns.barplot(x=sorted_lr.columns, y=sorted_lr.iloc[0,:])

plt.xticks(rotation = 90) # 旋转角度
plt.xlabel('Features')  # x-y轴名称
plt.ylabel('LogisticRegression Coefficients')
plt.show()

2、绘制混淆矩阵(有轴名称和显示数值):

In 45:

代码语言:python
代码运行次数:0
复制
matrix = confusion_matrix(y_test, y_pred_lr)
matrix = matrix.astype("float") / matrix.sum(axis=1)[:, np.newaxis]

plt.figure(figsize=(12,7))  # 图像大小
sns.set(font_scale=1.4)
sns.heatmap(matrix,  # 混淆矩阵
            annot=True, # 
            annot_kws={'size':10}, # 是否显示边框
            linewidths=0.2, # 线宽
            vmin=0,  # 色条的最值
            vmax=1)

class_names = ["Lung cancer", "No Lung cancer"]
tick_marks = np.arange(len(class_names))
tick_marks2 = tick_marks + 0.5
plt.xticks(tick_marks, class_names, rotation=25)
plt.yticks(tick_marks2, class_names, rotation=0)
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.title('Confusion Matrix for LogisticRegression')
plt.show()

3、计算分类得分报告classification_report:

In 46:

代码语言:python
代码运行次数:0
复制
print('LogisticRegression: ')
print(classification_report(y_test, y_pred_lr))
LogisticRegression: 
              precision    recall  f1-score   support

          NO       0.33      0.20      0.25        10
         YES       0.89      0.94      0.91        68

    accuracy                           0.85        78
   macro avg       0.61      0.57      0.58        78
weighted avg       0.82      0.85      0.83        78

4、绘制ROC曲线:

In 47:

代码语言:python
代码运行次数:0
复制
from sklearn import metrics

pred_prob_lr = lr.predict_proba(X_test)  # 预测概率值

fpr_lr, tpr_lr, thresholds_lr = metrics.roc_curve(y_test, pred_prob_lr[:,1],pos_label='YES')

fig, ax = plt.subplots(figsize=(16, 10))
ax.plot(fpr_lr, tpr_lr, label='LogisticRegression')
ax.plot([0, 1], [0, 1], transform=ax.transAxes, ls="--", c=".3")
plt.legend()
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.rcParams['font.size'] = 12
plt.title('ROC curve for LogisticRegression Classifiers')
plt.xlabel('False Positive Rate (1 - Specificity)')
plt.ylabel('True Positive Rate (Sensitivity)')
plt.show()

5、计算AUC值:

In 48:

代码语言:python
代码运行次数:0
复制
print('AUC of LogisticRegression: {}'.format(metrics.auc(fpr_lr, tpr_lr)))
AUC of LogisticRegression: 0.8764705882352941
5.5.2 Support Vector Classifier

In 49:

代码语言:python
代码运行次数:0
复制
# 1、最佳参数组合建模

svc = SVC(probability = True, random_state = 1,C= 50, gamma = 1, kernel= 'linear')  # 最佳参数组合
svc.fit(X_train,y_train)
y_pred_svc = svc.predict(X_test)

# SVC模型的准确率
print('SVC test accuracy: {}'.format(accuracy_score(y_test, y_pred_svc)))
SVC test accuracy: 0.8974358974358975

In 50:

代码语言:python
代码运行次数:0
复制
# 2、绘制混淆矩阵

matrix = confusion_matrix(y_test, y_pred_svc)
matrix = matrix.astype("float") / matrix.sum(axis=1)[:, np.newaxis]

plt.figure(figsize=(12,7))  # 图像大小
sns.set(font_scale=1.4)
sns.heatmap(matrix,  # 混淆矩阵
            annot=True, # 
            annot_kws={'size':10}, # 是否显示边框
            linewidths=0.2, # 线宽
            vmin=0,  # 色条的最值
            vmax=1)

class_names = ["Lung cancer", "No Lung cancer"]
tick_marks = np.arange(len(class_names))
tick_marks2 = tick_marks + 0.5
plt.xticks(tick_marks, class_names, rotation=25)
plt.yticks(tick_marks2, class_names, rotation=0)
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.title('Confusion Matrix for SVC')
plt.show()

输出分类报告和绘制ROC曲线:

In 51:

代码语言:python
代码运行次数:0
复制
# 3、打印分类报告

print('SVC: ')
print(classification_report(y_test, y_pred_svc))
SVC: 
              precision    recall  f1-score   support

          NO       0.62      0.50      0.56        10
         YES       0.93      0.96      0.94        68

    accuracy                           0.90        78
   macro avg       0.78      0.73      0.75        78
weighted avg       0.89      0.90      0.89        78

In 52:

代码语言:python
代码运行次数:0
复制
# 4、ROC曲线绘制

from sklearn import metrics

pred_prob_svc = svc.predict_proba(X_test)  # 预测概率值

fpr_svc, tpr_svc, thresholds_svc = metrics.roc_curve(y_test, pred_prob_svc[:,1],pos_label='YES')

fig, ax = plt.subplots(figsize=(16, 10))
ax.plot(fpr_svc, tpr_svc, label='SVC')
ax.plot([0, 1], [0, 1], transform=ax.transAxes, ls="--", c=".3")
plt.legend()
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.rcParams['font.size'] = 12
plt.title('ROC curve for SVC ')
plt.xlabel('False Positive Rate (1 - Specificity)')
plt.ylabel('True Positive Rate (Sensitivity)')
plt.show()

最后计算AUC值:

In 53:

代码语言:python
代码运行次数:0
复制
print('AUC of SVC: {}'.format(metrics.auc(fpr_svc, tpr_svc)))
AUC of SVC: 0.9058823529411765

6 模型可视化

6.1 eli5

In 54:

代码语言:python
代码运行次数:0
复制
import eli5
from eli5.sklearn import PermutationImportance

In 55:

代码语言:python
代码运行次数:0
复制
perm = PermutationImportance(svc, random_state=1).fit(X_test, y_test)  # 指定使用svc模型进行测试集上的训练
eli5.show_weights(perm, feature_names=list(feature_names),top=len(feature_names))

6.2 shap

In 56:

代码语言:python
代码运行次数:0
复制
import shap

explainer = shap.KernelExplainer(svc.predict_proba, X_train)
pred_data = pd.DataFrame(X_test)
pred_data.columns = feature_names
data_for_prediction = pred_data
Using 231 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.

In 57:

代码语言:python
代码运行次数:0
复制
shap_values = explainer.shap_values(data_for_prediction)
shap.initjs()

shap.summary_plot(shap_values[1], data_for_prediction)

需要jupyter notebook源码和本文数据的同学,请直接联系小编,非诚勿扰~

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 引言
  • 1 导入库
  • 2 读取数据
  • 3 数据探索性分析EDA
    • 3.1 直方图hist
      • 3.2 柱状图barplot
        • 3.3 成对图pairplot
        • 4 数据预处理
          • 4.1 特征&目标
            • 4.2 数据编码:因子化factorize
              • 4.3 特征选择
                • 4.4 特征重要性绘图
                • 5 建模
                  • 5.1 切分数据
                    • 5.2 数据标准化
                      • 5.3 交叉验证cross_val_score
                        • 5.3.1 LogisticRegression
                        • 5.3.2 RandomForestClassifier
                        • 5.3.3 Support Vector Classifier
                      • 5.4 网格搜索GridSearchCV
                        • 5.4.1 LogisticRegression
                        • 5.4.2 RandomForestClassifier
                        • 5.4.3 Support Vector Classifier
                      • 5.5 模型评估
                        • 5.5.1 LogisticRegression
                        • 5.5.2 Support Vector Classifier
                    • 6 模型可视化
                      • 6.1 eli5
                        • 6.2 shap
                        领券
                        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档