当前位置: 首页 > news >正文

R6:LSTM实现糖尿病探索与预测

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、实验目的:

学习使用LSTM对糖尿病进行探索预测

二、实验环境:

  • 语言环境:python 3.8
  • 编译器:Jupyter notebook
  • 深度学习环境:Pytorch
    • torch==2.4.0+cu124
    • torchvision==0.19.0+cu124

三、数据预处理

逻辑回归在二分类问题中应用广泛;KNN(K 近邻算法)、SVM(支持向量机)、决策树、贝叶斯分类器、随机森林和 XGBoost(极端梯度提升树)都是常见的用于结构化数据分类的算法。

本次实验我们采用 LSTM(长短期记忆网络)进行分类预测。LSTM 主要用于处理序列数据,虽然在一些特定情况下可以对序列数据进行分类,但对于一般的二维结构化数据,上述提到的传统分类算法通常更加合适。二维结构化数据通常指表格形式的数据,每一行代表一个样本,每一列代表一个特征,对于这类数据,传统的机器学习分类算法在计算效率和可解释性方面往往具有优势。

在这里插入图片描述

1. 设置GPU、导入数据

#设置GPU 
import torch.nn as nn 
import torch.nn.functional as F 
import torchvision,torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu")device 
#导入数据
import numpy   as np
import pandas  as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
plt.rcParams['savefig.dpi'] = 500 #图片像素
plt.rcParams['figure.dpi'] = 500 #分辨率plt.rcParams['font.sans-serif'] = ['SimHei'] #用来正常显示中文标签import warnings
warnings.filterwarnings('ignore')DataFrame = pd.read_excel('diabetes.xls')
DataFrame.head()

在这里插入图片描述

DataFrame.shape
(1006, 16)    

2. 数据检查

#查看数据是否有缺失值
print('数据缺失值--------------------------')
print(DataFrame.isnull().sum())

在这里插入图片描述

#查看数据是否有重复值
print('数据重复值--------------------------')
print('数据集的重复值为:'f'{DataFrame.duplicated().sum()}')

在这里插入图片描述

3. 数据分布分析

feature_map = { '年龄': '年龄','高密度脂蛋白胆固醇': '高密度脂蛋白胆固醇','低密度脂蛋白胆固醇': '低密度脂蛋白胆固醇','极低密度脂蛋白胆固醇': '极低密度脂蛋白胆固醇','甘油三酯': '甘油三酯','总胆固醇': '总胆固醇','脉搏': '脉搏','舒张压':'舒张压','高血压史':'高血压史','尿素氮':'尿素氮','尿酸':'尿酸','肌酐':'肌酐','体重检查结果':'体重检查结果'}plt.figure(figsize=(15,10))for i, (col, col_name) in enumerate(feature_map.items(), 1):plt.subplot(3,5,i)sns.boxplot(x=DataFrame['是否糖尿病'], y=DataFrame[col])plt.title(f'{col_name}的箱线图', fontsize=14)plt.ylabel('数值', fontsize=12)plt.grid(axis='y', linestyle='--', alpha=0.7)plt.tight_layout()
plt.show()

在这里插入图片描述
以下是分析箱线图的方法,并以年龄的箱线图为例进行介绍:

一、认识箱线图的组成部分

  1. 箱体:箱体的上下边界分别代表数据的上四分位数(Q3)和下四分位数(Q1)。箱体中间的线通常代表中位数。
  2. whiskers(须):从箱体延伸出去的线段,代表数据的范围。一般来说,须的长度是由一些特定的规则决定的,常见的是 1.5 倍的四分位距(IQR,即 Q3 - Q1)。超出须范围的数据点被视为异常值,可能会以单独的点显示。

