File size: 12,926 Bytes
a7ced43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1af97d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7ced43
 
 
5ae966a
1af97d1
5ae966a
1af97d1
 
5ae966a
 
 
 
 
 
 
 
 
1af97d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ae966a
 
 
1af97d1
 
 
 
79ec034
 
 
 
 
 
1af97d1
 
 
 
 
 
 
79ec034
1af97d1
 
d62c290
79ec034
d62c290
79ec034
d62c290
79ec034
 
1af97d1
79ec034
1af97d1
 
 
79ec034
1af97d1
79ec034
 
 
 
1af97d1
79ec034
1af97d1
 
 
79ec034
 
 
5ae966a
 
 
1af97d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ae966a
1af97d1
a7ced43
 
 
 
 
0f1f788
a7ced43
 
0f1f788
a7ced43
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
import streamlit as st
import pandas as pd
import re
from config import MODEL_PATH, ENCODER_DIR, OPENAI_API_KEY, OPENAI_BASE_URL
from utils import load_model, load_label_encoders
from prediction import predict_susceptibility
from ai_assistant import initialize_openai_client, get_ai_response

# Load assets
model = load_model(MODEL_PATH)
encoders = load_label_encoders(ENCODER_DIR)

# Initialize OpenAI client
client = initialize_openai_client(OPENAI_API_KEY, OPENAI_BASE_URL)

# Streamlit UI
st.set_page_config(page_title="Microbial Susceptibility Analyzer", layout="wide")

st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Home", "Susceptibility Analysis", "Data Upload", "About"])

# Home Page
if page == "Home":
    st.title("Microbial Susceptibility Analyzer")
    st.image("bacteria.jpeg", use_container_width=True)
    st.markdown("""
    **Welcome to the Microbial Susceptibility Analyzer!**  
    This app helps analyze **antibiotic resistance** using **machine learning and rule-based decisions**.
    - Predict microbial susceptibility.
    - Ask an AI assistant for expert advice.
    - Upload datasets for batch predictions.
    """)

# Susceptibility Analysis Page
elif page == "Susceptibility Analysis":
    st.title("Susceptibility Prediction")
    
    # Initialize session state for messages if not exists
    if 'messages' not in st.session_state:
        st.session_state.messages = []
    
    # Create two columns for layout
    col1, col2 = st.columns([1, 1])
    
    with col1:
        with st.form("prediction_form"):
            organism = st.selectbox('Organism', options=encoders['organism'].keys())
            antibiotic = st.selectbox('Antibiotic', options=encoders['antibiotic'].keys())
            was_positive = st.selectbox('Was Positive', options=[1, 0])
            submit_button = st.form_submit_button("Predict")
        
        if submit_button:
            # Store inputs in session state
            st.session_state['current_organism'] = organism
            st.session_state['current_antibiotic'] = antibiotic
            st.session_state['current_was_positive'] = was_positive
            
            result = predict_susceptibility({
                'was_positive': was_positive,
                'organism': organism,
                'antibiotic': antibiotic
            }, model, encoders)
            
            st.subheader("Prediction Results")
            if "Error" in result:
                st.error(result["Error"])
            else:
                st.write(f"**Final Decision:** {result['Final Output']}")
                st.write(f"**Rule-Based Guidance:** {result['Rule Guidance']}")
                st.write(f"**Model Prediction:** {result['Model Prediction']}")
                st.write(f"**Decision Explanation:** {result['Decision Reason']}")
                
                # Clear previous messages when new prediction is made
                st.session_state.messages = []
    
    with col2:
        st.subheader("DeepSeek AI Assistant")
        
        # Only show assistant if a prediction has been made
        if 'current_organism' in st.session_state:
            st.markdown(f"Ask about **{st.session_state.get('current_organism')}** and **{st.session_state.get('current_antibiotic')}**:")
            
            # Example prompts as buttons
            example_prompts = [
                "Explain why this combination might show resistance",
                "Suggest alternative antibiotics for this organism",
                "What resistance mechanisms are common here?",
                "How should this result influence treatment decisions?"
            ]
            
            # Create a unique key for each button
            for i, prompt in enumerate(example_prompts):
                if st.button(prompt, key=f"prompt_{i}"):
                    # Create context-enhanced prompt
                    enhanced_prompt = f"For organism {st.session_state.get('current_organism')} " \
                                     f"and antibiotic {st.session_state.get('current_antibiotic')}: {prompt}"
                    response = get_ai_response(client, enhanced_prompt)
                    
                    # Display the user prompt and AI response
                    st.chat_message("user").markdown(prompt)
                    st.chat_message("assistant").markdown(response)
            
            # Chat input
            user_prompt = st.chat_input("Ask about this result...")
            if user_prompt:
                # Create context-enhanced prompt
                enhanced_prompt = f"For organism {st.session_state.get('current_organism')} " \
                                 f"and antibiotic {st.session_state.get('current_antibiotic')}: {user_prompt}"
                response = get_ai_response(client, enhanced_prompt)
                
                # Display the user prompt and AI response
                st.chat_message("user").markdown(user_prompt)
                st.chat_message("assistant").markdown(response)
        else:
            st.info("Make a prediction first to get specific AI assistance")

