aaronemmanuel's picture
Update app.py
12d314d verified
import openai
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
import joblib
import gradio as gr
import google.generativeai as gai
# Load the trained classifier model
model = joblib.load('model_pkl')
gai.configure(api_key='AIzaSyAwP55Zlq9KqUBjHWWUjfzHcP4Sr8DVMuk')
# Function to simulate the medical assistant's interaction
def medical_assistant_interaction(pulse_rate, blood_pressure_systolic, blood_pressure_diastolic, temperature_celsius,
headache, muscle_pain, sore_throat, nausea_vomiting, diarrhea, joint_pains, cough):
try:
# Prepare symptom responses
symptom_responses = {
"Headache": headache,
"Muscle Pain": muscle_pain,
"Sore Throat": sore_throat,
"Nausea and Vomiting": nausea_vomiting,
"Diarrhea": diarrhea,
"Joint Pains": joint_pains,
"Dry Cough or Coughing Blood": cough
}
# Combine responses into a single text input
patient_responses = ', '.join([f"{symptom}: {response}" for symptom, response in symptom_responses.items()])
# Compile sensor data for the classifier
features = pd.DataFrame({
'Pulse_Rate': [pulse_rate],
'Blood_Pressure_Systolic': [blood_pressure_systolic],
'Blood_Pressure_Diastolic': [blood_pressure_diastolic],
'Temperature_Celsius': [temperature_celsius]
})
# Make a prediction using the trained classifier
prediction = model.predict(features)[0]
# Use OpenAI to make a combined inference
response_text = f"""
A patient has the following sensor readings:
- Pulse Rate: {pulse_rate} bpm
- Systolic Blood Pressure: {blood_pressure_systolic} mmHg
- Diastolic Blood Pressure: {blood_pressure_diastolic} mmHg
- Temperature: {temperature_celsius}°C
The patient reported the following symptoms: {patient_responses}.
Based on these symptoms, what is the likelihood of Lassa fever? Provide additional follow-up questions if necessary.
"""
response = gai.chat(
model="models/chat-bison-001",
messages=[{"content": response_text}],
#max_output_tokens=150
)
#assistant_message = response.last['content'].strip()
if hasattr(response, 'candidates'):
candidates = response.candidates
if len(candidates) > 0 and hasattr(candidates[0], 'content'):
assistant_message = candidates[0].content.strip()
else:
assistant_message = "Error: No valid response from the model."
else:
assistant_message = "Error: Response object does not have 'candidates' attribute."
# Combine the assistant's message with the model's prediction
if prediction == 1:
model_inference = "Based on the sensor readings, the classifier suggests a high likelihood of Lassa fever."
else:
model_inference = "Based on the sensor readings, the classifier suggests a low likelihood of Lassa fever."
final_inference = f"{assistant_message}\n\n{model_inference}"
return final_inference
except Exception as e:
return f"An error occurred: {str(e)}"
# Gradio interface for interaction
def gradio_interface(pulse_rate, blood_pressure_systolic, blood_pressure_diastolic, temperature_celsius,
headache, muscle_pain, sore_throat, nausea_vomiting, diarrhea, joint_pains, cough):
return medical_assistant_interaction(
pulse_rate, blood_pressure_systolic, blood_pressure_diastolic, temperature_celsius,
headache, muscle_pain, sore_throat, nausea_vomiting, diarrhea, joint_pains, cough
)
# Create a Gradio interface
interface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Number(label="Pulse Rate (bpm)"),
gr.Number(label="Systolic Blood Pressure (mmHg)"),
gr.Number(label="Diastolic Blood Pressure (mmHg)"),
gr.Number(label="Temperature (Celsius)"),
gr.Radio(choices=["Yes", "No"], label="Are you experiencing headaches?"),
gr.Radio(choices=["Yes", "No"], label="Are you experiencing muscle pain?"),
gr.Radio(choices=["Yes", "No"], label="Are you experiencing a sore throat?"),
gr.Radio(choices=["Yes", "No"], label="Are you experiencing nausea and vomiting?"),
gr.Radio(choices=["Yes", "No"], label="Are you experiencing diarrhea?"),
gr.Radio(choices=["Yes", "No"], label="Are you experiencing joint pains?"),
gr.Radio(choices=["Yes", "No"], label="Are you experiencing a dry cough or coughing blood?")
],
outputs="text",
title="Lassa Fever Medical Assistant",
description="This assistant uses sensor data and patient symptom responses to infer the likelihood of Lassa fever."
)
# Launch the Gradio interface
interface.launch()