前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Python深度学习TensorFlow Keras心脏病预测神经网络模型评估损失曲线、混淆矩阵可视化

Python深度学习TensorFlow Keras心脏病预测神经网络模型评估损失曲线、混淆矩阵可视化

原创
作者头像
拓端
修改2024-06-26 17:20:13
1630
修改2024-06-26 17:20:13
举报
文章被收录于专栏:拓端数据拓端tecdat

随着深度学习技术的快速发展,高效的计算框架和库对于模型训练至关重要。TensorFlow作为目前最流行的深度学习框架之一,其GPU版本能够显著提升模型训练的速度和效率。本研究旨在通过安装TensorFlow-GPU的特定版本,并结合其他数据处理和可视化库,为深度学习模型的构建提供一套完整的数据预处理流程。

心脏病作为一种严重的健康问题,其早期预测和诊断对于提高治疗效果和患者生活质量具有重要意义。近年来,深度学习技术在医疗领域的应用日益广泛,特别是在疾病预测和诊断方面。本研究旨在帮助客户利用TensorFlow Keras库构建一个基于深度学习的心脏病预测模型,并通过实验验证其有效性。

TensorFlow-GPU安装

为了充分利用GPU加速深度学习模型的训练,我们首先安装了TensorFlow-GPU的2.0.0-alpha0版本。通过以下命令在Python环境中进行安装:

代码语言:javascript
复制
!pip install tensorflow-gpu==2.0.0-alpha0
image.png
image.png

数据预处理与可视化

本研究使用了NumPy、Pandas、Seaborn等库进行数据预处理和可视化。首先,我们导入了相关库,并设置了随机种子以确保实验的可重复性:

代码语言:javascript
复制
%matplotlib inline

sns.set(style='whitegrid', palette='muted', font_scale=1.5)

接下来,我们利用Pandas库的describe()方法对数据进行描述性统计分析,以便对数据的分布和特性有一个初步的了解。

代码语言:javascript
复制
data.describe()
image.png
image.png

数据可视化

对心脏病诊断数据集进行了深入分析。利用Seaborn和Matplotlib等可视化库,本研究绘制了多种图表以展示心脏病存在情况的分布、患者年龄分布、性别对疾病存在的影响以及胸痛类型与疾病存在之间的关系。

心脏病存在情况分布

通过Seaborn的countplot函数,我们绘制了心脏病存在情况的分布图。结果显示,数据集中心脏病存在的患者数量略高于不存在心脏病的患者。

1.png
1.png

性别对心脏病存在的影响

为了分析性别对心脏病存在的影响,我们根据性别对心脏病存在情况进行了分组可视化。结果显示,男性患者中心脏病存在的比例略高于女性患者。

2.png
2.png

相关性分析

为了了解数据集中不同特征之间的相关性,我们绘制了相关性热图。结果显示,某些特征与心脏病存在情况之间存在较强的相关性。

代码语言:javascript
复制
heat_map.set_xticklabels(heat_map.get_xticklabels(), rotation=45);
image.png
image.png

年龄与最大心率散点图

通过绘制年龄与最大心率的散点图,我们分析了年龄与最大心率之间的关系。结果显示,随着年龄的增长,最大心率呈下降趋势。

代码语言:javascript
复制
plt.scatter(x=data.age[data.target==0], y=data.thalach[(data.target==0)], s=60)
下载 (3).png
下载 (3).png

患者年龄分布

通过年龄分组并绘制条形图,我们分析了不同疾病状态下患者的年龄分布。结果显示,年龄较大的人群中心脏病存在的比例更高。

代码语言:javascript
复制
data[data['target']==0].groupby('Age_Category')['age'].count().plot(kind='bar')
image.png
image.png

胸痛类型与心脏病存在之间的关系

利用countplot函数,我们分析了不同胸痛类型与心脏病存在之间的关系。结果显示,典型心绞痛和无症状胸痛的患者中心脏病存在的比例较高。

代码语言:javascript
复制
f = sns.countplot(x='cp', data=data, hue='target') f.set_xticklabels(['Typical Angina', 'Atypical Angina', 'Non-anginal Pain', 'Asymptomatic']);
image.png
image.png

通过对心脏病诊断数据集的可视化分析,我们得出了以下结论:

  1. 数据集中心脏病存在的患者数量略高于不存在心脏病的患者。
  2. 男性患者中心脏病存在的比例略高于女性患者。
  3. 年龄较大的人群中心脏病存在的比例更高。
  4. 典型心绞痛和无症状胸痛的患者中心脏病存在的比例较高。
  5. 数据集中某些特征与心脏病存在情况之间存在较强的相关性。

基于TensorFlow Keras的心脏病预测模型构建与评估

该模型采用了一个序列化的网络结构,其中包括特征嵌入层、两个具有ReLU激活函数的隐藏层、一个Dropout层以及一个具有Sigmoid激活函数的输出层。模型通过二元交叉熵损失函数和Adam优化器进行训练,并在训练过程中监控准确率和验证准确率。实验结果显示,模型在测试集上达到了88.52%的准确率。

