Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,59 +1,203 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
-
from pycaret.classification import setup as classification_setup, compare_models as compare_classification_models, evaluate_model as evaluate_classification_model, save_model as save_classification_model, plot_model as plot_classification_model
|
5 |
-
from pycaret.regression import setup as regression_setup, compare_models as compare_regression_models, evaluate_model as evaluate_regression_model, save_model as save_regression_model, plot_model as plot_regression_model
|
6 |
-
from pycaret.clustering import setup as clustering_setup, evaluate_model as evaluate_clustering_model, save_model as save_clustering_model, plot_model as plot_clustering_model
|
7 |
-
from ydata_profiling import ProfileReport
|
8 |
-
from streamlit_pandas_profiling import st_profile_report
|
9 |
-
from sklearn.preprocessing import LabelEncoder, StandardScaler, MinMaxScaler
|
10 |
-
from sklearn.decomposition import PCA
|
11 |
-
from scipy import stats
|
12 |
import plotly.express as px
|
13 |
import plotly.graph_objects as go
|
|
|
|
|
14 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
st.
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
# Sidebar Navigation
|
20 |
with st.sidebar:
|
21 |
-
st.title("🔮
|
22 |
-
st.markdown("Your AI-powered
|
23 |
-
st.markdown("---")
|
24 |
-
app_mode = st.selectbox("Navigation", ["Data Upload", "Data Cleaning", "EDA", "Model Training", "Validation & Exploration"])
|
25 |
-
data_type = st.selectbox("Data Type", ["Tabular"])
|
26 |
st.markdown("---")
|
27 |
-
st.
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
if
|
33 |
-
st.
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
st.markdown(f"<h2 style='text-align: center; color: #1e3a8a;'>{title}</h2>", unsafe_allow_html=True)
|
39 |
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
if app_mode == "Data Upload":
|
42 |
-
st.title("📤 Data Upload")
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
if uploaded_file:
|
45 |
-
|
46 |
-
st.session_state
|
47 |
-
st.session_state
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
elif app_mode == "Data Cleaning":
|
59 |
st.title("🧹 Smart Data Cleaning")
|
@@ -79,165 +223,9 @@ elif app_mode == "Data Cleaning":
|
|
79 |
if st.button("Undo Last Action"):
|
80 |
st.session_state.data_versions.pop()
|
81 |
st.session_state.cleaned_data = st.session_state.data_versions[-1].copy()
|
|
|
82 |
st.rerun()
|
83 |
|
84 |
-
with st.expander("🛠️ Data Cleaning Operations", expanded=True):
|
85 |
-
enhance_section_title("🔍 Missing Values Treatment")
|
86 |
-
missing_cols = df.columns[df.isna().any()].tolist()
|
87 |
-
if missing_cols:
|
88 |
-
cols = st.multiselect("Select columns with missing values", missing_cols)
|
89 |
-
method = st.selectbox("Choose imputation method", [
|
90 |
-
"Drop Missing Values", "Fill with Mean/Median", "Fill with Custom Value", "Forward Fill", "Backward Fill"
|
91 |
-
])
|
92 |
-
if method == "Fill with Custom Value":
|
93 |
-
custom_val = st.text_input("Enter custom value:")
|
94 |
-
if st.button("Apply Missing Value Treatment"):
|
95 |
-
new_df = df.copy()
|
96 |
-
if method == "Drop Missing Values":
|
97 |
-
new_df = new_df.dropna(subset=cols)
|
98 |
-
elif method == "Fill with Mean/Median":
|
99 |
-
for col in cols:
|
100 |
-
if pd.api.types.is_numeric_dtype(new_df[col]):
|
101 |
-
new_df[col] = new_df[col].fillna(new_df[col].median())
|
102 |
-
else:
|
103 |
-
new_df[col] = new_df[col].fillna(new_df[col].mode()[0])
|
104 |
-
elif method == "Fill with Custom Value" and custom_val:
|
105 |
-
new_df[cols] = new_df[cols].fillna(custom_val)
|
106 |
-
elif method == "Forward Fill":
|
107 |
-
new_df[cols] = new_df[cols].ffill()
|
108 |
-
elif method == "Backward Fill":
|
109 |
-
new_df[cols] = new_df[cols].bfill()
|
110 |
-
update_cleaned_data(new_df)
|
111 |
-
else:
|
112 |
-
st.success("✨ No missing values detected!")
|
113 |
-
|
114 |
-
enhance_section_title("🔄 Data Type Conversion")
|
115 |
-
col_to_convert = st.selectbox("Select column to convert", df.columns)
|
116 |
-
new_type = st.selectbox("Select new data type", ["String", "Integer", "Float", "Boolean", "Datetime"])
|
117 |
-
if new_type == "Datetime":
|
118 |
-
date_format = st.text_input("Enter date format (e.g., %Y-%m-%d):", "%Y-%m-%d")
|
119 |
-
if st.button("Convert Data Type"):
|
120 |
-
new_df = df.copy()
|
121 |
-
if new_type == "String":
|
122 |
-
new_df[col_to_convert] = new_df[col_to_convert].astype(str)
|
123 |
-
elif new_type == "Integer":
|
124 |
-
new_df[col_to_convert] = pd.to_numeric(new_df[col_to_convert], errors='coerce').astype('Int64')
|
125 |
-
elif new_type == "Float":
|
126 |
-
new_df[col_to_convert] = pd.to_numeric(new_df[col_to_convert], errors='coerce')
|
127 |
-
elif new_type == "Boolean":
|
128 |
-
new_df[col_to_convert] = new_df[col_to_convert].astype(bool)
|
129 |
-
elif new_type == "Datetime":
|
130 |
-
new_df[col_to_convert] = pd.to_datetime(new_df[col_to_convert], format=date_format, errors='coerce')
|
131 |
-
update_cleaned_data(new_df)
|
132 |
-
|
133 |
-
enhance_section_title("🗑️ Drop Columns")
|
134 |
-
columns_to_drop = st.multiselect("Select columns to remove", df.columns)
|
135 |
-
if columns_to_drop and st.button("Confirm Column Removal"):
|
136 |
-
new_df = df.copy()
|
137 |
-
new_df = new_df.drop(columns=columns_to_drop)
|
138 |
-
update_cleaned_data(new_df)
|
139 |
-
|
140 |
-
enhance_section_title("🔢 Encoding Options")
|
141 |
-
encoding_method = st.radio("Choose encoding method", ("Label Encoding", "One-Hot Encoding"))
|
142 |
-
data_to_encode = st.multiselect("Select columns to encode", df.select_dtypes(include='object').columns)
|
143 |
-
if data_to_encode and st.button("Apply Encoding"):
|
144 |
-
new_df = df.copy()
|
145 |
-
if encoding_method == "Label Encoding":
|
146 |
-
for col in data_to_encode:
|
147 |
-
le = LabelEncoder()
|
148 |
-
new_df[col] = le.fit_transform(new_df[col].astype(str))
|
149 |
-
elif encoding_method == "One-Hot Encoding":
|
150 |
-
new_df = pd.get_dummies(new_df, columns=data_to_encode, drop_first=True, dtype=int)
|
151 |
-
update_cleaned_data(new_df)
|
152 |
-
|
153 |
-
enhance_section_title("📏 StandardScaler")
|
154 |
-
scale_cols = st.multiselect("Select numerical columns to scale", df.select_dtypes(include=np.number).columns)
|
155 |
-
if scale_cols and st.button("Apply StandardScaler"):
|
156 |
-
new_df = df.copy()
|
157 |
-
scaler = StandardScaler()
|
158 |
-
new_df[scale_cols] = scaler.fit_transform(new_df[scale_cols])
|
159 |
-
update_cleaned_data(new_df)
|
160 |
-
|
161 |
-
enhance_section_title("🕵️ Pattern-Based Cleaning")
|
162 |
-
selected_col = st.selectbox("Select text column for pattern cleaning", df.select_dtypes(include='object').columns)
|
163 |
-
pattern = st.text_input("Enter regex pattern:")
|
164 |
-
replacement = st.text_input("Enter replacement value:")
|
165 |
-
if st.button("Apply Pattern Replacement"):
|
166 |
-
new_df = df.copy()
|
167 |
-
new_df[selected_col] = new_df[selected_col].str.replace(pattern, replacement, regex=True)
|
168 |
-
update_cleaned_data(new_df)
|
169 |
-
|
170 |
-
enhance_section_title("🚀 Bulk Actions")
|
171 |
-
bulk_action = st.selectbox("Choose bulk action", [
|
172 |
-
"Auto-Clean Common Issues", "Drop All Missing Values", "Fill All Missing Values",
|
173 |
-
"One-Hot Encode All Categorical Columns", "Apply Min-Max Scaling to All Numeric Columns",
|
174 |
-
"Remove Outliers from All Numeric Columns", "Principal Component Analysis (PCA)"
|
175 |
-
])
|
176 |
-
if bulk_action == "Auto-Clean Common Issues" and st.button("Run Auto-Clean"):
|
177 |
-
new_df = df.copy()
|
178 |
-
new_df = new_df.dropna(axis=1, how='all')
|
179 |
-
new_df = new_df.convert_dtypes()
|
180 |
-
text_cols = new_df.select_dtypes(include='object').columns
|
181 |
-
new_df[text_cols] = new_df[text_cols].apply(lambda x: x.str.strip())
|
182 |
-
update_cleaned_data(new_df)
|
183 |
-
elif bulk_action == "Drop All Missing Values" and st.button("Drop All Missing"):
|
184 |
-
new_df = df.copy()
|
185 |
-
new_df = new_df.dropna()
|
186 |
-
update_cleaned_data(new_df)
|
187 |
-
elif bulk_action == "Fill All Missing Values":
|
188 |
-
fill_value = st.text_input("Enter fill value:", "0")
|
189 |
-
if st.button("Fill Missing Values"):
|
190 |
-
new_df = df.copy()
|
191 |
-
new_df = new_df.fillna(fill_value)
|
192 |
-
update_cleaned_data(new_df)
|
193 |
-
elif bulk_action == "One-Hot Encode All Categorical Columns" and st.button("One-Hot Encode All"):
|
194 |
-
new_df = df.copy()
|
195 |
-
categorical_cols = new_df.select_dtypes(include='object').columns
|
196 |
-
new_df = pd.get_dummies(new_df, columns=categorical_cols, drop_first=True, dtype=int)
|
197 |
-
update_cleaned_data(new_df)
|
198 |
-
elif bulk_action == "Apply Min-Max Scaling to All Numeric Columns" and st.button("Apply Min-Max Scaling"):
|
199 |
-
new_df = df.copy()
|
200 |
-
scaler = MinMaxScaler()
|
201 |
-
numerical_cols = new_df.select_dtypes(include=np.number).columns
|
202 |
-
new_df[numerical_cols] = scaler.fit_transform(new_df[numerical_cols])
|
203 |
-
update_cleaned_data(new_df)
|
204 |
-
elif bulk_action == "Remove Outliers from All Numeric Columns" and st.button("Remove All Outliers"):
|
205 |
-
new_df = df.copy()
|
206 |
-
z_scores = np.abs(stats.zscore(new_df.select_dtypes(include=np.number)))
|
207 |
-
new_df = new_df[(z_scores < 3).all(axis=1)]
|
208 |
-
update_cleaned_data(new_df)
|
209 |
-
elif bulk_action == "Principal Component Analysis (PCA)":
|
210 |
-
n_components_bulk = st.slider("Number of components", 1, min(df.shape[1], 10), 2)
|
211 |
-
if st.button("Apply PCA (Bulk)"):
|
212 |
-
new_df = df.copy()
|
213 |
-
pca = PCA(n_components=n_components_bulk)
|
214 |
-
numerical_cols = new_df.select_dtypes(include=np.number).columns
|
215 |
-
pca_result = pca.fit_transform(new_df[numerical_cols])
|
216 |
-
new_df = pd.DataFrame(pca_result, columns=[f'PC{i+1}' for i in range(pca_result.shape[1])])
|
217 |
-
update_cleaned_data(new_df.reset_index(drop=True))
|
218 |
-
|
219 |
-
enhance_section_title("📊 Principal Component Analysis (PCA)")
|
220 |
-
numerical_cols = df.select_dtypes(include=np.number).columns.tolist()
|
221 |
-
if numerical_cols:
|
222 |
-
pca_cols = st.multiselect("Select columns for PCA", numerical_cols, default=numerical_cols)
|
223 |
-
if pca_cols:
|
224 |
-
st.subheader("Covariance Matrix Heatmap")
|
225 |
-
cov_matrix = df[pca_cols].cov()
|
226 |
-
fig_cov = px.imshow(cov_matrix, labels=dict(x="Features", y="Features", color="Covariance"), color_continuous_scale='RdBu_r')
|
227 |
-
st.plotly_chart(fig_cov)
|
228 |
-
n_components = st.slider("Number of components", 1, min(len(pca_cols), 10), 2)
|
229 |
-
if st.button("Apply PCA"):
|
230 |
-
new_df = df.copy()
|
231 |
-
scaler = StandardScaler()
|
232 |
-
scaled_data = scaler.fit_transform(new_df[pca_cols])
|
233 |
-
pca = PCA(n_components=n_components)
|
234 |
-
pca_result = pca.fit_transform(scaled_data)
|
235 |
-
pca_df = pd.DataFrame(pca_result, columns=[f'PC{i+1}' for i in range(n_components)])
|
236 |
-
update_cleaned_data(pca_df.reset_index(drop=True))
|
237 |
-
st.write("Explained Variance Ratio:", pca.explained_variance_ratio_)
|
238 |
-
else:
|
239 |
-
st.warning("No numerical columns available for PCA.")
|
240 |
-
|
241 |
elif app_mode == "EDA":
|
242 |
st.title("🔍 Interactive Data Explorer")
|
243 |
if 'cleaned_data' not in st.session_state:
|
@@ -245,10 +233,7 @@ elif app_mode == "EDA":
|
|
245 |
st.stop()
|
246 |
df = st.session_state.cleaned_data.copy()
|
247 |
|
248 |
-
# Enhanced Section Title
|
249 |
enhance_section_title("Dataset Overview")
|
250 |
-
|
251 |
-
# Dataset Overview with More Visual Appeal
|
252 |
with st.container():
|
253 |
col1, col2, col3, col4 = st.columns(4)
|
254 |
col1.metric("Total Rows", df.shape[0])
|
@@ -257,197 +242,26 @@ elif app_mode == "EDA":
|
|
257 |
col3.metric("Missing Values", f"{df.isna().sum().sum()} ({missing_percentage:.1f}%)")
|
258 |
col4.metric("Duplicates", df.duplicated().sum())
|
259 |
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
with st.
|
282 |
-
|
283 |
-
with col1:
|
284 |
-
plot_type = st.selectbox("Choose visualization type", [
|
285 |
-
"Scatter Plot", "Histogram", "Box Plot", "Violin Plot", "Line Chart", "Bar Chart",
|
286 |
-
"Correlation Matrix", "Pair Plot", "Heatmap", "3D Scatter", "Parallel Categories",
|
287 |
-
"Segmented Bar Chart", "Swarm Plot", "Ridge Plot", "Bubble Plot", "Density Plot",
|
288 |
-
"Count Plot", "Lollipop Chart"
|
289 |
-
])
|
290 |
-
x_axis = st.selectbox("X-axis", df.columns) if plot_type not in ["Correlation Matrix", "Pair Plot"] else None
|
291 |
-
y_axis = st.selectbox("Y-axis", df.columns) if plot_type in ["Scatter Plot", "Box Plot", "Violin Plot", "Line Chart", "Heatmap", "Swarm Plot", "Ridge Plot", "Bubble Plot", "Density Plot", "Lollipop Chart"] else None
|
292 |
-
z_axis = st.selectbox("Z-axis", df.columns) if plot_type == "3D Scatter" else None
|
293 |
-
color_by = st.selectbox("Color encoding", ["None"] + df.columns.tolist(), format_func=lambda x: "No color" if x == "None" else x) if plot_type not in ["Correlation Matrix", "Pair Plot"] else None
|
294 |
-
|
295 |
-
if plot_type == "Parallel Categories":
|
296 |
-
dimensions = st.multiselect("Dimensions", df.columns.tolist(), default=df.columns[:3].tolist())
|
297 |
-
elif plot_type == "Segmented Bar Chart":
|
298 |
-
segment_col = st.selectbox("Segment Column (Categorical)", df.select_dtypes(exclude=np.number).columns)
|
299 |
-
elif plot_type == "Bubble Plot":
|
300 |
-
size_col = st.selectbox("Size Column", df.columns)
|
301 |
-
elif plot_type == "Pair Plot":
|
302 |
-
pair_cols = st.multiselect("Select columns for Pair Plot", df.columns, default=df.columns[:5].tolist())
|
303 |
-
|
304 |
-
with col2:
|
305 |
-
try:
|
306 |
-
fig = None
|
307 |
-
if plot_type == "Scatter Plot" and x_axis and y_axis:
|
308 |
-
fig = px.scatter(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, trendline="lowess", title=f'Scatter Plot of {x_axis} vs {y_axis}')
|
309 |
-
elif plot_type == "Histogram" and x_axis:
|
310 |
-
fig = px.histogram(df, x=x_axis, color=color_by if color_by != "None" else None, nbins=30, marginal="box", title=f'Histogram of {x_axis}')
|
311 |
-
elif plot_type == "Box Plot" and x_axis and y_axis:
|
312 |
-
fig = px.box(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Box Plot of {x_axis} vs {y_axis}')
|
313 |
-
elif plot_type == "Violin Plot" and x_axis and y_axis:
|
314 |
-
fig = px.violin(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, box=True, title=f'Violin Plot of {x_axis} vs {y_axis}')
|
315 |
-
elif plot_type == "Line Chart" and x_axis and y_axis:
|
316 |
-
fig = px.line(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Line Chart of {x_axis} vs {y_axis}')
|
317 |
-
elif plot_type == "Bar Chart" and x_axis:
|
318 |
-
fig = px.bar(df, x=x_axis, color=color_by if color_by != "None" else None, title=f'Bar Chart of {x_axis}')
|
319 |
-
elif plot_type == "Correlation Matrix":
|
320 |
-
numeric_df = df.select_dtypes(include=np.number)
|
321 |
-
if len(numeric_df.columns) > 1:
|
322 |
-
corr = numeric_df.corr()
|
323 |
-
fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu_r', zmin=-1, zmax=1, title='Correlation Matrix')
|
324 |
-
elif plot_type == "Pair Plot":
|
325 |
-
if pair_cols:
|
326 |
-
fig = px.scatter_matrix(df[pair_cols], color=color_by if color_by != "None" else None, title='Pair Plot')
|
327 |
-
elif plot_type == "Heatmap" and x_axis and y_axis:
|
328 |
-
fig = px.density_heatmap(df, x=x_axis, y=y_axis, facet_col=color_by if color_by != "None" else None, title=f'Heatmap of {x_axis} vs {y_axis}')
|
329 |
-
elif plot_type == "3D Scatter" and x_axis and y_axis and z_axis:
|
330 |
-
fig = px.scatter_3d(df, x=x_axis, y=y_axis, z=z_axis, color=color_by if color_by != "None" else None, title=f'3D Scatter Plot of {x_axis} vs {y_axis} vs {z_axis}')
|
331 |
-
elif plot_type == "Parallel Categories" and dimensions:
|
332 |
-
fig = px.parallel_categories(df, dimensions=dimensions, color=color_by if color_by != "None" else None, title='Parallel Categories Plot')
|
333 |
-
elif plot_type == "Segmented Bar Chart" and x_axis and segment_col:
|
334 |
-
segment_counts = df.groupby([x_axis, segment_col]).size().reset_index(name='counts')
|
335 |
-
fig = px.bar(segment_counts, x=x_axis, y='counts', color=segment_col, title=f'Segmented Bar Chart of {x_axis} by {segment_col}')
|
336 |
-
fig.update_layout(yaxis_title="Count")
|
337 |
-
elif plot_type == "Swarm Plot" and x_axis and y_axis:
|
338 |
-
fig = px.strip(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Swarm Plot of {x_axis} vs {y_axis}')
|
339 |
-
elif plot_type == "Ridge Plot" and x_axis and y_axis:
|
340 |
-
fig = px.histogram(df, x=x_axis, color=y_axis, marginal="rug", title=f'Ridge Plot of {x_axis} by {y_axis}')
|
341 |
-
elif plot_type == "Bubble Plot" and x_axis and y_axis and size_col:
|
342 |
-
fig = px.scatter(df, x=x_axis, y=y_axis, size=size_col, color=color_by if color_by != "None" else None, title=f'Bubble Plot of {x_axis} vs {y_axis}')
|
343 |
-
elif plot_type == "Density Plot" and x_axis and y_axis:
|
344 |
-
fig = px.density_heatmap(df, x=x_axis, y=y_axis, color_continuous_scale="Viridis", title=f'Density Plot of {x_axis} vs {y_axis}')
|
345 |
-
elif plot_type == "Count Plot" and x_axis:
|
346 |
-
fig = px.bar(df, x=x_axis, color=color_by if color_by != "None" else None, title=f'Count Plot of {x_axis}')
|
347 |
-
fig.update_layout(yaxis_title="Count")
|
348 |
-
elif plot_type == "Lollipop Chart" and x_axis and y_axis:
|
349 |
-
fig = go.Figure()
|
350 |
-
fig.add_trace(go.Scatter(x=df[x_axis], y=df[y_axis], mode='markers', marker=dict(size=10)))
|
351 |
-
for i in range(len(df)):
|
352 |
-
fig.add_trace(go.Scatter(x=[df[x_axis].iloc[i], df[x_axis].iloc[i]], y=[0, df[y_axis].iloc[i]], mode='lines', line=dict(color='gray')))
|
353 |
-
fig.update_layout(showlegend=False, title=f'Lollipop Chart of {x_axis} vs {y_axis}')
|
354 |
-
|
355 |
-
if fig:
|
356 |
-
fig.update_layout(template="plotly_white")
|
357 |
-
st.plotly_chart(fig, use_container_width=True)
|
358 |
-
else:
|
359 |
-
st.error("Please provide required inputs for the selected plot type.")
|
360 |
-
except Exception as e:
|
361 |
-
st.error(f"Couldn't create visualization: {str(e)}")
|
362 |
-
|
363 |
-
elif app_mode == "Model Training":
|
364 |
-
st.title("🧠 Model Training")
|
365 |
-
if 'cleaned_data' not in st.session_state:
|
366 |
-
st.warning("Please upload and clean data first.")
|
367 |
-
st.stop()
|
368 |
-
df = st.session_state.cleaned_data.copy()
|
369 |
-
problem_type = st.selectbox("Problem Type", ["Classification", "Regression", "Clustering"])
|
370 |
-
target = st.selectbox("Select Target Column", df.columns) if problem_type != "Clustering" else None
|
371 |
-
|
372 |
-
if st.button("Setup PyCaret"):
|
373 |
-
with st.spinner("Setting up PyCaret..."):
|
374 |
-
if problem_type == "Classification":
|
375 |
-
classification_setup(data=df, target=target, session_id=123, verbose=False)
|
376 |
-
st.session_state['problem_type'] = "Classification"
|
377 |
-
st.session_state['setup_complete'] = True
|
378 |
-
elif problem_type == "Regression":
|
379 |
-
regression_setup(data=df, target=target, session_id=123, verbose=False)
|
380 |
-
st.session_state['problem_type'] = "Regression"
|
381 |
-
st.session_state['setup_complete'] = True
|
382 |
-
elif problem_type == "Clustering":
|
383 |
-
clustering_setup(data=df, session_id=123, verbose=False)
|
384 |
-
st.session_state['problem_type'] = "Clustering"
|
385 |
-
st.session_state['setup_complete'] = True
|
386 |
-
st.success("PyCaret setup complete! You can now train models.")
|
387 |
-
|
388 |
-
if st.session_state.get('setup_complete', False):
|
389 |
-
st.subheader("Train Models")
|
390 |
-
if st.button("Compare Models"):
|
391 |
-
with st.spinner("Comparing models..."):
|
392 |
-
if st.session_state['problem_type'] == "Classification":
|
393 |
-
best_model = compare_classification_models()
|
394 |
-
elif st.session_state['problem_type'] == "Regression":
|
395 |
-
best_model = compare_regression_models()
|
396 |
-
elif st.session_state['problem_type'] == "Clustering":
|
397 |
-
st.info("Model comparison is not available for clustering. Please proceed with evaluation or create a model manually.")
|
398 |
-
best_model = None
|
399 |
-
else:
|
400 |
-
best_model = None
|
401 |
-
if best_model is not None:
|
402 |
-
st.session_state['best_model'] = best_model
|
403 |
-
st.success(f"Best Model: {best_model}")
|
404 |
-
|
405 |
-
if 'best_model' in st.session_state and st.session_state['best_model'] is not None:
|
406 |
-
st.subheader("Model Evaluation")
|
407 |
-
if st.button("Evaluate Model"):
|
408 |
-
with st.spinner("Evaluating model..."):
|
409 |
-
if st.session_state['problem_type'] == "Classification":
|
410 |
-
evaluate_classification_model(st.session_state['best_model'])
|
411 |
-
elif st.session_state['problem_type'] == "Regression":
|
412 |
-
evaluate_regression_model(st.session_state['best_model'])
|
413 |
-
elif st.session_state['problem_type'] == "Clustering":
|
414 |
-
evaluate_clustering_model(st.session_state['best_model'])
|
415 |
-
st.success("Model evaluation complete!")
|
416 |
-
|
417 |
-
if st.button("Save Model"):
|
418 |
-
if st.session_state['problem_type'] == "Classification":
|
419 |
-
save_classification_model(st.session_state['best_model'], "best_model")
|
420 |
-
elif st.session_state['problem_type'] == "Regression":
|
421 |
-
save_regression_model(st.session_state['best_model'], "best_model")
|
422 |
-
elif st.session_state['problem_type'] == "Clustering":
|
423 |
-
save_clustering_model(st.session_state['best_model'], "best_model")
|
424 |
-
st.success("Model saved as `best_model.pkl`!")
|
425 |
-
with open("best_model.pkl", "rb") as f:
|
426 |
-
st.download_button("Download Model", f, file_name="best_model.pkl")
|
427 |
-
|
428 |
-
elif app_mode == "Validation & Exploration":
|
429 |
-
st.title("🔍 Validation & Exploration")
|
430 |
-
if 'best_model' not in st.session_state or st.session_state['best_model'] is None:
|
431 |
-
st.warning("Please train a model first. Note: Clustering does not support automatic model comparison.")
|
432 |
-
st.stop()
|
433 |
-
|
434 |
-
st.subheader("Model Performance")
|
435 |
-
if st.session_state['problem_type'] == "Classification":
|
436 |
-
st.write("Classification Report:")
|
437 |
-
plot_classification_model(st.session_state['best_model'], plot="confusion_matrix", display_format="streamlit")
|
438 |
-
plot_classification_model(st.session_state['best_model'], plot="auc", display_format="streamlit")
|
439 |
-
elif st.session_state['problem_type'] == "Regression":
|
440 |
-
st.write("Regression Metrics:")
|
441 |
-
plot_regression_model(st.session_state['best_model'], plot="residuals", display_format="streamlit")
|
442 |
-
plot_regression_model(st.session_state['best_model'], plot="error", display_format="streamlit")
|
443 |
-
elif st.session_state['problem_type'] == "Clustering":
|
444 |
-
st.write("Clustering Results:")
|
445 |
-
plot_clustering_model(st.session_state['best_model'], plot="cluster", display_format="streamlit")
|
446 |
-
|
447 |
-
# Custom CSS
|
448 |
-
st.markdown("""
|
449 |
-
<style>
|
450 |
-
.stButton>button {background-color: #4CAF50; color: white;}
|
451 |
-
h1, h2 {color: #1e3a8a;}
|
452 |
-
</style>
|
453 |
-
""", unsafe_allow_html=True)
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import plotly.express as px
|
5 |
import plotly.graph_objects as go
|
6 |
+
from ydata_profiling import ProfileReport
|
7 |
+
from streamlit_pandas_profiling import st_profile_report
|
8 |
import os
|
9 |
+
import requests
|
10 |
+
import json
|
11 |
+
from datetime import datetime
|
12 |
+
import re
|
13 |
+
import tempfile
|
14 |
+
from scipy import stats
|
15 |
+
from sklearn.impute import SimpleImputer
|
16 |
+
from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
|
17 |
+
from sklearn.decomposition import PCA
|
18 |
+
import streamlit.components.v1 as components
|
19 |
+
from io import StringIO
|
20 |
+
from dotenv import load_dotenv
|
21 |
+
from flask import Flask, request, jsonify
|
22 |
+
from openai import OpenAI
|
23 |
+
import threading
|
24 |
+
from sentence_transformers import SentenceTransformer
|
25 |
+
|
26 |
+
# Load environment variables
|
27 |
+
load_dotenv()
|
28 |
+
|
29 |
+
# Initialize Flask app
|
30 |
+
flask_app = Flask(__name__)
|
31 |
+
FLASK_PORT = 5000 # Internal port for Flask, not exposed externally
|
32 |
+
|
33 |
+
# Initialize OpenAI client
|
34 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
35 |
+
if not api_key:
|
36 |
+
st.error("OPENAI_API_KEY not set. Please configure it in the Hugging Face Space secrets.")
|
37 |
+
st.stop()
|
38 |
+
client = OpenAI(api_key=api_key)
|
39 |
+
|
40 |
+
# Flask RAG Endpoint
|
41 |
+
@flask_app.route('/rag_chat', methods=['POST'])
|
42 |
+
def rag_chat():
|
43 |
+
data = request.get_json()
|
44 |
+
user_input = data.get('user_input', '')
|
45 |
+
app_mode = data.get('app_mode', 'Data Upload')
|
46 |
+
dataset_text = data.get('dataset_text', '')
|
47 |
+
|
48 |
+
# RAG Logic: Use dataset_text as retrieval context
|
49 |
+
system_prompt = (
|
50 |
+
"You are an AI assistant in Data-Vision Pro, a data analysis app with RAG capabilities. "
|
51 |
+
"The app has three pages:\n"
|
52 |
+
"- **Data Upload**: Upload CSV/XLSX files, view stats, or generate reports.\n"
|
53 |
+
"- **Data Cleaning**: Clean data (e.g., handle missing values, encode variables).\n"
|
54 |
+
"- **EDA**: Visualize data (e.g., scatter plots, histograms).\n"
|
55 |
+
f"The user is on the '{app_mode}' page.\n"
|
56 |
+
)
|
57 |
+
|
58 |
+
if dataset_text:
|
59 |
+
system_prompt += (
|
60 |
+
"Using the following dataset context, augment your response:\n"
|
61 |
+
f"{dataset_text}\n"
|
62 |
+
"Answer based on this data where relevant, otherwise provide general assistance."
|
63 |
+
)
|
64 |
+
else:
|
65 |
+
system_prompt += "No dataset is loaded. Assist based on app functionality."
|
66 |
+
|
67 |
+
try:
|
68 |
+
response = client.chat.completions.create(
|
69 |
+
model="gpt-3.5-turbo",
|
70 |
+
messages=[
|
71 |
+
{"role": "system", "content": system_prompt},
|
72 |
+
{"role": "user", "content": user_input}
|
73 |
+
],
|
74 |
+
max_tokens=100, # Increased for RAG context
|
75 |
+
temperature=0.7
|
76 |
+
)
|
77 |
+
return jsonify({"response": response.choices[0].message.content})
|
78 |
+
except Exception as e:
|
79 |
+
return jsonify({"error": str(e)}), 500
|
80 |
+
|
81 |
+
# Run Flask in a background thread
|
82 |
+
def run_flask():
|
83 |
+
flask_app.run(host='0.0.0.0', port=FLASK_PORT, debug=False, use_reloader=False)
|
84 |
+
|
85 |
+
# Start Flask thread
|
86 |
+
flask_thread = threading.Thread(target=run_flask, daemon=True)
|
87 |
+
flask_thread.start()
|
88 |
+
|
89 |
+
# Helper Functions
|
90 |
+
def enhance_section_title(title):
|
91 |
+
st.markdown(f"<h2 style='border-bottom: 2px solid #ccc; padding-bottom: 5px;'>{title}</h2>", unsafe_allow_html=True)
|
92 |
|
93 |
+
def update_cleaned_data(df):
|
94 |
+
st.session_state.cleaned_data = df
|
95 |
+
if 'data_versions' not in st.session_state:
|
96 |
+
st.session_state.data_versions = [st.session_state.raw_data.copy()]
|
97 |
+
st.session_state.data_versions.append(df.copy())
|
98 |
+
st.success("✅ Action completed successfully!")
|
99 |
+
st.rerun()
|
100 |
+
|
101 |
+
def convert_csv_to_json_and_text(df):
|
102 |
+
"""Convert DataFrame to JSON and then to plain text."""
|
103 |
+
json_data = df.to_json(orient="records")
|
104 |
+
data_dict = json.loads(json_data)
|
105 |
+
text_summary = f"Dataset Summary: {df.shape[0]} rows, {df.shape[1]} columns\n"
|
106 |
+
text_summary += f"Missing Values: {df.isna().sum().sum()}\n"
|
107 |
+
text_summary += "Columns:\n"
|
108 |
+
for col in df.columns:
|
109 |
+
text_summary += f"- {col} ({df[col].dtype}): "
|
110 |
+
if pd.api.types.is_numeric_dtype(df[col]):
|
111 |
+
text_summary += f"Mean={df[col].mean():.2f}, Min={df[col].min()}, Max={df[col].max()}"
|
112 |
+
else:
|
113 |
+
text_summary += f"Unique={df[col].nunique()}, Top={df[col].mode()[0] if not df[col].mode().empty else 'N/A'}"
|
114 |
+
text_summary += f", Missing={df[col].isna().sum()}\n"
|
115 |
+
return text_summary
|
116 |
+
|
117 |
+
def get_chatbot_response(user_input, app_mode, dataset_text=""):
|
118 |
+
"""Send request to internal Flask RAG endpoint."""
|
119 |
+
payload = {
|
120 |
+
"user_input": user_input,
|
121 |
+
"app_mode": app_mode,
|
122 |
+
"dataset_text": dataset_text
|
123 |
+
}
|
124 |
+
try:
|
125 |
+
response = requests.post(f"http://localhost:{FLASK_PORT}/rag_chat", json=payload, timeout=5)
|
126 |
+
response.raise_for_status()
|
127 |
+
return response.json().get("response", "Error: No response from server")
|
128 |
+
except requests.exceptions.RequestException as e:
|
129 |
+
return f"Error: Could not connect to RAG server. {str(e)}"
|
130 |
+
|
131 |
+
# Streamlit App
|
132 |
# Sidebar Navigation
|
133 |
with st.sidebar:
|
134 |
+
st.title("🔮 Data-Vision Pro")
|
135 |
+
st.markdown("Your AI-powered data analysis suite with RAG.")
|
|
|
|
|
|
|
136 |
st.markdown("---")
|
137 |
+
app_mode = st.selectbox(
|
138 |
+
"Navigation",
|
139 |
+
["Data Upload", "Data Cleaning", "EDA"],
|
140 |
+
format_func=lambda x: f"📌 {x}"
|
141 |
+
)
|
142 |
+
if app_mode == "Data Upload":
|
143 |
+
st.info("⬆️ Upload your CSV or XLSX dataset to begin.")
|
144 |
+
elif app_mode == "Data Cleaning":
|
145 |
+
st.info("🧹 Clean and preprocess your data using various tools.")
|
146 |
+
elif app_mode == "EDA":
|
147 |
+
st.info("🔍 Explore your data visually and statistically.")
|
|
|
148 |
|
149 |
+
st.markdown("---")
|
150 |
+
st.markdown("**Note**: Requires dependencies in `requirements.txt`.")
|
151 |
+
if 'cleaned_data' in st.session_state:
|
152 |
+
csv = st.session_state.cleaned_data.to_csv(index=False)
|
153 |
+
st.download_button(
|
154 |
+
label="Download Cleaned Data as CSV",
|
155 |
+
data=csv,
|
156 |
+
file_name='cleaned_data.csv',
|
157 |
+
mime='text/csv',
|
158 |
+
)
|
159 |
+
st.markdown("Created by Calvin Allen-Crawford")
|
160 |
+
st.markdown("v1.0 | © 2025")
|
161 |
+
|
162 |
+
# Main App Pages
|
163 |
if app_mode == "Data Upload":
|
164 |
+
st.title("📤 Data Upload & Profiling")
|
165 |
+
st.header("Upload Your Dataset")
|
166 |
+
st.write("Supported formats: CSV, XLSX")
|
167 |
+
|
168 |
+
if 'raw_data' not in st.session_state:
|
169 |
+
st.info("It looks like no dataset has been uploaded yet. Would you like to upload a CSV or XLSX file?")
|
170 |
+
|
171 |
+
uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"], key="file_uploader")
|
172 |
if uploaded_file:
|
173 |
+
st.session_state.pop('raw_data', None)
|
174 |
+
st.session_state.pop('cleaned_data', None)
|
175 |
+
st.session_state.pop('data_versions', None)
|
176 |
+
try:
|
177 |
+
if uploaded_file.name.endswith('.csv'):
|
178 |
+
df = pd.read_csv(uploaded_file)
|
179 |
+
else:
|
180 |
+
df = pd.read_excel(uploaded_file)
|
181 |
+
if df.empty:
|
182 |
+
st.error("Uploaded file is empty.")
|
183 |
+
st.stop()
|
184 |
+
st.session_state.raw_data = df
|
185 |
+
st.session_state.dataset_text = convert_csv_to_json_and_text(df)
|
186 |
+
if 'data_versions' not in st.session_state:
|
187 |
+
st.session_state.data_versions = [df.copy()]
|
188 |
+
col1, col2, col3 = st.columns(3)
|
189 |
+
with col1: st.metric("Rows", df.shape[0])
|
190 |
+
with col2: st.metric("Columns", df.shape[1])
|
191 |
+
with col3: st.metric("Missing Values", df.isna().sum().sum())
|
192 |
+
if st.checkbox("Show Data Preview"):
|
193 |
+
st.dataframe(df.head(10), use_container_width=True)
|
194 |
+
if st.button("Generate Full Profile Report"):
|
195 |
+
with st.spinner("Generating report..."):
|
196 |
+
pr = ProfileReport(df, explorative=True)
|
197 |
+
st_profile_report(pr)
|
198 |
+
st.success("✅ Data loaded successfully!")
|
199 |
+
except Exception as e:
|
200 |
+
st.error(f"An error occurred: {str(e)}")
|
201 |
|
202 |
elif app_mode == "Data Cleaning":
|
203 |
st.title("🧹 Smart Data Cleaning")
|
|
|
223 |
if st.button("Undo Last Action"):
|
224 |
st.session_state.data_versions.pop()
|
225 |
st.session_state.cleaned_data = st.session_state.data_versions[-1].copy()
|
226 |
+
st.session_state.dataset_text = convert_csv_to_json_and_text(st.session_state.cleaned_data)
|
227 |
st.rerun()
|
228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
elif app_mode == "EDA":
|
230 |
st.title("🔍 Interactive Data Explorer")
|
231 |
if 'cleaned_data' not in st.session_state:
|
|
|
233 |
st.stop()
|
234 |
df = st.session_state.cleaned_data.copy()
|
235 |
|
|
|
236 |
enhance_section_title("Dataset Overview")
|
|
|
|
|
237 |
with st.container():
|
238 |
col1, col2, col3, col4 = st.columns(4)
|
239 |
col1.metric("Total Rows", df.shape[0])
|
|
|
242 |
col3.metric("Missing Values", f"{df.isna().sum().sum()} ({missing_percentage:.1f}%)")
|
243 |
col4.metric("Duplicates", df.duplicated().sum())
|
244 |
|
245 |
+
# Chatbot Section
|
246 |
+
st.markdown("---")
|
247 |
+
st.subheader("💬 AI Chatbot Assistant (RAG Enabled)")
|
248 |
+
st.info("Ask me about the app or your data! Try: 'What can I do here?' or 'What’s in the dataset?'")
|
249 |
+
if "chat_history" not in st.session_state:
|
250 |
+
st.session_state.chat_history = []
|
251 |
+
|
252 |
+
for message in st.session_state.chat_history:
|
253 |
+
with st.chat_message(message["role"]):
|
254 |
+
st.markdown(message["content"])
|
255 |
+
|
256 |
+
user_input = st.chat_input("Ask me anything about the app or your data...")
|
257 |
+
if user_input:
|
258 |
+
st.session_state.chat_history.append({"role": "user", "content": user_input})
|
259 |
+
with st.chat_message("user"):
|
260 |
+
st.markdown(user_input)
|
261 |
+
|
262 |
+
with st.spinner("Thinking with RAG..."):
|
263 |
+
dataset_text = st.session_state.get("dataset_text", "")
|
264 |
+
response = get_chatbot_response(user_input, app_mode, dataset_text)
|
265 |
+
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
266 |
+
with st.chat_message("assistant"):
|
267 |
+
st.markdown(response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|