import streamlit as st import pandas as pd import re from config import MODEL_PATH, ENCODER_DIR, OPENAI_API_KEY, OPENAI_BASE_URL from utils import load_model, load_label_encoders from prediction import predict_susceptibility from ai_assistant import initialize_openai_client, get_ai_response # Load assets model = load_model(MODEL_PATH) encoders = load_label_encoders(ENCODER_DIR) # Initialize OpenAI client client = initialize_openai_client(OPENAI_API_KEY, OPENAI_BASE_URL) # Streamlit UI st.set_page_config(page_title="Microbial Susceptibility Analyzer", layout="wide") st.sidebar.title("Navigation") page = st.sidebar.radio("Go to", ["Home", "Susceptibility Analysis", "Data Upload", "About"]) # Home Page if page == "Home": st.title("Microbial Susceptibility Analyzer") st.image("bacteria.jpeg", use_container_width=True) st.markdown(""" **Welcome to the Microbial Susceptibility Analyzer!** This app helps analyze **antibiotic resistance** using **machine learning and rule-based decisions**. - Predict microbial susceptibility. - Ask an AI assistant for expert advice. - Upload datasets for batch predictions. """) # Susceptibility Analysis Page elif page == "Susceptibility Analysis": st.title("Susceptibility Prediction") # Initialize session state for messages if not exists if 'messages' not in st.session_state: st.session_state.messages = [] # Create two columns for layout col1, col2 = st.columns([1, 1]) with col1: with st.form("prediction_form"): organism = st.selectbox('Organism', options=encoders['organism'].keys()) antibiotic = st.selectbox('Antibiotic', options=encoders['antibiotic'].keys()) was_positive = st.selectbox('Was Positive', options=[1, 0]) submit_button = st.form_submit_button("Predict") if submit_button: # Store inputs in session state st.session_state['current_organism'] = organism st.session_state['current_antibiotic'] = antibiotic st.session_state['current_was_positive'] = was_positive result = predict_susceptibility({ 'was_positive': was_positive, 'organism': organism, 'antibiotic': antibiotic }, model, encoders) st.subheader("Prediction Results") if "Error" in result: st.error(result["Error"]) else: st.write(f"**Final Decision:** {result['Final Output']}") st.write(f"**Rule-Based Guidance:** {result['Rule Guidance']}") st.write(f"**Model Prediction:** {result['Model Prediction']}") st.write(f"**Decision Explanation:** {result['Decision Reason']}") # Clear previous messages when new prediction is made st.session_state.messages = [] with col2: st.subheader("DeepSeek AI Assistant") # Only show assistant if a prediction has been made if 'current_organism' in st.session_state: st.markdown(f"Ask about **{st.session_state.get('current_organism')}** and **{st.session_state.get('current_antibiotic')}**:") # Example prompts as buttons example_prompts = [ "Explain why this combination might show resistance", "Suggest alternative antibiotics for this organism", "What resistance mechanisms are common here?", "How should this result influence treatment decisions?" ] # Create a unique key for each button for i, prompt in enumerate(example_prompts): if st.button(prompt, key=f"prompt_{i}"): # Create context-enhanced prompt enhanced_prompt = f"For organism {st.session_state.get('current_organism')} " \ f"and antibiotic {st.session_state.get('current_antibiotic')}: {prompt}" response = get_ai_response(client, enhanced_prompt) # Display the user prompt and AI response st.chat_message("user").markdown(prompt) st.chat_message("assistant").markdown(response) # Chat input user_prompt = st.chat_input("Ask about this result...") if user_prompt: # Create context-enhanced prompt enhanced_prompt = f"For organism {st.session_state.get('current_organism')} " \ f"and antibiotic {st.session_state.get('current_antibiotic')}: {user_prompt}" response = get_ai_response(client, enhanced_prompt) # Display the user prompt and AI response st.chat_message("user").markdown(user_prompt) st.chat_message("assistant").markdown(response) else: st.info("Make a prediction first to get specific AI assistance") # Data Upload Page elif page == "Data Upload": st.title("Batch Prediction: Upload CSV") # Add sample data download option st.markdown("### Sample Data") sample_data = pd.DataFrame({ 'organism': ['Escherichia coli', 'Staphylococcus aureus', 'Pseudomonas aeruginosa'], 'antibiotic': ['Amoxicillin', 'Vancomycin', 'Ciprofloxacin'], 'was_positive': [1, 0, 1] }) csv_sample = sample_data.to_csv(index=False) st.download_button( label="Download Sample CSV Template", data=csv_sample, file_name="sample_template.csv", mime='text/csv' ) uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"]) if uploaded_file: try: # Read the CSV file df = pd.read_csv(uploaded_file) # Display preview with column check st.write("Uploaded Data Preview:", df.head()) # Validate the required columns required_columns = ['organism', 'antibiotic', 'was_positive'] missing_columns = [col for col in required_columns if col not in df.columns] if missing_columns: st.error(f"The uploaded CSV is missing the following required columns: {', '.join(missing_columns)}") else: # Check data types and convert if necessary if df['was_positive'].dtype != 'int64': try: df['was_positive'] = df['was_positive'].astype(int) st.info("Converted 'was_positive' column to integer type.") except ValueError: st.error("The 'was_positive' column must contain only 0 or 1 values.") st.stop() # Validate organisms and antibiotics against encoders invalid_organisms = [org for org in df['organism'].unique() if org not in encoders['organism']] invalid_antibiotics = [ab for ab in df['antibiotic'].unique() if ab not in encoders['antibiotic']] if invalid_organisms: st.warning(f"Found {len(invalid_organisms)} organisms not in the training data: {', '.join(invalid_organisms[:5])}{'...' if len(invalid_organisms) > 5 else ''}") if invalid_antibiotics: st.warning(f"Found {len(invalid_antibiotics)} antibiotics not in the training data: {', '.join(invalid_antibiotics[:5])}{'...' if len(invalid_antibiotics) > 5 else ''}") # Process predictions if st.button("Predict for Dataset"): with st.spinner("Processing predictions..."): # Create a progress bar progress_bar = st.progress(0) total_rows = len(df) # Create a new results DataFrame with the same index as the original results_df = pd.DataFrame(index=df.index) results_df["Prediction"] = "" results_df["Rule Guidance"] = "" results_df["Model Prediction"] = "" results_df["Decision Reason"] = "" # Process each row with error handling for i, (index, row) in enumerate(df.iterrows()): try: # Skip rows with invalid data if (row['organism'] not in encoders['organism'] or row['antibiotic'] not in encoders['antibiotic']): results_df.at[index, "Prediction"] = "Invalid data" continue # Extract only the required columns for prediction in specific order input_data = { 'was_positive': row['was_positive'], 'organism': row['organism'], 'antibiotic': row['antibiotic'] } # Get full prediction result result = predict_susceptibility(input_data, model, encoders) # Store all results if "Error" in result: results_df.at[index, "Prediction"] = "Error: " + result["Error"] else: results_df.at[index, "Prediction"] = result["Final Output"] results_df.at[index, "Rule Guidance"] = result["Rule Guidance"] results_df.at[index, "Model Prediction"] = result["Model Prediction"] results_df.at[index, "Decision Reason"] = result["Decision Reason"] except Exception as e: results_df.at[index, "Prediction"] = f"Error: {str(e)}" # Update progress bar progress_bar.progress((i + 1) / total_rows) # Combine original data with results df = pd.concat([df, results_df], axis=1) st.success("Predictions complete!") # Display results with tabs for different views tab1, tab2 = st.tabs(["Basic Results", "Detailed Results"]) with tab1: st.dataframe(df[['organism', 'antibiotic', 'was_positive', 'Prediction']]) with tab2: st.dataframe(df) # Download options col1, col2 = st.columns(2) with col1: # Download basic results csv_basic = df[['organism', 'antibiotic', 'was_positive', 'Prediction']].to_csv(index=False) st.download_button( label="Download Basic Results", data=csv_basic, file_name="predictions_basic.csv", mime='text/csv' ) with col2: # Download detailed results csv_detailed = df.to_csv(index=False) st.download_button( label="Download Detailed Results", data=csv_detailed, file_name="predictions_detailed.csv", mime='text/csv' ) except pd.errors.EmptyDataError: st.error("The uploaded file is empty.") except pd.errors.ParserError: st.error("Error parsing the CSV file. Please ensure it's a valid CSV format.") except Exception as e: st.error(f"An unexpected error occurred: {str(e)}") # About Page elif page == "About": st.title("About this App") st.markdown(""" - Developed by **Okunromade Joseph Oluwaseun** - Uses **Machine Learning & Rule-based AI** - Integrated with **DeepSeek AI** for advanced queries - Matric No: 22/SCI01/172 """)