import gradio as gr import pandas as pd import numpy as np import joblib import onnxruntime as ort # Load the ONNX model and scaler outside the function for efficiency try: ort_session = ort.InferenceSession("hiv_model.onnx") scaler = joblib.load("hiv_scaler.pkl") feature_names = ['Age', 'Sex', 'CD4+ T-cell count', 'Viral load', 'WBC count', 'Hemoglobin', 'Platelet count'] # Match your training data model_loaded = True scaler_loaded = True except Exception as e: print(f"Error loading model or scaler: {e}") model_loaded = False scaler_loaded = False ort_session = None scaler = None feature_names = [] # Set to empty to avoid errors later def predict_risk(age, sex, cd4_count, viral_load, wbc_count, hemoglobin, platelet_count): """ Predicts HIV risk probability based on input features. """ if not model_loaded or not scaler_loaded: return "Model or scaler not loaded. Please ensure 'hiv_model.onnx' and 'hiv_scaler.pkl' are in the same directory." try: # 1. Create a DataFrame input_data = { 'Age': [age], 'Sex': [0 if sex == "Female" else 1], # Encode Sex 'CD4+ T-cell count': [cd4_count], 'Viral load': [viral_load], 'WBC count': [wbc_count], 'Hemoglobin': [hemoglobin], 'Platelet count': [platelet_count] } input_df = pd.DataFrame(input_data) # 2. Standardize the data scaled_values = scaler.transform(input_df[feature_names]) scaled_df = pd.DataFrame(scaled_values, columns=feature_names) # 3. ONNX Prediction input_array = scaled_df[feature_names].values.astype(np.float32) # Enforce float32 ort_inputs = {ort_session.get_inputs()[0].name: input_array} ort_outs = ort_session.run(None, ort_inputs) # 4. Process Output probabilities = ort_outs[0][0] risk_probability = probabilities[1] # Probability of High Risk return f"High Risk Probability: {risk_probability:.4f}" except Exception as e: return f"An error occurred during prediction: {e}" # Define Gradio inputs age_input = gr.Number(label="Age", value=30) sex_input = gr.Radio(["Female", "Male"], label="Sex", value="Female") cd4_input = gr.Number(label="CD4+ T-cell count", value=500) viral_input = gr.Number(label="Viral load", value=10000) wbc_input = gr.Number(label="WBC count", value=7000) hemoglobin_input = gr.Number(label="Hemoglobin", value=14.0) platelet_input = gr.Number(label="Platelet count", value=250000) # Create Gradio interface iface = gr.Interface( fn=predict_risk, inputs=[age_input, sex_input, cd4_input, viral_input, wbc_input, hemoglobin_input, platelet_input], outputs="text", title="Sentinel-P1: HIV Risk Prediction Demo", description="Enter blood report values to estimate HIV risk. This is a demonstration model and should not be used for medical advice.", ) iface.launch()