Научная статья на тему 'HEART DISEASE PREDICTION WITH LOGISTIC REGRESSION AND RANDOM FOREST MODEL'

HEART DISEASE PREDICTION WITH LOGISTIC REGRESSION AND RANDOM FOREST MODEL Текст научной статьи по специальности «Клиническая медицина»

CC BY
399
39
i Надоели баннеры? Вы всегда можете отключить рекламу.
Ключевые слова
HEART DISEASE / LOGISTIC REGRESSION / RANDOM FOREST / HYPERPARAMETER TUNING

Аннотация научной статьи по клинической медицине, автор научной работы — Tang Diane

Heart disease is the leading cause of death for men, women, and people of most racial and ethnic groups in the United States. (Centers for Disease Control and Prevention, 2018) [3]. The early prognosis of CVDs can inspire better lifestyle choices among high-risk patients and in turn reduce the risk of CVD. This research aims to pinpoint the most relevant/risk factors of heart disease and to build a predictive model for the overall risk of heart disease using logistic regression and compare its performance to the random forest model with hyperparameter tuning. Due to the imbalance of the data and the nature of this study, the Average Precision score, the Recall score, and the area under Receiver Operating Characteristic curve are all important metrics. The Synthetic Minority Oversampling Technique (Smote) method with the logistic regression model performed the best among the various techniques implemented in this study.

i Надоели баннеры? Вы всегда можете отключить рекламу.
iНе можете найти то, что вам нужно? Попробуйте сервис подбора литературы.
i Надоели баннеры? Вы всегда можете отключить рекламу.

Текст научной работы на тему «HEART DISEASE PREDICTION WITH LOGISTIC REGRESSION AND RANDOM FOREST MODEL»

https://doi.org/10.29013/ELBLS -21-1.2-24-33

Tang Diane, YK Pao School, Shanghai, China E-mail: 2928599264@qq.com; xxjnicole@hotmail.com

HEART DISEASE PREDICTION WITH LOGISTIC REGRESSION AND RANDOM FOREST MODEL

Abstract: Heart disease is the leading cause of death for men, women, and people of most racial and ethnic groups in the United States. (Centers for Disease Control and Prevention, 2018) [3]. The early prognosis of CVDs can inspire better lifestyle choices among high-risk patients and in turn reduce the risk of CVD. This research aims to pinpoint the most relevant/risk factors of heart disease and to build a predictive model for the overall risk of heart disease using logistic regression and compare its performance to the random forest model with hyperparameter tuning. Due to the imbalance of the data and the nature of this study, the Average Precision score, the Recall score, and the area under Receiver Operating Characteristic curve are all important metrics. The Synthetic Minority Oversampling Technique (Smote) method with the logistic regression model performed the best among the various techniques implemented in this study.

Keywords: heart disease, logistic regression, random forest, hyperparameter tuning.

1. Introduction 2. Exploratory Data Analysis and Data Pre-

Cardiovascular diseases (CVDs) are a group of processing

disorders of the heart and blood vessels. Coronary The dataset used in this study is from an ongo-

heart disease (CHD) is the most common type of ing cardiovascular study on residents of the town of

CVD. World Health Organization has estimated 17.9 Framingham, Massachusetts (Framingham heart

million people died from CVDs in 2016, represent- study, n.d.) [4]. The dataset includes 4,238 records

ing 31% of all global deaths. Of these deaths, 85% are and 16 columns. Each of the 15 feature variables is

due to heart attack and stroke (World Health Orga- a potential risk factor. There are both demographic,

nization [15]). behavioral and medical risk factors.

Heart disease is the leading cause of death for The histogram plot of the 16 variables is shown in

men, women, and people of most racial and ethnic (figure 1). All columns are already numerical and one

groups in the United States (Centers for Disease hot encoding is not needed. From figure 1, BPM eds,

Control and Prevention, 2018) [3]. current Smoker, diabetes, male, prevalent Hyp, preva-

Coronary heart disease (CHD) is the most com- lent Stroke, and Ten Year CHD are binary, 0 represents

mon type of heart disease, killing 365,914 people in negative cases and 1 represents positive cases.

2017 (Benjamin E. J. [2]). About 655,000 Ameri- The classification variable (Ten Year CHD) is

