前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >kaggle实战-基于机器学习的中风病人预测

kaggle实战-基于机器学习的中风病人预测

作者头像
皮大大
发布2023-08-25 11:01:14
1.1K0
发布2023-08-25 11:01:14
举报

基于随机森林、逻辑回归、SVM的中风病人预测

原数据地址:https://www.kaggle.com/datasets/fedesoriano/stroke-prediction-dataset?datasetId=1120859&sortBy=voteCount&select=healthcare-dataset-stroke-data.csv

导入库

代码语言:javascript
复制
import numpy as np
import pandas as pd

# 绘图
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import matplotlib.gridspec as grid_spec
import seaborn as sns
plt.style.use("fivethirtyeight")

import plotly.express as px
import plotly.graph_objs as go

# 采样
from imblearn.over_sampling import SMOTE

# 数据标准化、分割、交叉验证
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler,LabelEncoder
from sklearn.model_selection import train_test_split,cross_val_score

# 各种模型
from sklearn.linear_model import LinearRegression,LogisticRegression
from sklearn.tree import DecisionTreeRegressor,DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC

# 模型评价
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import accuracy_score, recall_score, roc_auc_score, precision_score, f1_score
import warnings
warnings.filterwarnings('ignore')

数据基本信息

先把数据导进来,查看数据的基本信息

下面我们查看数据基本信息

In [3]:

代码语言:javascript
复制
df.shape

Out[3]:

代码语言:javascript
复制
(5110, 12)

In [4]:

代码语言:javascript
复制
df.dtypes

Out[4]:

代码语言:javascript
复制
id                     int64
gender                object
age                  float64
hypertension           int64
heart_disease          int64
ever_married          object  # 字符型
work_type             object
Residence_type        object
avg_glucose_level    float64
bmi                  float64
smoking_status        object
stroke                 int64
dtype: object

In [5]:

代码语言:javascript
复制
df.describe()  # 描述统计信息

字段分布

gender统计

In [6]:

代码语言:javascript
复制
plt.figure(1, figsize=(12,5))

sns.countplot(y="gender", data=df)
plt.show()

age分布

In [7]:

代码语言:javascript
复制
px.violin(y=df["age"])
代码语言:javascript
复制
fig = px.histogram(df,
                   x="age",
                   color_discrete_sequence=['firebrick'])

fig.show()

ever_married

In [9]:

代码语言:javascript
复制
plt.figure(1, figsize=(12,5))

sns.countplot(y="ever_married", data=df)

plt.show()

本数据集中的结婚人士大约是未结婚的两倍。

work-type

查看不同工作状态的人员数量

In [10]:

代码语言:javascript
复制
plt.figure(1, figsize=(12,8))

sns.countplot(y="work_type", data=df)

plt.show()

Residence_type

In [11]:

代码语言:javascript
复制
plt.figure(1, figsize=(12,8))

sns.countplot(y="Residence_type", data=df)

plt.show()

avg_glucose_level

血糖水平的分布

In [12]:

代码语言:javascript
复制
fig = px.histogram(df,
                   x="avg_glucose_level",
                   color_discrete_sequence=['firebrick'])

fig.show()

可以看到大部分人的血糖还是在100以下,说明是正常的

bmi

bmi指标的分布情况

In [13]:

代码语言:javascript
复制
fig = px.histogram(df,
                   x="bmi",
                   color_discrete_sequence=['firebrick'])

fig.show()

bmi指标的均值大约在28左右,呈现一定的正态分布

smoking_status

抽烟情况的统计

In [14]:

代码语言:javascript
复制
plt.figure(1, figsize=(12,8))

sns.countplot(y="smoking_status", data=df)

plt.show()

可以看到抽烟或者曾经抽烟的人相对来说是少一些的

缺失值情况

缺失值统计

In [15]:

代码语言:javascript
复制
df.isnull().sum()

Out[15]:

代码语言:javascript
复制
id                     0
gender                 0
age                    0
hypertension           0
heart_disease          0
ever_married           0
work_type              0
Residence_type         0
avg_glucose_level      0
bmi                  201
smoking_status         0
stroke                 0
dtype: int64

In [16]:

代码语言:javascript
复制
201 / len(df)  # 缺失比例

Out[16]:

代码语言:javascript
复制
0.03933463796477495

缺失值可视化

In [17]:

代码语言:javascript
复制
plt.title('Missing Value Status',fontweight='bold')
ax = sns.heatmap(df.isna().sum().to_frame(),
                 annot=True,
                 fmt='d',
                 cmap='vlag')

ax.set_xlabel('Amount Missing')

plt.show()
代码语言:javascript
复制
import missingno as mso
mso.bar(df,color="blue")
plt.show()

缺失值处理

