|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
import joblib |
|
import onnxruntime as ort |
|
import os |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
feature_names = ['Age', 'Sex', 'CD4+ T-cell count', 'Viral load', 'WBC count', 'Hemoglobin', 'Platelet count'] |
|
|
|
|
|
ort_session = None |
|
scaler = None |
|
model_loaded = False |
|
scaler_loaded = False |
|
|
|
|
|
try: |
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
os.chdir(script_dir) |
|
logging.info(f"Current working directory set to: {os.getcwd()}") |
|
|
|
|
|
model_path = "hiv_model.onnx" |
|
scaler_path = "hiv_scaler.pkl" |
|
|
|
if not os.path.exists(model_path): |
|
logging.error(f"Model file not found: {model_path}") |
|
raise FileNotFoundError(f"Model file not found: {model_path}") |
|
|
|
if not os.path.exists(scaler_path): |
|
logging.error(f"Scaler file not found: {scaler_path}") |
|
raise FileNotFoundError(f"Scaler file not found: {scaler_path}") |
|
|
|
|
|
ort_session = ort.InferenceSession(model_path) |
|
scaler = joblib.load(scaler_path) |
|
|
|
model_loaded = True |
|
scaler_loaded = True |
|
|
|
logging.info("Model and scaler loaded successfully.") |
|
|
|
except FileNotFoundError as e: |
|
logging.error(f"File not found: {e}") |
|
ort_session = None |
|
scaler = None |
|
model_loaded = False |
|
scaler_loaded = False |
|
|
|
except Exception as e: |
|
logging.exception(f"An error occurred while loading the model or scaler: {e}") |
|
ort_session = None |
|
scaler = None |
|
model_loaded = False |
|
scaler_loaded = False |
|
|
|
|
|
|
|
def predict_risk(age, sex, cd4_count, viral_load, wbc_count, hemoglobin, platelet_count): |
|
"""Predicts HIV risk probability based on input features.""" |
|
|
|
if not model_loaded or not scaler_loaded: |
|
return "Model or scaler not loaded. Check the logs for errors. Ensure 'hiv_model.onnx' and 'hiv_scaler.pkl' are in the same directory." |
|
|
|
try: |
|
|
|
input_data = { |
|
'Age': [age], |
|
'Sex': [0 if sex == "Female" else 1], |
|
'CD4+ T-cell count': [cd4_count], |
|
'Viral load': [viral_load], |
|
'WBC count': [wbc_count], |
|
'Hemoglobin': [hemoglobin], |
|
'Platelet count': [platelet_count] |
|
} |
|
input_df = pd.DataFrame(input_data) |
|
|
|
|
|
scaled_values = scaler.transform(input_df[feature_names]) |
|
scaled_df = pd.DataFrame(scaled_values, columns=feature_names) |
|
|
|
|
|
input_array = scaled_df[feature_names].values.astype(np.float32) |
|
ort_inputs = {ort_session.get_inputs()[0].name: input_array} |
|
ort_outs = ort_session.run(None, ort_inputs) |
|
|
|
|
|
probabilities = ort_outs[0][0] |
|
risk_probability = probabilities[1] |
|
if 0 < risk_probability <= 100: |
|
return f"HIV Risk Probability: {risk_probability:.4f}" |
|
elif risk_probability > 100: |
|
return f"HIV Risk Probability: 100" |
|
else: |
|
return f"HIV Risk Probability: 0" |
|
|
|
except Exception as e: |
|
logging.exception(f"An error occurred during prediction: {e}") |
|
return f"An error occurred during prediction: {e}. Check the logs for details." |
|
|
|
|
|
age_input = gr.Number(label="Age", value=30) |
|
sex_input = gr.Radio(["Female", "Male"], label="Sex", value="Female") |
|
cd4_input = gr.Number(label="CD4+ T-cell count", value=500) |
|
viral_input = gr.Number(label="Viral load", value=10000) |
|
wbc_input = gr.Number(label="WBC count", value=7000) |
|
hemoglobin_input = gr.Number(label="Hemoglobin", value=14.0) |
|
platelet_input = gr.Number(label="Platelet count", value=250000) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict_risk, |
|
inputs=[age_input, sex_input, cd4_input, viral_input, wbc_input, hemoglobin_input, platelet_input], |
|
outputs="text", |
|
title="Sentinel-P1: HIV Risk Prediction Demo", |
|
description="Enter blood report values to estimate HIV risk. This is a demonstration model and should not be used for medical advice.Low risk : <1% probability of HIV infection, Moderate risk: 1% to 5% probability,High risk: >5% probability", |
|
) |
|
|
|
iface.launch() |