|
import streamlit as st |
|
import pandas as pd |
|
import re |
|
from config import MODEL_PATH, ENCODER_DIR, OPENAI_API_KEY, OPENAI_BASE_URL |
|
from utils import load_model, load_label_encoders |
|
from prediction import predict_susceptibility |
|
from ai_assistant import initialize_openai_client, get_ai_response |
|
|
|
|
|
model = load_model(MODEL_PATH) |
|
encoders = load_label_encoders(ENCODER_DIR) |
|
|
|
|
|
client = initialize_openai_client(OPENAI_API_KEY, OPENAI_BASE_URL) |
|
|
|
|
|
st.set_page_config(page_title="Microbial Susceptibility Analyzer", layout="wide") |
|
|
|
st.sidebar.title("Navigation") |
|
page = st.sidebar.radio("Go to", ["Home", "Susceptibility Analysis", "Data Upload", "About"]) |
|
|
|
|
|
if page == "Home": |
|
st.title("Microbial Susceptibility Analyzer") |
|
st.image("bacteria.jpeg", use_container_width=True) |
|
st.markdown(""" |
|
**Welcome to the Microbial Susceptibility Analyzer!** |
|
This app helps analyze **antibiotic resistance** using **machine learning and rule-based decisions**. |
|
- Predict microbial susceptibility. |
|
- Ask an AI assistant for expert advice. |
|
- Upload datasets for batch predictions. |
|
""") |
|
|
|
|
|
elif page == "Susceptibility Analysis": |
|
st.title("Susceptibility Prediction") |
|
|
|
|
|
if 'messages' not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
col1, col2 = st.columns([1, 1]) |
|
|
|
with col1: |
|
with st.form("prediction_form"): |
|
organism = st.selectbox('Organism', options=encoders['organism'].keys()) |
|
antibiotic = st.selectbox('Antibiotic', options=encoders['antibiotic'].keys()) |
|
was_positive = st.selectbox('Was Positive', options=[1, 0]) |
|
submit_button = st.form_submit_button("Predict") |
|
|
|
if submit_button: |
|
|
|
st.session_state['current_organism'] = organism |
|
st.session_state['current_antibiotic'] = antibiotic |
|
st.session_state['current_was_positive'] = was_positive |
|
|
|
result = predict_susceptibility({ |
|
'was_positive': was_positive, |
|
'organism': organism, |
|
'antibiotic': antibiotic |
|
}, model, encoders) |
|
|
|
st.subheader("Prediction Results") |
|
if "Error" in result: |
|
st.error(result["Error"]) |
|
else: |
|
st.write(f"**Final Decision:** {result['Final Output']}") |
|
st.write(f"**Rule-Based Guidance:** {result['Rule Guidance']}") |
|
st.write(f"**Model Prediction:** {result['Model Prediction']}") |
|
st.write(f"**Decision Explanation:** {result['Decision Reason']}") |
|
|
|
|
|
st.session_state.messages = [] |
|
|
|
with col2: |
|
st.subheader("DeepSeek AI Assistant") |
|
|
|
|
|
if 'current_organism' in st.session_state: |
|
st.markdown(f"Ask about **{st.session_state.get('current_organism')}** and **{st.session_state.get('current_antibiotic')}**:") |
|
|
|
|
|
example_prompts = [ |
|
"Explain why this combination might show resistance", |
|
"Suggest alternative antibiotics for this organism", |
|
"What resistance mechanisms are common here?", |
|
"How should this result influence treatment decisions?" |
|
] |
|
|
|
|
|
for i, prompt in enumerate(example_prompts): |
|
if st.button(prompt, key=f"prompt_{i}"): |
|
|
|
enhanced_prompt = f"For organism {st.session_state.get('current_organism')} " \ |
|
f"and antibiotic {st.session_state.get('current_antibiotic')}: {prompt}" |
|
response = get_ai_response(client, enhanced_prompt) |
|
|
|
|
|
st.chat_message("user").markdown(prompt) |
|
st.chat_message("assistant").markdown(response) |
|
|
|
|
|
user_prompt = st.chat_input("Ask about this result...") |
|
if user_prompt: |
|
|
|
enhanced_prompt = f"For organism {st.session_state.get('current_organism')} " \ |
|
f"and antibiotic {st.session_state.get('current_antibiotic')}: {user_prompt}" |
|
response = get_ai_response(client, enhanced_prompt) |
|
|
|
|
|
st.chat_message("user").markdown(user_prompt) |
|
st.chat_message("assistant").markdown(response) |
|
else: |
|
st.info("Make a prediction first to get specific AI assistance") |
|
|
|
|
|
elif page == "Data Upload": |
|
st.title("Batch Prediction: Upload CSV") |
|
|
|
|
|
st.markdown("### Sample Data") |
|
sample_data = pd.DataFrame({ |
|
'organism': ['Escherichia coli', 'Staphylococcus aureus', 'Pseudomonas aeruginosa'], |
|
'antibiotic': ['Amoxicillin', 'Vancomycin', 'Ciprofloxacin'], |
|
'was_positive': [1, 0, 1] |
|
}) |
|
csv_sample = sample_data.to_csv(index=False) |
|
st.download_button( |
|
label="Download Sample CSV Template", |
|
data=csv_sample, |
|
file_name="sample_template.csv", |
|
mime='text/csv' |
|
) |
|
|
|
uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"]) |
|
|
|
if uploaded_file: |
|
try: |
|
|
|
df = pd.read_csv(uploaded_file) |
|
|
|
|
|
st.write("Uploaded Data Preview:", df.head()) |
|
|
|
|
|
required_columns = ['organism', 'antibiotic', 'was_positive'] |
|
missing_columns = [col for col in required_columns if col not in df.columns] |
|
|
|
if missing_columns: |
|
st.error(f"The uploaded CSV is missing the following required columns: {', '.join(missing_columns)}") |
|
else: |
|
|
|
if df['was_positive'].dtype != 'int64': |
|
try: |
|
df['was_positive'] = df['was_positive'].astype(int) |
|
st.info("Converted 'was_positive' column to integer type.") |
|
except ValueError: |
|
st.error("The 'was_positive' column must contain only 0 or 1 values.") |
|
st.stop() |
|
|
|
|
|
invalid_organisms = [org for org in df['organism'].unique() if org not in encoders['organism']] |
|
invalid_antibiotics = [ab for ab in df['antibiotic'].unique() if ab not in encoders['antibiotic']] |
|
|
|
if invalid_organisms: |
|
st.warning(f"Found {len(invalid_organisms)} organisms not in the training data: {', '.join(invalid_organisms[:5])}{'...' if len(invalid_organisms) > 5 else ''}") |
|
|
|
if invalid_antibiotics: |
|
st.warning(f"Found {len(invalid_antibiotics)} antibiotics not in the training data: {', '.join(invalid_antibiotics[:5])}{'...' if len(invalid_antibiotics) > 5 else ''}") |
|
|
|
|
|
if st.button("Predict for Dataset"): |
|
with st.spinner("Processing predictions..."): |
|
|
|
progress_bar = st.progress(0) |
|
total_rows = len(df) |
|
|
|
|
|
results_df = pd.DataFrame(index=df.index) |
|
results_df["Prediction"] = "" |
|
results_df["Rule Guidance"] = "" |
|
results_df["Model Prediction"] = "" |
|
results_df["Decision Reason"] = "" |
|
|
|
|
|
for i, (index, row) in enumerate(df.iterrows()): |
|
try: |
|
|
|
if (row['organism'] not in encoders['organism'] or |
|
row['antibiotic'] not in encoders['antibiotic']): |
|
results_df.at[index, "Prediction"] = "Invalid data" |
|
continue |
|
|
|
|
|
input_data = { |
|
'was_positive': row['was_positive'], |
|
'organism': row['organism'], |
|
'antibiotic': row['antibiotic'] |
|
} |
|
|
|
|
|
result = predict_susceptibility(input_data, model, encoders) |
|
|
|
|
|
if "Error" in result: |
|
results_df.at[index, "Prediction"] = "Error: " + result["Error"] |
|
else: |
|
results_df.at[index, "Prediction"] = result["Final Output"] |
|
results_df.at[index, "Rule Guidance"] = result["Rule Guidance"] |
|
results_df.at[index, "Model Prediction"] = result["Model Prediction"] |
|
results_df.at[index, "Decision Reason"] = result["Decision Reason"] |
|
except Exception as e: |
|
results_df.at[index, "Prediction"] = f"Error: {str(e)}" |
|
|
|
|
|
progress_bar.progress((i + 1) / total_rows) |
|
|
|
|
|
df = pd.concat([df, results_df], axis=1) |
|
|
|
st.success("Predictions complete!") |
|
|
|
|
|
tab1, tab2 = st.tabs(["Basic Results", "Detailed Results"]) |
|
|
|
with tab1: |
|
st.dataframe(df[['organism', 'antibiotic', 'was_positive', 'Prediction']]) |
|
|
|
with tab2: |
|
st.dataframe(df) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
|
|
csv_basic = df[['organism', 'antibiotic', 'was_positive', 'Prediction']].to_csv(index=False) |
|
st.download_button( |
|
label="Download Basic Results", |
|
data=csv_basic, |
|
file_name="predictions_basic.csv", |
|
mime='text/csv' |
|
) |
|
|
|
with col2: |
|
|
|
csv_detailed = df.to_csv(index=False) |
|
st.download_button( |
|
label="Download Detailed Results", |
|
data=csv_detailed, |
|
file_name="predictions_detailed.csv", |
|
mime='text/csv' |
|
) |
|
except pd.errors.EmptyDataError: |
|
st.error("The uploaded file is empty.") |
|
except pd.errors.ParserError: |
|
st.error("Error parsing the CSV file. Please ensure it's a valid CSV format.") |
|
except Exception as e: |
|
st.error(f"An unexpected error occurred: {str(e)}") |
|
|
|
|
|
elif page == "About": |
|
st.title("About this App") |
|
st.markdown(""" |
|
- Developed by **Okunromade Joseph Oluwaseun** |
|
- Uses **Machine Learning & Rule-based AI** |
|
- Integrated with **DeepSeek AI** for advanced queries |
|
- Matric No: 22/SCI01/172 |
|
""") |