EAV123's picture
Update app.py
0f1f788 verified
import streamlit as st
import pandas as pd
import re
from config import MODEL_PATH, ENCODER_DIR, OPENAI_API_KEY, OPENAI_BASE_URL
from utils import load_model, load_label_encoders
from prediction import predict_susceptibility
from ai_assistant import initialize_openai_client, get_ai_response
# Load assets
model = load_model(MODEL_PATH)
encoders = load_label_encoders(ENCODER_DIR)
# Initialize OpenAI client
client = initialize_openai_client(OPENAI_API_KEY, OPENAI_BASE_URL)
# Streamlit UI
st.set_page_config(page_title="Microbial Susceptibility Analyzer", layout="wide")
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Home", "Susceptibility Analysis", "Data Upload", "About"])
# Home Page
if page == "Home":
st.title("Microbial Susceptibility Analyzer")
st.image("bacteria.jpeg", use_container_width=True)
st.markdown("""
**Welcome to the Microbial Susceptibility Analyzer!**
This app helps analyze **antibiotic resistance** using **machine learning and rule-based decisions**.
- Predict microbial susceptibility.
- Ask an AI assistant for expert advice.
- Upload datasets for batch predictions.
""")
# Susceptibility Analysis Page
elif page == "Susceptibility Analysis":
st.title("Susceptibility Prediction")
# Initialize session state for messages if not exists
if 'messages' not in st.session_state:
st.session_state.messages = []
# Create two columns for layout
col1, col2 = st.columns([1, 1])
with col1:
with st.form("prediction_form"):
organism = st.selectbox('Organism', options=encoders['organism'].keys())
antibiotic = st.selectbox('Antibiotic', options=encoders['antibiotic'].keys())
was_positive = st.selectbox('Was Positive', options=[1, 0])
submit_button = st.form_submit_button("Predict")
if submit_button:
# Store inputs in session state
st.session_state['current_organism'] = organism
st.session_state['current_antibiotic'] = antibiotic
st.session_state['current_was_positive'] = was_positive
result = predict_susceptibility({
'was_positive': was_positive,
'organism': organism,
'antibiotic': antibiotic
}, model, encoders)
st.subheader("Prediction Results")
if "Error" in result:
st.error(result["Error"])
else:
st.write(f"**Final Decision:** {result['Final Output']}")
st.write(f"**Rule-Based Guidance:** {result['Rule Guidance']}")
st.write(f"**Model Prediction:** {result['Model Prediction']}")
st.write(f"**Decision Explanation:** {result['Decision Reason']}")
# Clear previous messages when new prediction is made
st.session_state.messages = []
with col2:
st.subheader("DeepSeek AI Assistant")
# Only show assistant if a prediction has been made
if 'current_organism' in st.session_state:
st.markdown(f"Ask about **{st.session_state.get('current_organism')}** and **{st.session_state.get('current_antibiotic')}**:")
# Example prompts as buttons
example_prompts = [
"Explain why this combination might show resistance",
"Suggest alternative antibiotics for this organism",
"What resistance mechanisms are common here?",
"How should this result influence treatment decisions?"
]
# Create a unique key for each button
for i, prompt in enumerate(example_prompts):
if st.button(prompt, key=f"prompt_{i}"):
# Create context-enhanced prompt
enhanced_prompt = f"For organism {st.session_state.get('current_organism')} " \
f"and antibiotic {st.session_state.get('current_antibiotic')}: {prompt}"
response = get_ai_response(client, enhanced_prompt)
# Display the user prompt and AI response
st.chat_message("user").markdown(prompt)
st.chat_message("assistant").markdown(response)
# Chat input
user_prompt = st.chat_input("Ask about this result...")
if user_prompt:
# Create context-enhanced prompt
enhanced_prompt = f"For organism {st.session_state.get('current_organism')} " \
f"and antibiotic {st.session_state.get('current_antibiotic')}: {user_prompt}"
response = get_ai_response(client, enhanced_prompt)
# Display the user prompt and AI response
st.chat_message("user").markdown(user_prompt)
st.chat_message("assistant").markdown(response)
else:
st.info("Make a prediction first to get specific AI assistance")
# Data Upload Page
elif page == "Data Upload":
st.title("Batch Prediction: Upload CSV")
uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
if uploaded_file:
df = pd.read_csv(uploaded_file)
st.write("Uploaded Data Preview:", df.head())
if st.button("Predict for Dataset"):
df["Prediction"] = df.apply(lambda row: predict_susceptibility(row.to_dict(), model, encoders)["Final Output"], axis=1)
st.write("Prediction Results", df)
# About Page
elif page == "About":
st.title("About this App")
st.markdown("""
- Developed by **Okunromade Joseph Oluwaseun**
- Uses **Machine Learning & Rule-based AI**
- Integrated with **DeepSeek AI** for advanced queries
- Matric No: 22/SCI01/172
""")