cans die from heart disease each year - that's 1 in the patients' 10-year risk of coronary heart disease

every 4 deaths (Virani S. S. [13]). (CHD). A count plot of the classification variable

The early prognosis of CVDs can inspire better made using Seaborn is shown in figure 2. (Seaborn

lifestyle choices among high-risk patients and in turn Countplot, n.d.) [11; 12]. Count plot shows the

reduce the risk of CVD. counts of observations in each categorical bin using

bars. From (figure 2) we can see the classification variable is imbalanced which could cause a classification model to over-favor predictions of the class or

classes with the overwhelming majority of the data, crippling the model's ability to classify others with the small minority.

Table 1. - Feature variables in the dataset

Feature variables Description Data type

Demographic Sex Male or female Nominal

Age Age of the patient Continuous

Education Education level of the patient Continuous

Behavioral Current smoker Whether or not the patient is a current smoker Nominal

Cigs per day The number of cigarettes that the person smoked on average in one day Continuous

Medical (history) BP Meds Whether or not the patient was on blood pressure medication Nominal

Prevalent stroke Whether or not the patient had previously had a stroke Nominal

Prevalent Hyp Whether or not the patient was hypertensive Nominal

Diabetes Whether or not the patient had diabetes Nominal

Medical (current) Tot Chol Total cholesterol level Continuous

Sys BP Systolic blood pressure Continuous

Dia BP Diastolic blood pressure Continuous

BMI Body Mass Index Continuous

Heart Rate Heart rate Continuous

Glucose Glucose Level Continuous

Figure 1. Histogram plot of feature variables