使用决策树回归来预测缺失值的BMI值:通过年龄、性别和现有的bmi值来进行预测填充

In [19]:

代码语言:javascript
复制
dt_bmi = Pipeline(steps=[("scale",StandardScaler()), # 数据标准化
                         ("lr",DecisionTreeRegressor(random_state=42))
                        ])

In [20]:

代码语言:javascript
复制
X = df[["age","gender","bmi"]].copy()

dic = {"Male":0, "Female":1, "Other":-1}

X["gender"] = X["gender"].map(dic).astype(np.uint8)
X.head()

取出非缺失值的部分进行训练:

代码语言:javascript
复制
# 缺失值部分
missing = X[X.bmi.isna()]

# 非缺失值部分
X = X[~X.bmi.isna()]
y = X.pop("bmi")
代码语言:javascript
复制
# 模型训练

dt_bmi.fit(X,y)

Out[22]:

代码语言:javascript
复制
Pipeline(steps=[('scale', StandardScaler()),
                ('lr', DecisionTreeRegressor(random_state=42))])

In [23]:

代码语言:javascript
复制
# 模型预测

y_pred = dt_bmi.predict(missing[["age","gender"]])
y_pred[:5]

Out[23]:

代码语言:javascript
复制
array([29.87948718, 30.55609756, 27.24722222, 30.84186047, 33.14666667])

将预测的值转成Series,并且注意索引号:

代码语言:javascript
复制
predict_bmi = pd.Series(y_pred, index=missing.index)
predict_bmi

Out[24]:

代码语言:javascript
复制
1       29.879487
8       30.556098
13      27.247222
19      30.841860
27      33.146667
          ...
5039    32.716000
5048    28.313636
5093    31.459322
5099    28.313636
5105    28.476923
Length: 201, dtype: float64

填充到原来的df数据中:

In [25]:

代码语言:javascript
复制
df.loc[missing.index, "bmi"] = predict_bmi

进行上面的预测和填充之后,我们再次查看缺失值情况,发现已经没有任何缺失值:

In [26]:

代码语言:javascript
复制
df.isnull().sum()

Out[26]:

代码语言:javascript
复制
id                   0
gender               0
age                  0
hypertension         0
heart_disease        0
ever_married         0
work_type            0
Residence_type       0
avg_glucose_level    0
bmi                  0
smoking_status       0
stroke               0
dtype: int64

数据EDA

In [27]:

代码语言:javascript
复制
variables = [variable for variable in df.columns if variable not in ['id','stroke']]

# 除去id号和是否中风外的全部字段
variables

Out[27]:

代码语言:javascript
复制
['gender',
 'age',
 'hypertension',
 'heart_disease',
 'ever_married',
 'work_type',
 'Residence_type',
 'avg_glucose_level',
 'bmi',
 'smoking_status']

连续型变量

In [28]:

代码语言:javascript
复制
conts = ['age','avg_glucose_level','bmi']

for cont in conts:
    plt.figure(1, figsize=(15,6))
    sns.distplot(df[cont])

    plt.show()
几点结论:
  • 年龄age:整体分布比较均衡,不同年龄段的人数差异小
  • 血糖水平:主要集中在100以下
  • bmi指标:呈现一定的正态分布

中风和未中风

上面我们查看了连续型变量的分布情况;可以看到bmi呈现明显的左偏态的分布。下面我们对比中风和未中风的情况:

代码语言:javascript
复制
conts = ['age','avg_glucose_level','bmi']

for cont in conts:
    plt.figure(1, figsize=(15,12))
    sns.displot(data=df,
                x=cont,
                hue="stroke",
                kind="kde")
plt.show()

从3个密度图中能够观察到:从上面的密度图中可以看出来:对于是否中风,年龄age是一个最主要的因素

对比不同年龄段的血糖和BMI指数

In [30]:

代码语言:javascript
复制
px.scatter(df,x="age",
           y="avg_glucose_level",
           color="stroke",
           trendline='ols'
          )

年龄和血糖、bmi关系

代码语言:javascript
复制
px.scatter(df,x="age",
           y="bmi",
           color="stroke",
           trendline='ols'
          )

年龄和患病几率

从散点分布图中看到:年龄可能真的是一个比较重要的因素,和BMI以及平均的血糖水平有着一定的关系。

可能随着年龄的增长,风险在增加。果真如此吗?

In [32]:

代码语言:javascript
复制
background_color = "#fafafa"

fig = plt.figure(figsize=(12, 6),
                 dpi=160,
                 facecolor=background_color)

gs = fig.add_gridspec(2, 1)
gs.update(wspace=0.11, hspace=0.5)

ax0 = fig.add_subplot(gs[0, 0])
ax0.set_facecolor(background_color)

