File size: 2,993 Bytes
b1446bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
import pandas as pd
import numpy as np
import joblib
import onnxruntime as ort

# Load the ONNX model and scaler outside the function for efficiency
try:
    ort_session = ort.InferenceSession("hiv_model.onnx")
    scaler = joblib.load("hiv_scaler.pkl")
    feature_names = ['Age', 'Sex', 'CD4+ T-cell count', 'Viral load', 'WBC count', 'Hemoglobin', 'Platelet count']  # Match your training data

    model_loaded = True
    scaler_loaded = True
except Exception as e:
    print(f"Error loading model or scaler: {e}")
    model_loaded = False
    scaler_loaded = False
    ort_session = None
    scaler = None
    feature_names = []  # Set to empty to avoid errors later

def predict_risk(age, sex, cd4_count, viral_load, wbc_count, hemoglobin, platelet_count):
    """
    Predicts HIV risk probability based on input features.
    """
    if not model_loaded or not scaler_loaded:
        return "Model or scaler not loaded. Please ensure 'hiv_model.onnx' and 'hiv_scaler.pkl' are in the same directory."

    try:
        # 1. Create a DataFrame
        input_data = {
            'Age': [age],
            'Sex': [0 if sex == "Female" else 1],  # Encode Sex
            'CD4+ T-cell count': [cd4_count],
            'Viral load': [viral_load],
            'WBC count': [wbc_count],
            'Hemoglobin': [hemoglobin],
            'Platelet count': [platelet_count]
        }
        input_df = pd.DataFrame(input_data)

        # 2. Standardize the data
        scaled_values = scaler.transform(input_df[feature_names])
        scaled_df = pd.DataFrame(scaled_values, columns=feature_names)

        # 3. ONNX Prediction
        input_array = scaled_df[feature_names].values.astype(np.float32)  # Enforce float32
        ort_inputs = {ort_session.get_inputs()[0].name: input_array}
        ort_outs = ort_session.run(None, ort_inputs)

        # 4. Process Output
        probabilities = ort_outs[0][0]
        risk_probability = probabilities[1] # Probability of High Risk

        return f"High Risk Probability: {risk_probability:.4f}"

    except Exception as e:
        return f"An error occurred during prediction: {e}"


# Define Gradio inputs
age_input = gr.Number(label="Age", value=30)
sex_input = gr.Radio(["Female", "Male"], label="Sex", value="Female")
cd4_input = gr.Number(label="CD4+ T-cell count", value=500)
viral_input = gr.Number(label="Viral load", value=10000)
wbc_input = gr.Number(label="WBC count", value=7000)
hemoglobin_input = gr.Number(label="Hemoglobin", value=14.0)
platelet_input = gr.Number(label="Platelet count", value=250000)

# Create Gradio interface
iface = gr.Interface(
    fn=predict_risk,
    inputs=[age_input, sex_input, cd4_input, viral_input, wbc_input, hemoglobin_input, platelet_input],
    outputs="text",
    title="Sentinel-P1: HIV Risk Prediction Demo",
    description="Enter blood report values to estimate HIV risk. This is a demonstration model and should not be used for medical advice.",
)

iface.launch()