Gordon-H commited on
Commit
b1446bf
·
verified ·
1 Parent(s): 31db67b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import joblib
5
+ import onnxruntime as ort
6
+
7
+ # Load the ONNX model and scaler outside the function for efficiency
8
+ try:
9
+ ort_session = ort.InferenceSession("hiv_model.onnx")
10
+ scaler = joblib.load("hiv_scaler.pkl")
11
+ feature_names = ['Age', 'Sex', 'CD4+ T-cell count', 'Viral load', 'WBC count', 'Hemoglobin', 'Platelet count'] # Match your training data
12
+
13
+ model_loaded = True
14
+ scaler_loaded = True
15
+ except Exception as e:
16
+ print(f"Error loading model or scaler: {e}")
17
+ model_loaded = False
18
+ scaler_loaded = False
19
+ ort_session = None
20
+ scaler = None
21
+ feature_names = [] # Set to empty to avoid errors later
22
+
23
+ def predict_risk(age, sex, cd4_count, viral_load, wbc_count, hemoglobin, platelet_count):
24
+ """
25
+ Predicts HIV risk probability based on input features.
26
+ """
27
+ if not model_loaded or not scaler_loaded:
28
+ return "Model or scaler not loaded. Please ensure 'hiv_model.onnx' and 'hiv_scaler.pkl' are in the same directory."
29
+
30
+ try:
31
+ # 1. Create a DataFrame
32
+ input_data = {
33
+ 'Age': [age],
34
+ 'Sex': [0 if sex == "Female" else 1], # Encode Sex
35
+ 'CD4+ T-cell count': [cd4_count],
36
+ 'Viral load': [viral_load],
37
+ 'WBC count': [wbc_count],
38
+ 'Hemoglobin': [hemoglobin],
39
+ 'Platelet count': [platelet_count]
40
+ }
41
+ input_df = pd.DataFrame(input_data)
42
+
43
+ # 2. Standardize the data
44
+ scaled_values = scaler.transform(input_df[feature_names])
45
+ scaled_df = pd.DataFrame(scaled_values, columns=feature_names)
46
+
47
+ # 3. ONNX Prediction
48
+ input_array = scaled_df[feature_names].values.astype(np.float32) # Enforce float32
49
+ ort_inputs = {ort_session.get_inputs()[0].name: input_array}
50
+ ort_outs = ort_session.run(None, ort_inputs)
51
+
52
+ # 4. Process Output
53
+ probabilities = ort_outs[0][0]
54
+ risk_probability = probabilities[1] # Probability of High Risk
55
+
56
+ return f"High Risk Probability: {risk_probability:.4f}"
57
+
58
+ except Exception as e:
59
+ return f"An error occurred during prediction: {e}"
60
+
61
+
62
+ # Define Gradio inputs
63
+ age_input = gr.Number(label="Age", value=30)
64
+ sex_input = gr.Radio(["Female", "Male"], label="Sex", value="Female")
65
+ cd4_input = gr.Number(label="CD4+ T-cell count", value=500)
66
+ viral_input = gr.Number(label="Viral load", value=10000)
67
+ wbc_input = gr.Number(label="WBC count", value=7000)
68
+ hemoglobin_input = gr.Number(label="Hemoglobin", value=14.0)
69
+ platelet_input = gr.Number(label="Platelet count", value=250000)
70
+
71
+ # Create Gradio interface
72
+ iface = gr.Interface(
73
+ fn=predict_risk,
74
+ inputs=[age_input, sex_input, cd4_input, viral_input, wbc_input, hemoglobin_input, platelet_input],
75
+ outputs="text",
76
+ title="Sentinel-P1: HIV Risk Prediction Demo",
77
+ description="Enter blood report values to estimate HIV risk. This is a demonstration model and should not be used for medical advice.",
78
+ )
79
+
80
+ iface.launch()