公众号:尤而小屋 作者:Peter 编辑:Peter
大家好,我是Peter~
本文介绍一个完整的数据挖掘实战项目,主要内容包含:
肺癌是全球范围内最常见的癌症之一,也是导致癌症相关死亡的主要原因。早期发现和诊断对于提高患者的生存率和治疗效果至关重要。
随着电子健康记录的普及,大量的医疗数据被数字化存储,包括患者的临床信息、影像学资料和生物标志物等,为机器学习模型的训练提供了丰富的数据资源。
通过机器学习模型对肺癌进行自动识别和分类,可以帮助医生更准确地诊断肺癌,尤其是在早期阶段,从而提高治疗效果。
项目地址:https://www.kaggle.com/code/michaelbryantds/lung-cancer-classification/notebook
导入建模所需要的各种库,包含数据处理、可视化、scikit-learn建模、模型可解释性
In 1:
import pandas as pd
import numpy as np
from numpy import mean, std
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
import eli5 # pip install eli5,shap
from eli5.sklearn import PermutationImportance
import shap
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
# 分类模型评估指标
from sklearn import metrics
from sklearn.metrics import accuracy_score,precision_score, matthews_corrcoef, confusion_matrix, classification_report
import scikitplot as skplt # pip install scikit-plot
from sklearn.preprocessing import MinMaxScaler
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
# 忽略警告
import warnings
warnings.filterwarnings("ignore")
In 2:
df = pd.read_csv("survey lung cancer.csv")
df.head()
Out2:
GENDER | AGE | SMOKING | YELLOW_FINGERS | ANXIETY | PEER_PRESSURE | CHRONIC DISEASE | FATIGUE | ALLERGY | WHEEZING | ALCOHOL CONSUMING | COUGHING | SHORTNESS OF BREATH | SWALLOWING DIFFICULTY | CHEST PAIN | LUNG_CANCER | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | M | 69 | 1 | 2 | 2 | 1 | 1 | 2 | 1 | 2 | 2 | 2 | 2 | 2 | 2 | YES |
1 | M | 74 | 2 | 1 | 1 | 1 | 2 | 2 | 2 | 1 | 1 | 1 | 2 | 2 | 2 | YES |
2 | F | 59 | 1 | 1 | 1 | 2 | 1 | 2 | 1 | 2 | 1 | 2 | 2 | 1 | 2 | NO |
3 | M | 63 | 2 | 2 | 2 | 1 | 1 | 1 | 1 | 1 | 2 | 1 | 1 | 2 | 2 | NO |
4 | F | 63 | 1 | 2 | 1 | 1 | 1 | 1 | 1 | 2 | 1 | 2 | 2 | 1 | 1 | NO |
查看数据的基本信息:
1、整体的数据量
In 3:
df.shape # 1、整体的数据量
Out3:
(309, 16)
2、数据字段信息:
In 4:
df.columns # 字段名称
Out4:
Index(['GENDER', 'AGE', 'SMOKING', 'YELLOW_FINGERS', 'ANXIETY',
'PEER_PRESSURE', 'CHRONIC DISEASE', 'FATIGUE ', 'ALLERGY ', 'WHEEZING',
'ALCOHOL CONSUMING', 'COUGHING', 'SHORTNESS OF BREATH',
'SWALLOWING DIFFICULTY', 'CHEST PAIN', 'LUNG_CANCER'],
dtype='object')
In 5:
df.dtypes # 字段的不同数据类型
Out5:
GENDER object
AGE int64
SMOKING int64
YELLOW_FINGERS int64
ANXIETY int64
PEER_PRESSURE int64
CHRONIC DISEASE int64
FATIGUE int64
ALLERGY int64
WHEEZING int64
ALCOHOL CONSUMING int64
COUGHING int64
SHORTNESS OF BREATH int64
SWALLOWING DIFFICULTY int64
CHEST PAIN int64
LUNG_CANCER object
dtype: object
可以看到本次数据中主要是字符型object和数值型int64的类型。
3、数据基本信息
In 6:
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 309 entries, 0 to 308
Data columns (total 16 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 GENDER 309 non-null object
1 AGE 309 non-null int64
2 SMOKING 309 non-null int64
3 YELLOW_FINGERS 309 non-null int64
4 ANXIETY 309 non-null int64
5 PEER_PRESSURE 309 non-null int64
6 CHRONIC DISEASE 309 non-null int64
7 FATIGUE 309 non-null int64
8 ALLERGY 309 non-null int64
9 WHEEZING 309 non-null int64
10 ALCOHOL CONSUMING 309 non-null int64
11 COUGHING 309 non-null int64
12 SHORTNESS OF BREATH 309 non-null int64
13 SWALLOWING DIFFICULTY 309 non-null int64
14 CHEST PAIN 309 non-null int64
15 LUNG_CANCER 309 non-null object
dtypes: int64(14), object(2)
memory usage: 38.8+ KB
4、数据描述统计信息
In 7:
df.describe()
Out7:
AGE | SMOKING | YELLOW_FINGERS | ANXIETY | PEER_PRESSURE | CHRONIC DISEASE | FATIGUE | ALLERGY | WHEEZING | ALCOHOL CONSUMING | COUGHING | SHORTNESS OF BREATH | SWALLOWING DIFFICULTY | CHEST PAIN | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 309.000000 | 309.000000 | 309.000000 | 309.000000 | 309.000000 | 309.000000 | 309.000000 | 309.000000 | 309.000000 | 309.000000 | 309.000000 | 309.000000 | 309.000000 | 309.000000 |
mean | 62.673139 | 1.563107 | 1.569579 | 1.498382 | 1.501618 | 1.504854 | 1.673139 | 1.556634 | 1.556634 | 1.556634 | 1.579288 | 1.640777 | 1.469256 | 1.556634 |
std | 8.210301 | 0.496806 | 0.495938 | 0.500808 | 0.500808 | 0.500787 | 0.469827 | 0.497588 | 0.497588 | 0.497588 | 0.494474 | 0.480551 | 0.499863 | 0.497588 |
min | 21.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 |
25% | 57.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 |
50% | 62.000000 | 2.000000 | 2.000000 | 1.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 1.000000 | 2.000000 |
75% | 69.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 |
max | 87.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 | 2.000000 |
In 8:
# 数值型和分类型
numerical = ["AGE"]
categorical = ['GENDER','SMOKING','YELLOW_FINGERS','ANXIETY','PEER_PRESSURE',
'CHRONIC DISEASE','FATIGUE ','ALLERGY ','WHEEZING','ALCOHOL CONSUMING',
'COUGHING','SHORTNESS OF BREATH','SWALLOWING DIFFICULTY',
'CHEST PAIN','LUNG_CANCER']
绘制不同数值型字段的直方图
In 9:
# numerical
for i in df[numerical].columns:
plt.hist(df[numerical][i])
plt.xticks()
plt.xlabel(i)
plt.ylabel('Number of People')
plt.show()
In 10:
df["GENDER"].value_counts().index
Out10:
Index(['M', 'F'], dtype='object', name='GENDER')
In 11:
df["GENDER"].value_counts()
Out11:
GENDER
M 162
F 147
Name: count, dtype: int64
In 12:
# # categorical
# for i in df[categorical].columns:
# sns.barplot(x=df[categorical][i].value_counts().index, # 索引index即名称
# y=df[categorical][i].value_counts() # 数量统计
# )
# plt.xlabel(i)
# plt.ylabel("Number of People")
# plt.show()
In 13:
# 使用seaborn绘制柱状图
fig, axs = plt.subplots(5, 3, figsize=(15, 20))
axs = axs.flatten()
for i, col in enumerate(categorical):
sns.barplot(x=df[col].value_counts().index,
y=df[col].value_counts(),
ax=axs[i]
)
axs[i].set_title(col)
# 调整子图间距
plt.tight_layout()
plt.show()
pairplot显示多个变量之间的成对关系
In 14:
sns.pairplot(df, hue="LUNG_CANCER")
plt.legend()
plt.show()
为了方便后续的建模,对数据进行预处理:
In 15:
categorical.remove("LUNG_CANCER") # 目标字段
In 16:
df[categorical] = df[categorical].astype("object") # 强制转成字符型
In 17:
X = df.copy()
y = X.pop("LUNG_CANCER") # 提取目标字段
y
Out17:
0 YES
1 YES
2 NO
3 NO
4 NO
...
304 YES
305 YES
306 YES
307 YES
308 YES
Name: LUNG_CANCER, Length: 309, dtype: object
In 18:
X.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 309 entries, 0 to 308
Data columns (total 15 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 GENDER 309 non-null object
1 AGE 309 non-null int64
2 SMOKING 309 non-null object
3 YELLOW_FINGERS 309 non-null object
4 ANXIETY 309 non-null object
5 PEER_PRESSURE 309 non-null object
6 CHRONIC DISEASE 309 non-null object
7 FATIGUE 309 non-null object
8 ALLERGY 309 non-null object
9 WHEEZING 309 non-null object
10 ALCOHOL CONSUMING 309 non-null object
11 COUGHING 309 non-null object
12 SHORTNESS OF BREATH 309 non-null object
13 SWALLOWING DIFFICULTY 309 non-null object
14 CHEST PAIN 309 non-null object
dtypes: int64(1), object(14)
memory usage: 36.3+ KB
In 19:
y.info()
<class 'pandas.core.series.Series'>
RangeIndex: 309 entries, 0 to 308
Series name: LUNG_CANCER
Non-Null Count Dtype
-------------- -----
309 non-null object
dtypes: object(1)
memory usage: 2.5+ KB
In 20:
X_model = X.copy()
In 21:
for col in X_model.select_dtypes("object"):
X_model[col], _ = X_model[col].factorize() # 因子化过程
X_model["AGE"] = X_model["AGE"].astype("float64") # 转成float64
In 22:
X_model.dtypes
Out22:
GENDER int64
AGE float64
SMOKING int64
YELLOW_FINGERS int64
ANXIETY int64
PEER_PRESSURE int64
CHRONIC DISEASE int64
FATIGUE int64
ALLERGY int64
WHEEZING int64
ALCOHOL CONSUMING int64
COUGHING int64
SHORTNESS OF BREATH int64
SWALLOWING DIFFICULTY int64
CHEST PAIN int64
dtype: object
In 23:
# 判断X_model中的字段类型是否为int
discrete_features = X_model.dtypes == int
In 24:
discrete_features
Out24:
GENDER False
AGE False
SMOKING False
YELLOW_FINGERS False
ANXIETY False
PEER_PRESSURE False
CHRONIC DISEASE False
FATIGUE False
ALLERGY False
WHEEZING False
ALCOHOL CONSUMING False
COUGHING False
SHORTNESS OF BREATH False
SWALLOWING DIFFICULTY False
CHEST PAIN False
dtype: bool
In 25:
from sklearn.feature_selection import mutual_info_classif
In 26:
def calculate_mic_scores(X_model, y, discrete_features):
mic = mutual_info_classif(X_model, y, discrete_features=discrete_features) # 计算特征X_model和木匾变量y的mic
mic = pd.Series(mic, name="MIC SCORES", index=X_model.columns) # 转成Series数据并排序
mic = mic.sort_values(ascending=False)
return mic
In 27:
mic = calculate_mic_scores(X_model,y,discrete_features)
mic
Out27:
ALLERGY 0.059814
CHEST PAIN 0.058419
ALCOHOL CONSUMING 0.042346
COUGHING 0.040891
SMOKING 0.033635
SWALLOWING DIFFICULTY 0.027580
AGE 0.025290
ANXIETY 0.011860
WHEEZING 0.007738
PEER_PRESSURE 0.005157
CHRONIC DISEASE 0.003491
FATIGUE 0.003473
GENDER 0.000000
YELLOW_FINGERS 0.000000
SHORTNESS OF BREATH 0.000000
Name: MIC SCORES, dtype: float64
In 28:
def plot_mic(scores):
scores = scores.sort_values(ascending=True)
width = np.arange(len(scores))
ticks = list(scores.index)
plt.barh(width,scores)
plt.yticks(width,ticks)
plt.title("Mutual Information Scores")
In 29:
plt.figure(dpi=100, figsize=(8,5))
plot_mic(mic) # 上面计算的mic得分代入
In 30:
X = pd.concat([X[numerical], pd.get_dummies(X[categorical])],axis=1)
X.head()
Out30:
AGE | GENDER_F | GENDER_M | SMOKING_1 | SMOKING_2 | YELLOW_FINGERS_1 | YELLOW_FINGERS_2 | ANXIETY_1 | ANXIETY_2 | PEER_PRESSURE_1 | ... | ALCOHOL CONSUMING_1 | ALCOHOL CONSUMING_2 | COUGHING_1 | COUGHING_2 | SHORTNESS OF BREATH_1 | SHORTNESS OF BREATH_2 | SWALLOWING DIFFICULTY_1 | SWALLOWING DIFFICULTY_2 | CHEST PAIN_1 | CHEST PAIN_2 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 69 | False | True | True | False | False | True | False | True | True | ... | False | True | False | True | False | True | False | True | False | True |
1 | 74 | False | True | False | True | True | False | True | False | True | ... | True | False | True | False | False | True | False | True | False | True |
2 | 59 | True | False | True | False | True | False | True | False | False | ... | True | False | False | True | False | True | True | False | False | True |
3 | 63 | False | True | False | True | False | True | False | True | True | ... | False | True | True | False | True | False | False | True | False | True |
4 | 63 | True | False | True | False | False | True | True | False | True | ... | True | False | False | True | False | True | True | False | True | False |
5 rows × 29 columns
In 31:
feature_names = X.columns
In 32:
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, train_size=0.75, random_state=1) # 根据y中的比例进行划分
In 33:
mm = MinMaxScaler()
In 34:
X_train[numerical] = mm.fit_transform(X_train[numerical])
X_test[numerical] = mm.transform(X_test[numerical])
In 35:
lr = LogisticRegression(max_iter=2000)
cv = cross_val_score(lr, X_train,y_train,cv=5)
# 逻辑回归模型在5折交叉验证下的均值和标准差
print(mean(cv), "±", np.std(cv))
0.9308048103607771 ± 0.0158848177577188
In 36:
rf = RandomForestClassifier(random_state = 1)
cv = cross_val_score(rf,X_train,y_train,cv=5)
print(mean(cv), "±", std(cv))
0.9350601295097132 ± 0.013760232121339587
In 37:
svc = SVC(probability = True)
cv = cross_val_score(svc,X_train,y_train,cv=5)
print(mean(cv), "±", std(cv))
0.9351526364477335 ± 0.023485469326540147
网格搜索(Grid Search)是一种在机器学习中用于模型超参数优化的方法。它通过遍历所有的超参数组合来找到最佳的参数设置,从而使得模型在给定的任务上达到最优的性能。
网格搜索的基本思想是为每个超参数设定一个范围或者列表,然后尝试所有可能的组合。具体来说,算法会为每个超参数生成一个候选值列表,然后将这些列表进行笛卡尔积运算,生成所有可能的参数组合。
之后,算法会使用这些组合来训练模型,并通过交叉验证等方式评估每个模型的性能。最终,算法会选择表现最好的参数组合作为最优解。
In 38:
def clf_performance(model, model_name):
print("模型名称:", model_name)
# 双引号里面只能使用单引号
print(f"Best Score: {str(model.best_score_)} ± {str(model.cv_results_['std_test_score'][model.best_index_])}")
print(f"Best Parameters: {str(model.best_params_)}")
In 39:
lr = LogisticRegression()
# 参数设置
param_grid = {"max_iter":[15000],
"C":np.arange(0.1,0.6,0.1)
}
# 对逻辑回归模型的网格搜索调参
clf_lr = GridSearchCV(lr, param_grid=param_grid, cv=5, n_jobs=-1)
best_clf_lr = clf_lr.fit(X_train, y_train)
clf_performance(best_clf_lr, "Logistic Regression")
模型名称: Logistic Regression
Best Score: 0.9264569842738206 ± 0.01712752904500742
Best Parameters: {'C': 0.30000000000000004, 'max_iter': 15000}
In 40:
rf = RandomForestClassifier(random_state=1)
# 待搜索的参数组合
param_grid = {
"n_estimators":np.arange(8,20,2),
"bootstrap":[True, False],
"max_depth":[10],
"max_features":["atuo","sqrt"],
"min_samples_leaf":np.arange(2,6,1),
"min_samples_split":np.arange(2,6,1)
}
clf_rf = GridSearchCV(rf, param_grid=param_grid,cv=5,n_jobs=-1)
best_clf_rf = clf_rf.fit(X_train, y_train)
clf_performance(best_clf_rf, "Random Forest")
模型名称: Random Forest
Best Score: 0.9481961147086032 ± 0.021593980755705847
Best Parameters: {'bootstrap': False, 'max_depth': 10, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 8}
In 41:
svc = SVC(probability=True, random_state=1)
param_grid = {'kernel': ['linear', 'poly', 'sigmoid','rbf'], # 核函数选择
'gamma': [1, 1e-1, 1e-2, 1e-3, 1e-4], # 核函数参数
'C': np.arange(40,70,5) # 正则化参数
}
clf_svc = GridSearchCV(svc, param_grid = param_grid, cv = 5, n_jobs = -1)
best_clf_svc = clf_svc.fit(X_train,y_train)
clf_performance(best_clf_svc,'Support Vector Classifier')
模型名称: Support Vector Classifier
Best Score: 0.9438482886216466 ± 0.016747588503435138
Best Parameters: {'C': 50, 'gamma': 1, 'kernel': 'linear'}
使用基于网格搜索找到的最佳参数组合进行建模,然后对模型进行评估:
In 42:
lr = LogisticRegression(C=0.3, max_iter=15000) # 最佳参数组合
lr.fit(X_train, y_train) # 模型训练
y_pred_lr = lr.predict(X_test) # 模型预测
# 准确率
print('LogisticRegression test accuracy: {}'.format(accuracy_score(y_test, y_pred_lr)))
LogisticRegression test accuracy: 0.8461538461538461
1、计算特征的重要性程度
In 43:
lr_coef = pd.DataFrame([lr.coef_[0]], columns=feature_names) # X中的特征名称
sorted_lr = lr_coef.iloc[:, np.argsort(lr_coef.loc[0])]
sorted_lr
Out43:
SWALLOWING DIFFICULTY_1 | ALLERGY _1 | ALCOHOL CONSUMING_1 | COUGHING_1 | PEER_PRESSURE_1 | FATIGUE _1 | YELLOW_FINGERS_1 | WHEEZING_1 | ANXIETY_1 | CHEST PAIN_1 | ... | CHEST PAIN_2 | ANXIETY_2 | WHEEZING_2 | YELLOW_FINGERS_2 | FATIGUE _2 | PEER_PRESSURE_2 | COUGHING_2 | ALCOHOL CONSUMING_2 | ALLERGY _2 | SWALLOWING DIFFICULTY_2 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | -0.641878 | -0.613902 | -0.568027 | -0.530684 | -0.491953 | -0.474459 | -0.469793 | -0.409697 | -0.404706 | -0.378352 | ... | 0.378346 | 0.404699 | 0.409691 | 0.469787 | 0.474453 | 0.491947 | 0.530678 | 0.568021 | 0.613895 | 0.641872 |
1 rows × 29 columns
In 44:
plt.figure(figsize=(16,7))
sns.barplot(x=sorted_lr.columns, y=sorted_lr.iloc[0,:])
plt.xticks(rotation = 90) # 旋转角度
plt.xlabel('Features') # x-y轴名称
plt.ylabel('LogisticRegression Coefficients')
plt.show()
2、绘制混淆矩阵(有轴名称和显示数值):
In 45:
matrix = confusion_matrix(y_test, y_pred_lr)
matrix = matrix.astype("float") / matrix.sum(axis=1)[:, np.newaxis]
plt.figure(figsize=(12,7)) # 图像大小
sns.set(font_scale=1.4)
sns.heatmap(matrix, # 混淆矩阵
annot=True, #
annot_kws={'size':10}, # 是否显示边框
linewidths=0.2, # 线宽
vmin=0, # 色条的最值
vmax=1)
class_names = ["Lung cancer", "No Lung cancer"]
tick_marks = np.arange(len(class_names))
tick_marks2 = tick_marks + 0.5
plt.xticks(tick_marks, class_names, rotation=25)
plt.yticks(tick_marks2, class_names, rotation=0)
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.title('Confusion Matrix for LogisticRegression')
plt.show()
3、计算分类得分报告classification_report:
In 46:
print('LogisticRegression: ')
print(classification_report(y_test, y_pred_lr))
LogisticRegression:
precision recall f1-score support
NO 0.33 0.20 0.25 10
YES 0.89 0.94 0.91 68
accuracy 0.85 78
macro avg 0.61 0.57 0.58 78
weighted avg 0.82 0.85 0.83 78
4、绘制ROC曲线:
In 47:
from sklearn import metrics
pred_prob_lr = lr.predict_proba(X_test) # 预测概率值
fpr_lr, tpr_lr, thresholds_lr = metrics.roc_curve(y_test, pred_prob_lr[:,1],pos_label='YES')
fig, ax = plt.subplots(figsize=(16, 10))
ax.plot(fpr_lr, tpr_lr, label='LogisticRegression')
ax.plot([0, 1], [0, 1], transform=ax.transAxes, ls="--", c=".3")
plt.legend()
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.rcParams['font.size'] = 12
plt.title('ROC curve for LogisticRegression Classifiers')
plt.xlabel('False Positive Rate (1 - Specificity)')
plt.ylabel('True Positive Rate (Sensitivity)')
plt.show()
5、计算AUC值:
In 48:
print('AUC of LogisticRegression: {}'.format(metrics.auc(fpr_lr, tpr_lr)))
AUC of LogisticRegression: 0.8764705882352941
In 49:
# 1、最佳参数组合建模
svc = SVC(probability = True, random_state = 1,C= 50, gamma = 1, kernel= 'linear') # 最佳参数组合
svc.fit(X_train,y_train)
y_pred_svc = svc.predict(X_test)
# SVC模型的准确率
print('SVC test accuracy: {}'.format(accuracy_score(y_test, y_pred_svc)))
SVC test accuracy: 0.8974358974358975
In 50:
# 2、绘制混淆矩阵
matrix = confusion_matrix(y_test, y_pred_svc)
matrix = matrix.astype("float") / matrix.sum(axis=1)[:, np.newaxis]
plt.figure(figsize=(12,7)) # 图像大小
sns.set(font_scale=1.4)
sns.heatmap(matrix, # 混淆矩阵
annot=True, #
annot_kws={'size':10}, # 是否显示边框
linewidths=0.2, # 线宽
vmin=0, # 色条的最值
vmax=1)
class_names = ["Lung cancer", "No Lung cancer"]
tick_marks = np.arange(len(class_names))
tick_marks2 = tick_marks + 0.5
plt.xticks(tick_marks, class_names, rotation=25)
plt.yticks(tick_marks2, class_names, rotation=0)
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.title('Confusion Matrix for SVC')
plt.show()
输出分类报告和绘制ROC曲线:
In 51:
# 3、打印分类报告
print('SVC: ')
print(classification_report(y_test, y_pred_svc))
SVC:
precision recall f1-score support
NO 0.62 0.50 0.56 10
YES 0.93 0.96 0.94 68
accuracy 0.90 78
macro avg 0.78 0.73 0.75 78
weighted avg 0.89 0.90 0.89 78
In 52:
# 4、ROC曲线绘制
from sklearn import metrics
pred_prob_svc = svc.predict_proba(X_test) # 预测概率值
fpr_svc, tpr_svc, thresholds_svc = metrics.roc_curve(y_test, pred_prob_svc[:,1],pos_label='YES')
fig, ax = plt.subplots(figsize=(16, 10))
ax.plot(fpr_svc, tpr_svc, label='SVC')
ax.plot([0, 1], [0, 1], transform=ax.transAxes, ls="--", c=".3")
plt.legend()
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.rcParams['font.size'] = 12
plt.title('ROC curve for SVC ')
plt.xlabel('False Positive Rate (1 - Specificity)')
plt.ylabel('True Positive Rate (Sensitivity)')
plt.show()
最后计算AUC值:
In 53:
print('AUC of SVC: {}'.format(metrics.auc(fpr_svc, tpr_svc)))
AUC of SVC: 0.9058823529411765
In 54:
import eli5
from eli5.sklearn import PermutationImportance
In 55:
perm = PermutationImportance(svc, random_state=1).fit(X_test, y_test) # 指定使用svc模型进行测试集上的训练
eli5.show_weights(perm, feature_names=list(feature_names),top=len(feature_names))
In 56:
import shap
explainer = shap.KernelExplainer(svc.predict_proba, X_train)
pred_data = pd.DataFrame(X_test)
pred_data.columns = feature_names
data_for_prediction = pred_data
Using 231 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.
In 57:
shap_values = explainer.shap_values(data_for_prediction)
shap.initjs()
shap.summary_plot(shap_values[1], data_for_prediction)
需要
jupyter notebook
源码和本文数据的同学,请直接联系小编,非诚勿扰~
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。