!pip install ydata_profiling
!pip install catboost
!pip install --upgrade scikit-learn
!git clone https://github.com/hsma-programme/h6_4e_boosted_trees.git
%cd /content/h6_4e_boosted_trees/exercises_colab
10 Exercise Solution: Regression with Trees (LOS Dataset)
If using colab, run this cell first. Otherwise, skip this cell.
10.1 Core
We’re going to work with a dataset to try to predict patient length of stay.
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
# import the relevant models from Sklearn, XGBoost, CatBoost and LightGBM
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from ydata_profiling import ProfileReport
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor
from catboost import CatBoostRegressor
from sklearn.ensemble import HistGradientBoostingRegressor
# import any other libraries you need
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_absolute_percentage_error, \
r2_score, root_mean_squared_error
Open the data dictionary in the los_dataset folder and take a look at what data is available.
Next, load in the dataframe containing the LOS data.
= pd.read_csv("../datasets/los_dataset/LengthOfStay.csv", index_col="eid") los_df
View the dataframe.
los_df.head()
vdate | rcount | gender | dialysisrenalendstage | asthma | irondef | pneum | substancedependence | psychologicaldisordermajor | depress | ... | sodium | glucose | bloodureanitro | creatinine | bmi | pulse | respiration | secondarydiagnosisnonicd9 | facid | lengthofstay | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
eid | |||||||||||||||||||||
1 | 8/29/2012 | 0 | F | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 140.361132 | 192.476918 | 12.0 | 1.390722 | 30.432418 | 96 | 6.5 | 4 | B | 3 |
2 | 5/26/2012 | 5+ | F | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 136.731692 | 94.078507 | 8.0 | 0.943164 | 28.460516 | 61 | 6.5 | 1 | A | 7 |
3 | 9/22/2012 | 1 | F | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 133.058514 | 130.530524 | 12.0 | 1.065750 | 28.843812 | 64 | 6.5 | 2 | B | 3 |
4 | 8/9/2012 | 0 | F | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 138.994023 | 163.377028 | 12.0 | 0.906862 | 27.959007 | 76 | 6.5 | 1 | A | 1 |
5 | 12/20/2012 | 0 | F | 0 | 0 | 0 | 1 | 0 | 1 | 0 | ... | 138.634836 | 94.886654 | 11.5 | 1.242854 | 30.258927 | 67 | 5.6 | 2 | E | 4 |
5 rows × 26 columns
ProfileReport(los_df)