# 字段类型转化
df['age'] = df['age'].astype(int)

rate = []
for i in range(df['age'].min(), df['age'].max()):
    rate.append(df[df['age'] < i]['stroke'].sum() / len(df[df['age'] < i]['stroke']))  # sum求和就是中风人数 / 总人数

sns.lineplot(data=rate,color='#0f4c81',ax=ax0)

for s in ["top","right","left"]:
    ax0.spines[s].set_visible(False)

ax0.tick_params(axis='both',
                which='major',
                labelsize=8)

ax0.tick_params(axis=u'both',
                which=u'both',
                length=0)

ax0.text(-3,
         0.055,
         'Risk Increase by Age',
         fontsize=18,
         fontfamily='serif',
         fontweight='bold')

ax0.text(-3,0.047,
         'As age increase, so too does risk of having a stroke',
         fontsize=14,
         fontfamily='serif')


plt.show()

上面的图形说明了两点:

  1. 年龄越大,中风的几率的确越来越高
  2. 中风的几率是非常低的(y轴的值很低),这是由于中风和未中风的样本不均衡造成的

原数据5000个样本中只有249个中风样本,比例接近1:20

样本不均衡

代码语言:javascript
复制
from pywaffle import Waffle

fig = plt.figure(
    figsize=(7, 2),
    dpi=150,
    facecolor=background_color,
    FigureClass=Waffle,
    rows=1,
    values=[1, 19],
    colors=['#0f4c81', "lightgray"],
    characters='⬤',
    font_size=16,
    vertical=True,
)

# 主标题
fig.text(0.035,0.78,
         'Stroked People in our dataset',
         fontfamily='serif',
         fontsize=10,
         fontweight='bold')
# 子标题
fig.text(0.035,
         0.65,
         '1:20 [249 out of 5000]',
         fontfamily='serif',
         fontsize=10)

plt.show()

属性分布

整体变量情况

首先我们剔除gender中为Other的情况

In [34]:

代码语言:javascript
复制
str_only = df[df['stroke'] == 1]   # 中风
no_str_only = df[df['stroke'] == 0]  # 未中风

In [35]:

代码语言:javascript
复制
len(str_only)

Out[35]:

代码语言:javascript
复制
249

In [36]:

代码语言:javascript
复制
# 剔除other
no_str_only = no_str_only[(no_str_only['gender'] != 'Other')]

下面的代码是比较在不同的属性下中风和未中风的情况:

代码语言:javascript
复制
fig = plt.figure(figsize=(22,15))
gs = fig.add_gridspec(3, 3)
gs.update(wspace=0.35, hspace=0.27)

# 生成9个子图
ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1])
ax2 = fig.add_subplot(gs[0, 2])
ax3 = fig.add_subplot(gs[1, 0])
ax4 = fig.add_subplot(gs[1, 1])
ax5 = fig.add_subplot(gs[1, 2])
ax6 = fig.add_subplot(gs[2, 0])
ax7 = fig.add_subplot(gs[2, 1])
ax8 = fig.add_subplot(gs[2, 2])

# 背景色
background_color = "#f6f6f6"
fig.patch.set_facecolor(background_color)

## 1、Age

ax0.grid(color='gray',
         linestyle=':',
         axis='y',
         zorder=0,  dashes=(1,5))

# 中风和未中风
positive = pd.DataFrame(str_only["age"])
negative = pd.DataFrame(no_str_only["age"])
# kde密度图
sns.kdeplot(positive["age"],  # 中风数据
            ax=ax0,  # 指定子图
            color="#0f4c81",  # 颜色
            shade=True,  # 阴影
            ec='black',  # 边缘色
            label="positive"  # label
           )

sns.kdeplot(negative["age"], # 未中风
            ax=ax0,
            color="#9bb7d4",
            shade=True,
            ec='black',
            label="negative")

ax0.yaxis.set_major_locator(mtick.MultipleLocator(2))
ax0.set_ylabel('')
ax0.set_xlabel('')
ax0.text(-20, # 文本信息设置
         0.0465,
         'Age',
         fontsize=14,
         fontweight='bold',
         fontfamily='serif',
         color="#323232")


# 2、Smoking
# 不同状态的人数
positive = pd.DataFrame(str_only["smoking_status"].value_counts())
# 比例情况
positive["Percentage"] = positive["smoking_status"].apply(lambda x: x/sum(positive["smoking_status"])*100)

negative = pd.DataFrame(no_str_only["smoking_status"].value_counts())
negative["Percentage"] = negative["smoking_status"].apply(lambda x: x/sum(negative["smoking_status"])*100)

ax1.text(0, 4,
         'Smoking Status',
         fontsize=14,
         fontweight='bold',
         fontfamily='serif',
         color="#323232")
