File size: 4,744 Bytes
b1446bf
 
 
 
 
2aeb05a
 
b1446bf
2aeb05a
 
 
 
 
 
 
 
 
 
 
 
 
b1446bf
2aeb05a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1446bf
 
 
2aeb05a
 
 
 
 
 
 
b1446bf
2aeb05a
 
 
 
b1446bf
 
2aeb05a
 
 
 
b1446bf
 
2aeb05a
 
b1446bf
2aeb05a
b1446bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2aeb05a
 
b1446bf
 
 
 
 
 
 
 
 
008464e
 
 
 
 
 
b1446bf
 
2aeb05a
 
b1446bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f587c3
b1446bf
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import pandas as pd
import numpy as np
import joblib
import onnxruntime as ort
import os
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Set feature names. CRUCIAL - must match your training data
feature_names = ['Age', 'Sex', 'CD4+ T-cell count', 'Viral load', 'WBC count', 'Hemoglobin', 'Platelet count']

# Initialize model and scaler (set to None initially)
ort_session = None
scaler = None
model_loaded = False
scaler_loaded = False

# --- Attempt to Load Model and Scaler ---
try:
    # 1. Set the current working directory (as a precaution)
    script_dir = os.path.dirname(os.path.abspath(__file__))
    os.chdir(script_dir)
    logging.info(f"Current working directory set to: {os.getcwd()}")

    # 2. Check if files exist
    model_path = "hiv_model.onnx"
    scaler_path = "hiv_scaler.pkl"

    if not os.path.exists(model_path):
        logging.error(f"Model file not found: {model_path}")
        raise FileNotFoundError(f"Model file not found: {model_path}")

    if not os.path.exists(scaler_path):
        logging.error(f"Scaler file not found: {scaler_path}")
        raise FileNotFoundError(f"Scaler file not found: {scaler_path}")

    # 3. Load the model and scaler
    ort_session = ort.InferenceSession(model_path)
    scaler = joblib.load(scaler_path)

    model_loaded = True
    scaler_loaded = True

    logging.info("Model and scaler loaded successfully.")

except FileNotFoundError as e:
    logging.error(f"File not found: {e}")
    ort_session = None
    scaler = None
    model_loaded = False
    scaler_loaded = False # Make sure these are false if loading fails!

except Exception as e:
    logging.exception(f"An error occurred while loading the model or scaler: {e}")
    ort_session = None
    scaler = None
    model_loaded = False
    scaler_loaded = False
    # Log the full exception traceback for debugging
# --- End Model Loading ---

def predict_risk(age, sex, cd4_count, viral_load, wbc_count, hemoglobin, platelet_count):
    """Predicts HIV risk probability based on input features."""

    if not model_loaded or not scaler_loaded:
         return "Model or scaler not loaded.  Check the logs for errors. Ensure 'hiv_model.onnx' and 'hiv_scaler.pkl' are in the same directory."

    try:
        # 1. Create a DataFrame
        input_data = {
            'Age': [age],
            'Sex': [0 if sex == "Female" else 1],  # Encode Sex
            'CD4+ T-cell count': [cd4_count],
            'Viral load': [viral_load],
            'WBC count': [wbc_count],
            'Hemoglobin': [hemoglobin],
            'Platelet count': [platelet_count]
        }
        input_df = pd.DataFrame(input_data)

        # 2. Standardize the data
        scaled_values = scaler.transform(input_df[feature_names]) #Use ALL features now.
        scaled_df = pd.DataFrame(scaled_values, columns=feature_names)  #Use ALL feature names now.

        # 3. ONNX Prediction
        input_array = scaled_df[feature_names].values.astype(np.float32)  # Enforce float32
        ort_inputs = {ort_session.get_inputs()[0].name: input_array}
        ort_outs = ort_session.run(None, ort_inputs)

        # 4. Process Output
        probabilities = ort_outs[0][0]
        risk_probability = probabilities[1] # Probability of High Risk
        if 0 < risk_probability <= 100:
            return f"HIV Risk Probability: {risk_probability:.4f}"
        elif risk_probability > 100:
            return f"HIV Risk Probability: 100"
        else:
            return f"HIV Risk Probability: 0"

    except Exception as e:
        logging.exception(f"An error occurred during prediction: {e}")
        return f"An error occurred during prediction: {e}.  Check the logs for details."

# Define Gradio inputs
age_input = gr.Number(label="Age", value=30)
sex_input = gr.Radio(["Female", "Male"], label="Sex", value="Female")
cd4_input = gr.Number(label="CD4+ T-cell count", value=500)
viral_input = gr.Number(label="Viral load", value=10000)
wbc_input = gr.Number(label="WBC count", value=7000)
hemoglobin_input = gr.Number(label="Hemoglobin", value=14.0)
platelet_input = gr.Number(label="Platelet count", value=250000)

# Create Gradio interface
iface = gr.Interface(
    fn=predict_risk,
    inputs=[age_input, sex_input, cd4_input, viral_input, wbc_input, hemoglobin_input, platelet_input],
    outputs="text",
    title="Sentinel-P1: HIV Risk Prediction Demo",
    description="Enter blood report values to estimate HIV risk. This is a demonstration model and should not be used for medical advice.Low risk : <1% probability of HIV infection, Moderate risk: 1% to 5% probability,High risk: >5% probability",
)

iface.launch()