File size: 12,926 Bytes
a7ced43 1af97d1 a7ced43 5ae966a 1af97d1 5ae966a 1af97d1 5ae966a 1af97d1 5ae966a 1af97d1 79ec034 1af97d1 79ec034 1af97d1 d62c290 79ec034 d62c290 79ec034 d62c290 79ec034 1af97d1 79ec034 1af97d1 79ec034 1af97d1 79ec034 1af97d1 79ec034 1af97d1 79ec034 5ae966a 1af97d1 5ae966a 1af97d1 a7ced43 0f1f788 a7ced43 0f1f788 a7ced43 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 |
import streamlit as st
import pandas as pd
import re
from config import MODEL_PATH, ENCODER_DIR, OPENAI_API_KEY, OPENAI_BASE_URL
from utils import load_model, load_label_encoders
from prediction import predict_susceptibility
from ai_assistant import initialize_openai_client, get_ai_response
# Load assets
model = load_model(MODEL_PATH)
encoders = load_label_encoders(ENCODER_DIR)
# Initialize OpenAI client
client = initialize_openai_client(OPENAI_API_KEY, OPENAI_BASE_URL)
# Streamlit UI
st.set_page_config(page_title="Microbial Susceptibility Analyzer", layout="wide")
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Home", "Susceptibility Analysis", "Data Upload", "About"])
# Home Page
if page == "Home":
st.title("Microbial Susceptibility Analyzer")
st.image("bacteria.jpeg", use_container_width=True)
st.markdown("""
**Welcome to the Microbial Susceptibility Analyzer!**
This app helps analyze **antibiotic resistance** using **machine learning and rule-based decisions**.
- Predict microbial susceptibility.
- Ask an AI assistant for expert advice.
- Upload datasets for batch predictions.
""")
# Susceptibility Analysis Page
elif page == "Susceptibility Analysis":
st.title("Susceptibility Prediction")
# Initialize session state for messages if not exists
if 'messages' not in st.session_state:
st.session_state.messages = []
# Create two columns for layout
col1, col2 = st.columns([1, 1])
with col1:
with st.form("prediction_form"):
organism = st.selectbox('Organism', options=encoders['organism'].keys())
antibiotic = st.selectbox('Antibiotic', options=encoders['antibiotic'].keys())
was_positive = st.selectbox('Was Positive', options=[1, 0])
submit_button = st.form_submit_button("Predict")
if submit_button:
# Store inputs in session state
st.session_state['current_organism'] = organism
st.session_state['current_antibiotic'] = antibiotic
st.session_state['current_was_positive'] = was_positive
result = predict_susceptibility({
'was_positive': was_positive,
'organism': organism,
'antibiotic': antibiotic
}, model, encoders)
st.subheader("Prediction Results")
if "Error" in result:
st.error(result["Error"])
else:
st.write(f"**Final Decision:** {result['Final Output']}")
st.write(f"**Rule-Based Guidance:** {result['Rule Guidance']}")
st.write(f"**Model Prediction:** {result['Model Prediction']}")
st.write(f"**Decision Explanation:** {result['Decision Reason']}")
# Clear previous messages when new prediction is made
st.session_state.messages = []
with col2:
st.subheader("DeepSeek AI Assistant")
# Only show assistant if a prediction has been made
if 'current_organism' in st.session_state:
st.markdown(f"Ask about **{st.session_state.get('current_organism')}** and **{st.session_state.get('current_antibiotic')}**:")
# Example prompts as buttons
example_prompts = [
"Explain why this combination might show resistance",
"Suggest alternative antibiotics for this organism",
"What resistance mechanisms are common here?",
"How should this result influence treatment decisions?"
]
# Create a unique key for each button
for i, prompt in enumerate(example_prompts):
if st.button(prompt, key=f"prompt_{i}"):
# Create context-enhanced prompt
enhanced_prompt = f"For organism {st.session_state.get('current_organism')} " \
f"and antibiotic {st.session_state.get('current_antibiotic')}: {prompt}"
response = get_ai_response(client, enhanced_prompt)
# Display the user prompt and AI response
st.chat_message("user").markdown(prompt)
st.chat_message("assistant").markdown(response)
# Chat input
user_prompt = st.chat_input("Ask about this result...")
if user_prompt:
# Create context-enhanced prompt
enhanced_prompt = f"For organism {st.session_state.get('current_organism')} " \
f"and antibiotic {st.session_state.get('current_antibiotic')}: {user_prompt}"
response = get_ai_response(client, enhanced_prompt)
# Display the user prompt and AI response
st.chat_message("user").markdown(user_prompt)
st.chat_message("assistant").markdown(response)
else:
st.info("Make a prediction first to get specific AI assistance")
# Data Upload Page
elif page == "Data Upload":
st.title("Batch Prediction: Upload CSV")
# Add sample data download option
st.markdown("### Sample Data")
sample_data = pd.DataFrame({
'organism': ['Escherichia coli', 'Staphylococcus aureus', 'Pseudomonas aeruginosa'],
'antibiotic': ['Amoxicillin', 'Vancomycin', 'Ciprofloxacin'],
'was_positive': [1, 0, 1]
})
csv_sample = sample_data.to_csv(index=False)
st.download_button(
label="Download Sample CSV Template",
data=csv_sample,
file_name="sample_template.csv",
mime='text/csv'
)
uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
if uploaded_file:
try:
# Read the CSV file
df = pd.read_csv(uploaded_file)
# Display preview with column check
st.write("Uploaded Data Preview:", df.head())
# Validate the required columns
required_columns = ['organism', 'antibiotic', 'was_positive']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
st.error(f"The uploaded CSV is missing the following required columns: {', '.join(missing_columns)}")
else:
# Check data types and convert if necessary
if df['was_positive'].dtype != 'int64':
try:
df['was_positive'] = df['was_positive'].astype(int)
st.info("Converted 'was_positive' column to integer type.")
except ValueError:
st.error("The 'was_positive' column must contain only 0 or 1 values.")
st.stop()
# Validate organisms and antibiotics against encoders
invalid_organisms = [org for org in df['organism'].unique() if org not in encoders['organism']]
invalid_antibiotics = [ab for ab in df['antibiotic'].unique() if ab not in encoders['antibiotic']]
if invalid_organisms:
st.warning(f"Found {len(invalid_organisms)} organisms not in the training data: {', '.join(invalid_organisms[:5])}{'...' if len(invalid_organisms) > 5 else ''}")
if invalid_antibiotics:
st.warning(f"Found {len(invalid_antibiotics)} antibiotics not in the training data: {', '.join(invalid_antibiotics[:5])}{'...' if len(invalid_antibiotics) > 5 else ''}")
# Process predictions
if st.button("Predict for Dataset"):
with st.spinner("Processing predictions..."):
# Create a progress bar
progress_bar = st.progress(0)
total_rows = len(df)
# Create a new results DataFrame with the same index as the original
results_df = pd.DataFrame(index=df.index)
results_df["Prediction"] = ""
results_df["Rule Guidance"] = ""
results_df["Model Prediction"] = ""
results_df["Decision Reason"] = ""
# Process each row with error handling
for i, (index, row) in enumerate(df.iterrows()):
try:
# Skip rows with invalid data
if (row['organism'] not in encoders['organism'] or
row['antibiotic'] not in encoders['antibiotic']):
results_df.at[index, "Prediction"] = "Invalid data"
continue
# Extract only the required columns for prediction in specific order
input_data = {
'was_positive': row['was_positive'],
'organism': row['organism'],
'antibiotic': row['antibiotic']
}
# Get full prediction result
result = predict_susceptibility(input_data, model, encoders)
# Store all results
if "Error" in result:
results_df.at[index, "Prediction"] = "Error: " + result["Error"]
else:
results_df.at[index, "Prediction"] = result["Final Output"]
results_df.at[index, "Rule Guidance"] = result["Rule Guidance"]
results_df.at[index, "Model Prediction"] = result["Model Prediction"]
results_df.at[index, "Decision Reason"] = result["Decision Reason"]
except Exception as e:
results_df.at[index, "Prediction"] = f"Error: {str(e)}"
# Update progress bar
progress_bar.progress((i + 1) / total_rows)
# Combine original data with results
df = pd.concat([df, results_df], axis=1)
st.success("Predictions complete!")
# Display results with tabs for different views
tab1, tab2 = st.tabs(["Basic Results", "Detailed Results"])
with tab1:
st.dataframe(df[['organism', 'antibiotic', 'was_positive', 'Prediction']])
with tab2:
st.dataframe(df)
# Download options
col1, col2 = st.columns(2)
with col1:
# Download basic results
csv_basic = df[['organism', 'antibiotic', 'was_positive', 'Prediction']].to_csv(index=False)
st.download_button(
label="Download Basic Results",
data=csv_basic,
file_name="predictions_basic.csv",
mime='text/csv'
)
with col2:
# Download detailed results
csv_detailed = df.to_csv(index=False)
st.download_button(
label="Download Detailed Results",
data=csv_detailed,
file_name="predictions_detailed.csv",
mime='text/csv'
)
except pd.errors.EmptyDataError:
st.error("The uploaded file is empty.")
except pd.errors.ParserError:
st.error("Error parsing the CSV file. Please ensure it's a valid CSV format.")
except Exception as e:
st.error(f"An unexpected error occurred: {str(e)}")
# About Page
elif page == "About":
st.title("About this App")
st.markdown("""
- Developed by **Okunromade Joseph Oluwaseun**
- Uses **Machine Learning & Rule-based AI**
- Integrated with **DeepSeek AI** for advanced queries
- Matric No: 22/SCI01/172
""") |