File size: 4,744 Bytes
b1446bf 2aeb05a b1446bf 2aeb05a b1446bf 2aeb05a b1446bf 2aeb05a b1446bf 2aeb05a b1446bf 2aeb05a b1446bf 2aeb05a b1446bf 2aeb05a b1446bf 2aeb05a b1446bf 008464e b1446bf 2aeb05a b1446bf 5f587c3 b1446bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import gradio as gr
import pandas as pd
import numpy as np
import joblib
import onnxruntime as ort
import os
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Set feature names. CRUCIAL - must match your training data
feature_names = ['Age', 'Sex', 'CD4+ T-cell count', 'Viral load', 'WBC count', 'Hemoglobin', 'Platelet count']
# Initialize model and scaler (set to None initially)
ort_session = None
scaler = None
model_loaded = False
scaler_loaded = False
# --- Attempt to Load Model and Scaler ---
try:
# 1. Set the current working directory (as a precaution)
script_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(script_dir)
logging.info(f"Current working directory set to: {os.getcwd()}")
# 2. Check if files exist
model_path = "hiv_model.onnx"
scaler_path = "hiv_scaler.pkl"
if not os.path.exists(model_path):
logging.error(f"Model file not found: {model_path}")
raise FileNotFoundError(f"Model file not found: {model_path}")
if not os.path.exists(scaler_path):
logging.error(f"Scaler file not found: {scaler_path}")
raise FileNotFoundError(f"Scaler file not found: {scaler_path}")
# 3. Load the model and scaler
ort_session = ort.InferenceSession(model_path)
scaler = joblib.load(scaler_path)
model_loaded = True
scaler_loaded = True
logging.info("Model and scaler loaded successfully.")
except FileNotFoundError as e:
logging.error(f"File not found: {e}")
ort_session = None
scaler = None
model_loaded = False
scaler_loaded = False # Make sure these are false if loading fails!
except Exception as e:
logging.exception(f"An error occurred while loading the model or scaler: {e}")
ort_session = None
scaler = None
model_loaded = False
scaler_loaded = False
# Log the full exception traceback for debugging
# --- End Model Loading ---
def predict_risk(age, sex, cd4_count, viral_load, wbc_count, hemoglobin, platelet_count):
"""Predicts HIV risk probability based on input features."""
if not model_loaded or not scaler_loaded:
return "Model or scaler not loaded. Check the logs for errors. Ensure 'hiv_model.onnx' and 'hiv_scaler.pkl' are in the same directory."
try:
# 1. Create a DataFrame
input_data = {
'Age': [age],
'Sex': [0 if sex == "Female" else 1], # Encode Sex
'CD4+ T-cell count': [cd4_count],
'Viral load': [viral_load],
'WBC count': [wbc_count],
'Hemoglobin': [hemoglobin],
'Platelet count': [platelet_count]
}
input_df = pd.DataFrame(input_data)
# 2. Standardize the data
scaled_values = scaler.transform(input_df[feature_names]) #Use ALL features now.
scaled_df = pd.DataFrame(scaled_values, columns=feature_names) #Use ALL feature names now.
# 3. ONNX Prediction
input_array = scaled_df[feature_names].values.astype(np.float32) # Enforce float32
ort_inputs = {ort_session.get_inputs()[0].name: input_array}
ort_outs = ort_session.run(None, ort_inputs)
# 4. Process Output
probabilities = ort_outs[0][0]
risk_probability = probabilities[1] # Probability of High Risk
if 0 < risk_probability <= 100:
return f"HIV Risk Probability: {risk_probability:.4f}"
elif risk_probability > 100:
return f"HIV Risk Probability: 100"
else:
return f"HIV Risk Probability: 0"
except Exception as e:
logging.exception(f"An error occurred during prediction: {e}")
return f"An error occurred during prediction: {e}. Check the logs for details."
# Define Gradio inputs
age_input = gr.Number(label="Age", value=30)
sex_input = gr.Radio(["Female", "Male"], label="Sex", value="Female")
cd4_input = gr.Number(label="CD4+ T-cell count", value=500)
viral_input = gr.Number(label="Viral load", value=10000)
wbc_input = gr.Number(label="WBC count", value=7000)
hemoglobin_input = gr.Number(label="Hemoglobin", value=14.0)
platelet_input = gr.Number(label="Platelet count", value=250000)
# Create Gradio interface
iface = gr.Interface(
fn=predict_risk,
inputs=[age_input, sex_input, cd4_input, viral_input, wbc_input, hemoglobin_input, platelet_input],
outputs="text",
title="Sentinel-P1: HIV Risk Prediction Demo",
description="Enter blood report values to estimate HIV risk. This is a demonstration model and should not be used for medical advice.Low risk : <1% probability of HIV infection, Moderate risk: 1% to 5% probability,High risk: >5% probability",
)
iface.launch() |