|
import gradio as gr |
|
import joblib |
|
import pandas as pd |
|
import numpy as np |
|
import pickle |
|
import tensorflow as tf |
|
|
|
|
|
best_models = [ |
|
("Linear Regression", joblib.load('best_Linear_Regression_model.pkl')), |
|
("Random Forest", joblib.load('best_Random_Forest_model.pkl')), |
|
("Ridge Regression", joblib.load('best_Ridge_Regression_model.pkl')), |
|
("Decision Tree", joblib.load('best_Decision_Tree_model.pkl')), |
|
("MLP", tf.keras.models.load_model('best_mlp_model.h5')), |
|
] |
|
|
|
|
|
scaler = joblib.load('scaler.pkl') |
|
pca = joblib.load('pca.pkl') |
|
|
|
|
|
|
|
|
|
|
|
|
|
attribute_names = [ |
|
'Age', 'Height', 'Weight', 'Diabetes', 'Simvastatin', 'Amiodarone', |
|
'INR', 'Gender', 'Race', 'VKORC1_genotype' |
|
] |
|
|
|
race_columns = [ |
|
'Race_Asian', 'Race_Black', 'Race_Black African', 'Race_Black Caribbean', |
|
'Race_Black or African American', 'Race_Black other', 'Race_Caucasian', |
|
'Race_Chinese', 'Race_Han Chinese', 'Race_Hispanic', 'Race_Indian', |
|
'Race_Intermediate', 'Race_Japanese', 'Race_Korean', 'Race_Malay', |
|
'Race_Other', 'Race_Other (Black British)', 'Race_Other (Hungarian)', |
|
'Race_Other Mixed Race', 'Race_White', 'Race_other' |
|
] |
|
|
|
def predict(*args): |
|
user_features = list(args[:-1]) |
|
model_name = args[-1] |
|
model = dict(best_models)[model_name] |
|
|
|
|
|
user_features[3] = 1 if user_features[3] == 'Yes' else 0 |
|
user_features[4] = 1 if user_features[4] == 'Yes' else 0 |
|
user_features[5] = 1 if user_features[5] == 'Yes' else 0 |
|
user_features[7] = 1 if user_features[7] == 'Male' else 0 |
|
|
|
|
|
race = user_features.pop(8) |
|
race_encoded = [1 if col == race else 0 for col in race_columns] |
|
user_features = user_features[:8] + race_encoded + user_features[8:] |
|
|
|
|
|
vkorc1_genotype = user_features.pop(-1) |
|
vkorc1_genotype_encoded = [1 if vkorc1_genotype == 'A/G' else 0, 1 if vkorc1_genotype == 'G/G' else 0] |
|
user_features += vkorc1_genotype_encoded |
|
|
|
input_data = np.array(user_features).reshape(1, -1) |
|
input_data_scaled = scaler.transform(input_data) |
|
input_data_pca = pca.transform(input_data_scaled) |
|
|
|
print("Input data PCA:", input_data_pca) |
|
|
|
|
|
if model_name == "MLP": |
|
prediction = model.predict(input_data_pca).reshape(-1) |
|
else: |
|
prediction = model.predict(input_data_pca) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {"Therapeutic Dosage": f"{prediction[0]:.2f} mg"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_names = [name for name, _ in best_models] |
|
|
|
|
|
age = gr.inputs.Slider(minimum=10, maximum=90, label='Age') |
|
height = gr.inputs.Slider(minimum=124.968, maximum=202.0, label='Height') |
|
weight = gr.inputs.Slider(minimum=30.0, maximum=237.7, label='Weight') |
|
diabetes = gr.inputs.Dropdown(choices=['Yes', 'No'], label='Diabetes') |
|
simvastatin = gr.inputs.Dropdown(choices=['Yes', 'No'], label='Simvastatin') |
|
amiodarone = gr.inputs.Dropdown(choices=['Yes', 'No'], label='Amiodarone') |
|
inr = gr.inputs.Slider(minimum=0.8, maximum=6.1, label='INR') |
|
gender = gr.inputs.Dropdown(choices=['Male', 'Female'], label='Gender') |
|
race = gr.inputs.Dropdown(choices=[ |
|
'Asian', 'Black', 'Black African', |
|
'Black Caribbean', 'Black or African American', |
|
'Black other', 'Caucasian', 'Chinese', |
|
'Han Chinese', 'Hispanic', 'Indian', 'Intermediate', |
|
'Japanese', 'Korean', 'Malay', 'Other', |
|
'Other (Black British)', 'Other (Hungarian)', |
|
'Other Mixed Race', 'White', 'other' |
|
], label='Race') |
|
vkorc1_genotype = gr.inputs.Dropdown(choices=['A/G', 'G/G'], label='VKORC1_genotype') |
|
model_name = gr.inputs.Dropdown(choices=model_names, label="Model") |
|
|
|
|
|
output = gr.outputs.Textbox(label="Therapeutic Dosage") |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
age, height, weight, diabetes, simvastatin, amiodarone, |
|
inr, gender, race, vkorc1_genotype, model_name |
|
], |
|
outputs=output ,title="Warafarin_dose_Predicition", |
|
description="Select a model and enter user features to predict the therapeutic dose.", |
|
).launch(debug=True) |
|
|
|
|
|
|