Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,15 +2,20 @@ import streamlit as st
|
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import plotly.express as px
|
5 |
-
from sklearn.model_selection import train_test_split
|
6 |
-
from sklearn.linear_model import LinearRegression
|
7 |
-
from sklearn.tree import DecisionTreeRegressor
|
8 |
-
from sklearn.
|
|
|
|
|
9 |
from sklearn.impute import KNNImputer
|
10 |
-
from sklearn.preprocessing import RobustScaler
|
|
|
|
|
11 |
from ydata_profiling import ProfileReport
|
12 |
from streamlit_pandas_profiling import st_profile_report
|
13 |
from io import StringIO
|
|
|
14 |
|
15 |
# Configuration
|
16 |
st.set_page_config(page_title="Data Wizard Pro", layout="wide", page_icon="🧙")
|
@@ -25,12 +30,12 @@ st.markdown(
|
|
25 |
color: #e0e0ff; /* Light text */
|
26 |
font-family: 'Courier New', monospace; /* Monospace font */
|
27 |
}
|
28 |
-
|
29 |
/* Main content area */
|
30 |
.stApp {
|
31 |
background-color: #0a0a1a; /* Match body background */
|
32 |
}
|
33 |
-
|
34 |
/* Containers and blocks */
|
35 |
.st-emotion-cache-16idsys,
|
36 |
.st-emotion-cache-1v0mbdj,
|
@@ -46,44 +51,44 @@ st.markdown(
|
|
46 |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5); /* Enhanced shadow */
|
47 |
color: #e0e0ff; /* Light text color */
|
48 |
}
|
49 |
-
|
50 |
/* Sidebar */
|
51 |
.st-bb {
|
52 |
background-color: #141422; /* Dark sidebar background */
|
53 |
padding: 20px;
|
54 |
border-radius: 10px;
|
55 |
}
|
56 |
-
|
57 |
/* Headers */
|
58 |
h1, h2, h3, h4, h5, h6, .st-bb {
|
59 |
color: #00f7ff; /* Cyan color for headers */
|
60 |
}
|
61 |
-
|
62 |
/* Selectboxes and Buttons */
|
63 |
.st-cb, .st-ci, .st-cj, .st-ch {
|
64 |
background-color: #141422; /* Dark selectbox background */
|
65 |
color: #00f7ff !important; /* Cyan text color */
|
66 |
border: 1px solid #00f7ff; /* Cyan border */
|
67 |
}
|
68 |
-
|
69 |
/* Selectbox text */
|
70 |
.st-cv {
|
71 |
color: #00f7ff !important; /* Cyan color for selectbox text */
|
72 |
}
|
73 |
-
|
74 |
/* Number input and text input */
|
75 |
.st-cr {
|
76 |
background-color: #141422 !important; /* Dark input background */
|
77 |
color: #00f7ff !important; /* Cyan text color */
|
78 |
border: 1px solid #00f7ff !important; /* Cyan border */
|
79 |
}
|
80 |
-
|
81 |
/* Slider */
|
82 |
.st-cw {
|
83 |
background-color: #141422 !important; /* Dark slider background */
|
84 |
border: 1px solid #00f7ff !important; /* Cyan border */
|
85 |
}
|
86 |
-
|
87 |
/* Buttons */
|
88 |
.st-bz, .st-b0 {
|
89 |
background-color: #141422; /* Darker Button background */
|
@@ -95,7 +100,7 @@ st.markdown(
|
|
95 |
background-color: #00f7ff; /* Hover color */
|
96 |
color: #0a0a1a; /* Hover text color */
|
97 |
}
|
98 |
-
|
99 |
/* File uploader */
|
100 |
.st-ae {
|
101 |
background-color: #141422 !important; /* Dark file uploader background */
|
@@ -103,24 +108,21 @@ st.markdown(
|
|
103 |
border: 1px solid #00f7ff !important; /* Cyan border */
|
104 |
border-radius: 10px; /* Rounded corners */
|
105 |
}
|
106 |
-
|
107 |
/* Metric */
|
108 |
.st-emotion-cache-10trblm {
|
109 |
border-radius: 10px !important; /* Rounded corners */
|
110 |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5) !important; /* Enhanced shadow */
|
111 |
}
|
112 |
-
|
113 |
/* Dataframes and tables */
|
114 |
.dataframe {
|
115 |
background-color: #1e1e30 !important; /* Dark table background */
|
116 |
color: #e0e0ff !important; /* Light text in tables */
|
117 |
border: 1px solid #00f7ff !important; /* Cyan border for tables */
|
118 |
}
|
119 |
-
|
120 |
.dataframe tr:nth-child(odd) {
|
121 |
background-color: #141422 !important; /* Alternating row color */
|
122 |
}
|
123 |
-
|
124 |
/* Expanders*/
|
125 |
.st-emotion-cache-10oheav {
|
126 |
color: #00f7ff !important; /* Cyan text color */
|
@@ -142,10 +144,20 @@ st.markdown(
|
|
142 |
# Cache decorators
|
143 |
@st.cache_data(ttl=3600)
|
144 |
def load_data(uploaded_file):
|
145 |
-
"""Load and cache dataset"""
|
146 |
-
if uploaded_file
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
@st.cache_data(ttl=3600)
|
151 |
def generate_profile(df):
|
@@ -161,12 +173,14 @@ if 'train_test' not in st.session_state:
|
|
161 |
st.session_state.train_test = {}
|
162 |
if 'model' not in st.session_state:
|
163 |
st.session_state.model = None
|
|
|
|
|
164 |
|
165 |
# Sidebar Navigation
|
166 |
st.sidebar.title("🔮 Data Wizard Pro")
|
167 |
app_mode = st.sidebar.radio("Navigate", [
|
168 |
-
"Data Upload",
|
169 |
-
"Smart Cleaning",
|
170 |
"Advanced EDA",
|
171 |
"Model Training",
|
172 |
"Predictions",
|
@@ -176,35 +190,36 @@ app_mode = st.sidebar.radio("Navigate", [
|
|
176 |
# Data Upload Section
|
177 |
if app_mode == "Data Upload":
|
178 |
st.title("📤 Data Upload & Analysis")
|
179 |
-
|
180 |
uploaded_file = st.file_uploader("Upload Dataset", type=["csv", "xlsx"])
|
181 |
if uploaded_file:
|
182 |
df = load_data(uploaded_file)
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
|
|
200 |
|
201 |
# Smart Cleaning Section
|
202 |
elif app_mode == "Smart Cleaning":
|
203 |
st.title("🧼 Intelligent Data Cleaning")
|
204 |
-
|
205 |
if st.session_state.raw_data is not None:
|
206 |
df = st.session_state.cleaned_data
|
207 |
-
|
208 |
# Cleaning Toolkit
|
209 |
col1, col2 = st.columns([1, 3])
|
210 |
with col1:
|
@@ -213,9 +228,10 @@ elif app_mode == "Smart Cleaning":
|
|
213 |
"Handle Missing Values",
|
214 |
"Remove Duplicates",
|
215 |
"Normalize Data",
|
216 |
-
"Encode Categories"
|
|
|
217 |
])
|
218 |
-
|
219 |
if clean_action == "Handle Missing Values":
|
220 |
method = st.selectbox("Imputation Method", [
|
221 |
"KNN Imputation",
|
@@ -223,160 +239,383 @@ elif app_mode == "Smart Cleaning":
|
|
223 |
"Mean Fill",
|
224 |
"Drop Missing"
|
225 |
])
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
with col2:
|
228 |
if st.button("Apply Transformation"):
|
229 |
with st.spinner("Applying changes..."):
|
230 |
if clean_action == "Handle Missing Values":
|
231 |
-
if
|
232 |
-
|
233 |
-
df = pd.DataFrame(imputer.fit_transform(df), columns=df.columns)
|
234 |
-
elif method == "Median Fill":
|
235 |
-
df = df.fillna(df.median())
|
236 |
-
elif method == "Mean Fill":
|
237 |
-
df = df.fillna(df.mean())
|
238 |
else:
|
239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
elif clean_action == "Remove Duplicates":
|
241 |
df = df.drop_duplicates()
|
242 |
elif clean_action == "Normalize Data":
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
st.session_state.cleaned_data = df
|
248 |
st.success("Transformation applied!")
|
249 |
-
|
250 |
# Data Comparison
|
251 |
st.subheader("Data Version Comparison")
|
252 |
col1, col2 = st.columns(2)
|
253 |
with col1:
|
254 |
-
st.write("Original Data", st.session_state.raw_data.head(3))
|
255 |
with col2:
|
256 |
st.write("Cleaned Data", df.head(3))
|
257 |
|
258 |
# Advanced EDA Section
|
259 |
elif app_mode == "Advanced EDA":
|
260 |
st.title("🔍 Advanced Exploratory Analysis")
|
261 |
-
|
262 |
if st.session_state.cleaned_data is not None:
|
263 |
df = st.session_state.cleaned_data
|
264 |
-
|
265 |
# Visualization Selector
|
266 |
plot_type = st.selectbox("Choose Visualization", [
|
267 |
-
"Histogram",
|
268 |
"Scatter Plot",
|
269 |
"Box Plot",
|
270 |
"Correlation Heatmap",
|
271 |
-
"3D Scatter"
|
|
|
|
|
272 |
])
|
273 |
-
|
274 |
# Dynamic Axis Selection
|
275 |
cols = st.columns(3)
|
276 |
with cols[0]:
|
277 |
x_col = st.selectbox("X Axis", df.columns)
|
278 |
with cols[1]:
|
279 |
-
y_col = st.selectbox("Y Axis", df.columns) if plot_type in ["Scatter Plot", "Box Plot"] else None
|
280 |
with cols[2]:
|
281 |
z_col = st.selectbox("Z Axis", df.columns) if plot_type == "3D Scatter" else None
|
282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
# Generate Plot
|
284 |
if st.button("Generate Visualization"):
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
|
305 |
# Model Training Section
|
306 |
elif app_mode == "Model Training":
|
307 |
st.title("🤖 Model Training Studio")
|
308 |
-
|
309 |
if st.session_state.cleaned_data is not None:
|
310 |
df = st.session_state.cleaned_data
|
311 |
-
|
|
|
|
|
|
|
|
|
|
|
312 |
# Model Setup
|
313 |
col1, col2 = st.columns([1, 3])
|
314 |
with col1:
|
315 |
-
|
316 |
-
|
317 |
-
"
|
318 |
-
|
319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
test_size = st.slider("Test Size", 0.1, 0.5, 0.2)
|
321 |
target = st.selectbox("Target Variable", df.columns)
|
322 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
with col2:
|
324 |
if st.button("Train Model"):
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
|
349 |
# Predictions Section
|
350 |
elif app_mode == "Predictions":
|
351 |
st.title("🔮 Make Predictions")
|
352 |
-
|
353 |
-
if st.session_state.model is not None:
|
354 |
-
model = st.session_state.model
|
355 |
-
|
|
|
356 |
# Prediction Interface
|
357 |
input_data = {}
|
358 |
-
for
|
359 |
-
|
360 |
-
|
|
|
|
|
|
|
361 |
if st.button("Predict"):
|
362 |
-
|
363 |
-
|
364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
|
366 |
elif app_mode == "Visualization Lab":
|
367 |
st.title("📊 Advanced Visualization Lab")
|
368 |
-
|
369 |
if st.session_state.cleaned_data is not None:
|
370 |
df = st.session_state.cleaned_data
|
371 |
-
|
372 |
# Visualization Gallery
|
373 |
viz_type = st.selectbox("Choose Visualization Type", [
|
374 |
"3D Scatter Plot",
|
375 |
"Interactive Heatmap",
|
376 |
"Time Series Analysis",
|
377 |
-
"Cluster Analysis"
|
378 |
])
|
379 |
-
|
380 |
# Dynamic Controls
|
381 |
cols = st.columns(3)
|
382 |
with cols[0]:
|
@@ -385,23 +624,26 @@ elif app_mode == "Visualization Lab":
|
|
385 |
y_axis = st.selectbox("Y Axis", df.columns)
|
386 |
with cols[2]:
|
387 |
z_axis = st.selectbox("Z Axis", df.columns) if viz_type == "3D Scatter Plot" else None
|
388 |
-
|
389 |
# Generate Visualization
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
|
|
|
|
|
|
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import plotly.express as px
|
5 |
+
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
|
6 |
+
from sklearn.linear_model import LinearRegression, LogisticRegression
|
7 |
+
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
|
8 |
+
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, RandomForestClassifier
|
9 |
+
from sklearn.svm import SVR, SVC
|
10 |
+
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
|
11 |
from sklearn.impute import KNNImputer
|
12 |
+
from sklearn.preprocessing import RobustScaler, StandardScaler, OneHotEncoder
|
13 |
+
from sklearn.compose import ColumnTransformer
|
14 |
+
from sklearn.pipeline import Pipeline
|
15 |
from ydata_profiling import ProfileReport
|
16 |
from streamlit_pandas_profiling import st_profile_report
|
17 |
from io import StringIO
|
18 |
+
import joblib # For saving and loading models
|
19 |
|
20 |
# Configuration
|
21 |
st.set_page_config(page_title="Data Wizard Pro", layout="wide", page_icon="🧙")
|
|
|
30 |
color: #e0e0ff; /* Light text */
|
31 |
font-family: 'Courier New', monospace; /* Monospace font */
|
32 |
}
|
33 |
+
|
34 |
/* Main content area */
|
35 |
.stApp {
|
36 |
background-color: #0a0a1a; /* Match body background */
|
37 |
}
|
38 |
+
|
39 |
/* Containers and blocks */
|
40 |
.st-emotion-cache-16idsys,
|
41 |
.st-emotion-cache-1v0mbdj,
|
|
|
51 |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5); /* Enhanced shadow */
|
52 |
color: #e0e0ff; /* Light text color */
|
53 |
}
|
54 |
+
|
55 |
/* Sidebar */
|
56 |
.st-bb {
|
57 |
background-color: #141422; /* Dark sidebar background */
|
58 |
padding: 20px;
|
59 |
border-radius: 10px;
|
60 |
}
|
61 |
+
|
62 |
/* Headers */
|
63 |
h1, h2, h3, h4, h5, h6, .st-bb {
|
64 |
color: #00f7ff; /* Cyan color for headers */
|
65 |
}
|
66 |
+
|
67 |
/* Selectboxes and Buttons */
|
68 |
.st-cb, .st-ci, .st-cj, .st-ch {
|
69 |
background-color: #141422; /* Dark selectbox background */
|
70 |
color: #00f7ff !important; /* Cyan text color */
|
71 |
border: 1px solid #00f7ff; /* Cyan border */
|
72 |
}
|
73 |
+
|
74 |
/* Selectbox text */
|
75 |
.st-cv {
|
76 |
color: #00f7ff !important; /* Cyan color for selectbox text */
|
77 |
}
|
78 |
+
|
79 |
/* Number input and text input */
|
80 |
.st-cr {
|
81 |
background-color: #141422 !important; /* Dark input background */
|
82 |
color: #00f7ff !important; /* Cyan text color */
|
83 |
border: 1px solid #00f7ff !important; /* Cyan border */
|
84 |
}
|
85 |
+
|
86 |
/* Slider */
|
87 |
.st-cw {
|
88 |
background-color: #141422 !important; /* Dark slider background */
|
89 |
border: 1px solid #00f7ff !important; /* Cyan border */
|
90 |
}
|
91 |
+
|
92 |
/* Buttons */
|
93 |
.st-bz, .st-b0 {
|
94 |
background-color: #141422; /* Darker Button background */
|
|
|
100 |
background-color: #00f7ff; /* Hover color */
|
101 |
color: #0a0a1a; /* Hover text color */
|
102 |
}
|
103 |
+
|
104 |
/* File uploader */
|
105 |
.st-ae {
|
106 |
background-color: #141422 !important; /* Dark file uploader background */
|
|
|
108 |
border: 1px solid #00f7ff !important; /* Cyan border */
|
109 |
border-radius: 10px; /* Rounded corners */
|
110 |
}
|
|
|
111 |
/* Metric */
|
112 |
.st-emotion-cache-10trblm {
|
113 |
border-radius: 10px !important; /* Rounded corners */
|
114 |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5) !important; /* Enhanced shadow */
|
115 |
}
|
116 |
+
|
117 |
/* Dataframes and tables */
|
118 |
.dataframe {
|
119 |
background-color: #1e1e30 !important; /* Dark table background */
|
120 |
color: #e0e0ff !important; /* Light text in tables */
|
121 |
border: 1px solid #00f7ff !important; /* Cyan border for tables */
|
122 |
}
|
|
|
123 |
.dataframe tr:nth-child(odd) {
|
124 |
background-color: #141422 !important; /* Alternating row color */
|
125 |
}
|
|
|
126 |
/* Expanders*/
|
127 |
.st-emotion-cache-10oheav {
|
128 |
color: #00f7ff !important; /* Cyan text color */
|
|
|
144 |
# Cache decorators
|
145 |
@st.cache_data(ttl=3600)
|
146 |
def load_data(uploaded_file):
|
147 |
+
"""Load and cache dataset, with file type validation."""
|
148 |
+
if uploaded_file is not None:
|
149 |
+
file_extension = uploaded_file.name.split(".")[-1].lower()
|
150 |
+
|
151 |
+
if file_extension == "csv":
|
152 |
+
return pd.read_csv(uploaded_file)
|
153 |
+
elif file_extension in ["xlsx", "xls"]:
|
154 |
+
return pd.read_excel(uploaded_file)
|
155 |
+
else:
|
156 |
+
st.error("Unsupported file type. Please upload a CSV or Excel file.")
|
157 |
+
return None
|
158 |
+
else:
|
159 |
+
return None
|
160 |
+
|
161 |
|
162 |
@st.cache_data(ttl=3600)
|
163 |
def generate_profile(df):
|
|
|
173 |
st.session_state.train_test = {}
|
174 |
if 'model' not in st.session_state:
|
175 |
st.session_state.model = None
|
176 |
+
if 'preprocessor' not in st.session_state:
|
177 |
+
st.session_state.preprocessor = None # to store the column transformer
|
178 |
|
179 |
# Sidebar Navigation
|
180 |
st.sidebar.title("🔮 Data Wizard Pro")
|
181 |
app_mode = st.sidebar.radio("Navigate", [
|
182 |
+
"Data Upload",
|
183 |
+
"Smart Cleaning",
|
184 |
"Advanced EDA",
|
185 |
"Model Training",
|
186 |
"Predictions",
|
|
|
190 |
# Data Upload Section
|
191 |
if app_mode == "Data Upload":
|
192 |
st.title("📤 Data Upload & Analysis")
|
193 |
+
|
194 |
uploaded_file = st.file_uploader("Upload Dataset", type=["csv", "xlsx"])
|
195 |
if uploaded_file:
|
196 |
df = load_data(uploaded_file)
|
197 |
+
if df is not None: # only proceed if load_data returned a valid dataframe
|
198 |
+
st.session_state.raw_data = df
|
199 |
+
st.session_state.cleaned_data = df.copy()
|
200 |
+
|
201 |
+
# Data Overview Cards
|
202 |
+
col1, col2, col3 = st.columns(3)
|
203 |
+
with col1:
|
204 |
+
st.metric("Rows", df.shape[0])
|
205 |
+
with col2:
|
206 |
+
st.metric("Columns", df.shape[1])
|
207 |
+
with col3:
|
208 |
+
st.metric("Missing Values", df.isna().sum().sum())
|
209 |
+
|
210 |
+
# Automated EDA Report
|
211 |
+
with st.expander("🚀 Automated Data Report"):
|
212 |
+
if st.button("Generate Smart Report"):
|
213 |
+
pr = generate_profile(df)
|
214 |
+
st_profile_report(pr)
|
215 |
|
216 |
# Smart Cleaning Section
|
217 |
elif app_mode == "Smart Cleaning":
|
218 |
st.title("🧼 Intelligent Data Cleaning")
|
219 |
+
|
220 |
if st.session_state.raw_data is not None:
|
221 |
df = st.session_state.cleaned_data
|
222 |
+
|
223 |
# Cleaning Toolkit
|
224 |
col1, col2 = st.columns([1, 3])
|
225 |
with col1:
|
|
|
228 |
"Handle Missing Values",
|
229 |
"Remove Duplicates",
|
230 |
"Normalize Data",
|
231 |
+
"Encode Categories",
|
232 |
+
"Outlier Removal"
|
233 |
])
|
234 |
+
|
235 |
if clean_action == "Handle Missing Values":
|
236 |
method = st.selectbox("Imputation Method", [
|
237 |
"KNN Imputation",
|
|
|
239 |
"Mean Fill",
|
240 |
"Drop Missing"
|
241 |
])
|
242 |
+
impute_cols = st.multiselect("Columns to Impute", df.columns)
|
243 |
+
|
244 |
+
elif clean_action == "Normalize Data":
|
245 |
+
scaler_type = st.selectbox("Scaler Type", ["RobustScaler", "StandardScaler"])
|
246 |
+
normalize_cols = st.multiselect("Columns to Normalize", df.select_dtypes(include=np.number).columns.tolist())
|
247 |
+
|
248 |
+
elif clean_action == "Encode Categories":
|
249 |
+
encode_cols = st.multiselect("Columns to Encode", df.select_dtypes(include='object').columns.tolist())
|
250 |
+
encoding_method = st.selectbox("Encoding Method", ["OneHotEncoder"]) # Add more if needed
|
251 |
+
|
252 |
+
elif clean_action == "Outlier Removal":
|
253 |
+
outlier_cols = st.multiselect("Columns to Remove Outliers From", df.select_dtypes(include=np.number).columns.tolist())
|
254 |
+
outlier_method = st.selectbox("Outlier Removal Method", ["IQR", "Z-score"])
|
255 |
+
if outlier_method == "IQR":
|
256 |
+
iqr_threshold = st.slider("IQR Threshold", 1.0, 3.0, 1.5)
|
257 |
+
else:
|
258 |
+
zscore_threshold = st.slider("Z-score Threshold", 2.0, 4.0, 3.0)
|
259 |
+
|
260 |
+
|
261 |
with col2:
|
262 |
if st.button("Apply Transformation"):
|
263 |
with st.spinner("Applying changes..."):
|
264 |
if clean_action == "Handle Missing Values":
|
265 |
+
if not impute_cols:
|
266 |
+
st.warning("Please select columns to impute.")
|
|
|
|
|
|
|
|
|
|
|
267 |
else:
|
268 |
+
if method == "KNN Imputation":
|
269 |
+
imputer = KNNImputer()
|
270 |
+
df[impute_cols] = imputer.fit_transform(df[impute_cols])
|
271 |
+
elif method == "Median Fill":
|
272 |
+
df[impute_cols] = df[impute_cols].fillna(df[impute_cols].median())
|
273 |
+
elif method == "Mean Fill":
|
274 |
+
df[impute_cols] = df[impute_cols].fillna(df[impute_cols].mean())
|
275 |
+
else:
|
276 |
+
df = df.dropna(subset=impute_cols)
|
277 |
elif clean_action == "Remove Duplicates":
|
278 |
df = df.drop_duplicates()
|
279 |
elif clean_action == "Normalize Data":
|
280 |
+
if not normalize_cols:
|
281 |
+
st.warning("Please select columns to normalize.")
|
282 |
+
else:
|
283 |
+
if scaler_type == "RobustScaler":
|
284 |
+
scaler = RobustScaler()
|
285 |
+
else:
|
286 |
+
scaler = StandardScaler()
|
287 |
+
df[normalize_cols] = scaler.fit_transform(df[normalize_cols])
|
288 |
+
|
289 |
+
elif clean_action == "Encode Categories":
|
290 |
+
if not encode_cols:
|
291 |
+
st.warning("Please select columns to encode.")
|
292 |
+
else:
|
293 |
+
if encoding_method == "OneHotEncoder":
|
294 |
+
encoder = OneHotEncoder(handle_unknown='ignore', sparse_output=False)
|
295 |
+
encoded_data = encoder.fit_transform(df[encode_cols])
|
296 |
+
encoded_df = pd.DataFrame(encoded_data, columns=encoder.get_feature_names_out(encode_cols))
|
297 |
+
df = pd.concat([df.drop(columns=encode_cols), encoded_df], axis=1)
|
298 |
+
|
299 |
+
elif clean_action == "Outlier Removal":
|
300 |
+
if not outlier_cols:
|
301 |
+
st.warning("Please select columns to remove outliers from.")
|
302 |
+
else:
|
303 |
+
for col in outlier_cols:
|
304 |
+
if outlier_method == "IQR":
|
305 |
+
Q1 = df[col].quantile(0.25)
|
306 |
+
Q3 = df[col].quantile(0.75)
|
307 |
+
IQR = Q3 - Q1
|
308 |
+
lower_bound = Q1 - iqr_threshold * IQR
|
309 |
+
upper_bound = Q3 + iqr_threshold * IQR
|
310 |
+
df = df[(df[col] >= lower_bound) & (df[col] <= upper_bound)]
|
311 |
+
else: # Z-score
|
312 |
+
z_scores = np.abs((df[col] - df[col].mean()) / df[col].std())
|
313 |
+
df = df[z_scores <= zscore_threshold]
|
314 |
+
|
315 |
+
|
316 |
st.session_state.cleaned_data = df
|
317 |
st.success("Transformation applied!")
|
318 |
+
|
319 |
# Data Comparison
|
320 |
st.subheader("Data Version Comparison")
|
321 |
col1, col2 = st.columns(2)
|
322 |
with col1:
|
323 |
+
st.write("Original Data", st.session_state.raw_data.head(3) if st.session_state.raw_data is not None else "No data uploaded")
|
324 |
with col2:
|
325 |
st.write("Cleaned Data", df.head(3))
|
326 |
|
327 |
# Advanced EDA Section
|
328 |
elif app_mode == "Advanced EDA":
|
329 |
st.title("🔍 Advanced Exploratory Analysis")
|
330 |
+
|
331 |
if st.session_state.cleaned_data is not None:
|
332 |
df = st.session_state.cleaned_data
|
333 |
+
|
334 |
# Visualization Selector
|
335 |
plot_type = st.selectbox("Choose Visualization", [
|
336 |
+
"Histogram",
|
337 |
"Scatter Plot",
|
338 |
"Box Plot",
|
339 |
"Correlation Heatmap",
|
340 |
+
"3D Scatter",
|
341 |
+
"Violin Plot",
|
342 |
+
"Time Series"
|
343 |
])
|
344 |
+
|
345 |
# Dynamic Axis Selection
|
346 |
cols = st.columns(3)
|
347 |
with cols[0]:
|
348 |
x_col = st.selectbox("X Axis", df.columns)
|
349 |
with cols[1]:
|
350 |
+
y_col = st.selectbox("Y Axis", df.columns) if plot_type in ["Scatter Plot", "Box Plot", "Violin Plot", "Time Series"] else None
|
351 |
with cols[2]:
|
352 |
z_col = st.selectbox("Z Axis", df.columns) if plot_type == "3D Scatter" else None
|
353 |
+
if plot_type == "Time Series":
|
354 |
+
time_col = x_col # rename for clarity
|
355 |
+
value_col = y_col
|
356 |
+
|
357 |
+
#Interactive filtering
|
358 |
+
filter_col = st.selectbox("Filter Column", [None] + list(df.columns))
|
359 |
+
if filter_col:
|
360 |
+
unique_values = df[filter_col].unique()
|
361 |
+
filter_options = st.multiselect("Filter Values", unique_values, default=unique_values)
|
362 |
+
df = df[df[filter_col].isin(filter_options)]
|
363 |
+
|
364 |
+
|
365 |
# Generate Plot
|
366 |
if st.button("Generate Visualization"):
|
367 |
+
try: # add try-except block for potential errors
|
368 |
+
if plot_type == "Histogram":
|
369 |
+
fig = px.histogram(df, x=x_col, nbins=30, template="plotly_dark")
|
370 |
+
elif plot_type == "Scatter Plot":
|
371 |
+
fig = px.scatter(df, x=x_col, y=y_col, color_discrete_sequence=['#00f7ff'])
|
372 |
+
elif plot_type == "3D Scatter":
|
373 |
+
fig = px.scatter_3d(df, x=x_col, y=y_col, z=z_col, color=x_col)
|
374 |
+
elif plot_type == "Correlation Heatmap":
|
375 |
+
corr = df.corr(numeric_only=True) #handle non-numeric cols
|
376 |
+
fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu')
|
377 |
+
elif plot_type == "Box Plot":
|
378 |
+
fig = px.box(df,x=x_col, y=y_col, color_discrete_sequence=['#00f7ff'])
|
379 |
+
elif plot_type == "Violin Plot":
|
380 |
+
fig = px.violin(df, x=x_col, y=y_col, color_discrete_sequence=['#00f7ff'])
|
381 |
+
elif plot_type == "Time Series":
|
382 |
+
fig = px.line(df, x=time_col, y=value_col)
|
383 |
+
|
384 |
+
fig.update_layout(
|
385 |
+
plot_bgcolor="#1e1e30",
|
386 |
+
paper_bgcolor="#1e1e30",
|
387 |
+
font_color="#e0e0ff"
|
388 |
+
)
|
389 |
+
|
390 |
+
st.plotly_chart(fig, use_container_width=True)
|
391 |
+
except Exception as e:
|
392 |
+
st.error(f"Error generating plot: {e}")
|
393 |
|
394 |
# Model Training Section
|
395 |
elif app_mode == "Model Training":
|
396 |
st.title("🤖 Model Training Studio")
|
397 |
+
|
398 |
if st.session_state.cleaned_data is not None:
|
399 |
df = st.session_state.cleaned_data
|
400 |
+
|
401 |
+
# Check for missing values before proceeding
|
402 |
+
if df.isnull().sum().sum() > 0:
|
403 |
+
st.error("Data contains missing values. Please handle them in the 'Smart Cleaning' section before training.")
|
404 |
+
st.stop()
|
405 |
+
|
406 |
# Model Setup
|
407 |
col1, col2 = st.columns([1, 3])
|
408 |
with col1:
|
409 |
+
task_type = st.selectbox("Choose Task", ["Regression", "Classification"])
|
410 |
+
if task_type == "Regression":
|
411 |
+
model_type = st.selectbox("Choose Model", [
|
412 |
+
"Linear Regression",
|
413 |
+
"Decision Tree",
|
414 |
+
"Random Forest",
|
415 |
+
"Gradient Boosting"
|
416 |
+
])
|
417 |
+
else: # Classification
|
418 |
+
model_type = st.selectbox("Choose Model", [
|
419 |
+
"Logistic Regression",
|
420 |
+
"Decision Tree",
|
421 |
+
"Random Forest",
|
422 |
+
"Support Vector Machine" #SVC
|
423 |
+
])
|
424 |
+
|
425 |
test_size = st.slider("Test Size", 0.1, 0.5, 0.2)
|
426 |
target = st.selectbox("Target Variable", df.columns)
|
427 |
+
features = [col for col in df.columns if col != target] #Exclude target
|
428 |
+
numeric_features = df[features].select_dtypes(include=np.number).columns.tolist()
|
429 |
+
categorical_features = [col for col in features if col not in numeric_features]
|
430 |
+
|
431 |
+
# Hyperparameter tuning options (example for RandomForest)
|
432 |
+
enable_hyperparameter_tuning = st.checkbox("Enable Hyperparameter Tuning")
|
433 |
+
if enable_hyperparameter_tuning and model_type in ["Random Forest", "Gradient Boosting", "Support Vector Machine", "Logistic Regression", "Decision Tree"]: # Add more models later
|
434 |
+
st.write("Hyperparameter Tuning Options:")
|
435 |
+
if model_type == "Random Forest":
|
436 |
+
n_estimators = st.slider("Number of Estimators", 50, 200, 100)
|
437 |
+
max_depth = st.slider("Max Depth", 5, 20, None) #None for unlimited
|
438 |
+
param_grid = {'n_estimators': [n_estimators], 'max_depth': [max_depth]}
|
439 |
+
|
440 |
+
elif model_type == "Gradient Boosting":
|
441 |
+
n_estimators = st.slider("Number of Estimators", 50, 200, 100, key = "gb_n_estimators")
|
442 |
+
learning_rate = st.slider("Learning Rate", 0.01, 0.1, 0.05, key = "gb_learning_rate")
|
443 |
+
max_depth = st.slider("Max Depth", 3, 10, 5, key = "gb_max_depth")
|
444 |
+
param_grid = {'n_estimators': [n_estimators], 'learning_rate': [learning_rate], 'max_depth': [max_depth]}
|
445 |
+
|
446 |
+
elif model_type == "Support Vector Machine": #SVC/SVR
|
447 |
+
kernel = st.selectbox("Kernel", ['linear', 'rbf', 'poly'])
|
448 |
+
C = st.slider("C (Regularization)", 0.1, 1.0, 0.5)
|
449 |
+
param_grid = {'kernel': [kernel], 'C': [C]}
|
450 |
+
elif model_type == "Logistic Regression":
|
451 |
+
C = st.slider("C (Regularization)", 0.1, 1.0, 0.5)
|
452 |
+
param_grid = {'C': [C]} # add more as needed
|
453 |
+
elif model_type == "Decision Tree":
|
454 |
+
max_depth = st.slider("Max Depth", 5, 20, None) # None for unlimited
|
455 |
+
param_grid = {'max_depth': [max_depth]}
|
456 |
+
|
457 |
+
|
458 |
+
|
459 |
with col2:
|
460 |
if st.button("Train Model"):
|
461 |
+
try:
|
462 |
+
X = df.drop(columns=[target])
|
463 |
+
y = df[target]
|
464 |
+
|
465 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
466 |
+
X, y, test_size=test_size, random_state=42
|
467 |
+
)
|
468 |
+
|
469 |
+
# Preprocessing
|
470 |
+
numeric_transformer = StandardScaler() #StandardScaler or other scalers
|
471 |
+
categorical_transformer = OneHotEncoder(handle_unknown='ignore', sparse_output=False) #sparse=False for array output
|
472 |
+
|
473 |
+
preprocessor = ColumnTransformer(
|
474 |
+
transformers=[
|
475 |
+
('num', numeric_transformer, numeric_features),
|
476 |
+
('cat', categorical_transformer, categorical_features)
|
477 |
+
],
|
478 |
+
remainder='passthrough' # or 'drop' if you want to drop untransformed cols
|
479 |
+
)
|
480 |
+
|
481 |
+
X_train = preprocessor.fit_transform(X_train)
|
482 |
+
X_test = preprocessor.transform(X_test)
|
483 |
+
st.session_state.preprocessor = preprocessor #store for prediction later
|
484 |
+
|
485 |
+
|
486 |
+
# Model Training
|
487 |
+
if task_type == "Regression":
|
488 |
+
if model_type == "Linear Regression":
|
489 |
+
model = LinearRegression()
|
490 |
+
elif model_type == "Decision Tree":
|
491 |
+
model = DecisionTreeRegressor()
|
492 |
+
elif model_type == "Random Forest":
|
493 |
+
model = RandomForestRegressor()
|
494 |
+
elif model_type == "Gradient Boosting":
|
495 |
+
model = GradientBoostingRegressor()
|
496 |
+
elif model_type == "Support Vector Machine":
|
497 |
+
model = SVR()
|
498 |
+
else: #Classification
|
499 |
+
if model_type == "Logistic Regression":
|
500 |
+
model = LogisticRegression(max_iter=1000) #increase max_iter if needed
|
501 |
+
elif model_type == "Decision Tree":
|
502 |
+
model = DecisionTreeClassifier()
|
503 |
+
elif model_type == "Random Forest":
|
504 |
+
model = RandomForestClassifier()
|
505 |
+
elif model_type == "Support Vector Machine":
|
506 |
+
model = SVC(probability=True) #probability=True needed for ROC AUC
|
507 |
+
|
508 |
+
|
509 |
+
#Hyperparameter tuning
|
510 |
+
if enable_hyperparameter_tuning and model_type in ["Random Forest", "Gradient Boosting", "Support Vector Machine", "Logistic Regression", "Decision Tree"]:
|
511 |
+
grid_search = GridSearchCV(model, param_grid, cv=3, scoring='neg_mean_squared_error' if task_type == "Regression" else 'accuracy')
|
512 |
+
grid_search.fit(X_train, y_train)
|
513 |
+
model = grid_search.best_estimator_ #use best model
|
514 |
+
st.write("Best Parameters:", grid_search.best_params_)
|
515 |
+
|
516 |
+
else:
|
517 |
+
model.fit(X_train, y_train)
|
518 |
+
|
519 |
+
st.session_state.model = model
|
520 |
+
st.session_state.train_test = {
|
521 |
+
'X_test': X_test,
|
522 |
+
'y_test': y_test,
|
523 |
+
'task': task_type #Store task for eval
|
524 |
+
}
|
525 |
+
|
526 |
+
# Evaluation Metrics
|
527 |
+
y_pred = model.predict(X_test)
|
528 |
+
|
529 |
+
if task_type == "Regression":
|
530 |
+
r2 = r2_score(y_test, y_pred)
|
531 |
+
mse = mean_squared_error(y_test, y_pred)
|
532 |
+
mae = mean_absolute_error(y_test, y_pred) #ADDED
|
533 |
+
st.metric("R² Score", round(r2, 2))
|
534 |
+
st.metric("MSE", round(mse, 2))
|
535 |
+
st.metric("MAE", round(mae, 2)) #ADDED
|
536 |
+
else: #Classification
|
537 |
+
accuracy = accuracy_score(y_test, y_pred)
|
538 |
+
precision = precision_score(y_test, y_pred, average='weighted', zero_division=0)
|
539 |
+
recall = recall_score(y_test, y_pred, average='weighted', zero_division=0)
|
540 |
+
f1 = f1_score(y_test, y_pred, average='weighted', zero_division=0)
|
541 |
+
try:
|
542 |
+
roc_auc = roc_auc_score(y_test, model.predict_proba(X_test)[:, 1]) #requires probabilities
|
543 |
+
st.metric("ROC AUC", round(roc_auc, 2))
|
544 |
+
except:
|
545 |
+
st.warning("ROC AUC score not available for this classifier.")
|
546 |
+
st.metric("Accuracy", round(accuracy, 2))
|
547 |
+
st.metric("Precision", round(precision, 2))
|
548 |
+
st.metric("Recall", round(recall, 2))
|
549 |
+
st.metric("F1 Score", round(f1, 2))
|
550 |
+
|
551 |
+
#Cross Validation
|
552 |
+
scores = cross_val_score(model, X_train, y_train, cv=5, scoring='neg_mean_squared_error' if task_type == "Regression" else 'accuracy') # use appropriate scoring
|
553 |
+
st.write("Cross-Validation Scores:", scores)
|
554 |
+
st.write("Mean Cross-Validation Score:", scores.mean())
|
555 |
+
|
556 |
+
|
557 |
+
|
558 |
+
#Model persistence
|
559 |
+
if st.checkbox("Save Model"):
|
560 |
+
model_filename = st.text_input("Model Filename", "trained_model.joblib")
|
561 |
+
joblib.dump((model, preprocessor), model_filename) # save both model AND preprocessor
|
562 |
+
st.success(f"Model saved as {model_filename}")
|
563 |
+
except Exception as e:
|
564 |
+
st.error(f"Error during training: {e}")
|
565 |
|
566 |
# Predictions Section
|
567 |
elif app_mode == "Predictions":
|
568 |
st.title("🔮 Make Predictions")
|
569 |
+
|
570 |
+
if st.session_state.model is not None and st.session_state.preprocessor is not None:
|
571 |
+
model, preprocessor = st.session_state.model, st.session_state.preprocessor
|
572 |
+
X_test_cols = st.session_state.train_test['X_test'].shape[1] #get the number of input cols
|
573 |
+
|
574 |
# Prediction Interface
|
575 |
input_data = {}
|
576 |
+
X_test_columns = [f"feature_{i}" for i in range(X_test_cols)] # Generate placeholder column names
|
577 |
+
input_data = {}
|
578 |
+
for i in range(X_test_cols):
|
579 |
+
input_data[f"feature_{i}"] = st.number_input(f"Feature {i+1}", value=0.0)
|
580 |
+
#for col in st.session_state.train_test['X_test'].columns: # causes error since its preprocessed
|
581 |
+
|
582 |
if st.button("Predict"):
|
583 |
+
try:
|
584 |
+
|
585 |
+
input_df = pd.DataFrame([input_data])
|
586 |
+
# Preprocess input
|
587 |
+
input_processed = preprocessor.transform(input_df)
|
588 |
+
prediction = model.predict(input_processed)
|
589 |
+
|
590 |
+
if st.session_state.train_test['task'] == "Regression":
|
591 |
+
st.success(f"Predicted Value: {prediction[0]:.2f}")
|
592 |
+
else:
|
593 |
+
st.success(f"Predicted Class: {prediction[0]}")
|
594 |
+
# Show probabilities if it's a classifier
|
595 |
+
if hasattr(model, "predict_proba"):
|
596 |
+
proba = model.predict_proba(input_processed)[0]
|
597 |
+
for i, p in enumerate(proba):
|
598 |
+
st.write(f"Probability of class {i}: {p:.2f}")
|
599 |
+
|
600 |
+
except Exception as e:
|
601 |
+
st.error(f"Error during prediction: {e}")
|
602 |
+
else:
|
603 |
+
st.warning("Please train a model first.")
|
604 |
|
605 |
elif app_mode == "Visualization Lab":
|
606 |
st.title("📊 Advanced Visualization Lab")
|
607 |
+
|
608 |
if st.session_state.cleaned_data is not None:
|
609 |
df = st.session_state.cleaned_data
|
610 |
+
|
611 |
# Visualization Gallery
|
612 |
viz_type = st.selectbox("Choose Visualization Type", [
|
613 |
"3D Scatter Plot",
|
614 |
"Interactive Heatmap",
|
615 |
"Time Series Analysis",
|
616 |
+
"Cluster Analysis (Coming Soon)" #Removed placeholder, keep in mind
|
617 |
])
|
618 |
+
|
619 |
# Dynamic Controls
|
620 |
cols = st.columns(3)
|
621 |
with cols[0]:
|
|
|
624 |
y_axis = st.selectbox("Y Axis", df.columns)
|
625 |
with cols[2]:
|
626 |
z_axis = st.selectbox("Z Axis", df.columns) if viz_type == "3D Scatter Plot" else None
|
627 |
+
|
628 |
# Generate Visualization
|
629 |
+
try: #Add try-except
|
630 |
+
if viz_type == "3D Scatter Plot":
|
631 |
+
fig = px.scatter_3d(df, x=x_axis, y=y_axis, z=z_axis, color=x_axis)
|
632 |
+
st.plotly_chart(fig, use_container_width=True)
|
633 |
+
|
634 |
+
elif viz_type == "Interactive Heatmap":
|
635 |
+
corr = df.corr(numeric_only=True) #Add numeric_only=True
|
636 |
+
fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu')
|
637 |
+
st.plotly_chart(fig, use_container_width=True)
|
638 |
+
|
639 |
+
elif viz_type == "Time Series Analysis":
|
640 |
+
# Basic time series plot
|
641 |
+
time_col = st.selectbox("Time Column", df.columns)
|
642 |
+
value_col = st.selectbox("Value Column", df.columns)
|
643 |
+
fig = px.line(df, x=time_col, y=value_col)
|
644 |
+
st.plotly_chart(fig, use_container_width=True)
|
645 |
+
|
646 |
+
elif viz_type == "Cluster Analysis (Coming Soon)": #Removed placeholder
|
647 |
+
st.write("Cluster Analysis Feature Coming Soon!") # placeholder for future development
|
648 |
+
except Exception as e:
|
649 |
+
st.error(f"Error generating visualization: {e}")
|