EAV123 commited on
Commit
1af97d1
·
verified ·
1 Parent(s): 5ae966a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -13
app.py CHANGED
@@ -117,11 +117,30 @@ elif page == "Susceptibility Analysis":
117
  # Data Upload Page
118
  elif page == "Data Upload":
119
  st.title("Batch Prediction: Upload CSV")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
121
 
122
  if uploaded_file:
123
  try:
 
124
  df = pd.read_csv(uploaded_file)
 
 
125
  st.write("Uploaded Data Preview:", df.head())
126
 
127
  # Validate the required columns
@@ -131,26 +150,103 @@ elif page == "Data Upload":
131
  if missing_columns:
132
  st.error(f"The uploaded CSV is missing the following required columns: {', '.join(missing_columns)}")
133
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  # Process predictions
135
  if st.button("Predict for Dataset"):
136
  with st.spinner("Processing predictions..."):
137
- df["Prediction"] = df.apply(
138
- lambda row: predict_susceptibility(row.to_dict(), model, encoders)["Final Output"], axis=1
139
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  st.success("Predictions complete!")
142
- st.write("Prediction Results:", df)
143
 
144
- # Optionally, download the results as a CSV
145
- csv = df.to_csv(index=False)
146
- st.download_button(
147
- label="Download Results as CSV",
148
- data=csv,
149
- file_name="predictions.csv",
150
- mime='text/csv'
151
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  except Exception as e:
153
- st.error(f"Error reading the uploaded CSV file: {e}")
154
 
155
  # About Page
156
  elif page == "About":
 
117
  # Data Upload Page
118
  elif page == "Data Upload":
119
  st.title("Batch Prediction: Upload CSV")
120
+
121
+ # Add sample data download option
122
+ st.markdown("### Sample Data")
123
+ sample_data = pd.DataFrame({
124
+ 'organism': ['Escherichia coli', 'Staphylococcus aureus', 'Pseudomonas aeruginosa'],
125
+ 'antibiotic': ['Amoxicillin', 'Vancomycin', 'Ciprofloxacin'],
126
+ 'was_positive': [1, 0, 1]
127
+ })
128
+ csv_sample = sample_data.to_csv(index=False)
129
+ st.download_button(
130
+ label="Download Sample CSV Template",
131
+ data=csv_sample,
132
+ file_name="sample_template.csv",
133
+ mime='text/csv'
134
+ )
135
+
136
  uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
137
 
138
  if uploaded_file:
139
  try:
140
+ # Read the CSV file
141
  df = pd.read_csv(uploaded_file)
142
+
143
+ # Display preview with column check
144
  st.write("Uploaded Data Preview:", df.head())
145
 
146
  # Validate the required columns
 
150
  if missing_columns:
151
  st.error(f"The uploaded CSV is missing the following required columns: {', '.join(missing_columns)}")
152
  else:
153
+ # Check data types and convert if necessary
154
+ if df['was_positive'].dtype != 'int64':
155
+ try:
156
+ df['was_positive'] = df['was_positive'].astype(int)
157
+ st.info("Converted 'was_positive' column to integer type.")
158
+ except ValueError:
159
+ st.error("The 'was_positive' column must contain only 0 or 1 values.")
160
+ st.stop()
161
+
162
+ # Validate organisms and antibiotics against encoders
163
+ invalid_organisms = [org for org in df['organism'].unique() if org not in encoders['organism']]
164
+ invalid_antibiotics = [ab for ab in df['antibiotic'].unique() if ab not in encoders['antibiotic']]
165
+
166
+ if invalid_organisms:
167
+ st.warning(f"Found {len(invalid_organisms)} organisms not in the training data: {', '.join(invalid_organisms[:5])}{'...' if len(invalid_organisms) > 5 else ''}")
168
+
169
+ if invalid_antibiotics:
170
+ st.warning(f"Found {len(invalid_antibiotics)} antibiotics not in the training data: {', '.join(invalid_antibiotics[:5])}{'...' if len(invalid_antibiotics) > 5 else ''}")
171
+
172
  # Process predictions
173
  if st.button("Predict for Dataset"):
174
  with st.spinner("Processing predictions..."):
175
+ # Create a progress bar
176
+ progress_bar = st.progress(0)
177
+ total_rows = len(df)
178
+
179
+ # Initialize results columns
180
+ df["Prediction"] = ""
181
+ df["Rule Guidance"] = ""
182
+ df["Model Prediction"] = ""
183
+ df["Decision Reason"] = ""
184
+
185
+ # Process each row with error handling
186
+ for i, (index, row) in enumerate(df.iterrows()):
187
+ try:
188
+ # Skip rows with invalid data
189
+ if (row['organism'] not in encoders['organism'] or
190
+ row['antibiotic'] not in encoders['antibiotic']):
191
+ df.at[index, "Prediction"] = "Invalid data"
192
+ continue
193
+
194
+ # Get full prediction result
195
+ result = predict_susceptibility(row.to_dict(), model, encoders)
196
+
197
+ # Store all results
198
+ if "Error" in result:
199
+ df.at[index, "Prediction"] = "Error: " + result["Error"]
200
+ else:
201
+ df.at[index, "Prediction"] = result["Final Output"]
202
+ df.at[index, "Rule Guidance"] = result["Rule Guidance"]
203
+ df.at[index, "Model Prediction"] = result["Model Prediction"]
204
+ df.at[index, "Decision Reason"] = result["Decision Reason"]
205
+ except Exception as e:
206
+ df.at[index, "Prediction"] = f"Error: {str(e)}"
207
+
208
+ # Update progress bar
209
+ progress_bar.progress((i + 1) / total_rows)
210
 
211
  st.success("Predictions complete!")
 
212
 
213
+ # Display results with tabs for different views
214
+ tab1, tab2 = st.tabs(["Basic Results", "Detailed Results"])
215
+
216
+ with tab1:
217
+ st.dataframe(df[['organism', 'antibiotic', 'was_positive', 'Prediction']])
218
+
219
+ with tab2:
220
+ st.dataframe(df)
221
+
222
+ # Download options
223
+ col1, col2 = st.columns(2)
224
+
225
+ with col1:
226
+ # Download basic results
227
+ csv_basic = df[['organism', 'antibiotic', 'was_positive', 'Prediction']].to_csv(index=False)
228
+ st.download_button(
229
+ label="Download Basic Results",
230
+ data=csv_basic,
231
+ file_name="predictions_basic.csv",
232
+ mime='text/csv'
233
+ )
234
+
235
+ with col2:
236
+ # Download detailed results
237
+ csv_detailed = df.to_csv(index=False)
238
+ st.download_button(
239
+ label="Download Detailed Results",
240
+ data=csv_detailed,
241
+ file_name="predictions_detailed.csv",
242
+ mime='text/csv'
243
+ )
244
+ except pd.errors.EmptyDataError:
245
+ st.error("The uploaded file is empty.")
246
+ except pd.errors.ParserError:
247
+ st.error("Error parsing the CSV file. Please ensure it's a valid CSV format.")
248
  except Exception as e:
249
+ st.error(f"An unexpected error occurred: {str(e)}")
250
 
251
  # About Page
252
  elif page == "About":