# Data Upload Page
elif page == "Data Upload":
    st.title("Batch Prediction: Upload CSV")
    
    # Add sample data download option
    st.markdown("### Sample Data")
    sample_data = pd.DataFrame({
        'organism': ['Escherichia coli', 'Staphylococcus aureus', 'Pseudomonas aeruginosa'],
        'antibiotic': ['Amoxicillin', 'Vancomycin', 'Ciprofloxacin'],
        'was_positive': [1, 0, 1]
    })
    csv_sample = sample_data.to_csv(index=False)
    st.download_button(
        label="Download Sample CSV Template",
        data=csv_sample,
        file_name="sample_template.csv",
        mime='text/csv'
    )
    
    uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
    
    if uploaded_file:
        try:
            # Read the CSV file
            df = pd.read_csv(uploaded_file)
            
            # Display preview with column check
            st.write("Uploaded Data Preview:", df.head())
            
            # Validate the required columns
            required_columns = ['organism', 'antibiotic', 'was_positive']
            missing_columns = [col for col in required_columns if col not in df.columns]
            
            if missing_columns:
                st.error(f"The uploaded CSV is missing the following required columns: {', '.join(missing_columns)}")
            else:
                # Check data types and convert if necessary
                if df['was_positive'].dtype != 'int64':
                    try:
                        df['was_positive'] = df['was_positive'].astype(int)
                        st.info("Converted 'was_positive' column to integer type.")
                    except ValueError:
                        st.error("The 'was_positive' column must contain only 0 or 1 values.")
                        st.stop()
                
                # Validate organisms and antibiotics against encoders
                invalid_organisms = [org for org in df['organism'].unique() if org not in encoders['organism']]
                invalid_antibiotics = [ab for ab in df['antibiotic'].unique() if ab not in encoders['antibiotic']]
                
                if invalid_organisms:
                    st.warning(f"Found {len(invalid_organisms)} organisms not in the training data: {', '.join(invalid_organisms[:5])}{'...' if len(invalid_organisms) > 5 else ''}")
                
                if invalid_antibiotics:
                    st.warning(f"Found {len(invalid_antibiotics)} antibiotics not in the training data: {', '.join(invalid_antibiotics[:5])}{'...' if len(invalid_antibiotics) > 5 else ''}")
                
                # Process predictions
                if st.button("Predict for Dataset"):
                    with st.spinner("Processing predictions..."):
                        # Create a progress bar
                        progress_bar = st.progress(0)
                        total_rows = len(df)
                        
                        # Create a new results DataFrame with the same index as the original
                        results_df = pd.DataFrame(index=df.index)
                        results_df["Prediction"] = ""
                        results_df["Rule Guidance"] = ""
                        results_df["Model Prediction"] = ""
                        results_df["Decision Reason"] = ""
                        
                        # Process each row with error handling
                        for i, (index, row) in enumerate(df.iterrows()):
                            try:
                                # Skip rows with invalid data
                                if (row['organism'] not in encoders['organism'] or 
                                    row['antibiotic'] not in encoders['antibiotic']):
                                    results_df.at[index, "Prediction"] = "Invalid data"
                                    continue
                                
                                # Extract only the required columns for prediction in specific order
                                input_data = {
                                    'was_positive': row['was_positive'],
                                    'organism': row['organism'],
                                    'antibiotic': row['antibiotic']
                                }
                                
                                # Get full prediction result
                                result = predict_susceptibility(input_data, model, encoders)
                                
                                # Store all results
                                if "Error" in result:
                                    results_df.at[index, "Prediction"] = "Error: " + result["Error"]
                                else:
                                    results_df.at[index, "Prediction"] = result["Final Output"]
                                    results_df.at[index, "Rule Guidance"] = result["Rule Guidance"]
                                    results_df.at[index, "Model Prediction"] = result["Model Prediction"]
                                    results_df.at[index, "Decision Reason"] = result["Decision Reason"]
                            except Exception as e:
                                results_df.at[index, "Prediction"] = f"Error: {str(e)}"
                            
                            # Update progress bar
                            progress_bar.progress((i + 1) / total_rows)
                        
                        # Combine original data with results
                        df = pd.concat([df, results_df], axis=1)
                    
                    st.success("Predictions complete!")
                    
                    # Display results with tabs for different views
                    tab1, tab2 = st.tabs(["Basic Results", "Detailed Results"])
                    
                    with tab1:
                        st.dataframe(df[['organism', 'antibiotic', 'was_positive', 'Prediction']])
                    
                    with tab2:
                        st.dataframe(df)
                    
                    # Download options
                    col1, col2 = st.columns(2)
                    
                    with col1:
                        # Download basic results
                        csv_basic = df[['organism', 'antibiotic', 'was_positive', 'Prediction']].to_csv(index=False)
                        st.download_button(
                            label="Download Basic Results",
                            data=csv_basic,
                            file_name="predictions_basic.csv",
                            mime='text/csv'
                        )
                    
                    with col2:
                        # Download detailed results
                        csv_detailed = df.to_csv(index=False)
                        st.download_button(
                            label="Download Detailed Results",
                            data=csv_detailed,
                            file_name="predictions_detailed.csv",
                            mime='text/csv'
                        )
        except pd.errors.EmptyDataError:
            st.error("The uploaded file is empty.")
        except pd.errors.ParserError:
            st.error("Error parsing the CSV file. Please ensure it's a valid CSV format.")
        except Exception as e:
            st.error(f"An unexpected error occurred: {str(e)}")

# About Page
elif page == "About":
    st.title("About this App")
    st.markdown("""
    - Developed by **Okunromade Joseph Oluwaseun**  
    - Uses **Machine Learning & Rule-based AI**  
    - Integrated with **DeepSeek AI** for advanced queries  
    - Matric No: 22/SCI01/172
    """)