ax1.barh(positive.index,
         positive['Percentage'],
         color="#0f4c81",
         zorder=3,
         height=0.7)
ax1.barh(negative.index,
         negative['Percentage'],
         color="#9bb7d4",
         zorder=3,
         ec='black',
         height=0.3)
ax1.xaxis.set_major_formatter(mtick.PercentFormatter())
ax1.xaxis.set_major_locator(mtick.MultipleLocator(10))

# gender
# 1、统计人数
positive = pd.DataFrame(str_only["gender"].value_counts())
# 2、转成比例
positive["Percentage"] = positive["gender"].apply(lambda x: x/sum(positive["gender"])*100)
negative = pd.DataFrame(no_str_only["gender"].value_counts())
negative["Percentage"] = negative["gender"].apply(lambda x: x/sum(negative["gender"])*100)

x = np.arange(len(positive))
ax2.text(-0.4, 68.5,
         'Gender',
         fontsize=14,
         fontweight='bold',
         fontfamily='serif',
         color="#323232")

ax2.grid(color='gray',
         linestyle=':',
         axis='y',
         zorder=0,
         dashes=(1,5))
ax2.bar(x,
        height=positive["Percentage"],
        zorder=3,
        color="#0f4c81",
        width=0.4)

ax2.bar(x+0.4,
        height=negative["Percentage"],
        zorder=3,
        color="#9bb7d4",
        width=0.4)
ax2.set_xticks(x + 0.4 / 2)
ax2.set_xticklabels(['Male','Female'])
ax2.yaxis.set_major_formatter(mtick.PercentFormatter())
ax2.yaxis.set_major_locator(mtick.MultipleLocator(10))

for i,j in zip([0, 1], positive["Percentage"]):
    ax2.annotate(f'{j:0.0f}%',
                 xy=(i, j/2),
                 color='#f6f6f6',
                 horizontalalignment='center',
                 verticalalignment='center')

for i,j in zip([0, 1], negative["Percentage"]):
    ax2.annotate(f'{j:0.0f}%',
                 xy=(i+0.4, j/2),
                 color='#f6f6f6',
                 horizontalalignment='center',
                 verticalalignment='center')


# heart_disease

positive = pd.DataFrame(str_only["heart_disease"].value_counts())
positive["Percentage"] = positive["heart_disease"].apply(lambda x: x/sum(positive["heart_disease"])*100)
negative = pd.DataFrame(no_str_only["heart_disease"].value_counts())
negative["Percentage"] = negative["heart_disease"].apply(lambda x: x/sum(negative["heart_disease"])*100)

x = np.arange(len(positive))
ax3.text(-0.3, 110,
         'Heart Disease',
         fontsize=14,
         fontweight='bold',
         fontfamily='serif',
         color="#323232")
ax3.grid(color='gray',
         linestyle=':',
         axis='y',
         zorder=0,
         dashes=(1,5))

ax3.bar(x,
        height=positive["Percentage"],
        zorder=3,
        color="#0f4c81",
        width=0.4)

ax3.bar(x+0.4,
        height=negative["Percentage"],
        zorder=3,
        color="#9bb7d4",
        width=0.4)

ax3.set_xticks(x + 0.4 / 2)
ax3.set_xticklabels(['No History','History'])
ax3.yaxis.set_major_formatter(mtick.PercentFormatter())
ax3.yaxis.set_major_locator(mtick.MultipleLocator(20))

for i,j in zip([0, 1], positive["Percentage"]):
    ax3.annotate(f'{j:0.0f}%',
                 xy=(i, j/2),
                 color='#f6f6f6',
                 horizontalalignment='center',
                 verticalalignment='center')
for i,j in zip([0, 1], negative["Percentage"]):
    ax3.annotate(f'{j:0.0f}%',
                 xy=(i+0.4, j/2),
                 color='#f6f6f6',
                 horizontalalignment='center',
                 verticalalignment='center')


# ## AX4 - TITLE

ax4.spines["bottom"].set_visible(False)
ax4.tick_params(left=False, bottom=False)
ax4.set_xticklabels([])
ax4.set_yticklabels([])
ax4.text(0.5, 0.6, 'Can we see patterns for\n\n patients in our data?', horizontalalignment='center', verticalalignment='center',
         fontsize=22, fontweight='bold', fontfamily='serif', color="#323232")

ax4.text(0.15,0.57,"Stroke", fontweight="bold", fontfamily='serif', fontsize=22, color='#0f4c81')
ax4.text(0.41,0.57,"&", fontweight="bold", fontfamily='serif', fontsize=22, color='#323232')
ax4.text(0.49,0.57,"No-Stroke", fontweight="bold", fontfamily='serif', fontsize=22, color='#9bb7d4')


