EAV123's picture
Update app.py
d62c290 verified
import streamlit as st
import pandas as pd
import re
from config import MODEL_PATH, ENCODER_DIR, OPENAI_API_KEY, OPENAI_BASE_URL
from utils import load_model, load_label_encoders
from prediction import predict_susceptibility
from ai_assistant import initialize_openai_client, get_ai_response
# Load assets
model = load_model(MODEL_PATH)
encoders = load_label_encoders(ENCODER_DIR)
# Initialize OpenAI client
client = initialize_openai_client(OPENAI_API_KEY, OPENAI_BASE_URL)
# Streamlit UI
st.set_page_config(page_title="Microbial Susceptibility Analyzer", layout="wide")
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Home", "Susceptibility Analysis", "Data Upload", "About"])
# Home Page
if page == "Home":
st.title("Microbial Susceptibility Analyzer")
st.image("bacteria.jpeg", use_container_width=True)
st.markdown("""
**Welcome to the Microbial Susceptibility Analyzer!**
This app helps analyze **antibiotic resistance** using **machine learning and rule-based decisions**.
- Predict microbial susceptibility.
- Ask an AI assistant for expert advice.
- Upload datasets for batch predictions.
""")
# Susceptibility Analysis Page
elif page == "Susceptibility Analysis":
st.title("Susceptibility Prediction")
# Initialize session state for messages if not exists
if 'messages' not in st.session_state:
st.session_state.messages = []
# Create two columns for layout
col1, col2 = st.columns([1, 1])
with col1:
with st.form("prediction_form"):
organism = st.selectbox('Organism', options=encoders['organism'].keys())
antibiotic = st.selectbox('Antibiotic', options=encoders['antibiotic'].keys())
was_positive = st.selectbox('Was Positive', options=[1, 0])
submit_button = st.form_submit_button("Predict")
if submit_button:
# Store inputs in session state
st.session_state['current_organism'] = organism
st.session_state['current_antibiotic'] = antibiotic
st.session_state['current_was_positive'] = was_positive
result = predict_susceptibility({
'was_positive': was_positive,
'organism': organism,
'antibiotic': antibiotic
}, model, encoders)
st.subheader("Prediction Results")
if "Error" in result:
st.error(result["Error"])
else:
st.write(f"**Final Decision:** {result['Final Output']}")
st.write(f"**Rule-Based Guidance:** {result['Rule Guidance']}")
st.write(f"**Model Prediction:** {result['Model Prediction']}")
st.write(f"**Decision Explanation:** {result['Decision Reason']}")
# Clear previous messages when new prediction is made
st.session_state.messages = []
with col2:
st.subheader("DeepSeek AI Assistant")
# Only show assistant if a prediction has been made
if 'current_organism' in st.session_state:
st.markdown(f"Ask about **{st.session_state.get('current_organism')}** and **{st.session_state.get('current_antibiotic')}**:")
# Example prompts as buttons
example_prompts = [
"Explain why this combination might show resistance",
"Suggest alternative antibiotics for this organism",
"What resistance mechanisms are common here?",
"How should this result influence treatment decisions?"
]
# Create a unique key for each button
for i, prompt in enumerate(example_prompts):
if st.button(prompt, key=f"prompt_{i}"):
# Create context-enhanced prompt
enhanced_prompt = f"For organism {st.session_state.get('current_organism')} " \
f"and antibiotic {st.session_state.get('current_antibiotic')}: {prompt}"
response = get_ai_response(client, enhanced_prompt)
# Display the user prompt and AI response
st.chat_message("user").markdown(prompt)
st.chat_message("assistant").markdown(response)
# Chat input
user_prompt = st.chat_input("Ask about this result...")
if user_prompt:
# Create context-enhanced prompt
enhanced_prompt = f"For organism {st.session_state.get('current_organism')} " \
f"and antibiotic {st.session_state.get('current_antibiotic')}: {user_prompt}"
response = get_ai_response(client, enhanced_prompt)
# Display the user prompt and AI response
st.chat_message("user").markdown(user_prompt)
st.chat_message("assistant").markdown(response)
else:
st.info("Make a prediction first to get specific AI assistance")
# Data Upload Page
elif page == "Data Upload":
st.title("Batch Prediction: Upload CSV")
# Add sample data download option
st.markdown("### Sample Data")
sample_data = pd.DataFrame({
'organism': ['Escherichia coli', 'Staphylococcus aureus', 'Pseudomonas aeruginosa'],
'antibiotic': ['Amoxicillin', 'Vancomycin', 'Ciprofloxacin'],
'was_positive': [1, 0, 1]
})
csv_sample = sample_data.to_csv(index=False)
st.download_button(
label="Download Sample CSV Template",
data=csv_sample,
file_name="sample_template.csv",
mime='text/csv'
)
uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
if uploaded_file:
try:
# Read the CSV file
df = pd.read_csv(uploaded_file)
# Display preview with column check
st.write("Uploaded Data Preview:", df.head())
# Validate the required columns
required_columns = ['organism', 'antibiotic', 'was_positive']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
st.error(f"The uploaded CSV is missing the following required columns: {', '.join(missing_columns)}")
else:
# Check data types and convert if necessary
if df['was_positive'].dtype != 'int64':
try:
df['was_positive'] = df['was_positive'].astype(int)
st.info("Converted 'was_positive' column to integer type.")
except ValueError:
st.error("The 'was_positive' column must contain only 0 or 1 values.")
st.stop()
# Validate organisms and antibiotics against encoders
invalid_organisms = [org for org in df['organism'].unique() if org not in encoders['organism']]
invalid_antibiotics = [ab for ab in df['antibiotic'].unique() if ab not in encoders['antibiotic']]
if invalid_organisms:
st.warning(f"Found {len(invalid_organisms)} organisms not in the training data: {', '.join(invalid_organisms[:5])}{'...' if len(invalid_organisms) > 5 else ''}")
if invalid_antibiotics:
st.warning(f"Found {len(invalid_antibiotics)} antibiotics not in the training data: {', '.join(invalid_antibiotics[:5])}{'...' if len(invalid_antibiotics) > 5 else ''}")
# Process predictions
if st.button("Predict for Dataset"):
with st.spinner("Processing predictions..."):
# Create a progress bar
progress_bar = st.progress(0)
total_rows = len(df)
# Create a new results DataFrame with the same index as the original
results_df = pd.DataFrame(index=df.index)
results_df["Prediction"] = ""
results_df["Rule Guidance"] = ""
results_df["Model Prediction"] = ""
results_df["Decision Reason"] = ""
# Process each row with error handling
for i, (index, row) in enumerate(df.iterrows()):
try:
# Skip rows with invalid data
if (row['organism'] not in encoders['organism'] or
row['antibiotic'] not in encoders['antibiotic']):
results_df.at[index, "Prediction"] = "Invalid data"
continue
# Extract only the required columns for prediction in specific order
input_data = {
'was_positive': row['was_positive'],
'organism': row['organism'],
'antibiotic': row['antibiotic']
}
# Get full prediction result
result = predict_susceptibility(input_data, model, encoders)
# Store all results
if "Error" in result:
results_df.at[index, "Prediction"] = "Error: " + result["Error"]
else:
results_df.at[index, "Prediction"] = result["Final Output"]
results_df.at[index, "Rule Guidance"] = result["Rule Guidance"]
results_df.at[index, "Model Prediction"] = result["Model Prediction"]
results_df.at[index, "Decision Reason"] = result["Decision Reason"]
except Exception as e:
results_df.at[index, "Prediction"] = f"Error: {str(e)}"
# Update progress bar
progress_bar.progress((i + 1) / total_rows)
# Combine original data with results
df = pd.concat([df, results_df], axis=1)
st.success("Predictions complete!")
# Display results with tabs for different views
tab1, tab2 = st.tabs(["Basic Results", "Detailed Results"])
with tab1:
st.dataframe(df[['organism', 'antibiotic', 'was_positive', 'Prediction']])
with tab2:
st.dataframe(df)
# Download options
col1, col2 = st.columns(2)
with col1:
# Download basic results
csv_basic = df[['organism', 'antibiotic', 'was_positive', 'Prediction']].to_csv(index=False)
st.download_button(
label="Download Basic Results",
data=csv_basic,
file_name="predictions_basic.csv",
mime='text/csv'
)
with col2:
# Download detailed results
csv_detailed = df.to_csv(index=False)
st.download_button(
label="Download Detailed Results",
data=csv_detailed,
file_name="predictions_detailed.csv",
mime='text/csv'
)
except pd.errors.EmptyDataError:
st.error("The uploaded file is empty.")
except pd.errors.ParserError:
st.error("Error parsing the CSV file. Please ensure it's a valid CSV format.")
except Exception as e:
st.error(f"An unexpected error occurred: {str(e)}")
# About Page
elif page == "About":
st.title("About this App")
st.markdown("""
- Developed by **Okunromade Joseph Oluwaseun**
- Uses **Machine Learning & Rule-based AI**
- Integrated with **DeepSeek AI** for advanced queries
- Matric No: 22/SCI01/172
""")