|
import gradio as gr |
|
from model_utils import load_all_models, predict_with_model |
|
|
|
|
|
models, model_features = load_all_models() |
|
|
|
|
|
MODEL_MAPPING = { |
|
"Death": "Death_random_forest_model", |
|
"Binary diagnosis": "Binary diagnosis_random_forest_model", |
|
"Necessity of transplantation": "Necessity of transplantation_random_forest_model", |
|
"Progressive disease": "Progressive disease_random_forest_model" |
|
} |
|
|
|
|
|
INVERSE_MODEL_MAPPING = {v: k for k, v in MODEL_MAPPING.items()} |
|
|
|
|
|
FEATURES = { |
|
"Death": [ |
|
'Pedigree', 'Age at diagnosis', 'FVC (L) at diagnosis', |
|
'FVC (%) at diagnosis', 'DLCO (%) at diagnosis', 'RadioWorsening2y', |
|
'Severity of telomere shortening - Transform 4', 'Progressive disease' |
|
], |
|
"Binary diagnosis": [ |
|
'Pedigree', 'Age at diagnosis', 'Antifibrotic Drug', |
|
'Prednisone', 'Mycophenolate', 'FVC (L) at diagnosis', |
|
'FVC (%) at diagnosis', 'DLCO (%) at diagnosis' |
|
], |
|
"Necessity of transplantation": [ |
|
'Pedigree','Age at diagnosis','FVC (L) at diagnosis', 'FVC (%) at diagnosis', 'DLCO (%) at diagnosis', |
|
'FVC (L) 1 year after diagnosis','FVC (%) 1 year after diagnosis','DLCO (%) 1 year after diagnosis', |
|
'RadioWorsening2y' |
|
], |
|
"Progressive disease": [ |
|
'Pedigree', 'Age at diagnosis', 'FVC (L) at diagnosis','FVC (%) at diagnosis', 'DLCO (%) at diagnosis','FVC (L) 1 year after diagnosis', |
|
'FVC (%) 1 year after diagnosis', 'DLCO (%) 1 year after diagnosis', |
|
'RadioWorsening2y', 'Genetic mutation studied in patient' |
|
] |
|
|
|
} |
|
|
|
FEATURE_RANGES = { |
|
'Pedigree': (0, 67), |
|
'Age at diagnosis': (36.0, 92.0), |
|
'FVC (L) at diagnosis': (0.0, 5.0), |
|
'FVC (%) at diagnosis': (0.0, 200.0), |
|
'DLCO (%) at diagnosis': (0.0, 200.0), |
|
'RadioWorsening2y': (0, 3), |
|
'Severity of telomere shortening - Transform 4': (1, 6), |
|
'Progressive disease': (0, 1), |
|
'Antifibrotic Drug': (0, 1), |
|
'Prednisone': (0, 1), |
|
'Mycophenolate': (0, 1), |
|
'FVC (L) 1 year after diagnosis': (0.0, 5.0), |
|
'FVC (%) 1 year after diagnosis': (0.0, 200.0), |
|
'DLCO (%) 1 year after diagnosis': (0.0, 200.0), |
|
'Genetic mutation studied in patient': (0, 1), |
|
'Comorbidities': (0, 1) |
|
} |
|
|
|
|
|
|
|
def make_prediction(input_features, friendly_model_name): |
|
""" |
|
Predict using the selected model and input features. |
|
""" |
|
|
|
target_model = MODEL_MAPPING.get(friendly_model_name) |
|
if target_model not in models: |
|
return f"Model '{friendly_model_name}' not found. Please select a valid model." |
|
|
|
model = models[target_model] |
|
features = model_features[target_model] |
|
|
|
if len(input_features) != len(features): |
|
return f"Invalid input. Expected features: {features}" |
|
|
|
input_array = [float(x) for x in input_features] |
|
prediction = predict_with_model(model, input_array) |
|
return f"Prediction for {friendly_model_name}: {prediction}" |
|
|
|
|
|
def gradio_interface(): |
|
def create_inputs_for_features(features): |
|
inputs = [] |
|
for feature in features: |
|
min_val, max_val = FEATURE_RANGES.get(feature, (None, None)) |
|
inputs.append(gr.Number(label=f"{feature} (Range: {min_val} - {max_val})", minimum=min_val, maximum=max_val)) |
|
return inputs |
|
|
|
|
|
interfaces = [] |
|
for target, features in FEATURES.items(): |
|
inputs = create_inputs_for_features(features) |
|
interface = gr.Interface( |
|
fn=lambda *args, target=target: make_prediction(args, target), |
|
inputs=inputs, |
|
outputs=gr.Text(label="Prediction Result"), |
|
title=f"Prediction for {target}", |
|
description=f"Provide values for features relevant to {target}" |
|
) |
|
interfaces.append(interface) |
|
|
|
|
|
tabbed_interface = gr.TabbedInterface( |
|
interface_list=interfaces, |
|
tab_names=list(FEATURES.keys()) |
|
) |
|
return tabbed_interface |
|
|
|
|
|
if __name__ == "__main__": |
|
interface = gradio_interface() |
|
interface.launch() |
|
|