Figure 2. Count plot of classification variable (0 means the patient won't have CHD in ten years, 1 means the patient will have CHD in ten years)

Figure 3 presents the distribution plots of the 15 feature variables, color coded by the outcome variable (Ten Year CHD). Blue curves in each subplot represent the variable's distribution with the outcome variable being negative and red curves represent the

variable's distribution with the outcome variable being positive. The distribution plot made by the Seaborn library combines the matplotlib histogram function with the Seaborn kernel density estimate (KDE) plot and rugplot (marginal distributions) functions.

0 20 40 60 80 0.0 0.5 1.0 0.0 0.5 1.0 -0.5 0.0 0.5 1.0 1.5

cigsPerDay BPMeds prevalentStroke prevalentHyp

20 40 60 50 75 100 125 150 0 100 200 300 400 0.0 0.2 0.4 0.6 0.8 1 0

BMI heartRate glucose

Figure 3. Distribution plot of 15 feature variables, colored by the classification variable ("Ten Year CHD"), blue curves

(Seaborn distplot, n.d.) The x-axis shows bins Missing values in the dataset have been imputed

(ranges) of the variable and the y-axis is the prob- using K Nearest Neighbors (KNN) imputer function

ability density function for the kernel density es- from Scikit-learn. Scikit-learn is a Python module

timation. The total area of each density plot is 1. integrating a wide range of state-of-the-art machine

From (figure 3), males, older people, less educat- learning algorithms for medium-scale supervised and ed, smokers, all have higher probabilities of having unsupervised problems (Pedregosa [10]). This im-

CHD. puter replaces the missing values with the mean value

from k nearest neighbors found in the dataset. By default, it uses a Euclidean distance metric to impute the missing values (KNN imputation, n.d.) [7].

MinMax Scaler from Scikit-learn was used to center the variables and bring them to the same scale to ensure the proper performance of the machine models. After standardization, the data was separated into two groups using Train_test_split function from Scikit-learn: the training sample with 80% of the data and the testing sample with 20%. The train set would be used to train the model, and the test set would be used to test the accuracy and performance of the model.

3. Models and Results 3.1 Logistic Regression Model The first model trained by the train set was the logistic regression model implemented from Scikit-learn. Since the classification variable is categorical and binary, binomial logistic regression is a natural

first step. It is based on sigmoid function where output is a probability. The key argument "class_weight" was set to be balanced to compensate for the imbal-anced dataset. The classification report, confusion matrix, cross validated roc_auc score and average precision score were calculated, and the precision recall curve was shown in figure 4. The cross validated roc_auc score and aps score for the logistic regression are: 0.725 and 0.3448.

Table 2 shows the classification report of the logistic regression model. The classification report shows the main performance metrics of the machine learning model in predicting each class, including the precision, recall, fl-score, their macro and weighted averages, their counts, and the overall accuracy. (Kohli [8]). Table 3 shows the confusion matrix of the logistic regression model. A confusion matrix presents the count of the true positives, false positives, true negatives, and false negatives of each class.

Table 2.- Logistic regression base model classification report

Class Precision Recall F1-score Support

0.0 0.93 0.67 0.78 733

1.0 0.25 0.70 0.37 115

Accuracy 0.68 848

Macro avg 0.59 0.68 0.57 848

Weighted avg 0.84 0.68 0.73 848

Table 3.- Logistic regression base model confusion matrix

True Positive (TP) 493 False Positive (FP) 240

False Negative (FN) 35 True Negative (TN) 80

^ TP: positive samples predicted as positive.

^ Accuracy - Percentage of true/correct predic-

TP + TN tions: -

TP + FP + TN + FN

^ Precision - Percentage of correctly predicted positive predictions:

TP

TP + FP

^ Recall (Sensitivity / True Positive Rate) - Percentage of the positive cases that were predicted TP

true:-

TP + FN

^ Fl-score - Harmonic mean of precision and

2 2 * precision * recall

recall: _

1

1

precision + recall

precision recall ^ Specificity (l - False Positive Rate) - Percentage of the negative cases that were predicted true:

TN

TN + FP

^ 'macro avg' - Unweighted average ofeach metric. ^ 'weighted avg' - Weighted average of each metric. This accounted for class imbalance.

Accuracy is a great measure but only when you have symmetric datasets (false negatives & false positives counts are close), also, false negatives & false positives have similar costs. F1 is best if you have

an uneven class distribution and if the cost of false positives and false negatives are different. Precision is how sure you are of your true positives whilst recall is how sure you are that you are not missing any positives (Ghoneim [5]).

In the case of predicting heart disease risks, false positives are far better than false negatives, so Recall is an important metric in this study. For a model ofwhich the purpose is to predict illness, leaving a higher risk person labeled healthy is far less desirable than getting some healthy people labeled positive for the illness.

Cross-validation is a model validation method that tests whether the model is overfitted to the specific training and testing samples, which could result in bad performance with any other dataset that the model has not seen yet (Shaikh [3]).

Figure 4 shows the precision-recall curve which focuses mainly on the performance of the positive

class which is crucial when dealing with imbalanced classes. In the PR space, the goal is to be in the upper-right-hand corner - the top right corner (1, 1) means that we classified all positives as positive (Recall=l) and that everything we are classifying as positive is true positive (Precision=l) - the latter translates to zero False Positives (Azevedo, n.d.) [1]. Average Precision Score is basically the "area under curve" of Precison-Recall curve.

iНе можете найти то, что вам нужно? Попробуйте сервис подбора литературы.

Using both roc_auc and average_precision_ score to evaluate the model. roc_auc is the area under curve for ROC curve. The ROC curve is plotted with TPR (True Positive Rate / Recall / Sensitivity) against the FPR (False Positive Rate / 1-Specificity) where TPR is on y-axis and FPR is on the x-axis. Roc_auc ranges from 0 to 1, 1 means the model has excellent performance where the positive and negative cases are perfectly distinguishable.

Figure 4. Precision Recall curve

3.2 Logistic Regression Hyperparameter Turning parameters for the dataset. It will work both for Grid search is an approach to parameter tuning regression and classification machine learning althat will methodically build and evaluate a model gorithms.

for each combination of algorithm parameters spec- Grid search tunes hyperparameter C in logistic

ified in a grid. Grid Search CV from Scikit-learn will regression model and the best C is 0.4. The resulting

try all combinations of those parameters, evaluate model's classification report and confusion matrix

the results using cross-validation, and the scoring are shown in table 2 and 3. The cross validated roc_

metric provided. In the end, it will output the best auc score and aps score are 0.7256 and 0.352136.

Table 4.- Logistic regression after Grid Search CV classification report

Class Precision Recall F1-score Support

0.0 0.94 0.67 0.78 733

1.0 0.25 0.71 0.37 115

Accuracy 0.67 848

Macro avg 0.59 0.69 0.58 848

Weighted avg 0.84 0.67 0.72 848

the k-nearest neighbors from the data are set (Wijaya, n.d.) [14]. Synthetic data would then be made between the random data and the randomly selected k-nearest neighbor. In this case, SMOTE is imported from the imbalanced-learn API. (imbalanced-learn API, n.d.) It'll create synthetic positive cases to match the number of negative cases in the dataset. The resulting model's classification report is shown in (table 4). Cross validated roc_auc score and aps score are 0.7284 and 0.7073. The aps score improved significantly.

Table 6.- Logistic regression (SMOTE) classification report

Class Precision Recall F1-score Support

0.0 0.93 0.67 0.78 733

1.0 0.25 0.70 0.37 115

Accuracy 0.67 848

Macro avg 0.59 0.69 0.57 848

Weighted avg 0.84 0.67 0.72 848

Another strategy was to resample the minority classification report is shown in table 5. Cross cases (y=l) to match majority cases (y=0) with validated roc auc score and aps score did not im-Scikit-learn's resample function. The resulting prove. Table 7. - Logistic regression (resample) classification report

Class Precision Recall F1-score Support

0.0 0.94 0.67 0.78 733

1.0 0.25 0.70 0.37 115

Accuracy 0.68 848

Macro avg 0.59 0.69 0.58 848

Weighted avg 0.84 0.68 0.73 848

Table 5.- Logistic regression after GridSearchCV confusion matrix

TP 489 FP 244

FN33 TN82

3.3 Logistic Regression Imbalance Sampling

Synthetic Minority Oversampling Technique (Smote) is an approach to the construction of classifiers from imbalanced datasets (Chawla N. V., [9]). Smote works by utilizing a k-nearest neighbor algorithm to create synthetic data. Smote first start by choosing random data from the minority class, then

3.4 Random Forest Model

Random forest is an ensemble learning method which consists of a large number ofindividual decision trees that operate as an ensemble. Each individual tree in the random forest outputs a class prediction and the

3.5 Random Forest Model Hyperparameter Tuning

Using Grid Search CV again to tune the hyperparameters in random forest model. The cri-

class with the most votes becomes the model's prediction shown in figure 5 (Yiu, [16]). The idea behind ensemble learning is that a large number of relatively uncorrelated models operating as a committee will outperform any of the individual constituent models.

terion for hyperparameter tuning is to achieve the highest roc_auc. The best roc_auc score was 0.7178 and the corresponding best parameters are {'class_weight': 'balanced', 'criterion': 'entropy',

Figure 5. Visualization of a random forest model making a prediction (Yiu, [16])

The resulting classification report is shown in (table 6). Cross validated roc_auc and aps scores are 0.689 and 0.297 which are not higher than logistic regression.

Table 6.- Random forest base model classification report

Class Precision Recall Fl-score Support

0.0 0.87 0.99 0.93 733

1.0 0.60 0.08 0.14 115

Accuracy 0.87 848

Macro avg 0.74 0.54 0.53 848

Weighted avg 0.84 0.87 0.82 848

'max_depth': 20, 'max_features': 'auto', 'max_ leaf_nodes': 650, 'min_samples_split': 0.07, 'n_ estimators': 100}. The classification report with

The roc_auc measures the overall performance of a model, but in this case, to find as many true positive heart disease cases as possible, models with higher recall (for class 1) score are more desirable. Taking recall into consideration, grid search was used again to choose the most suitable model. Among the models with reasonably high roc_auc score (>=0.71), models with highest recall scores were picked.

4. Conclusion

It is difficult to manually determine the odds of getting heart disease based on risk factors. However, machine learning techniques are useful to predict the output from existing data.

From the data, males, older people, less educated, smokers, all have higher probabilities of having CHD. The random forest model did not perform better than logistic regression model except for the accuracy score. Since the data is severely imbalanced, area under Receiver Operating Characteristic curve and Average Precision Score were both used to evalu-

the best parameters is shown in (table 7). The recall improved drastically from the random forest base model.

The results of the grid search were stored in an attribute ("cv_results") including every parameter comb, and its corresponding scores. There are 244 models with roc_auc score higher than 0.71, the best cross validated recall score among them is 0.634. Cross validated roc_auc and aps scores are 0.713 and 0.325. The corresponding classification report is shown in (table 8).

ate the models' performances. Recall (Sensitivity) is another important metric for this project. Since the model is predicting Heart disease, a false negative (ignoring the probability of disease when there actually is one) is more dangerous than a false positive in this case. In other words, more sensitive than specific is desirable for this project. The Synthetic Minority Oversampling Technique (SMOTE) method in combination with the logistic regression model achieved the highest Average Precision score, area under Receiver Operating Characteristic curve, and Recall score.

Table 7.- Random forest hyperparameter tuned classification report

Class Precision Recall F1-score Support

0.0 0.92 0.65 0.76 733

1.0 0.22 0.64 0.33 115

Accuracy 0.65 848

Macro avg 0.57 0.65 0.55 848

Weighted avg 0.83 0.65 0.70 848

Table 8.- Random forest best recall classification report

Class Precision Recall F1-score Support

0.0 0.93 0.67 0.78 733

1.0 0.24 0.67 0.35 115

Accuracy 0.67 848

Macro avg 0.58 0.67 0.57 848

Weighted avg 0.84 0.67 0.72 848

References:

1. Azevedo C. (n.d.). On ROC and Precision-Recall curves. Retrieved from towards data science: URL: https://towardsdatascience.com/on-roc-and-precision-recall-curves-c23e9b63820c

2. Benjamin E. and Muntner P. and ather Heart disease and stroke statistics-2019 update: a report from the American Heart Association. Circulation. 2019.

3. Centers for Disease Control and Prevention. Underlying Cause of Death, 1999-2018. CDC WONDER Online Database. 2018.

4. Framingham heart study. (n.d.). Retrieved from URL: https://www.kaggle.com/amanajmera1/fram-ingham-heart-study-dataset/data

5. Ghoneim S. Accuracy, Recall, Precision, F-Score & Specificity, which to optimize on? 2019. Retrieved from towards data science: URL: https://towardsdatascience.com/accuracy-recall-precision-f-score-specificity-which-to-optimize-on-867d3f11124

6. Imbalanced-learn API. (n.d.). Retrieved from imbalanced-learn.org: URL: https://imbalanced-learn. org/stable/api.html

7. KNN imputation. (n.d.). Retrieved from Medium: URL: medium.com/@kyawsawhtoon/a-guide-to-knn-imputation-95e2dc496e

8. Kohli S. Understanding a Classification Report For Your Machine Learning Model. 2019, November 18. Retrieved from Medium: URL: https://medium.com/@kohlishivam5522/understanding-a-classifica-tion-report-for-your-machine-learning-model-88815e2ce397

9. Chawla N. V., Bowyer and Hall L. and Kegelmeyer W. P. Synthetic Minority Over-sampling Technique. Journal ofArtificial Intelligence Research, 2002.- P. 321-357.

10. Pedregosa F. V. Scikit-learn: Machine learning in Python. Journal of Machine Learning Research, 2011.-P. 2825-2830.

11. Seaborn Countplot. (n.d.). Retrieved from URL: https://seaborn.pydata.org/generated/seaborn.count-plot.html

12. Seaborn distplot. (n.d.). Retrieved from: URL: https://seaborn.pydata.org/generated/seaborn.distplot. html

13. Virani S. S., Alonso A., Benjamin E.J. Heart disease and stroke statistics-2020 update: a report from the American Heart Associationexternal icon. Circulation. 2020.

14. Wijaya C. Y. (n.d.). 5 Smote Techniques for Oversampling your Imbalance Data. Retrieved from towards data science: URL: https://towardsdatascience.com/5-smote-techniques-for-oversampling-your-imbal-ance-data-b8155bdbe2b5

15. World Health Organization. (2017, May 17). Newsroom / Fact sheets / Cardiovascular diseases (CVDs). Retrieved from World Health Organization: URL: https://www.who.int/news-room/fact-sheets/detail/ cardiovascular-diseases-(cvds)

16. Yiu T. Understanding random forest. 2019. June 12. Retrieved from towards data science: URL: https://towardsdatascience.com/understanding-random-forest-58381e0602d2

i Надоели баннеры? Вы всегда можете отключить рекламу.