# Glucose

ax5.grid(color='gray',
         linestyle=':',
         axis='y',
         zorder=0,
         dashes=(1,5))
positive = pd.DataFrame(str_only["avg_glucose_level"])
negative = pd.DataFrame(no_str_only["avg_glucose_level"])
sns.kdeplot(positive["avg_glucose_level"],
            ax=ax5,
            color="#0f4c81",
            ec='black',
            shade=True,
            label="positive")

sns.kdeplot(negative["avg_glucose_level"],
            ax=ax5,
            color="#9bb7d4",
            ec='black',
            shade=True,
            label="negative")

ax5.text(-55, 0.01855,
         'Avg. Glucose Level',
         fontsize=14,
         fontweight='bold',
         fontfamily='serif',
         color="#323232")
ax5.yaxis.set_major_locator(mtick.MultipleLocator(2))
ax5.set_ylabel('')
ax5.set_xlabel('')



# bmi

ax6.grid(color='gray',
         linestyle=':',
         axis='y',
         zorder=0,
         dashes=(1,5))

positive = pd.DataFrame(str_only["bmi"])
negative = pd.DataFrame(no_str_only["bmi"])
sns.kdeplot(positive["bmi"],
            ax=ax6,
            color="#0f4c81",
            ec='black',
            shade=True,
            label="positive")

sns.kdeplot(negative["bmi"],
            ax=ax6,
            color="#9bb7d4",
            ec='black',
            shade=True,
            label="negative")
ax6.text(-0.06,
         0.09,
         'BMI',
         fontsize=14,
         fontweight='bold',
         fontfamily='serif',
         color="#323232")
ax6.yaxis.set_major_locator(mtick.MultipleLocator(2))
ax6.set_ylabel('')
ax6.set_xlabel('')


# Work Type

positive = pd.DataFrame(str_only["work_type"].value_counts())
positive["Percentage"] = positive["work_type"].apply(lambda x: x/sum(positive["work_type"])*100)
positive = positive.sort_index()

negative = pd.DataFrame(no_str_only["work_type"].value_counts())
negative["Percentage"] = negative["work_type"].apply(lambda x: x/sum(negative["work_type"])*100)
negative = negative.sort_index()

ax7.bar(negative.index, height=negative["Percentage"], zorder=3, color="#9bb7d4", width=0.05)
ax7.scatter(negative.index, negative["Percentage"], zorder=3,s=200, color="#9bb7d4")
ax7.bar(np.arange(len(positive.index))+0.4, height=positive["Percentage"], zorder=3, color="#0f4c81", width=0.05)
ax7.scatter(np.arange(len(positive.index))+0.4, positive["Percentage"], zorder=3,s=200, color="#0f4c81")

ax7.yaxis.set_major_formatter(mtick.PercentFormatter())
ax7.yaxis.set_major_locator(mtick.MultipleLocator(10))
ax7.set_xticks(np.arange(len(positive.index))+0.4 / 2)
ax7.set_xticklabels(list(positive.index),rotation=0)
ax7.text(-0.5, 66, 'Work Type', fontsize=14, fontweight='bold', fontfamily='serif', color="#323232")

# hypertension

positive = pd.DataFrame(str_only["hypertension"].value_counts())
positive["Percentage"] = positive["hypertension"].apply(lambda x: x/sum(positive["hypertension"])*100)
negative = pd.DataFrame(no_str_only["hypertension"].value_counts())
negative["Percentage"] = negative["hypertension"].apply(lambda x: x/sum(negative["hypertension"])*100)

x = np.arange(len(positive))
ax8.text(-0.45, 100, 'Hypertension', fontsize=14, fontweight='bold', fontfamily='serif', color="#323232")
ax8.grid(color='gray', linestyle=':', axis='y', zorder=0,  dashes=(1,5))
ax8.bar(x, height=positive["Percentage"], zorder=3, color="#0f4c81", width=0.4)
ax8.bar(x+0.4, height=negative["Percentage"], zorder=3, color="#9bb7d4", width=0.4)
ax8.set_xticks(x + 0.4 / 2)
ax8.set_xticklabels(['No History','History'])
ax8.yaxis.set_major_formatter(mtick.PercentFormatter())
ax8.yaxis.set_major_locator(mtick.MultipleLocator(20))
for i,j in zip([0, 1], positive["Percentage"]):
    ax8.annotate(f'{j:0.0f}%',xy=(i, j/2), color='#f6f6f6', horizontalalignment='center', verticalalignment='center')
for i,j in zip([0, 1], negative["Percentage"]):
    ax8.annotate(f'{j:0.0f}%',xy=(i+0.4, j/2), color='#f6f6f6', horizontalalignment='center', verticalalignment='center')


