Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import pandas as pd
|
|
3 |
import numpy as np
|
4 |
import plotly.express as px
|
5 |
import plotly.graph_objects as go
|
|
|
6 |
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
7 |
from sklearn.model_selection import train_test_split
|
8 |
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
@@ -192,39 +193,6 @@ with st.sidebar:
|
|
192 |
# --------------------------
|
193 |
# Main App Pages
|
194 |
# --------------------------
|
195 |
-
if app_mode == "Data Upload":
|
196 |
-
st.title("📤 Data Upload & Profiling")
|
197 |
-
|
198 |
-
uploaded_file = st.file_uploader("Upload your dataset (CSV/XLSX)", type=["csv", "xlsx"])
|
199 |
-
|
200 |
-
if uploaded_file:
|
201 |
-
try:
|
202 |
-
if uploaded_file.name.endswith('.csv'):
|
203 |
-
df = pd.read_csv(uploaded_file)
|
204 |
-
else:
|
205 |
-
df = pd.read_excel(uploaded_file)
|
206 |
-
|
207 |
-
st.session_state.raw_data = df
|
208 |
-
|
209 |
-
col1, col2, col3 = st.columns(3)
|
210 |
-
with col1:
|
211 |
-
st.metric("Rows", df.shape[0])
|
212 |
-
with col2:
|
213 |
-
st.metric("Columns", df.shape[1])
|
214 |
-
with col3:
|
215 |
-
st.metric("Missing Values", df.isna().sum().sum())
|
216 |
-
|
217 |
-
with st.expander("Data Preview", expanded=True):
|
218 |
-
st.dataframe(df.head(10), use_container_width=True)
|
219 |
-
|
220 |
-
if st.button("Generate Full Profile Report"):
|
221 |
-
with st.spinner("Generating comprehensive analysis..."):
|
222 |
-
pr = ProfileReport(df, explorative=True)
|
223 |
-
st_profile_report(pr)
|
224 |
-
|
225 |
-
except Exception as e:
|
226 |
-
st.error(f"Error loading file: {str(e)}")
|
227 |
-
|
228 |
elif app_mode == "Data Cleaning":
|
229 |
st.title("🧹 Smart Data Cleaning")
|
230 |
|
@@ -232,9 +200,43 @@ elif app_mode == "Data Cleaning":
|
|
232 |
st.warning("Please upload data first")
|
233 |
st.stop()
|
234 |
|
235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
|
|
237 |
# Missing Value Handling
|
|
|
238 |
with st.expander("🔍 Missing Values Treatment", expanded=True):
|
239 |
missing_cols = df.columns[df.isna().any()].tolist()
|
240 |
if missing_cols:
|
@@ -242,22 +244,43 @@ elif app_mode == "Data Cleaning":
|
|
242 |
method = st.selectbox("Imputation Method", [
|
243 |
"Drop Missing",
|
244 |
"Mean/Median",
|
245 |
-
"Custom Value"
|
|
|
|
|
246 |
])
|
247 |
|
|
|
|
|
|
|
248 |
if st.button("Apply Treatment"):
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
else:
|
258 |
-
st.success("No missing values found!")
|
259 |
|
|
|
260 |
# Data Type Conversion
|
|
|
261 |
with st.expander("🔄 Data Type Conversion"):
|
262 |
col_to_convert = st.selectbox("Select column", df.columns)
|
263 |
new_type = st.selectbox("New data type", [
|
@@ -265,64 +288,119 @@ elif app_mode == "Data Cleaning":
|
|
265 |
"Boolean", "Datetime"
|
266 |
])
|
267 |
|
|
|
|
|
|
|
268 |
if st.button("Convert"):
|
|
|
269 |
try:
|
270 |
if new_type == "String":
|
271 |
df[col_to_convert] = df[col_to_convert].astype(str)
|
272 |
elif new_type == "Integer":
|
273 |
-
df[col_to_convert]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
st.session_state.cleaned_data = df
|
275 |
st.success("Conversion successful!")
|
276 |
except Exception as e:
|
277 |
st.error(f"Error: {str(e)}")
|
278 |
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
# Drop Columns
|
284 |
with st.expander("🗑️ Drop Columns"):
|
285 |
columns_to_drop = st.multiselect("Select columns to drop", df.columns)
|
286 |
-
if
|
287 |
-
|
288 |
-
st.
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
|
|
|
|
|
|
|
|
|
|
294 |
if data_to_encode:
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
|
|
|
|
302 |
|
303 |
-
|
304 |
-
with st.expander("✨ Encoded Data Preview"):
|
305 |
-
st.dataframe(st.session_state.cleaned_data.head(), use_container_width=True)
|
306 |
# StandardScaler
|
|
|
307 |
with st.expander("📏 StandardScaler"):
|
308 |
-
scale_cols = st.multiselect("Select columns to scale", df.columns)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
|
310 |
-
if st.button("Apply
|
|
|
311 |
try:
|
312 |
-
|
313 |
-
df[scale_cols] = scaler.fit_transform(df[scale_cols])
|
314 |
st.session_state.cleaned_data = df
|
315 |
-
st.success("
|
316 |
-
|
317 |
-
# Optionally, display the scaled data
|
318 |
-
with st.expander("✨ Scaled Data Preview"):
|
319 |
-
st.dataframe(st.session_state.cleaned_data.head(), use_container_width=True)
|
320 |
except Exception as e:
|
321 |
-
st.error(f"Error: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
|
323 |
-
|
324 |
|
325 |
-
|
|
|
326 |
st.title("🔍 Exploratory Data Analysis")
|
327 |
|
328 |
if st.session_state.cleaned_data is None:
|
@@ -331,96 +409,562 @@ elif app_mode == "EDA":
|
|
331 |
|
332 |
df = st.session_state.cleaned_data
|
333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
# Visualization Selector
|
|
|
|
|
335 |
col1, col2 = st.columns([1, 3])
|
336 |
with col1:
|
337 |
-
st.subheader("Visualization Setup")
|
338 |
plot_type = st.selectbox("Choose plot type", [
|
339 |
"Scatter Plot", "Histogram",
|
340 |
-
"Box Plot", "Correlation Matrix"
|
|
|
|
|
|
|
341 |
])
|
342 |
|
343 |
x_axis = st.selectbox("X-Axis", df.columns)
|
344 |
-
y_axis = st.selectbox("Y-Axis", df.columns) if plot_type in ["Scatter Plot", "Box Plot"] else None
|
|
|
345 |
color_by = st.selectbox("Color By", [None] + df.columns.tolist())
|
|
|
346 |
|
347 |
with col2:
|
348 |
-
st.subheader("Visualization")
|
349 |
try:
|
350 |
if plot_type == "Scatter Plot":
|
351 |
-
fig = px.scatter(df, x=x_axis, y=y_axis, color=color_by)
|
352 |
elif plot_type == "Histogram":
|
353 |
-
fig = px.histogram(df, x=x_axis, color=color_by)
|
354 |
elif plot_type == "Box Plot":
|
355 |
-
fig = px.box(df, x=x_axis, y=y_axis, color=color_by)
|
356 |
elif plot_type == "Correlation Matrix":
|
357 |
corr = df.select_dtypes(include=np.number).corr()
|
358 |
-
fig = px.imshow(corr, text_auto=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
|
360 |
st.plotly_chart(fig, use_container_width=True)
|
361 |
except Exception as e:
|
362 |
st.error(f"Visualization error: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
elif app_mode == "Model Training":
|
365 |
st.title("🤖 Intelligent Model Training")
|
366 |
-
|
367 |
-
if st.session_state.cleaned_data is None:
|
368 |
st.warning("Please clean your data first")
|
369 |
st.stop()
|
370 |
-
|
371 |
df = st.session_state.cleaned_data
|
372 |
-
|
373 |
# Model Setup
|
374 |
-
col1, col2 = st.columns(
|
375 |
with col1:
|
376 |
target = st.selectbox("Select Target Variable", df.columns)
|
377 |
-
problem_type = st.selectbox("Problem Type", ["Classification", "Regression"])
|
378 |
with col2:
|
379 |
-
|
|
|
|
|
380 |
test_size = st.slider("Test Size", 0.1, 0.5, 0.2)
|
381 |
-
|
382 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
try:
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
else:
|
410 |
-
|
411 |
-
st.metric("
|
412 |
-
|
413 |
# Feature Importance
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
|
425 |
elif app_mode == "Predictions":
|
426 |
st.title("🔮 Predictive Analytics")
|
|
|
3 |
import numpy as np
|
4 |
import plotly.express as px
|
5 |
import plotly.graph_objects as go
|
6 |
+
from scipy.stats import pearsonr, spearmanr
|
7 |
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
8 |
from sklearn.model_selection import train_test_split
|
9 |
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
|
|
193 |
# --------------------------
|
194 |
# Main App Pages
|
195 |
# --------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
elif app_mode == "Data Cleaning":
|
197 |
st.title("🧹 Smart Data Cleaning")
|
198 |
|
|
|
200 |
st.warning("Please upload data first")
|
201 |
st.stop()
|
202 |
|
203 |
+
# Initialize session state for undo functionality
|
204 |
+
if 'data_versions' not in st.session_state:
|
205 |
+
st.session_state.data_versions = [st.session_state.raw_data.copy()]
|
206 |
+
|
207 |
+
df = st.session_state.data_versions[-1].copy()
|
208 |
+
|
209 |
+
# --------------------------
|
210 |
+
# Data Health Dashboard
|
211 |
+
# --------------------------
|
212 |
+
with st.expander("📊 Data Health Dashboard", expanded=True):
|
213 |
+
col1, col2, col3 = st.columns(3)
|
214 |
+
with col1:
|
215 |
+
st.metric("Total Columns", len(df.columns))
|
216 |
+
with col2:
|
217 |
+
st.metric("Total Rows", len(df))
|
218 |
+
with col3:
|
219 |
+
st.metric("Missing Values", df.isna().sum().sum())
|
220 |
+
|
221 |
+
# Generate quick profile report
|
222 |
+
if st.button("Generate Data Health Report"):
|
223 |
+
with st.spinner("Analyzing data..."):
|
224 |
+
profile = ProfileReport(df, minimal=True)
|
225 |
+
st_profile_report(profile)
|
226 |
+
|
227 |
+
# --------------------------
|
228 |
+
# Undo Functionality
|
229 |
+
# --------------------------
|
230 |
+
if len(st.session_state.data_versions) > 1:
|
231 |
+
if st.button("⏮️ Undo Last Action"):
|
232 |
+
st.session_state.data_versions.pop()
|
233 |
+
df = st.session_state.data_versions[-1].copy()
|
234 |
+
st.session_state.cleaned_data = df
|
235 |
+
st.success("Last action undone!")
|
236 |
|
237 |
+
# --------------------------
|
238 |
# Missing Value Handling
|
239 |
+
# --------------------------
|
240 |
with st.expander("🔍 Missing Values Treatment", expanded=True):
|
241 |
missing_cols = df.columns[df.isna().any()].tolist()
|
242 |
if missing_cols:
|
|
|
244 |
method = st.selectbox("Imputation Method", [
|
245 |
"Drop Missing",
|
246 |
"Mean/Median",
|
247 |
+
"Custom Value",
|
248 |
+
"Forward Fill",
|
249 |
+
"Backward Fill"
|
250 |
])
|
251 |
|
252 |
+
if method == "Custom Value":
|
253 |
+
custom_val = st.text_input("Enter custom value")
|
254 |
+
|
255 |
if st.button("Apply Treatment"):
|
256 |
+
st.session_state.data_versions.append(df.copy())
|
257 |
+
try:
|
258 |
+
if method == "Drop Missing":
|
259 |
+
df = df.dropna(subset=cols)
|
260 |
+
elif method == "Mean/Median":
|
261 |
+
for col in cols:
|
262 |
+
if pd.api.types.is_numeric_dtype(df[col]):
|
263 |
+
df[col] = df[col].fillna(df[col].median())
|
264 |
+
else:
|
265 |
+
df[col] = df[col].fillna(df[col].mode()[0])
|
266 |
+
elif method == "Custom Value" and custom_val:
|
267 |
+
for col in cols:
|
268 |
+
df[col] = df[col].fillna(custom_val)
|
269 |
+
elif method == "Forward Fill":
|
270 |
+
df[cols] = df[cols].ffill()
|
271 |
+
elif method == "Backward Fill":
|
272 |
+
df[cols] = df[cols].bfill()
|
273 |
+
|
274 |
+
st.session_state.cleaned_data = df
|
275 |
+
st.success("Missing values handled successfully!")
|
276 |
+
except Exception as e:
|
277 |
+
st.error(f"Error: {str(e)}")
|
278 |
else:
|
279 |
+
st.success("✨ No missing values found!")
|
280 |
|
281 |
+
# --------------------------
|
282 |
# Data Type Conversion
|
283 |
+
# --------------------------
|
284 |
with st.expander("🔄 Data Type Conversion"):
|
285 |
col_to_convert = st.selectbox("Select column", df.columns)
|
286 |
new_type = st.selectbox("New data type", [
|
|
|
288 |
"Boolean", "Datetime"
|
289 |
])
|
290 |
|
291 |
+
if new_type == "Datetime":
|
292 |
+
date_format = st.text_input("Date format (e.g. %Y-%m-%d)", "%Y-%m-%d")
|
293 |
+
|
294 |
if st.button("Convert"):
|
295 |
+
st.session_state.data_versions.append(df.copy())
|
296 |
try:
|
297 |
if new_type == "String":
|
298 |
df[col_to_convert] = df[col_to_convert].astype(str)
|
299 |
elif new_type == "Integer":
|
300 |
+
if df[col_to_convert].dtype == 'object':
|
301 |
+
st.error("Cannot convert text column to integer!")
|
302 |
+
else:
|
303 |
+
df[col_to_convert] = pd.to_numeric(df[col_to_convert], errors='coerce').astype('Int64')
|
304 |
+
elif new_type == "Float":
|
305 |
+
if df[col_to_convert].dtype == 'object':
|
306 |
+
st.error("Cannot convert text column to float!")
|
307 |
+
else:
|
308 |
+
df[col_to_convert] = pd.to_numeric(df[col_to_convert], errors='coerce')
|
309 |
+
elif new_type == "Boolean":
|
310 |
+
df[col_to_convert] = df[col_to_convert].astype(bool)
|
311 |
+
elif new_type == "Datetime":
|
312 |
+
df[col_to_convert] = pd.to_datetime(df[col_to_convert], format=date_format, errors='coerce')
|
313 |
+
|
314 |
st.session_state.cleaned_data = df
|
315 |
st.success("Conversion successful!")
|
316 |
except Exception as e:
|
317 |
st.error(f"Error: {str(e)}")
|
318 |
|
319 |
+
# --------------------------
|
320 |
+
# Drop Columns
|
321 |
+
# --------------------------
|
|
|
|
|
322 |
with st.expander("🗑️ Drop Columns"):
|
323 |
columns_to_drop = st.multiselect("Select columns to drop", df.columns)
|
324 |
+
if columns_to_drop:
|
325 |
+
st.warning(f"Will drop: {', '.join(columns_to_drop)}")
|
326 |
+
if st.button("Confirm Drop"):
|
327 |
+
st.session_state.data_versions.append(df.copy())
|
328 |
+
df = df.drop(columns=columns_to_drop)
|
329 |
+
st.session_state.cleaned_data = df
|
330 |
+
st.success("Selected columns dropped successfully!")
|
331 |
+
|
332 |
+
# --------------------------
|
333 |
+
# Label Encoding
|
334 |
+
# --------------------------
|
335 |
+
with st.expander("🔢 Label Encoding"):
|
336 |
+
data_to_encode = st.multiselect("Select categorical columns to encode", df.select_dtypes(include='object').columns)
|
337 |
if data_to_encode:
|
338 |
+
if st.button("Apply Label Encoding"):
|
339 |
+
st.session_state.data_versions.append(df.copy())
|
340 |
+
label_encoders = {}
|
341 |
+
for col in data_to_encode:
|
342 |
+
le = LabelEncoder()
|
343 |
+
df[col] = le.fit_transform(df[col].astype(str))
|
344 |
+
label_encoders[col] = le
|
345 |
+
st.session_state.cleaned_data = df
|
346 |
+
st.success("Label encoding applied successfully!")
|
347 |
|
348 |
+
# --------------------------
|
|
|
|
|
349 |
# StandardScaler
|
350 |
+
# --------------------------
|
351 |
with st.expander("📏 StandardScaler"):
|
352 |
+
scale_cols = st.multiselect("Select numeric columns to scale", df.select_dtypes(include=np.number).columns)
|
353 |
+
if scale_cols:
|
354 |
+
if st.button("Apply StandardScaler"):
|
355 |
+
st.session_state.data_versions.append(df.copy())
|
356 |
+
try:
|
357 |
+
scaler = StandardScaler()
|
358 |
+
df[scale_cols] = scaler.fit_transform(df[scale_cols])
|
359 |
+
st.session_state.cleaned_data = df
|
360 |
+
st.success("Standard scaling applied successfully!")
|
361 |
+
except Exception as e:
|
362 |
+
st.error(f"Error: {str(e)}")
|
363 |
+
|
364 |
+
# --------------------------
|
365 |
+
# Pattern-Based Cleaning
|
366 |
+
# --------------------------
|
367 |
+
with st.expander("🕵️ Pattern-Based Cleaning"):
|
368 |
+
selected_col = st.selectbox("Select text column", df.select_dtypes(include='object').columns)
|
369 |
+
pattern = st.text_input("Regex pattern (e.g. \d+ for numbers)")
|
370 |
+
replacement = st.text_input("Replacement value")
|
371 |
|
372 |
+
if st.button("Apply Pattern Replacement"):
|
373 |
+
st.session_state.data_versions.append(df.copy())
|
374 |
try:
|
375 |
+
df[selected_col] = df[selected_col].str.replace(pattern, replacement, regex=True)
|
|
|
376 |
st.session_state.cleaned_data = df
|
377 |
+
st.success("Pattern replacement applied successfully!")
|
|
|
|
|
|
|
|
|
378 |
except Exception as e:
|
379 |
+
st.error(f"Error: {str(e)}")
|
380 |
+
|
381 |
+
# --------------------------
|
382 |
+
# Bulk Operations
|
383 |
+
# --------------------------
|
384 |
+
with st.expander("🚀 Bulk Actions"):
|
385 |
+
if st.button("Auto-Clean Common Issues"):
|
386 |
+
st.session_state.data_versions.append(df.copy())
|
387 |
+
df = df.dropna(axis=1, how='all') # Remove empty cols
|
388 |
+
df = df.convert_dtypes() # Better type inference
|
389 |
+
text_cols = df.select_dtypes(include='object').columns
|
390 |
+
df[text_cols] = df[text_cols].apply(lambda x: x.str.strip())
|
391 |
+
st.session_state.cleaned_data = df
|
392 |
+
st.success("Bulk cleaning completed!")
|
393 |
+
|
394 |
+
# --------------------------
|
395 |
+
# Cleaned Data Preview
|
396 |
+
# --------------------------
|
397 |
+
if st.session_state.cleaned_data is not None:
|
398 |
+
with st.expander("✨ Cleaned Data Preview", expanded=True):
|
399 |
+
st.dataframe(st.session_state.cleaned_data.head(), use_container_width=True)
|
400 |
|
|
|
401 |
|
402 |
+
# Main function for EDA
|
403 |
+
def eda():
|
404 |
st.title("🔍 Exploratory Data Analysis")
|
405 |
|
406 |
if st.session_state.cleaned_data is None:
|
|
|
409 |
|
410 |
df = st.session_state.cleaned_data
|
411 |
|
412 |
+
# --------------------------
|
413 |
+
# Data Overview
|
414 |
+
# --------------------------
|
415 |
+
with st.expander("📊 Data Overview", expanded=True):
|
416 |
+
col1, col2, col3 = st.columns(3)
|
417 |
+
with col1:
|
418 |
+
st.metric("Total Rows", df.shape[0])
|
419 |
+
with col2:
|
420 |
+
st.metric("Total Columns", df.shape[1])
|
421 |
+
with col3:
|
422 |
+
st.metric("Missing Values", df.isna().sum().sum())
|
423 |
+
|
424 |
+
if st.checkbox("Show Data Preview"):
|
425 |
+
st.dataframe(df.head(), use_container_width=True)
|
426 |
+
|
427 |
+
# --------------------------
|
428 |
# Visualization Selector
|
429 |
+
# --------------------------
|
430 |
+
st.subheader("📈 Visualization Setup")
|
431 |
col1, col2 = st.columns([1, 3])
|
432 |
with col1:
|
|
|
433 |
plot_type = st.selectbox("Choose plot type", [
|
434 |
"Scatter Plot", "Histogram",
|
435 |
+
"Box Plot", "Correlation Matrix",
|
436 |
+
"Line Chart", "Heatmap", "Violin Plot",
|
437 |
+
"3D Scatter Plot", "Parallel Coordinates",
|
438 |
+
"Pair Plot", "Density Contour"
|
439 |
])
|
440 |
|
441 |
x_axis = st.selectbox("X-Axis", df.columns)
|
442 |
+
y_axis = st.selectbox("Y-Axis", df.columns) if plot_type in ["Scatter Plot", "Box Plot", "Line Chart", "Violin Plot", "3D Scatter Plot", "Density Contour"] else None
|
443 |
+
z_axis = st.selectbox("Z-Axis", df.columns) if plot_type == "3D Scatter Plot" else None
|
444 |
color_by = st.selectbox("Color By", [None] + df.columns.tolist())
|
445 |
+
facet_col = st.selectbox("Facet By", [None] + df.columns.tolist())
|
446 |
|
447 |
with col2:
|
448 |
+
st.subheader("📊 Visualization")
|
449 |
try:
|
450 |
if plot_type == "Scatter Plot":
|
451 |
+
fig = px.scatter(df, x=x_axis, y=y_axis, color=color_by, facet_col=facet_col)
|
452 |
elif plot_type == "Histogram":
|
453 |
+
fig = px.histogram(df, x=x_axis, color=color_by, facet_col=facet_col)
|
454 |
elif plot_type == "Box Plot":
|
455 |
+
fig = px.box(df, x=x_axis, y=y_axis, color=color_by, facet_col=facet_col)
|
456 |
elif plot_type == "Correlation Matrix":
|
457 |
corr = df.select_dtypes(include=np.number).corr()
|
458 |
+
fig = px.imshow(corr, text_auto=True, color_continuous_scale='Viridis')
|
459 |
+
elif plot_type == "Line Chart":
|
460 |
+
fig = px.line(df, x=x_axis, y=y_axis, color=color_by, facet_col=facet_col)
|
461 |
+
elif plot_type == "Heatmap":
|
462 |
+
fig = go.Figure(data=go.Heatmap(
|
463 |
+
z=df.corr().values,
|
464 |
+
x=df.columns,
|
465 |
+
y=df.columns,
|
466 |
+
colorscale='Viridis'))
|
467 |
+
elif plot_type == "Violin Plot":
|
468 |
+
fig = px.violin(df, x=x_axis, y=y_axis, color=color_by, facet_col=facet_col)
|
469 |
+
elif plot_type == "3D Scatter Plot":
|
470 |
+
fig = px.scatter_3d(df, x=x_axis, y=y_axis, z=z_axis, color=color_by)
|
471 |
+
elif plot_type == "Parallel Coordinates":
|
472 |
+
fig = px.parallel_coordinates(df, color=color_by)
|
473 |
+
elif plot_type == "Pair Plot":
|
474 |
+
fig = px.scatter_matrix(df, color=color_by)
|
475 |
+
elif plot_type == "Density Contour":
|
476 |
+
fig = px.density_contour(df, x=x_axis, y=y_axis, color=color_by)
|
477 |
|
478 |
st.plotly_chart(fig, use_container_width=True)
|
479 |
except Exception as e:
|
480 |
st.error(f"Visualization error: {str(e)}")
|
481 |
+
|
482 |
+
# --------------------------
|
483 |
+
# Relationship Diagnostics
|
484 |
+
# --------------------------
|
485 |
+
st.subheader("🔗 Relationship Diagnostics")
|
486 |
+
selected_columns = st.multiselect("Select columns to analyze relationships", df.columns)
|
487 |
+
if selected_columns:
|
488 |
+
if len(selected_columns) == 2:
|
489 |
+
col1, col2 = st.columns(2)
|
490 |
+
with col1:
|
491 |
+
st.write(f"**Scatter Plot: {selected_columns[0]} vs {selected_columns[1]}**")
|
492 |
+
fig = px.scatter(df, x=selected_columns[0], y=selected_columns[1], trendline="ols")
|
493 |
+
st.plotly_chart(fig, use_container_width=True)
|
494 |
+
|
495 |
+
with col2:
|
496 |
+
st.write("**Statistical Summary**")
|
497 |
+
st.write(df[selected_columns].describe())
|
498 |
+
|
499 |
+
# Correlation Analysis
|
500 |
+
pearson_corr, _ = pearsonr(df[selected_columns[0]], df[selected_columns[1]])
|
501 |
+
spearman_corr, _ = spearmanr(df[selected_columns[0]], df[selected_columns[1]])
|
502 |
+
|
503 |
+
st.metric("Pearson Correlation", f"{pearson_corr:.2f}")
|
504 |
+
st.metric("Spearman Correlation", f"{spearman_corr:.2f}")
|
505 |
+
|
506 |
+
st.write("**Regression Line**")
|
507 |
+
st.write(f"Equation: y = {fig.data[1].line.color} * x + {fig.data[1].line.dash}")
|
508 |
+
elif len(selected_columns) > 2:
|
509 |
+
st.warning("Please select only two columns for relationship analysis.")
|
510 |
+
else:
|
511 |
+
st.warning("Please select at least two columns for relationship analysis.")
|
512 |
+
|
513 |
+
# --------------------------
|
514 |
+
# Advanced Statistics
|
515 |
+
# --------------------------
|
516 |
+
with st.expander("📊 Advanced Statistics", expanded=False):
|
517 |
+
st.write("**Column-wise Statistics**")
|
518 |
+
selected_col = st.selectbox("Select a column for detailed analysis", df.columns)
|
519 |
+
if selected_col:
|
520 |
+
if pd.api.types.is_numeric_dtype(df[selected_col]):
|
521 |
+
st.write(f"**Distribution of {selected_col}**")
|
522 |
+
fig = px.histogram(df, x=selected_col, nbins=30)
|
523 |
+
st.plotly_chart(fig, use_container_width=True)
|
524 |
+
|
525 |
+
st.write("**Outlier Detection**")
|
526 |
+
Q1 = df[selected_col].quantile(0.25)
|
527 |
+
Q3 = df[selected_col].quantile(0.75)
|
528 |
+
IQR = Q3 - Q1
|
529 |
+
outliers = df[(df[selected_col] < (Q1 - 1.5 * IQR)) | (df[selected_col] > (Q3 + 1.5 * IQR))]
|
530 |
+
st.write(f"Number of outliers: {len(outliers)}")
|
531 |
+
st.dataframe(outliers.head(), use_container_width=True)
|
532 |
+
else:
|
533 |
+
st.write(f"**Value Counts for {selected_col}**")
|
534 |
+
value_counts = df[selected_col].value_counts()
|
535 |
+
st.bar_chart(value_counts)
|
536 |
+
|
537 |
+
# --------------------------
|
538 |
+
# Save Visualizations
|
539 |
+
# --------------------------
|
540 |
+
st.subheader("💾 Save Visualizations")
|
541 |
+
if st.button("Export Current Visualization as PNG"):
|
542 |
+
try:
|
543 |
+
fig.write_image("visualization.png")
|
544 |
+
st.success("Visualization saved as PNG!")
|
545 |
+
except Exception as e:
|
546 |
+
st.error(f"Error saving visualization: {str(e)}")
|
547 |
+
|
548 |
+
# Call the EDA function
|
549 |
+
eda()
|
550 |
+
|
551 |
+
# Function to train the model (Separated for clarity and reusability)
|
552 |
+
def train_model(df, target, features, problem_type, test_size, model_type, model_params, use_grid_search=False):
|
553 |
+
"""Trains a model with hyperparameter tuning, cross-validation, and customizable model architecture."""
|
554 |
+
|
555 |
+
try:
|
556 |
+
X = df[features]
|
557 |
+
y = df[target]
|
558 |
+
|
559 |
+
# Input Validation
|
560 |
+
if target not in df.columns:
|
561 |
+
raise ValueError(f"Target variable '{target}' not found in DataFrame.")
|
562 |
+
for feature in features:
|
563 |
+
if feature not in df.columns:
|
564 |
+
raise ValueError(f"Feature '{feature}' not found in DataFrame.")
|
565 |
+
|
566 |
+
# Preprocessing Pipeline: Handles missing values, encoding, scaling
|
567 |
+
# Imputation: Handle missing values BEFORE encoding (numerical only for SimpleImputer)
|
568 |
+
numerical_features = X.select_dtypes(include=np.number).columns
|
569 |
+
categorical_features = X.select_dtypes(exclude=np.number).columns
|
570 |
+
|
571 |
+
imputer_numerical = SimpleImputer(strategy='mean') # Or 'median', 'most_frequent', 'constant'
|
572 |
+
X[numerical_features] = imputer_numerical.fit_transform(X[numerical_features])
|
573 |
|
574 |
+
# Encoding (One-Hot Encode Categorical Features)
|
575 |
+
X = pd.get_dummies(X, columns=categorical_features, dummy_na=False) # dummy_na = False. We imputed already.
|
576 |
+
|
577 |
+
# Target Encoding (if classification)
|
578 |
+
label_encoder = None #Initialize label_encoder
|
579 |
+
if problem_type == "Classification" or problem_type == "Multiclass":
|
580 |
+
label_encoder = LabelEncoder()
|
581 |
+
y = label_encoder.fit_transform(y)
|
582 |
+
|
583 |
+
|
584 |
+
# Split the data
|
585 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
586 |
+
X, y, test_size=test_size, random_state=42
|
587 |
+
)
|
588 |
+
|
589 |
+
# Scaling (AFTER splitting!)
|
590 |
+
scaler = StandardScaler() # Or try MinMaxScaler, RobustScaler, QuantileTransformer
|
591 |
+
X_train = scaler.fit_transform(X_train) #Fit to the training data ONLY
|
592 |
+
X_test = scaler.transform(X_test) #Transform the test data using the fitted scaler
|
593 |
+
|
594 |
+
# Model Selection and Hyperparameter Tuning
|
595 |
+
if problem_type == "Regression":
|
596 |
+
if model_type == "Random Forest":
|
597 |
+
model = RandomForestRegressor(random_state=42)
|
598 |
+
param_grid = {
|
599 |
+
'n_estimators': [100, 200],
|
600 |
+
'max_depth': [None, 5, 10],
|
601 |
+
'min_samples_split': [2, 5]
|
602 |
+
}
|
603 |
+
elif model_type == "Gradient Boosting":
|
604 |
+
model = GradientBoostingRegressor(random_state=42)
|
605 |
+
param_grid = {
|
606 |
+
'n_estimators': [100, 200],
|
607 |
+
'learning_rate': [0.01, 0.1],
|
608 |
+
'max_depth': [3, 5]
|
609 |
+
}
|
610 |
+
elif model_type == "Neural Network":
|
611 |
+
model = MLPRegressor(random_state=42, max_iter=500) #set max_iter to 500
|
612 |
+
param_grid = {
|
613 |
+
'hidden_layer_sizes': [(50,), (100,), (50, 50)], #example sizes for depth
|
614 |
+
'activation': ['relu', 'tanh'],
|
615 |
+
'alpha': [0.0001, 0.001]
|
616 |
+
}
|
617 |
+
else:
|
618 |
+
raise ValueError(f"Invalid model type: {model_type}")
|
619 |
+
|
620 |
+
|
621 |
+
elif problem_type == "Classification": #Binary
|
622 |
+
if model_type == "Random Forest":
|
623 |
+
model = RandomForestClassifier(random_state=42)
|
624 |
+
param_grid = {
|
625 |
+
'n_estimators': [100, 200],
|
626 |
+
'max_depth': [None, 5, 10],
|
627 |
+
'min_samples_split': [2, 5]
|
628 |
+
}
|
629 |
+
elif model_type == "Gradient Boosting":
|
630 |
+
model = GradientBoostingClassifier(random_state=42)
|
631 |
+
param_grid = {
|
632 |
+
'n_estimators': [100, 200],
|
633 |
+
'learning_rate': [0.01, 0.1],
|
634 |
+
'max_depth': [3, 5]
|
635 |
+
}
|
636 |
+
elif model_type == "Neural Network":
|
637 |
+
model = MLPClassifier(random_state=42, max_iter=500) #set max_iter to 500
|
638 |
+
param_grid = {
|
639 |
+
'hidden_layer_sizes': [(50,), (100,), (50, 50)], #example sizes for depth
|
640 |
+
'activation': ['relu', 'tanh'],
|
641 |
+
'alpha': [0.0001, 0.001]
|
642 |
+
}
|
643 |
+
|
644 |
+
else:
|
645 |
+
raise ValueError(f"Invalid model type: {model_type}")
|
646 |
+
elif problem_type == "Multiclass": #Multiclass
|
647 |
+
|
648 |
+
if model_type == "Logistic Regression":
|
649 |
+
model = LogisticRegression(random_state=42, solver='liblinear', multi_class='ovr') # 'ovr' for one-vs-rest
|
650 |
+
param_grid = {'C': [0.1, 1.0, 10.0]} # Regularization parameter
|
651 |
+
|
652 |
+
elif model_type == "Support Vector Machine":
|
653 |
+
model = SVC(random_state=42, probability=True) # probability=True for probabilities
|
654 |
+
param_grid = {'C': [0.1, 1.0, 10.0], 'kernel': ['rbf', 'linear']}
|
655 |
+
|
656 |
+
elif model_type == "Random Forest":
|
657 |
+
model = RandomForestClassifier(random_state=42)
|
658 |
+
param_grid = {
|
659 |
+
'n_estimators': [100, 200],
|
660 |
+
'max_depth': [None, 5, 10],
|
661 |
+
'min_samples_split': [2, 5],
|
662 |
+
'criterion': ['gini', 'entropy'] #criterion for decision
|
663 |
+
}
|
664 |
+
|
665 |
+
else:
|
666 |
+
raise ValueError(f"Invalid model type: {model_type} for Multiclass")
|
667 |
+
else:
|
668 |
+
raise ValueError(f"Invalid problem type: {problem_type}")
|
669 |
+
|
670 |
+
# Update param_grid with user-defined parameters
|
671 |
+
param_grid.update(model_params) #This is key to use the model_params provided by user
|
672 |
+
|
673 |
+
if use_grid_search:
|
674 |
+
grid_search = GridSearchCV(model, param_grid, cv=3, scoring='accuracy' if problem_type in ['Classification', 'Multiclass'] else 'neg_mean_squared_error', verbose=1, n_jobs=-1)
|
675 |
+
grid_search.fit(X_train, y_train)
|
676 |
+
model = grid_search.best_estimator_ # Use the best model found
|
677 |
+
st.write("Best hyperparameters found by Grid Search:", grid_search.best_params_) #Print best parameters
|
678 |
+
|
679 |
+
else:
|
680 |
+
model.fit(X_train, y_train)
|
681 |
+
|
682 |
+
# Cross-Validation (after hyperparameter tuning, if applicable)
|
683 |
+
cv_scores = cross_val_score(model, X_train, y_train, cv=5, scoring='accuracy' if problem_type in ['Classification', 'Multiclass'] else 'neg_mean_squared_error')
|
684 |
+
st.write("Cross-validation scores:", cv_scores)
|
685 |
+
st.write("Mean cross-validation score:", cv_scores.mean())
|
686 |
+
|
687 |
+
# Evaluation
|
688 |
+
y_pred = model.predict(X_test)
|
689 |
+
metrics = {} #Store metrics in a dictionary
|
690 |
+
|
691 |
+
if problem_type == "Classification":
|
692 |
+
metrics['accuracy'] = accuracy_score(y_test, y_pred)
|
693 |
+
metrics['confusion_matrix'] = confusion_matrix(y_test, y_pred)
|
694 |
+
metrics['classification_report'] = classification_report(y_test, y_pred, output_dict=True) #Get report as dictionary
|
695 |
+
|
696 |
+
elif problem_type == "Multiclass":
|
697 |
+
|
698 |
+
metrics['accuracy'] = accuracy_score(y_test, y_pred)
|
699 |
+
metrics['confusion_matrix'] = confusion_matrix(y_test, y_pred)
|
700 |
+
metrics['classification_report'] = classification_report(y_test, y_pred, output_dict=True) #Get report as dictionary
|
701 |
+
else:
|
702 |
+
metrics['mse'] = mean_squared_error(y_test, y_pred)
|
703 |
+
metrics['r2'] = r2_score(y_test, y_pred)
|
704 |
+
|
705 |
+
# Feature Importance (Permutation Importance for potentially better handling of correlated features)
|
706 |
+
try:
|
707 |
+
result = permutation_importance(model, X_test, y_test, n_repeats=10, random_state=42) #Permutation Feature Importance
|
708 |
+
importance = result.importances_mean
|
709 |
+
|
710 |
+
except Exception as e:
|
711 |
+
st.warning(f"Could not calculate feature importance: {e}")
|
712 |
+
importance = None
|
713 |
+
|
714 |
+
# Store the column order for prediction purposes
|
715 |
+
column_order = X.columns
|
716 |
+
|
717 |
+
return model, scaler, label_encoder, imputer_numerical, metrics, column_order, importance
|
718 |
+
|
719 |
+
except Exception as e:
|
720 |
+
st.error(f"Training failed: {str(e)}")
|
721 |
+
return None, None, None, None, None, None, None
|
722 |
+
# Model Validation Function
|
723 |
+
def validate_model(model_path, df, target, features, test_size):
|
724 |
+
"""Loads a model, preprocesses data, and evaluates the model on a validation set."""
|
725 |
+
try:
|
726 |
+
loaded_data = joblib.load(model_path)
|
727 |
+
model = loaded_data['model']
|
728 |
+
scaler = loaded_data['scaler']
|
729 |
+
label_encoder = loaded_data['label_encoder']
|
730 |
+
imputer_numerical = loaded_data['imputer_numerical']
|
731 |
+
column_order = loaded_data['column_order']
|
732 |
+
problem_type = loaded_data['problem_type']
|
733 |
+
|
734 |
+
X = df[features]
|
735 |
+
y = df[target]
|
736 |
+
|
737 |
+
# Imputation
|
738 |
+
numerical_features = X.select_dtypes(include=np.number).columns
|
739 |
+
X[numerical_features] = imputer_numerical.transform(X[numerical_features])
|
740 |
+
|
741 |
+
# Encoding
|
742 |
+
X = pd.get_dummies(X, columns=X.select_dtypes(exclude=np.number).columns, dummy_na=False)
|
743 |
+
|
744 |
+
# Ensure correct column order
|
745 |
+
X = X[column_order] #Reorder the columns
|
746 |
+
|
747 |
+
# Split the data
|
748 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
749 |
+
X, y, test_size=test_size, random_state=42
|
750 |
+
)
|
751 |
+
|
752 |
+
# Scaling
|
753 |
+
X_train = scaler.transform(X_train)
|
754 |
+
X_test = scaler.transform(X_test)
|
755 |
+
|
756 |
+
# Target Encoding (if classification) - Use the same encoder used during training
|
757 |
+
if problem_type == "Classification" or problem_type == "Multiclass":
|
758 |
+
y = label_encoder.transform(y)
|
759 |
+
|
760 |
+
y_pred = model.predict(X_test)
|
761 |
+
|
762 |
+
metrics = {}
|
763 |
+
if problem_type == "Classification":
|
764 |
+
metrics['accuracy'] = accuracy_score(y_test, y_pred)
|
765 |
+
metrics['confusion_matrix'] = confusion_matrix(y_test, y_pred)
|
766 |
+
metrics['classification_report'] = classification_report(y_test, y_pred, output_dict=True)
|
767 |
+
|
768 |
+
elif problem_type == "Multiclass":
|
769 |
+
|
770 |
+
metrics['accuracy'] = accuracy_score(y_test, y_pred)
|
771 |
+
metrics['confusion_matrix'] = confusion_matrix(y_test, y_pred)
|
772 |
+
metrics['classification_report'] = classification_report(y_test, y_pred, output_dict=True)
|
773 |
+
else:
|
774 |
+
metrics['mse'] = mean_squared_error(y_test, y_pred)
|
775 |
+
metrics['r2'] = r2_score(y_test, y_pred)
|
776 |
+
|
777 |
+
return metrics, problem_type
|
778 |
+
|
779 |
+
except Exception as e:
|
780 |
+
st.error(f"Validation failed: {str(e)}")
|
781 |
+
return None, None
|
782 |
+
|
783 |
+
# Streamlit App
|
784 |
elif app_mode == "Model Training":
|
785 |
st.title("🤖 Intelligent Model Training")
|
786 |
+
|
787 |
+
if st.session_state.get("cleaned_data") is None:
|
788 |
st.warning("Please clean your data first")
|
789 |
st.stop()
|
790 |
+
|
791 |
df = st.session_state.cleaned_data
|
792 |
+
|
793 |
# Model Setup
|
794 |
+
col1, col2, col3 = st.columns(3)
|
795 |
with col1:
|
796 |
target = st.selectbox("Select Target Variable", df.columns)
|
797 |
+
problem_type = st.selectbox("Problem Type", ["Classification", "Regression", "Multiclass"]) #Added Multiclass
|
798 |
with col2:
|
799 |
+
available_features = df.columns.drop(target)
|
800 |
+
features = st.multiselect("Select Features", available_features, default=list(available_features)) # Select all as default
|
801 |
+
with col3:
|
802 |
test_size = st.slider("Test Size", 0.1, 0.5, 0.2)
|
803 |
+
|
804 |
+
# Model Type Selection
|
805 |
+
if problem_type == "Regression":
|
806 |
+
model_type = st.selectbox("Select Regression Model", ["Random Forest", "Gradient Boosting", "Neural Network"])
|
807 |
+
elif problem_type == "Classification":
|
808 |
+
model_type = st.selectbox("Select Classification Model", ["Random Forest", "Gradient Boosting", "Neural Network"])
|
809 |
+
elif problem_type == "Multiclass":
|
810 |
+
model_type = st.selectbox("Select Multiclass Model", ["Logistic Regression", "Support Vector Machine", "Random Forest"]) #Added SVM and Logistic Regression
|
811 |
+
else:
|
812 |
+
model_type = None #handle this
|
813 |
+
|
814 |
+
# Hyperparameter Configuration - Dynamic based on Model Type
|
815 |
+
st.subheader("Hyperparameter Configuration")
|
816 |
+
model_params = {}
|
817 |
+
|
818 |
+
if model_type == "Neural Network": #Add options for NN parameters
|
819 |
+
hidden_layers = st.text_input("Hidden Layer Sizes (e.g., 50,50 for two layers of 50 neurons)", "50,50")
|
820 |
+
activation = st.selectbox("Activation Function", ["relu", "tanh", "logistic"])
|
821 |
+
alpha = st.number_input("L2 Regularization (Alpha)", value=0.0001)
|
822 |
+
|
823 |
+
#Process the hidden layers string to a tuple of ints
|
824 |
try:
|
825 |
+
hidden_layer_sizes = tuple(map(int, hidden_layers.split(',')))
|
826 |
+
model_params['hidden_layer_sizes'] = hidden_layer_sizes
|
827 |
+
except ValueError:
|
828 |
+
st.error("Invalid format for Hidden Layer Sizes. Use comma-separated integers (e.g., 50,50)")
|
829 |
+
|
830 |
+
model_params['activation'] = activation
|
831 |
+
model_params['alpha'] = alpha
|
832 |
+
|
833 |
+
elif model_type == "Gradient Boosting":
|
834 |
+
n_estimators = st.slider("Number of Estimators", 50, 300, 100)
|
835 |
+
learning_rate = st.number_input("Learning Rate", value=0.1)
|
836 |
+
max_depth = st.slider("Max Depth", 2, 10, 3)
|
837 |
+
|
838 |
+
model_params['n_estimators'] = n_estimators
|
839 |
+
model_params['learning_rate'] = learning_rate
|
840 |
+
model_params['max_depth'] = max_depth
|
841 |
+
elif model_type == "Logistic Regression":
|
842 |
+
c_value = st.number_input("C (Regularization)", value=1.0)
|
843 |
+
model_params['C'] = c_value
|
844 |
+
|
845 |
+
elif model_type == "Support Vector Machine":
|
846 |
+
c_value = st.number_input("C (Regularization)", value=1.0)
|
847 |
+
kernel_type = st.selectbox("Kernel Type", ['rbf', 'linear', 'poly', 'sigmoid'])
|
848 |
+
model_params['C'] = c_value
|
849 |
+
model_params['kernel'] = kernel_type
|
850 |
+
|
851 |
+
elif model_type == "Random Forest":
|
852 |
+
n_estimators = st.slider("Number of Estimators", 50, 300, 100)
|
853 |
+
max_depth = st.slider("Max Depth", 2, 10, 3)
|
854 |
+
model_params['n_estimators'] = n_estimators
|
855 |
+
model_params['max_depth'] = max_depth
|
856 |
+
|
857 |
+
|
858 |
+
|
859 |
+
use_grid_search = st.checkbox("Use Grid Search for Hyperparameter Tuning")
|
860 |
+
|
861 |
+
if st.button("Train Model"):
|
862 |
+
if not features:
|
863 |
+
st.error("Please select at least one feature.")
|
864 |
+
st.stop()
|
865 |
+
|
866 |
+
# Call the training function
|
867 |
+
model, scaler, label_encoder, imputer_numerical, metrics, column_order, importance = train_model(df.copy(), target, features, problem_type, test_size, model_type, model_params, use_grid_search) # Pass a copy to avoid modifying the original
|
868 |
+
|
869 |
+
if model: # Only proceed if training was successful
|
870 |
+
st.success("Model trained successfully!")
|
871 |
+
|
872 |
+
# Display Metrics
|
873 |
+
st.subheader("Model Evaluation Metrics")
|
874 |
+
if problem_type in ["Classification", "Multiclass"]: #Combined here
|
875 |
+
st.metric("Accuracy", f"{metrics['accuracy']:.2%}")
|
876 |
+
|
877 |
+
# Confusion Matrix Visualization
|
878 |
+
st.subheader("Confusion Matrix")
|
879 |
+
cm = metrics['confusion_matrix']
|
880 |
+
class_names = [str(i) for i in np.unique(df[target])] #Get original class names
|
881 |
+
fig_cm = px.imshow(cm,
|
882 |
+
labels=dict(x="Predicted", y="Actual"),
|
883 |
+
x=class_names,
|
884 |
+
y=class_names,
|
885 |
+
color_continuous_scale="Viridis")
|
886 |
+
st.plotly_chart(fig_cm, use_container_width=True)
|
887 |
+
|
888 |
+
# Classification Report
|
889 |
+
st.subheader("Classification Report")
|
890 |
+
report = metrics['classification_report']
|
891 |
+
report_df = pd.DataFrame(report).transpose()
|
892 |
+
st.dataframe(report_df)
|
893 |
+
|
894 |
else:
|
895 |
+
st.metric("MSE", f"{metrics['mse']:.2f}")
|
896 |
+
st.metric("R2", f"{metrics['r2']:.2f}")
|
897 |
+
|
898 |
# Feature Importance
|
899 |
+
st.subheader("Feature Importance")
|
900 |
+
try:
|
901 |
+
fig_importance = px.bar(
|
902 |
+
x=importance,
|
903 |
+
y=column_order, #Use stored column order
|
904 |
+
orientation='h',
|
905 |
+
title="Feature Importance"
|
906 |
+
)
|
907 |
+
st.plotly_chart(fig_importance, use_container_width=True)
|
908 |
+
except Exception as e:
|
909 |
+
st.warning(f"Could not display feature importance: {e}")
|
910 |
+
|
911 |
+
# Explainable AI (Placeholder)
|
912 |
+
st.subheader("Explainable AI (XAI)")
|
913 |
+
st.write("Future implementation will include model explanations using techniques like SHAP or LIME.") #To be implemented
|
914 |
+
if st.checkbox("Show a random model explanation (example)"): #Example of a feature, to be implemented
|
915 |
+
st.write("This feature is important because...")
|
916 |
+
|
917 |
+
# Save Model
|
918 |
+
st.subheader("Save Model")
|
919 |
+
model_name = st.text_input("Enter model name (without extension)", "my_model")
|
920 |
+
if st.button("Save Model"):
|
921 |
+
try:
|
922 |
+
model_path = f"{model_name}.joblib"
|
923 |
+
joblib.dump({
|
924 |
+
'model': model,
|
925 |
+
'scaler': scaler,
|
926 |
+
'label_encoder': label_encoder,
|
927 |
+
'imputer_numerical': imputer_numerical,
|
928 |
+
'column_order': column_order,
|
929 |
+
'features': features,
|
930 |
+
'target': target,
|
931 |
+
'problem_type': problem_type,
|
932 |
+
'model_type': model_type,
|
933 |
+
'model_params': model_params
|
934 |
+
}, model_path)
|
935 |
+
st.success(f"Model saved as {model_path}")
|
936 |
+
except Exception as e:
|
937 |
+
st.error(f"Error saving model: {e}")
|
938 |
+
|
939 |
+
# Model Validation Section
|
940 |
+
st.header("Model Validation")
|
941 |
+
model_path_validate = st.text_input("Enter path to saved model for validation", "my_model.joblib")
|
942 |
+
if st.button("Validate Model"):
|
943 |
+
if not os.path.exists(model_path_validate):
|
944 |
+
st.error("Model file not found.")
|
945 |
+
else:
|
946 |
+
validation_metrics, problem_type = validate_model(model_path_validate, df.copy(), target, features, test_size) #Pass a copy of the dataframe
|
947 |
+
if validation_metrics:
|
948 |
+
st.subheader("Validation Metrics")
|
949 |
+
if problem_type in ["Classification", "Multiclass"]: #Combined here
|
950 |
+
st.metric("Accuracy", f"{validation_metrics['accuracy']:.2%}")
|
951 |
+
st.subheader("Confusion Matrix")
|
952 |
+
cm = validation_metrics['confusion_matrix']
|
953 |
+
class_names = [str(i) for i in np.unique(df[target])] #Get original class names
|
954 |
+
fig_cm = px.imshow(cm,
|
955 |
+
labels=dict(x="Predicted", y="Actual"),
|
956 |
+
x=class_names,
|
957 |
+
y=class_names,
|
958 |
+
color_continuous_scale="Viridis")
|
959 |
+
st.plotly_chart(fig_cm, use_container_width=True)
|
960 |
+
st.subheader("Classification Report")
|
961 |
+
report = validation_metrics['classification_report']
|
962 |
+
report_df = pd.DataFrame(report).transpose()
|
963 |
+
st.dataframe(report_df)
|
964 |
+
|
965 |
+
else:
|
966 |
+
st.metric("MSE", f"{validation_metrics['mse']:.2f}")
|
967 |
+
st.metric("R2", f"{validation_metrics['r2']:.2f}")
|
968 |
|
969 |
elif app_mode == "Predictions":
|
970 |
st.title("🔮 Predictive Analytics")
|