File size: 4,466 Bytes
6a9ff1a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import gradio as gr
from model_utils import load_all_models, predict_with_model
# Load all models
models, model_features = load_all_models()
# Mapeo de nombres amigables a nombres reales
MODEL_MAPPING = {
"Death": "Death_random_forest_model",
"Binary diagnosis": "Binary diagnosis_random_forest_model",
"Necessity of transplantation": "Necessity of transplantation_random_forest_model",
"Progressive disease": "Progressive disease_random_forest_model"
}
# Invertir el mapeo (opcional para facilidad)
INVERSE_MODEL_MAPPING = {v: k for k, v in MODEL_MAPPING.items()}
# Feature sets for each target variable
FEATURES = {
"Death": [
'Pedigree', 'Age at diagnosis', 'FVC (L) at diagnosis',
'FVC (%) at diagnosis', 'DLCO (%) at diagnosis', 'RadioWorsening2y',
'Severity of telomere shortening - Transform 4', 'Progressive disease'
],
"Binary diagnosis": [
'Pedigree', 'Age at diagnosis', 'Antifibrotic Drug',
'Prednisone', 'Mycophenolate', 'FVC (L) at diagnosis',
'FVC (%) at diagnosis', 'DLCO (%) at diagnosis'
],
"Necessity of transplantation": [
'Pedigree','Age at diagnosis','FVC (L) at diagnosis', 'FVC (%) at diagnosis', 'DLCO (%) at diagnosis',
'FVC (L) 1 year after diagnosis','FVC (%) 1 year after diagnosis','DLCO (%) 1 year after diagnosis',
'RadioWorsening2y'
],
"Progressive disease": [
'Pedigree', 'Age at diagnosis', 'FVC (L) at diagnosis','FVC (%) at diagnosis', 'DLCO (%) at diagnosis','FVC (L) 1 year after diagnosis',
'FVC (%) 1 year after diagnosis', 'DLCO (%) 1 year after diagnosis',
'RadioWorsening2y', 'Genetic mutation studied in patient'
]
}
FEATURE_RANGES = {
'Pedigree': (0, 67),
'Age at diagnosis': (0, 200),
'FVC (L) at diagnosis': (0.0, 5.0),
'FVC (%) at diagnosis': (0.0, 200.0),
'DLCO (%) at diagnosis': (0.0, 200.0),
'RadioWorsening2y': (0, 3),
'Severity of telomere shortening - Transform 4': (1, 6),
'Progressive disease': (0, 1),
'Antifibrotic Drug': (0, 1),
'Prednisone': (0, 1),
'Mycophenolate': (0, 1),
'FVC (L) 1 year after diagnosis': (0.0, 5.0),
'FVC (%) 1 year after diagnosis': (0.0, 200.0),
'DLCO (%) 1 year after diagnosis': (0.0, 200.0),
'Genetic mutation studied in patient': (0, 1),
'Comorbidities': (0, 1)
}
# Define prediction function
def make_prediction(input_features, friendly_model_name):
"""
Predict using the selected model and input features.
"""
# Map the friendly model name to the real model name
target_model = MODEL_MAPPING.get(friendly_model_name)
if target_model not in models:
return f"Model '{friendly_model_name}' not found. Please select a valid model."
model = models[target_model]
features = model_features[target_model]
if len(input_features) != len(features):
return f"Invalid input. Expected features: {features}"
input_array = [float(x) for x in input_features]
prediction = predict_with_model(model, input_array)
return f"Prediction for {friendly_model_name}: {prediction}"
# Define Gradio interface
def gradio_interface():
def create_inputs_for_features(features):
inputs = []
for feature in features:
min_val, max_val = FEATURE_RANGES.get(feature, (None, None))
inputs.append(gr.Number(label=f"{feature} (Range: {min_val} - {max_val})", minimum=min_val, maximum=max_val))
return inputs
# Create a separate interface for each target variable
interfaces = []
for target, features in FEATURES.items():
inputs = create_inputs_for_features(features)
interface = gr.Interface(
fn=lambda *args, target=target: make_prediction(args, target),
inputs=inputs,
outputs=gr.Text(label="Prediction Result"),
title=f"Prediction for {target}",
description=f"Provide values for features relevant to {target}"
)
interfaces.append(interface)
# Combine all interfaces into a tabbed layout
tabbed_interface = gr.TabbedInterface(
interface_list=interfaces,
tab_names=list(FEATURES.keys())
)
return tabbed_interface
# Launch Gradio app
if __name__ == "__main__":
interface = gradio_interface()
interface.launch()
|