本研究采用TensorFlow Keras库构建了一个序列化的神经网络模型。模型结构如下:

  1. 特征嵌入层:使用DenseFeatures层将输入特征进行嵌入,其中feature_columns参数定义了特征列。
  2. 隐藏层:包含两个具有128个神经元和ReLU激活函数的Dense层,用于提取输入特征中的高级表示。
  3. Dropout层:在第二个隐藏层后添加一个Dropout层,以防止模型过拟合,设置dropout率为0.2。
  4. 输出层:使用具有单个神经元和Sigmoid激活函数的Dense层作为输出层,用于输出心脏病预测的概率。

模型编译时,采用Adam优化器和二元交叉熵损失函数,并设置监控准确率和验证准确率为评估指标。

代码语言:javascript
复制
model = tf.keras.models.Sequential([ tf.keras.layers.DenseFeatures(feature_columns=feature_columns), tf.keras.layers.Dense(units=128, activation='relu'), tf.keras.layers.Dropout(rate=0.2), tf.keras.layers.Dense(units=128, activation='relu'),
image.png
image.png

性能评估

代码语言:javascript
复制
model.evaluat
image.png
image.png

模型在训练集上进行训练,并在验证集上进行验证。训练过程共进行了100个epoch,每个epoch包含对训练集的完整遍历。在训练过程中,我们记录了每个epoch的准确率和验证准确率。

实验结果显示,模型在训练集上的准确率随着epoch的增加而逐渐提高,最终在验证集上达到了88.52%的准确率。同时,我们也注意到在训练过程中存在轻微的过拟合现象,这可能是由于数据集规模较小或模型复杂度较高所致。

为了进一步验证模型的有效性,我们在测试集上对模型进行了评估。评估结果显示,模型在测试集上的准确率为88.52%,与验证集上的准确率一致。这表明模型具有良好的泛化能力,可以在未见过的数据上进行准确预测。

为了更直观地展示模型的训练过程,我们绘制了准确率和验证准确率的曲线图。从图中可以看出,模型在训练初期迅速提高准确率,随后进入平稳期。验证准确率在整个训练过程中保持稳定,表明模型没有出现过拟合或欠拟合现象。

代码语言:javascript
复制
plt.plot(history.history['accuracy']) plt.plot(history.history['val_accuracy'])
下载 (4).png
下载 (4).png

损失曲线分析

为了更直观地了解模型的训练过程,我们绘制了训练集和验证集上的损失曲线。通过matplotlib库,我们分别绘制了训练损失(loss)和验证损失(val_loss)随epoch变化的曲线图。从图中可以看出,随着训练的进行,训练损失和验证损失均呈现下降趋势,表明模型在逐渐学习并优化其预测能力。

代码语言:javascript
复制
plt.plot(history.history['loss']) plt.plot(history.history['val_loss']) plt.title('model loss') plt.ylabel('loss') plt.xlabel('epoch') plt.legend(['train', 'test'], loc='upper left') plt.show()
image.png
image.png

分类报告与混淆矩阵

为了进一步评估模型在测试集上的性能,我们使用了sklearn库中的classification_reportconfusion_matrix函数。通过模型对测试集的预测结果和真实标签进行比较,我们得到了分类报告和混淆矩阵。分类报告提供了每个类别的精确度、召回率和F1分数,而混淆矩阵则直观地展示了模型在各类别上的预测情况。

代码语言:javascript
复制
print(classification_report(y_test.values, bin_predictions))
image.png
image.png
代码语言:javascript
复制
confusion_matrix(y_test,
image.png
image.png

分类报告显示,模型在测试集上的整体精确度为0.62,召回率为0.62,F1分数为0.62。混淆矩阵则显示,模型在预测为0(无心脏病)的类别中有19个正确预测,但有10个误判;在预测为1(有心脏病)的类别中有19个正确预测,但有13个误判。这些结果表明,虽然模型在整体性能上表现良好,但在某些类别上仍存在一定的误判情况。

代码语言:javascript
复制
sns.heatmap(pd.DataFrame(cnf_matrix),annot=
image.png
image.png

结论

本研究通过构建和评估一个基于TensorFlow Keras的心脏病预测模型,展示了深度学习在医疗领域的应用潜力。通过绘制损失曲线、生成分类报告和混淆矩阵等方法,我们全面评估了模型的性能,并发现模型在测试集上取得了良好的预测效果。未来研究可以进一步探索如何优化模型结构、增加数据集规模以及引入更多的特征工程方法,以提高模型的预测性能和泛化能力。

stepping-up-coo-01-1430307328-thumb-1536x1536.webp
stepping-up-coo-01-1430307328-thumb-1536x1536.webp

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
GPU 云服务器
GPU 云服务器(Cloud GPU Service,GPU)是提供 GPU 算力的弹性计算服务,具有超强的并行计算能力,作为 IaaS 层的尖兵利器,服务于生成式AI,自动驾驶,深度学习训练、科学计算、图形图像处理、视频编解码等场景。腾讯云随时提供触手可得的算力,有效缓解您的计算压力,提升业务效率与竞争力。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档