|
from sklearn.ensemble import RandomForestRegressor |
|
from sklearn.model_selection import GridSearchCV |
|
from sklearn.metrics import ( |
|
mean_squared_error, |
|
mean_absolute_error, |
|
root_mean_squared_error, |
|
) |
|
from xgboost import XGBRegressor |
|
from lightgbm import LGBMRegressor |
|
from constants import HES |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def choose_model(model: str = "rf"): |
|
""" |
|
Choose the model to use for training. |
|
|
|
Parameters: |
|
model (str): The model to use, it can be "rf" for Random Forest or "xgb" for XGBoost. |
|
|
|
Returns: |
|
model: The training model. |
|
""" |
|
if model == "rf": |
|
rf = RandomForestRegressor() |
|
return rf |
|
elif model == "xgb": |
|
xgb = XGBRegressor() |
|
return xgb |
|
elif model == "lgbm": |
|
lgbm = LGBMRegressor() |
|
return lgbm |
|
else: |
|
raise ValueError("Invalid model name") |
|
|
|
|
|
def training_testing_data( |
|
df: pd.DataFrame, |
|
train_start: str, |
|
train_end: str, |
|
test_start: str, |
|
test_end: str, |
|
train_start_bis: str = None, |
|
train_end_bis: str = None, |
|
target: str = "cpih_medical", |
|
columns: list = ["cpih", *HES], |
|
) -> tuple: |
|
""" |
|
Split the data into training and testing sets. |
|
|
|
Parameters: |
|
df (pd.DataFrame): The DataFrame containing the data. |
|
target (str): The target column. |
|
train_start (str): The start date for the training set, in the format "YYYY-MM-DD". |
|
train_end (str): The end date for the training set, in the format "YYYY-MM-DD". |
|
test_start (str): The start date for the testing set, in the format "YYYY-MM-DD". |
|
test_end (str): The end date for the testing set, in the format "YYYY-MM-DD". |
|
train_start_bis (str): The start date for the additional training set, in the format "YYYY-MM-DD". |
|
train_end_bis (str): The end date for the additional training set, in the format "YYYY-MM-DD". |
|
columns (list): The columns to use as features. |
|
|
|
Returns: |
|
tuple: A tuple containing the training and testing sets.""" |
|
if train_start_bis and train_end_bis: |
|
train = df[ |
|
((df["date"] >= train_start) & (df["date"] <= train_end)) |
|
| ((df["date"] >= train_start_bis) & (df["date"] <= train_end_bis)) |
|
] |
|
else: |
|
train = df[(df["date"] >= train_start) & (df["date"] <= train_end)] |
|
test = df[(df["date"] >= test_start) & (df["date"] <= test_end)] |
|
X_train = train[columns] |
|
y_train = train[target] |
|
X_test = test[columns] |
|
y_test = test[target] |
|
return X_train, y_train, X_test, y_test |
|
|
|
|
|
def train_model(model, X_train, y_train, X_test, y_test) -> tuple: |
|
""" |
|
Train the model on the training data and evaluate it on the testing data. |
|
|
|
Parameters: |
|
model: The model to train. |
|
X_train: The features for the training set. |
|
y_train: The target for the training set. |
|
X_test: The features for the testing set. |
|
y_test: The target for the testing set. |
|
|
|
Returns: |
|
float: The R^2 score of the model. |
|
float: The mean absolute error of the model. |
|
float: The mean squared error of the model. |
|
float: The root mean squared error of the model.""" |
|
model.fit(X_train, y_train) |
|
|
|
feature_importances = model.feature_importances_ |
|
importance_df = pd.DataFrame( |
|
{"Feature": X_train.columns, "Importance": feature_importances} |
|
) |
|
importance_df = importance_df.sort_values(by="Importance", ascending=False) |
|
print(f"Feature Importance: \n{importance_df}") |
|
y_pred = model.predict(X_test) |
|
r2 = model.score(X_test, y_test) |
|
mae = mean_absolute_error(y_test, y_pred) |
|
mse = mean_squared_error(y_test, y_pred) |
|
rmse = root_mean_squared_error(y_test, y_pred) |
|
return r2, mae, mse, rmse |
|
|
|
|
|
def get_best_params(model, X_train, y_train, X_test, y_test, param_grid): |
|
""" |
|
Find the best hyperparameters for the model using grid search. |
|
|
|
Parameters: |
|
model: The model to train. |
|
X_train: The features for the training set. |
|
y_train: The target for the training set. |
|
X_test: The features for the testing set. |
|
y_test: The target for the testing set. |
|
param_grid: The hyperparameters to search over. |
|
|
|
Returns: |
|
dict: The best hyperparameters for the model.""" |
|
grid_search = GridSearchCV(model, param_grid, cv=5, n_jobs=-1) |
|
grid_search.fit(X_train, y_train) |
|
best_params = grid_search.best_params_ |
|
return best_params |
|
|
|
|
|
def plot( |
|
df: pd.DataFrame, |
|
model, |
|
test_start: str, |
|
test_end: str, |
|
target: str, |
|
features: list, |
|
): |
|
""" |
|
Plot the predicted and actual values of the target variable. |
|
|
|
Parameters: |
|
df (pd.DataFrame): The DataFrame containing the data. |
|
model: The model to use. |
|
test_start (str): The start date for the testing set, in the format "YYYY-MM-DD". |
|
test_end (str): The end date for the testing set, in the format "YYYY-MM-DD". |
|
target (str): The target column. |
|
features (list): The features to use. |
|
""" |
|
|
|
dates_all = df["date"] |
|
actual_all = df[target] |
|
dates_new = df[(df["date"] >= test_start) & (df["date"] <= test_end)]["date"] |
|
actual_new = df[(df["date"] >= test_start) & (df["date"] <= test_end)][target] |
|
predicted_new = model.predict( |
|
df[(df["date"] >= test_start) & (df["date"] <= test_end)][features] |
|
) |
|
|
|
plt.figure(figsize=(12, 6)) |
|
plt.scatter( |
|
dates_new, |
|
predicted_new, |
|
color="blue", |
|
alpha=0.7, |
|
label="Predicted CPIH (2022/2023)", |
|
zorder=2, |
|
) |
|
plt.scatter( |
|
dates_new, |
|
actual_new, |
|
color="orange", |
|
alpha=0.7, |
|
label="Actual CPIH (2022/2023)", |
|
zorder=2, |
|
) |
|
plt.plot( |
|
dates_new, predicted_new, color="blue", linestyle="--", alpha=0.7, zorder=1 |
|
) |
|
plt.plot( |
|
dates_all, |
|
actual_all, |
|
color="green", |
|
alpha=0.8, |
|
label="Actual CPIH (All Years)", |
|
zorder=0, |
|
) |
|
|
|
plt.title(f"CPIH Medical: Test on {test_start} to {test_end}", fontsize=14) |
|
plt.xlabel("Date", fontsize=12) |
|
plt.ylabel("CPIH Medical", fontsize=12) |
|
plt.xticks(rotation=45) |
|
plt.legend(fontsize=10) |
|
plt.grid(alpha=0.3) |
|
|
|
plt.tight_layout() |
|
plt.savefig(f"quanti/data/CPIH Medical Test {test_start} to {test_end}.png") |
|
plt.show() |
|
|