二、分析年龄箱线图的具体步骤

  1. 观察中位数:

    • 首先找到箱体中间的线,它代表了年龄数据的中位数。如果这条线在箱线图的中间位置附近,说明数据分布相对较为对称;如果偏向箱体的上边界或下边界,则说明数据可能存在偏斜。
    • 假设年龄箱线图中,中位数线靠近箱体上边界,这可能意味着年龄数据整体上偏大,即大部分人的年龄较高。
  2. 分析箱体长度:

    • 箱体的长度反映了数据的离散程度。如果箱体较短,说明数据比较集中;如果箱体较长,说明数据的分散程度较大。
    • 例如,如果年龄箱线图的箱体较短,说明年龄数据相对集中在一个较小的范围内。
  3. 观察须的长度:

    • 须的长度可以让你了解数据的整体范围。较长的须表示数据的范围较大;较短的须可能意味着数据比较集中在一个较小的区间内。
    • 如果年龄箱线图的须较长,说明年龄数据的跨度较大,可能有一些年龄较大或较小的极端值。
  4. 检查异常值:

    • 异常值通常以单独的点显示在箱线图之外。观察异常值的数量和分布,可以了解数据中是否存在极端情况。
    • 如果年龄箱线图中有一些异常值,需要进一步分析这些异常值的来源,例如是否是由于数据录入错误或者特殊的个体情况导致的。
  5. 比较不同组别的箱线图:

    • 如果有多个组别的年龄箱线图,可以比较它们的中位数、箱体长度、须的长度和异常值情况,以了解不同组之间年龄分布的差异。
    • 例如,比较糖尿病患者和非糖尿病患者的年龄箱线图,看是否存在明显的差异。如果糖尿病患者的年龄箱线图中位数较高,箱体较长,可能说明糖尿病患者的年龄普遍较大。

通过以上步骤,你可以对年龄箱线图进行较为全面的分析,了解年龄数据的分布特征和潜在的问题。对于其他变量的箱线图,也可以采用类似的方法进行分析。

df_corr = DataFrame.drop(['卡号'],axis=1).corr()
plt.figure(figsize=(12,10))
plt.title('相关性热图')
sns.heatmap(df_corr,annot=True)
plt.show()

在这里插入图片描述

四、LSTM模型

#数据集构建from sklearn.preprocessing import StandardScaler# '高密度脂蛋白胆固醇'字段与糖尿病负相关,故而在 X 中去掉该字段
X = DataFrame.drop(['卡号','是否糖尿病','高密度脂蛋白胆固醇'],axis=1)
y = DataFrame['是否糖尿病']# sc_X    = StandardScaler()
# X = sc_X.fit_transform(X)X = torch.tensor(np.array(X), dtype=torch.float32)
y = torch.tensor(np.array(y), dtype=torch.int64)train_X, test_X, train_y, test_y = train_test_split(X, y, test_size=0.2,random_state=1)
train_X.shape, train_y.shapefrom torch.utils.data import TensorDataset, DataLoadertrain_dl = DataLoader(TensorDataset(train_X, train_y),batch_size=64,shuffle=False)
test_dl  = DataLoader(TensorDataset(test_X, test_y),batch_size=64,shuffle=False)
#定义模型
class model_lstm(nn.Module):def __init__(self):super(model_lstm, self).__init__()self.lstm0 = nn.LSTM(input_size=13,  hidden_size=200, num_layers=1, batch_first=True)self.lstm1 = nn.LSTM(input_size=200, hidden_size=200, num_layers=1, batch_first=True)self.fc0   = nn.Linear(200, 2)def forward(self, x):out, hidden1 = self.lstm0(x)out, _       = self.lstm1(out, hidden1)out          = self.fc0(out)return outmodel = model_lstm().to(device)
model

在这里插入图片描述

五、训练模型

#定义训练函数
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小num_batches = len(dataloader)   # 批次数目, (size/batch_size,向上取整)train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)          # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()  # grad属性归零loss.backward()        # 反向传播optimizer.step()       # 每一步自动更新# 记录acc与losstrain_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc  /= sizetrain_loss /= num_batchesreturn train_acc, train_loss
#定义测试函数
def test (dataloader, model, loss_fn):size        = len(dataloader.dataset)  # 测试集的大小num_batches = len(dataloader)          # 批次数目, (size/batch_size,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss        = loss_fn(target_pred, target)test_loss += loss.item()test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc  /= sizetest_loss /= num_batchesreturn test_acc, test_loss
#训练模型
loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-4   # 学习率
opt        = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs     = 30train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = opt.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))print("="*20, 'Done', "="*20)

在这里插入图片描述

六、模型评估

#Loss与Accuracy图
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

七、总结

