|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
import joblib |
|
import onnxruntime as ort |
|
|
|
|
|
try: |
|
ort_session = ort.InferenceSession("hiv_model.onnx") |
|
scaler = joblib.load("hiv_scaler.pkl") |
|
feature_names = ['Age', 'Sex', 'CD4+ T-cell count', 'Viral load', 'WBC count', 'Hemoglobin', 'Platelet count'] |
|
|
|
model_loaded = True |
|
scaler_loaded = True |
|
except Exception as e: |
|
print(f"Error loading model or scaler: {e}") |
|
model_loaded = False |
|
scaler_loaded = False |
|
ort_session = None |
|
scaler = None |
|
feature_names = [] |
|
|
|
def predict_risk(age, sex, cd4_count, viral_load, wbc_count, hemoglobin, platelet_count): |
|
""" |
|
Predicts HIV risk probability based on input features. |
|
""" |
|
if not model_loaded or not scaler_loaded: |
|
return "Model or scaler not loaded. Please ensure 'hiv_model.onnx' and 'hiv_scaler.pkl' are in the same directory." |
|
|
|
try: |
|
|
|
input_data = { |
|
'Age': [age], |
|
'Sex': [0 if sex == "Female" else 1], |
|
'CD4+ T-cell count': [cd4_count], |
|
'Viral load': [viral_load], |
|
'WBC count': [wbc_count], |
|
'Hemoglobin': [hemoglobin], |
|
'Platelet count': [platelet_count] |
|
} |
|
input_df = pd.DataFrame(input_data) |
|
|
|
|
|
scaled_values = scaler.transform(input_df[feature_names]) |
|
scaled_df = pd.DataFrame(scaled_values, columns=feature_names) |
|
|
|
|
|
input_array = scaled_df[feature_names].values.astype(np.float32) |
|
ort_inputs = {ort_session.get_inputs()[0].name: input_array} |
|
ort_outs = ort_session.run(None, ort_inputs) |
|
|
|
|
|
probabilities = ort_outs[0][0] |
|
risk_probability = probabilities[1] |
|
|
|
return f"High Risk Probability: {risk_probability:.4f}" |
|
|
|
except Exception as e: |
|
return f"An error occurred during prediction: {e}" |
|
|
|
|
|
|
|
age_input = gr.Number(label="Age", value=30) |
|
sex_input = gr.Radio(["Female", "Male"], label="Sex", value="Female") |
|
cd4_input = gr.Number(label="CD4+ T-cell count", value=500) |
|
viral_input = gr.Number(label="Viral load", value=10000) |
|
wbc_input = gr.Number(label="WBC count", value=7000) |
|
hemoglobin_input = gr.Number(label="Hemoglobin", value=14.0) |
|
platelet_input = gr.Number(label="Platelet count", value=250000) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict_risk, |
|
inputs=[age_input, sex_input, cd4_input, viral_input, wbc_input, hemoglobin_input, platelet_input], |
|
outputs="text", |
|
title="Sentinel-P1: HIV Risk Prediction Demo", |
|
description="Enter blood report values to estimate HIV risk. This is a demonstration model and should not be used for medical advice.", |
|
) |
|
|
|
iface.launch() |