# tidy up

for s in ["top","right","left"]:
    for i in range(0,9):
        locals()["ax"+str(i)].spines[s].set_visible(False)

for i in range(0,9):
        locals()["ax"+str(i)].set_facecolor(background_color)
        locals()["ax"+str(i)].tick_params(axis=u'both', which=u'both',length=0)
        locals()["ax"+str(i)].set_facecolor(background_color)


plt.show()

建模

模型baseline

In [38]:

代码语言:javascript
复制
len(str_only)

Out[38]:

代码语言:javascript
复制
249

In [39]:

代码语言:javascript
复制
249 / len(df)

Out[39]:

代码语言:javascript
复制
0.0487279843444227

说明总共有249个人是中风的。本数据的总人数是len(df),根据下面的表达式能够得到本次模型的baseline。

也就说,对于阳性中风患者的召回率,一个好的目标是4.8%。

字段编码

对4个字符型的字段进行编码工作:

In [40]:

代码语言:javascript
复制
df['gender'] = df['gender'].replace({'Male':0,
                                     'Female':1,
                                     'Other':-1}
                                   ).astype(np.uint8)

df['Residence_type'] = df['Residence_type'].map({'Rural':0,
                                                 'Urban':1}
                                               ).astype(np.uint8)

df['work_type'] = df['work_type'].map({'Private':0,
                                       'Self-employed':1,
                                       'Govt_job':2,
                                       'children':-1,
                                       'Never_worked':-2}
                                     ).astype(np.uint8)

df['ever_married'] = df['ever_married'].map({'No':0,'Yes':1}).astype(np.uint8)

df.head()

抽烟状态的独热码转换:

In [41]:

代码语言:javascript
复制
df["smoking_status"].value_counts()

Out[41]:

代码语言:javascript
复制
never smoked       1892
Unknown            1544
formerly smoked     885
smokes              789
Name: smoking_status, dtype: int64

In [42]:

代码语言:javascript
复制
df = df.join(pd.get_dummies(df["smoking_status"]))
df.drop("smoking_status",axis=1,inplace=True)

数据分割

In [43]:

代码语言:javascript
复制
# 选取特征
X  = df.drop("stroke",axis=1)
# 目标变量
y = df['stroke']
from sklearn.model_selection import train_test_split

# 3-7比例
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.3, random_state=42)

上采样

前文中提到,本案例中风和未中风的数据比例接近1:20,在这里我们采样基于SMOTE的上采样方法

In [44]:

代码语言:javascript
复制
oversample = SMOTE()
X_train_smote, y_train_smote = oversample.fit_resample(X_train, y_train.ravel())

In [45]:

代码语言:javascript
复制
len(y_train_smote)

Out[45]:

代码语言:javascript
复制
2914

In [46]:

代码语言:javascript
复制
len(X_train_smote)

Out[46]:

代码语言:javascript
复制
2914

建模

采用3种不同的分类模型来建立模型:Random Forest, SVM, Logisitc Regression

In [47]:

代码语言:javascript
复制
rf_pipeline = Pipeline(steps = [('scale',StandardScaler()), # 标准化
                                ('RF',RandomForestClassifier(random_state=42))]  # 模型
                      )
svm_pipeline = Pipeline(steps = [('scale',StandardScaler()),
                                 ('SVM',SVC(random_state=42))])
logreg_pipeline = Pipeline(steps = [('scale',StandardScaler()),
                                    ('LR',LogisticRegression(random_state=42))])

10折交叉验证

In [48]:

代码语言:javascript
复制
rf_cv = cross_val_score(rf_pipeline,
                        X_train_smote,
                        y_train_smote,
                        cv=10,
                        scoring='f1' # 模型得分评价指标
                       )

svm_cv = cross_val_score(svm_pipeline,
                         X_train_smote,
                         y_train_smote,
                         cv=10,
                         scoring='f1'
                        )

logreg_cv = cross_val_score(logreg_pipeline,
                            X_train_smote,
                            y_train_smote,
                            cv=10,
                            scoring='f1'
                           )

3种模型得分对比

In [49]:

代码语言:javascript
复制
print('随机森林:', rf_cv.mean())
print('支持向量机:',svm_cv.mean())
print('逻辑回归:', logreg_cv.mean())
随机森林: 0.9628909366701726
支持向量机: 0.9363667907023254
逻辑回归: 0.8859930523017683

很明显:随机森林表现的最好!

模型训练fit

In [50]:

代码语言:javascript
复制
rf_pipeline.fit(X_train_smote,y_train_smote)

svm_pipeline.fit(X_train_smote,y_train_smote)

logreg_pipeline.fit(X_train_smote,y_train_smote)

Out[50]:

