Gordon-H's picture
Update app.py
008464e verified
raw
history blame
4.63 kB
import gradio as gr
import pandas as pd
import numpy as np
import joblib
import onnxruntime as ort
import os
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Set feature names. CRUCIAL - must match your training data
feature_names = ['Age', 'Sex', 'CD4+ T-cell count', 'Viral load', 'WBC count', 'Hemoglobin', 'Platelet count']
# Initialize model and scaler (set to None initially)
ort_session = None
scaler = None
model_loaded = False
scaler_loaded = False
# --- Attempt to Load Model and Scaler ---
try:
# 1. Set the current working directory (as a precaution)
script_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(script_dir)
logging.info(f"Current working directory set to: {os.getcwd()}")
# 2. Check if files exist
model_path = "hiv_model.onnx"
scaler_path = "hiv_scaler.pkl"
if not os.path.exists(model_path):
logging.error(f"Model file not found: {model_path}")
raise FileNotFoundError(f"Model file not found: {model_path}")
if not os.path.exists(scaler_path):
logging.error(f"Scaler file not found: {scaler_path}")
raise FileNotFoundError(f"Scaler file not found: {scaler_path}")
# 3. Load the model and scaler
ort_session = ort.InferenceSession(model_path)
scaler = joblib.load(scaler_path)
model_loaded = True
scaler_loaded = True
logging.info("Model and scaler loaded successfully.")
except FileNotFoundError as e:
logging.error(f"File not found: {e}")
ort_session = None
scaler = None
model_loaded = False
scaler_loaded = False # Make sure these are false if loading fails!
except Exception as e:
logging.exception(f"An error occurred while loading the model or scaler: {e}")
ort_session = None
scaler = None
model_loaded = False
scaler_loaded = False
# Log the full exception traceback for debugging
# --- End Model Loading ---
def predict_risk(age, sex, cd4_count, viral_load, wbc_count, hemoglobin, platelet_count):
"""Predicts HIV risk probability based on input features."""
if not model_loaded or not scaler_loaded:
return "Model or scaler not loaded. Check the logs for errors. Ensure 'hiv_model.onnx' and 'hiv_scaler.pkl' are in the same directory."
try:
# 1. Create a DataFrame
input_data = {
'Age': [age],
'Sex': [0 if sex == "Female" else 1], # Encode Sex
'CD4+ T-cell count': [cd4_count],
'Viral load': [viral_load],
'WBC count': [wbc_count],
'Hemoglobin': [hemoglobin],
'Platelet count': [platelet_count]
}
input_df = pd.DataFrame(input_data)
# 2. Standardize the data
scaled_values = scaler.transform(input_df[feature_names]) #Use ALL features now.
scaled_df = pd.DataFrame(scaled_values, columns=feature_names) #Use ALL feature names now.
# 3. ONNX Prediction
input_array = scaled_df[feature_names].values.astype(np.float32) # Enforce float32
ort_inputs = {ort_session.get_inputs()[0].name: input_array}
ort_outs = ort_session.run(None, ort_inputs)
# 4. Process Output
probabilities = ort_outs[0][0]
risk_probability = probabilities[1] # Probability of High Risk
if 0 < risk_probability <= 100:
return f"HIV Risk Probability: {risk_probability:.4f}"
elif risk_probability > 100:
return f"HIV Risk Probability: 100"
else:
return f"HIV Risk Probability: 0"
except Exception as e:
logging.exception(f"An error occurred during prediction: {e}")
return f"An error occurred during prediction: {e}. Check the logs for details."
# Define Gradio inputs
age_input = gr.Number(label="Age", value=30)
sex_input = gr.Radio(["Female", "Male"], label="Sex", value="Female")
cd4_input = gr.Number(label="CD4+ T-cell count", value=500)
viral_input = gr.Number(label="Viral load", value=10000)
wbc_input = gr.Number(label="WBC count", value=7000)
hemoglobin_input = gr.Number(label="Hemoglobin", value=14.0)
platelet_input = gr.Number(label="Platelet count", value=250000)
# Create Gradio interface
iface = gr.Interface(
fn=predict_risk,
inputs=[age_input, sex_input, cd4_input, viral_input, wbc_input, hemoglobin_input, platelet_input],
outputs="text",
title="Sentinel-P1: HIV Risk Prediction Demo",
description="Enter blood report values to estimate HIV risk. This is a demonstration model and should not be used for medical advice.",
)
iface.launch()