|
import streamlit as st |
|
import pandas as pd |
|
import re |
|
from config import MODEL_PATH, ENCODER_DIR, OPENAI_API_KEY, OPENAI_BASE_URL |
|
from utils import load_model, load_label_encoders |
|
from prediction import predict_susceptibility |
|
from ai_assistant import initialize_openai_client, get_ai_response |
|
|
|
|
|
model = load_model(MODEL_PATH) |
|
encoders = load_label_encoders(ENCODER_DIR) |
|
|
|
|
|
client = initialize_openai_client(OPENAI_API_KEY, OPENAI_BASE_URL) |
|
|
|
|
|
st.set_page_config(page_title="Microbial Susceptibility Analyzer", layout="wide") |
|
|
|
st.sidebar.title("Navigation") |
|
page = st.sidebar.radio("Go to", ["Home", "Susceptibility Analysis", "Data Upload", "About"]) |
|
|
|
|
|
if page == "Home": |
|
st.title("Microbial Susceptibility Analyzer") |
|
st.image("bacteria.jpeg", use_container_width=True) |
|
st.markdown(""" |
|
**Welcome to the Microbial Susceptibility Analyzer!** |
|
This app helps analyze **antibiotic resistance** using **machine learning and rule-based decisions**. |
|
- Predict microbial susceptibility. |
|
- Ask an AI assistant for expert advice. |
|
- Upload datasets for batch predictions. |
|
""") |
|
|
|
|
|
elif page == "Susceptibility Analysis": |
|
st.title("Susceptibility Prediction") |
|
|
|
|
|
if 'messages' not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
col1, col2 = st.columns([1, 1]) |
|
|
|
with col1: |
|
with st.form("prediction_form"): |
|
organism = st.selectbox('Organism', options=encoders['organism'].keys()) |
|
antibiotic = st.selectbox('Antibiotic', options=encoders['antibiotic'].keys()) |
|
was_positive = st.selectbox('Was Positive', options=[1, 0]) |
|
submit_button = st.form_submit_button("Predict") |
|
|
|
if submit_button: |
|
|
|
st.session_state['current_organism'] = organism |
|
st.session_state['current_antibiotic'] = antibiotic |
|
st.session_state['current_was_positive'] = was_positive |
|
|
|
result = predict_susceptibility({ |
|
'was_positive': was_positive, |
|
'organism': organism, |
|
'antibiotic': antibiotic |
|
}, model, encoders) |
|
|
|
st.subheader("Prediction Results") |
|
if "Error" in result: |
|
st.error(result["Error"]) |
|
else: |
|
st.write(f"**Final Decision:** {result['Final Output']}") |
|
st.write(f"**Rule-Based Guidance:** {result['Rule Guidance']}") |
|
st.write(f"**Model Prediction:** {result['Model Prediction']}") |
|
st.write(f"**Decision Explanation:** {result['Decision Reason']}") |
|
|
|
|
|
st.session_state.messages = [] |
|
|
|
with col2: |
|
st.subheader("DeepSeek AI Assistant") |
|
|
|
|
|
if 'current_organism' in st.session_state: |
|
st.markdown(f"Ask about **{st.session_state.get('current_organism')}** and **{st.session_state.get('current_antibiotic')}**:") |
|
|
|
|
|
example_prompts = [ |
|
"Explain why this combination might show resistance", |
|
"Suggest alternative antibiotics for this organism", |
|
"What resistance mechanisms are common here?", |
|
"How should this result influence treatment decisions?" |
|
] |
|
|
|
|
|
for i, prompt in enumerate(example_prompts): |
|
if st.button(prompt, key=f"prompt_{i}"): |
|
|
|
enhanced_prompt = f"For organism {st.session_state.get('current_organism')} " \ |
|
f"and antibiotic {st.session_state.get('current_antibiotic')}: {prompt}" |
|
response = get_ai_response(client, enhanced_prompt) |
|
|
|
|
|
st.chat_message("user").markdown(prompt) |
|
st.chat_message("assistant").markdown(response) |
|
|
|
|
|
user_prompt = st.chat_input("Ask about this result...") |
|
if user_prompt: |
|
|
|
enhanced_prompt = f"For organism {st.session_state.get('current_organism')} " \ |
|
f"and antibiotic {st.session_state.get('current_antibiotic')}: {user_prompt}" |
|
response = get_ai_response(client, enhanced_prompt) |
|
|
|
|
|
st.chat_message("user").markdown(user_prompt) |
|
st.chat_message("assistant").markdown(response) |
|
else: |
|
st.info("Make a prediction first to get specific AI assistance") |
|
|
|
|
|
elif page == "Data Upload": |
|
st.title("Batch Prediction: Upload CSV") |
|
uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"]) |
|
|
|
if uploaded_file: |
|
df = pd.read_csv(uploaded_file) |
|
st.write("Uploaded Data Preview:", df.head()) |
|
|
|
if st.button("Predict for Dataset"): |
|
df["Prediction"] = df.apply(lambda row: predict_susceptibility(row.to_dict(), model, encoders)["Final Output"], axis=1) |
|
st.write("Prediction Results", df) |
|
|
|
|
|
elif page == "About": |
|
st.title("About this App") |
|
st.markdown(""" |
|
- Developed by **Okunromade Joseph Oluwaseun** |
|
- Uses **Machine Learning & Rule-based AI** |
|
- Integrated with **DeepSeek AI** for advanced queries |
|
- Matric No: 22/SCI01/172 |
|
""") |