代码语言:javascript
复制
Pipeline(steps=[('scale', StandardScaler()),
                ('LR', LogisticRegression(random_state=42))])

In [51]:

代码语言:javascript
复制
# 3种模型预测

rf_pred =rf_pipeline.predict(X_test)
svm_pred = svm_pipeline.predict(X_test)
logreg_pred = logreg_pipeline.predict(X_test)

评价指标

In [52]:

代码语言:javascript
复制
# 1、混淆矩阵

rf_cm  = confusion_matrix(y_test, rf_pred )
svm_cm = confusion_matrix(y_test, svm_pred)
logreg_cm  = confusion_matrix(y_test, logreg_pred)

In [53]:

代码语言:javascript
复制
print(rf_cm)
print("----")
print(svm_cm)
print("----")
print(logreg_cm)
[[3338   66]
 [ 164    9]]
----
[[3196  208]
 [ 148   25]]
----
[[3138  266]
 [ 116   57]]

In [54]:

代码语言:javascript
复制
# 2、F_1得分
# F1分数可以看作是模型准确率和召回率的一种加权平均,它的最大值是1,最小值是0,值越大意味着模型越好

rf_f1  = f1_score(y_test,rf_pred)
svm_f1 = f1_score(y_test,svm_pred)
logreg_f1  = f1_score(y_test,logreg_pred)

In [55]:

代码语言:javascript
复制
print('RF mean :',rf_f1)
print('SVM mean :',svm_f1)
print('LR mean :',logreg_f1)
RF mean : 0.07258064516129033
SVM mean : 0.1231527093596059
LR mean : 0.22983870967741934

随机森林模型的分类报告:

In [56]:

代码语言:javascript
复制
from sklearn.metrics import plot_confusion_matrix, classification_report

print(classification_report(y_test,rf_pred))

print('Accuracy Score: ',accuracy_score(y_test,rf_pred))
              precision    recall  f1-score   support

           0       0.95      0.98      0.97      3404
           1       0.12      0.05      0.07       173

    accuracy                           0.94      3577
   macro avg       0.54      0.52      0.52      3577
weighted avg       0.91      0.94      0.92      3577

Accuracy Score:  0.9357003075202683

随机森林模型调参

基于网格搜索的参数调优:

In [57]:

代码语言:javascript
复制
from sklearn.model_selection import GridSearchCV

n_estimators =[64,100,128,200]
max_features = [2,3,5,7]
bootstrap = [True,False]

param_grid = {'n_estimators':n_estimators,
             'max_features':max_features,
             'bootstrap':bootstrap}

rfc = RandomForestClassifier()

In [58]:

代码语言:javascript
复制
grid = GridSearchCV(rfc,param_grid)

grid.fit(X_train,y_train)

Out[58]:

代码语言:javascript
复制
GridSearchCV(estimator=RandomForestClassifier(),
             param_grid={'bootstrap': [True, False],
                         'max_features': [2, 3, 5, 7],
                         'n_estimators': [64, 100, 128, 200]})

In [59]:

代码语言:javascript
复制
grid.best_params_  # 找到最优的参数

Out[59]:

代码语言:javascript
复制
{'bootstrap': False, 'max_features': 3, 'n_estimators': 200}

In [60]:

代码语言:javascript
复制
# 再次建立随机森林模型

rfc = RandomForestClassifier(
    max_features=3,
    n_estimators=200,
    bootstrap=False)

rfc.fit(X_train_smote,y_train_smote)

rfc_tuned_pred = rfc.predict(X_test)

In [61]:

代码语言:javascript
复制
# 新的分类报告得分

print(classification_report(y_test,rfc_tuned_pred))

print('Accuracy Score: ',accuracy_score(y_test,rfc_tuned_pred))
print('F1 Score: ',f1_score(y_test,rfc_tuned_pred))
              precision    recall  f1-score   support

           0       0.95      0.98      0.97      3404
           1       0.05      0.02      0.03       173

    accuracy                           0.94      3577
   macro avg       0.50      0.50      0.50      3577
weighted avg       0.91      0.94      0.92      3577

Accuracy Score:  0.9362594352809617
F1 Score:  0.025641025641025644

逻辑回归模型调参

In [62]:

代码语言:javascript
复制
penalty = ['l1','l2']
C = [0.001, 0.01, 0.1, 1, 10, 100]

log_param_grid = {'penalty': penalty,
                  'C': C}

logreg = LogisticRegression()
grid = GridSearchCV(logreg,log_param_grid)

In [63]:

代码语言:javascript
复制
grid.fit(X_train_smote,y_train_smote)

Out[63]:

代码语言:javascript
复制
GridSearchCV(estimator=LogisticRegression(),
             param_grid={'C': [0.001, 0.01, 0.1, 1, 10, 100],
                         'penalty': ['l1', 'l2']})