分析数据可知存在有一定的过拟合迹象:

  • 随着训练的进行,训练准确率不断上升,而测试准确率在前期长时间停滞,后期虽然有所上升,但上升幅度小于训练准确率。这表明模型在训练集上的学习能力较强,但在测试集上的泛化能力相对较弱。
  • 训练损失持续下降,而测试损失在下降过程中出现了波动,并且在后期与训练损失的差距有一定程度的扩大。这也暗示模型可能过度拟合了训练数据,导致在测试集上的表现不如在训练集上的表现稳定。
  • 实验中尝试通过提高学习率至1e-3,可以将预测准确率提升到71.3%,而提高训练轮数则始终难以收敛。而在构建数据集部分,可以看到注释部分的代码为数据的标准化处理。
  • 除此之外,还可以考虑采用正则化方法、增加数据量、早停法等技术来缓解过拟合问题。

在划分数据集过程中添加标准化处理可以提升测试数据集准确率的原因主要有以下几点:

一、消除量纲影响

  1. 不同特征往往具有不同的量纲和尺度。例如,一个特征可能取值范围在 0 到 100 之间,而另一个特征可能取值在 0 到 1 之间。这会使得在某些算法中,具有较大数值范围的特征对模型的影响更大,从而可能导致模型偏向于这些特征,而忽略了其他重要特征的作用。
  2. 标准化处理将数据的各个特征转换到相同的尺度上,通常使得特征的均值为 0,标准差为 1。这样可以确保每个特征在模型中具有相对平等的影响力,避免了因量纲差异而导致的不公平性。

二、加速模型收敛

  1. 许多优化算法在处理标准化后的数据时能够更快地收敛。例如,梯度下降算法在标准化的数据上能够更有效地确定下降的方向和步长,因为数据的分布更加稳定,不会因为特征的尺度差异而导致梯度在不同方向上的变化幅度差异巨大。
  2. 当数据经过标准化后,模型在训练过程中可以更稳定地更新参数,减少了因数据尺度不一致而引起的震荡,从而更快地找到最优解,这也有助于提高模型在测试集上的准确率。

三、提高模型的泛化能力

  1. 标准化可以使模型对不同单位和尺度的输入数据具有更好的适应性,从而提高模型的泛化能力。如果模型在训练时只适应了特定尺度的数据集,那么在面对测试集上不同尺度的数据时,可能表现不佳。
  2. 标准化处理可以减少异常值对模型的影响。异常值在未标准化的数据中可能会对模型产生较大的干扰,而经过标准化后,异常值的影响相对减小,模型能够更加关注数据的整体分布特征,从而提高在测试集上的准确率。

http://www.mrgr.cn/news/64428.html

相关文章:

  • @Async(“asyncTaskExecutor“) 注解介绍
  • 记录新建wordpress站的实践踩坑:wordpress 上传源码新建站因权限问题导致无法访问、配置新站建站向导以及插件主题上传配置的解决办法
  • HTTP服务器测试与优化
  • 如何在Oracle数据库中获取版本信息。
  • QEMU学习之路(4)— Xilinx开源项目systemctlm-cosim-demo安装与使用
  • RDT——清华开源的双臂机器人扩散大模型:先预训练后微调,支持语言、图像、动作多种输入
  • 基于微信小程序的校园失物招领系统的研究与实现(V4.0)
  • 0-1规划的求解
  • Java 中 HashMap集合使用
  • wireshark抓包查看langchain的ChatOpenAI接口发送和接收的数据
  • next项目app router 中layout命名规范
  • ViT面试知识点
  • Google Guava 发布订阅模式/生产消费者模式 使用详情
  • SpringMVC的执行流程以及运行原理
  • 单链表OJ题(3):合并两个有序链表、链表分割、链表的回文结构
  • Oracle视频基础1.4.2练习
  • FFmpeg 4.3 音视频-多路H265监控录放C++开发十. 多线程控制帧率。
  • 大学新生入门编程的最佳选择:为什么我推荐Python?
  • RSI是指在5G通信技术中用于标识小区的特定参数
  • Spring框架中的AOP是什么?如何使用AOP实现切面编程和拦截器功能?
  • 3.2链路聚合
  • P3-2.【结构化程序设计】第二节——知识要点:多分支选择语句
  • 2024年系统架构师---下午题目真题
  • php开发实战分析(8):优化MySQL分页查询与数量统计,提升数据库性能
  • sql在hive和阿里云maxComputer的区别
  • 合并区间 leetcode56