Update app.py
Browse files
app.py
CHANGED
@@ -117,11 +117,30 @@ elif page == "Susceptibility Analysis":
|
|
117 |
# Data Upload Page
|
118 |
elif page == "Data Upload":
|
119 |
st.title("Batch Prediction: Upload CSV")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
|
121 |
|
122 |
if uploaded_file:
|
123 |
try:
|
|
|
124 |
df = pd.read_csv(uploaded_file)
|
|
|
|
|
125 |
st.write("Uploaded Data Preview:", df.head())
|
126 |
|
127 |
# Validate the required columns
|
@@ -131,26 +150,103 @@ elif page == "Data Upload":
|
|
131 |
if missing_columns:
|
132 |
st.error(f"The uploaded CSV is missing the following required columns: {', '.join(missing_columns)}")
|
133 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
# Process predictions
|
135 |
if st.button("Predict for Dataset"):
|
136 |
with st.spinner("Processing predictions..."):
|
137 |
-
|
138 |
-
|
139 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
st.success("Predictions complete!")
|
142 |
-
st.write("Prediction Results:", df)
|
143 |
|
144 |
-
#
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
except Exception as e:
|
153 |
-
st.error(f"
|
154 |
|
155 |
# About Page
|
156 |
elif page == "About":
|
|
|
117 |
# Data Upload Page
|
118 |
elif page == "Data Upload":
|
119 |
st.title("Batch Prediction: Upload CSV")
|
120 |
+
|
121 |
+
# Add sample data download option
|
122 |
+
st.markdown("### Sample Data")
|
123 |
+
sample_data = pd.DataFrame({
|
124 |
+
'organism': ['Escherichia coli', 'Staphylococcus aureus', 'Pseudomonas aeruginosa'],
|
125 |
+
'antibiotic': ['Amoxicillin', 'Vancomycin', 'Ciprofloxacin'],
|
126 |
+
'was_positive': [1, 0, 1]
|
127 |
+
})
|
128 |
+
csv_sample = sample_data.to_csv(index=False)
|
129 |
+
st.download_button(
|
130 |
+
label="Download Sample CSV Template",
|
131 |
+
data=csv_sample,
|
132 |
+
file_name="sample_template.csv",
|
133 |
+
mime='text/csv'
|
134 |
+
)
|
135 |
+
|
136 |
uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
|
137 |
|
138 |
if uploaded_file:
|
139 |
try:
|
140 |
+
# Read the CSV file
|
141 |
df = pd.read_csv(uploaded_file)
|
142 |
+
|
143 |
+
# Display preview with column check
|
144 |
st.write("Uploaded Data Preview:", df.head())
|
145 |
|
146 |
# Validate the required columns
|
|
|
150 |
if missing_columns:
|
151 |
st.error(f"The uploaded CSV is missing the following required columns: {', '.join(missing_columns)}")
|
152 |
else:
|
153 |
+
# Check data types and convert if necessary
|
154 |
+
if df['was_positive'].dtype != 'int64':
|
155 |
+
try:
|
156 |
+
df['was_positive'] = df['was_positive'].astype(int)
|
157 |
+
st.info("Converted 'was_positive' column to integer type.")
|
158 |
+
except ValueError:
|
159 |
+
st.error("The 'was_positive' column must contain only 0 or 1 values.")
|
160 |
+
st.stop()
|
161 |
+
|
162 |
+
# Validate organisms and antibiotics against encoders
|
163 |
+
invalid_organisms = [org for org in df['organism'].unique() if org not in encoders['organism']]
|
164 |
+
invalid_antibiotics = [ab for ab in df['antibiotic'].unique() if ab not in encoders['antibiotic']]
|
165 |
+
|
166 |
+
if invalid_organisms:
|
167 |
+
st.warning(f"Found {len(invalid_organisms)} organisms not in the training data: {', '.join(invalid_organisms[:5])}{'...' if len(invalid_organisms) > 5 else ''}")
|
168 |
+
|
169 |
+
if invalid_antibiotics:
|
170 |
+
st.warning(f"Found {len(invalid_antibiotics)} antibiotics not in the training data: {', '.join(invalid_antibiotics[:5])}{'...' if len(invalid_antibiotics) > 5 else ''}")
|
171 |
+
|
172 |
# Process predictions
|
173 |
if st.button("Predict for Dataset"):
|
174 |
with st.spinner("Processing predictions..."):
|
175 |
+
# Create a progress bar
|
176 |
+
progress_bar = st.progress(0)
|
177 |
+
total_rows = len(df)
|
178 |
+
|
179 |
+
# Initialize results columns
|
180 |
+
df["Prediction"] = ""
|
181 |
+
df["Rule Guidance"] = ""
|
182 |
+
df["Model Prediction"] = ""
|
183 |
+
df["Decision Reason"] = ""
|
184 |
+
|
185 |
+
# Process each row with error handling
|
186 |
+
for i, (index, row) in enumerate(df.iterrows()):
|
187 |
+
try:
|
188 |
+
# Skip rows with invalid data
|
189 |
+
if (row['organism'] not in encoders['organism'] or
|
190 |
+
row['antibiotic'] not in encoders['antibiotic']):
|
191 |
+
df.at[index, "Prediction"] = "Invalid data"
|
192 |
+
continue
|
193 |
+
|
194 |
+
# Get full prediction result
|
195 |
+
result = predict_susceptibility(row.to_dict(), model, encoders)
|
196 |
+
|
197 |
+
# Store all results
|
198 |
+
if "Error" in result:
|
199 |
+
df.at[index, "Prediction"] = "Error: " + result["Error"]
|
200 |
+
else:
|
201 |
+
df.at[index, "Prediction"] = result["Final Output"]
|
202 |
+
df.at[index, "Rule Guidance"] = result["Rule Guidance"]
|
203 |
+
df.at[index, "Model Prediction"] = result["Model Prediction"]
|
204 |
+
df.at[index, "Decision Reason"] = result["Decision Reason"]
|
205 |
+
except Exception as e:
|
206 |
+
df.at[index, "Prediction"] = f"Error: {str(e)}"
|
207 |
+
|
208 |
+
# Update progress bar
|
209 |
+
progress_bar.progress((i + 1) / total_rows)
|
210 |
|
211 |
st.success("Predictions complete!")
|
|
|
212 |
|
213 |
+
# Display results with tabs for different views
|
214 |
+
tab1, tab2 = st.tabs(["Basic Results", "Detailed Results"])
|
215 |
+
|
216 |
+
with tab1:
|
217 |
+
st.dataframe(df[['organism', 'antibiotic', 'was_positive', 'Prediction']])
|
218 |
+
|
219 |
+
with tab2:
|
220 |
+
st.dataframe(df)
|
221 |
+
|
222 |
+
# Download options
|
223 |
+
col1, col2 = st.columns(2)
|
224 |
+
|
225 |
+
with col1:
|
226 |
+
# Download basic results
|
227 |
+
csv_basic = df[['organism', 'antibiotic', 'was_positive', 'Prediction']].to_csv(index=False)
|
228 |
+
st.download_button(
|
229 |
+
label="Download Basic Results",
|
230 |
+
data=csv_basic,
|
231 |
+
file_name="predictions_basic.csv",
|
232 |
+
mime='text/csv'
|
233 |
+
)
|
234 |
+
|
235 |
+
with col2:
|
236 |
+
# Download detailed results
|
237 |
+
csv_detailed = df.to_csv(index=False)
|
238 |
+
st.download_button(
|
239 |
+
label="Download Detailed Results",
|
240 |
+
data=csv_detailed,
|
241 |
+
file_name="predictions_detailed.csv",
|
242 |
+
mime='text/csv'
|
243 |
+
)
|
244 |
+
except pd.errors.EmptyDataError:
|
245 |
+
st.error("The uploaded file is empty.")
|
246 |
+
except pd.errors.ParserError:
|
247 |
+
st.error("Error parsing the CSV file. Please ensure it's a valid CSV format.")
|
248 |
except Exception as e:
|
249 |
+
st.error(f"An unexpected error occurred: {str(e)}")
|
250 |
|
251 |
# About Page
|
252 |
elif page == "About":
|