In [64]:

代码语言:javascript
复制
grid.best_params_

Out[64]:

代码语言:javascript
复制
{'C': 1, 'penalty': 'l2'}

In [65]:

代码语言:javascript
复制
logreg_pipeline = Pipeline(steps = [('scale',StandardScaler()),
                                    ('LR',LogisticRegression(C=1,penalty='l2',random_state=42))])

logreg_pipeline.fit(X_train_smote,y_train_smote)

Out[65]:

代码语言:javascript
复制
Pipeline(steps=[('scale', StandardScaler()),
                ('LR', LogisticRegression(C=1, random_state=42))])

In [66]:

代码语言:javascript
复制
logreg_new_pred   = logreg_pipeline.predict(X_test) # 新预测

In [67]:

代码语言:javascript
复制
print(classification_report(y_test,logreg_new_pred))

print('Accuracy Score: ',accuracy_score(y_test,logreg_new_pred))
print('F1 Score: ',f1_score(y_test,logreg_new_pred))
              precision    recall  f1-score   support

           0       0.96      0.92      0.94      3404
           1       0.18      0.33      0.23       173

    accuracy                           0.89      3577
   macro avg       0.57      0.63      0.59      3577
weighted avg       0.93      0.89      0.91      3577

Accuracy Score:  0.8932065977075762
F1 Score:  0.22983870967741934

支持向量机调参

In [68]:

代码语言:javascript
复制
svm_param_grid = {
            'C': [0.1, 1, 10, 100, 1000],
            'gamma': [1, 0.1, 0.01, 0.001, 0.0001],
            'kernel': ['rbf']}

svm = SVC(random_state=42)

grid = GridSearchCV(svm, svm_param_grid)

In [69]:

代码语言:javascript
复制
grid.fit(X_train_smote,y_train_smote)

Out[69]:

代码语言:javascript
复制
GridSearchCV(estimator=SVC(random_state=42),
             param_grid={'C': [0.1, 1, 10, 100, 1000],
                         'gamma': [1, 0.1, 0.01, 0.001, 0.0001],
                         'kernel': ['rbf']})

In [70]:

代码语言:javascript
复制
grid.best_params_

Out[70]:

代码语言:javascript
复制
{'C': 100, 'gamma': 0.0001, 'kernel': 'rbf'}

In [71]:

代码语言:javascript
复制
svm_pipeline = Pipeline(steps = [('scale',StandardScaler()),('SVM',SVC(C=100,gamma=0.0001,kernel='rbf',random_state=42))])

svm_pipeline.fit(X_train_smote,y_train_smote)

svm_tuned_pred   = svm_pipeline.predict(X_test)

In [72]:

代码语言:javascript
复制
print(classification_report(y_test,svm_tuned_pred))

print('Accuracy Score: ',accuracy_score(y_test,svm_tuned_pred))
print('F1 Score: ',f1_score(y_test,svm_tuned_pred))
              precision    recall  f1-score   support

           0       0.96      0.93      0.94      3404
           1       0.16      0.27      0.20       173

    accuracy                           0.90      3577
   macro avg       0.56      0.60      0.57      3577
weighted avg       0.92      0.90      0.91      3577

Accuracy Score:  0.8951635448700028
F1 Score:  0.19700214132762314

结论

  1. 在交叉验证的过程中,随机森林表现的最好。
  2. 3种模型的对比:随机森林的精度最好,但是F1-score缺失最低的
  3. 模型可能特点:更能预测哪些人将会中风,而不是哪些人不会中风
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022-6-25,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 基于随机森林、逻辑回归、SVM的中风病人预测
  • 导入库
  • 数据基本信息
  • 字段分布
    • gender统计
      • age分布
        • ever_married
          • work-type
            • Residence_type
              • avg_glucose_level
                • bmi
                  • smoking_status
                  • 缺失值情况
                    • 缺失值统计
                      • 缺失值可视化
                        • 缺失值处理
                        • 数据EDA
                          • 连续型变量
                          • 中风和未中风
                            • 对比不同年龄段的血糖和BMI指数
                            • 年龄和血糖、bmi关系
                            • 年龄和患病几率
                              • 样本不均衡
                              • 属性分布
                              • 整体变量情况
                              • 建模
                                • 模型baseline
                                  • 字段编码
                                    • 数据分割
                                      • 上采样
                                        • 建模
                                          • 10折交叉验证
                                            • 3种模型得分对比
                                              • 模型训练fit
                                                • 评价指标
                                                • 随机森林模型调参
                                                • 逻辑回归模型调参
                                                • 支持向量机调参
                                                • 结论
                                                领券
                                                问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档