Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -231,4 +231,423 @@ def extract_plot_data(plot_info, df):
|
|
231 |
if y_col:
|
232 |
plot_text += f"Y-Axis: {y_col}\n"
|
233 |
if plot_type == "Scatter Plot" and y_col:
|
234 |
-
correlation = data[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
if y_col:
|
232 |
plot_text += f"Y-Axis: {y_col}\n"
|
233 |
if plot_type == "Scatter Plot" and y_col:
|
234 |
+
correlation = data[x_col].corr(data[y_col])
|
235 |
+
slope, intercept, r_value, p_value, std_err = stats.linregress(data[x_col].dropna(), data[y_col].dropna())
|
236 |
+
plot_text += f"Correlation: {correlation:.2f}\n"
|
237 |
+
plot_text += f"Linear Regression: Slope={slope:.2f}, Intercept={intercept:.2f}, RΒ²={r_value**2:.2f}, p-value={p_value:.4f}\n"
|
238 |
+
plot_text += f"X Stats: Mean={data[x_col].mean():.2f}, Std={data[x_col].std():.2f}, Min={data[x_col].min():.2f}, Max={data[x_col].max():.2f}\n"
|
239 |
+
plot_text += f"Y Stats: Mean={data[y_col].mean():.2f}, Std={data[y_col].std():.2f}, Min={data[y_col].min():.2f}, Max={data[y_col].max():.2f}\n"
|
240 |
+
elif plot_type == "Histogram":
|
241 |
+
plot_text += f"Stats: Mean={data[x_col].mean():.2f}, Median={data[x_col].median():.2f}, Std={data[x_col].std():.2f}\n"
|
242 |
+
plot_text += f"Skewness: {data[x_col].skew():.2f}\n"
|
243 |
+
plot_text += f"Range: [{data[x_col].min():.2f}, {data[x_col].max():.2f}]\n"
|
244 |
+
elif plot_type == "Box Plot" and y_col:
|
245 |
+
q1, q3 = data[y_col].quantile(0.25), data[y_col].quantile(0.75)
|
246 |
+
iqr = q3 - q1
|
247 |
+
plot_text += f"Y Stats: Median={data[y_col].median():.2f}, Q1={q1:.2f}, Q3={q3:.2f}, IQR={iqr:.2f}\n"
|
248 |
+
plot_text += f"Outliers: {len(data[y_col][(data[y_col] < q1 - 1.5 * iqr) | (data[y_col] > q3 + 1.5 * iqr)])} potential outliers\n"
|
249 |
+
elif plot_type == "Line Chart" and y_col:
|
250 |
+
plot_text += f"Y Stats: Mean={data[y_col].mean():.2f}, Std={data[y_col].std():.2f}, Trend={'increasing' if data[y_col].iloc[-1] > data[y_col].iloc[0] else 'decreasing'}\n"
|
251 |
+
elif plot_type == "Bar Chart":
|
252 |
+
plot_text += f"Counts: {data[x_col].value_counts().to_dict()}\n"
|
253 |
+
elif plot_type == "Correlation Matrix":
|
254 |
+
corr = data.corr()
|
255 |
+
plot_text += "Correlation Matrix:\n"
|
256 |
+
for col1 in corr.columns:
|
257 |
+
for col2 in corr.index:
|
258 |
+
if col1 < col2:
|
259 |
+
plot_text += f"{col1} vs {col2}: {corr.loc[col2, col1]:.2f}\n"
|
260 |
+
return plot_text
|
261 |
+
|
262 |
+
def get_chatbot_response(user_input, app_mode, vector_store=None, model="llama3-70b-8192"):
|
263 |
+
system_prompt = (
|
264 |
+
"You are an AI assistant in Data-Vision Pro, a data analysis app with RAG capabilities. "
|
265 |
+
f"The user is on the '{app_mode}' page:\n"
|
266 |
+
"- **Data Upload**: Upload CSV/XLSX files, view stats, or generate reports.\n"
|
267 |
+
"- **Data Cleaning**: Clean data (e.g., handle missing values, encode variables).\n"
|
268 |
+
"- **EDA**: Visualize data (e.g., scatter plots, histograms) and analyze plots.\n"
|
269 |
+
"When analyzing plots, provide detailed insights based on numerical data extracted from them."
|
270 |
+
)
|
271 |
+
context = ""
|
272 |
+
if vector_store:
|
273 |
+
docs = vector_store.similarity_search(user_input, k=3)
|
274 |
+
if docs:
|
275 |
+
context = "\n\nDataset and Plot Context:\n" + "\n".join([f"- {doc.page_content}" for doc in docs])
|
276 |
+
system_prompt += f"Use this dataset and plot context to augment your response:\n{context}"
|
277 |
+
else:
|
278 |
+
system_prompt += "No dataset or plot data is loaded. Assist based on app functionality."
|
279 |
+
try:
|
280 |
+
response = client.chat.completions.create(
|
281 |
+
model=model,
|
282 |
+
messages=[
|
283 |
+
{"role": "system", "content": system_prompt},
|
284 |
+
{"role": "user", "content": user_input}
|
285 |
+
],
|
286 |
+
temperature=0.7,
|
287 |
+
max_tokens=1024
|
288 |
+
)
|
289 |
+
return response.choices[0].message.content
|
290 |
+
except Exception as e:
|
291 |
+
return f"Error: {str(e)}"
|
292 |
+
|
293 |
+
# Command Functions
|
294 |
+
def drop_columns(columns):
|
295 |
+
if 'cleaned_data' in st.session_state:
|
296 |
+
df = st.session_state.cleaned_data.copy()
|
297 |
+
columns_to_drop = [col.strip() for col in columns.split(',')]
|
298 |
+
valid_columns = [col for col in columns_to_drop if col in df.columns]
|
299 |
+
if valid_columns:
|
300 |
+
df.drop(valid_columns, axis=1, inplace=True)
|
301 |
+
update_cleaned_data(df)
|
302 |
+
return f"Dropped columns: {', '.join(valid_columns)}"
|
303 |
+
else:
|
304 |
+
return "No valid columns found to drop."
|
305 |
+
return "No dataset loaded."
|
306 |
+
|
307 |
+
def generate_scatter_plot(params):
|
308 |
+
df = st.session_state.cleaned_data
|
309 |
+
match = re.search(r"([\w\s]+)\s+vs\s+([\w\s]+)", params)
|
310 |
+
if match and len(match.groups()) >= 2:
|
311 |
+
x_axis, y_axis = match.group(1).strip(), match.group(2).strip()
|
312 |
+
if x_axis in df.columns and y_axis in df.columns:
|
313 |
+
fig = px.scatter(df, x=x_axis, y=y_axis, title=f'Scatter Plot of {x_axis} vs {y_axis}')
|
314 |
+
st.plotly_chart(fig)
|
315 |
+
st.session_state.last_plot = {"type": "Scatter Plot", "x": x_axis, "y": y_axis, "data": df[[x_axis, y_axis]].to_json()}
|
316 |
+
return f"Generated scatter plot of {x_axis} vs {y_axis}"
|
317 |
+
return "Invalid columns for scatter plot."
|
318 |
+
|
319 |
+
def generate_histogram(params):
|
320 |
+
df = st.session_state.cleaned_data
|
321 |
+
x_axis = params.strip()
|
322 |
+
if x_axis in df.columns:
|
323 |
+
fig = px.histogram(df, x=x_axis, title=f'Histogram of {x_axis}')
|
324 |
+
st.plotly_chart(fig)
|
325 |
+
st.session_state.last_plot = {"type": "Histogram", "x": x_axis, "data": df[[x_axis]].to_json()}
|
326 |
+
return f"Generated histogram of {x_axis}"
|
327 |
+
return "Invalid column for histogram."
|
328 |
+
|
329 |
+
def analyze_plot():
|
330 |
+
if "last_plot" not in st.session_state:
|
331 |
+
return "No plot available to analyze."
|
332 |
+
plot_info = st.session_state.last_plot
|
333 |
+
df = pd.read_json(plot_info["data"])
|
334 |
+
plot_text = extract_plot_data(plot_info, df)
|
335 |
+
return f"Analysis of the last plot:\n{plot_text}"
|
336 |
+
|
337 |
+
def parse_command(command):
|
338 |
+
command = command.lower().strip()
|
339 |
+
if "drop columns" in command or "drop column" in command:
|
340 |
+
columns = command.replace("drop columns", "").replace("drop column", "").strip()
|
341 |
+
return drop_columns, columns
|
342 |
+
elif "show a scatter plot" in command or "scatter plot of" in command:
|
343 |
+
params = command.replace("show a scatter plot of", "").replace("scatter plot of", "").strip()
|
344 |
+
return generate_scatter_plot, params
|
345 |
+
elif "show a histogram" in command or "histogram of" in command:
|
346 |
+
params = command.replace("show a histogram of", "").replace("histogram of", "").strip()
|
347 |
+
return generate_histogram, params
|
348 |
+
elif "analyze plot" in command:
|
349 |
+
return lambda x: analyze_plot(), None
|
350 |
+
return None, command
|
351 |
+
|
352 |
+
# Dataset Preview Function
|
353 |
+
def display_dataset_preview():
|
354 |
+
if 'cleaned_data' in st.session_state:
|
355 |
+
st.subheader("Current Dataset Preview")
|
356 |
+
st.dataframe(st.session_state.cleaned_data.head(10), use_container_width=True)
|
357 |
+
st.markdown("---")
|
358 |
+
|
359 |
+
# Main App
|
360 |
+
def main():
|
361 |
+
# Header
|
362 |
+
st.markdown("""
|
363 |
+
<div class="header">
|
364 |
+
<h1 class="header-title">Data-Vision Pro</h1>
|
365 |
+
<div class="header-subtitle">Advanced Data Analysis with Groq Inference</div>
|
366 |
+
</div>
|
367 |
+
""", unsafe_allow_html=True)
|
368 |
+
|
369 |
+
# Top Navigation Bar
|
370 |
+
st.markdown('<div class="nav-bar">', unsafe_allow_html=True)
|
371 |
+
col1, col2, col3, col4 = st.columns([1, 1, 1, 1])
|
372 |
+
with col1:
|
373 |
+
st.markdown('<div class="nav-item">Data Input</div>', unsafe_allow_html=True)
|
374 |
+
uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"], key="file_uploader")
|
375 |
+
with col2:
|
376 |
+
st.markdown('<div class="nav-item">Navigation</div>', unsafe_allow_html=True)
|
377 |
+
app_mode = st.selectbox("Navigation", ["Data Upload", "Data Cleaning", "EDA"], format_func=lambda x: f"π {x}", label_visibility="collapsed")
|
378 |
+
with col3:
|
379 |
+
st.markdown('<div class="nav-item">Model</div>', unsafe_allow_html=True)
|
380 |
+
model = st.selectbox("Select Groq Model", ["llama3-70b-8192", "llama3-8b-8192", "mixtral-8x7b-32768", "gemma-7b-it"], index=0, label_visibility="collapsed")
|
381 |
+
with col4:
|
382 |
+
st.markdown('<div class="nav-item">Download</div>', unsafe_allow_html=True)
|
383 |
+
if 'cleaned_data' in st.session_state:
|
384 |
+
csv = st.session_state.cleaned_data.to_csv(index=False)
|
385 |
+
st.download_button(label="Download Cleaned Data", data=csv, file_name='cleaned_data.csv', mime='text/csv')
|
386 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
387 |
+
|
388 |
+
# Initialize Session State
|
389 |
+
if 'vector_store' not in st.session_state:
|
390 |
+
st.session_state.vector_store = None
|
391 |
+
if 'chat_history' not in st.session_state:
|
392 |
+
st.session_state.chat_history = []
|
393 |
+
|
394 |
+
# Display Dataset Preview
|
395 |
+
display_dataset_preview()
|
396 |
+
|
397 |
+
# App Pages
|
398 |
+
if app_mode == "Data Upload":
|
399 |
+
st.header("π€ Data Upload & Profiling")
|
400 |
+
if uploaded_file:
|
401 |
+
st.session_state.pop('raw_data', None)
|
402 |
+
st.session_state.pop('cleaned_data', None)
|
403 |
+
st.session_state.pop('data_versions', None)
|
404 |
+
try:
|
405 |
+
if uploaded_file.name.endswith('.csv'):
|
406 |
+
df = pd.read_csv(uploaded_file)
|
407 |
+
else:
|
408 |
+
df = pd.read_excel(uploaded_file)
|
409 |
+
if df.empty:
|
410 |
+
st.error("Uploaded file is empty.")
|
411 |
+
st.stop()
|
412 |
+
st.session_state.raw_data = df
|
413 |
+
st.session_state.cleaned_data = df.copy()
|
414 |
+
st.session_state.dataset_text = convert_df_to_text(df)
|
415 |
+
st.session_state.vector_store = create_vector_store(st.session_state.dataset_text)
|
416 |
+
if 'data_versions' not in st.session_state:
|
417 |
+
st.session_state.data_versions = [df.copy()]
|
418 |
+
col1, col2, col3 = st.columns(3)
|
419 |
+
with col1: st.metric("Rows", df.shape[0])
|
420 |
+
with col2: st.metric("Columns", df.shape[1])
|
421 |
+
with col3: st.metric("Missing Values", df.isna().sum().sum())
|
422 |
+
if st.checkbox("Show Data Preview"):
|
423 |
+
st.dataframe(df.head(10), use_container_width=True)
|
424 |
+
if st.button("Generate Full Profile Report"):
|
425 |
+
with st.spinner("Generating report..."):
|
426 |
+
pr = ProfileReport(df, explorative=True)
|
427 |
+
st_profile_report(pr)
|
428 |
+
st.success("β
Data loaded successfully!")
|
429 |
+
except Exception as e:
|
430 |
+
st.error(f"An error occurred: {str(e)}")
|
431 |
+
|
432 |
+
elif app_mode == "Data Cleaning":
|
433 |
+
st.header("π§Ή Smart Data Cleaning")
|
434 |
+
if 'raw_data' not in st.session_state:
|
435 |
+
st.warning("Please upload data first in the Data Upload section.")
|
436 |
+
st.stop()
|
437 |
+
if 'cleaned_data' in st.session_state:
|
438 |
+
df = st.session_state.cleaned_data.copy()
|
439 |
+
else:
|
440 |
+
st.session_state.cleaned_data = st.session_state.raw_data.copy()
|
441 |
+
df = st.session_state.cleaned_data.copy()
|
442 |
+
|
443 |
+
enhance_section_title("π Data Health Dashboard")
|
444 |
+
with st.expander("Explore Data Health Metrics", expanded=True):
|
445 |
+
col1, col2, col3 = st.columns(3)
|
446 |
+
with col1: st.metric("Columns", len(df.columns))
|
447 |
+
with col2: st.metric("Rows", len(df))
|
448 |
+
with col3: st.metric("Missing Values", df.isna().sum().sum())
|
449 |
+
if st.button("Generate Detailed Health Report"):
|
450 |
+
with st.spinner("Generating report..."):
|
451 |
+
profile = ProfileReport(df, minimal=True)
|
452 |
+
st_profile_report(profile)
|
453 |
+
if 'data_versions' in st.session_state and len(st.session_state.data_versions) > 1:
|
454 |
+
if st.button("Undo Last Action"):
|
455 |
+
st.session_state.data_versions.pop()
|
456 |
+
st.session_state.cleaned_data = st.session_state.data_versions[-1].copy()
|
457 |
+
st.session_state.dataset_text = convert_df_to_text(st.session_state.cleaned_data)
|
458 |
+
st.session_state.vector_store = create_vector_store(st.session_state.dataset_text)
|
459 |
+
st.rerun()
|
460 |
+
|
461 |
+
with st.expander("π οΈ Data Cleaning Operations", expanded=True):
|
462 |
+
enhance_section_title("π Missing Values Treatment")
|
463 |
+
missing_cols = df.columns[df.isna().any()].tolist()
|
464 |
+
if missing_cols:
|
465 |
+
cols = st.multiselect("Select columns with missing values", missing_cols)
|
466 |
+
method = st.selectbox("Choose imputation method", [
|
467 |
+
"Drop Missing Values", "Fill with Mean/Median", "Fill with Custom Value", "Forward Fill", "Backward Fill"
|
468 |
+
])
|
469 |
+
if method == "Fill with Custom Value":
|
470 |
+
custom_val = st.text_input("Enter custom value:")
|
471 |
+
if st.button("Apply Missing Value Treatment"):
|
472 |
+
new_df = df.copy()
|
473 |
+
if method == "Drop Missing Values":
|
474 |
+
new_df = new_df.dropna(subset=cols)
|
475 |
+
elif method == "Fill with Mean/Median":
|
476 |
+
for col in cols:
|
477 |
+
if pd.api.types.is_numeric_dtype(new_df[col]):
|
478 |
+
new_df[col] = new_df[col].fillna(new_df[col].median())
|
479 |
+
else:
|
480 |
+
new_df[col] = new_df[col].fillna(new_df[col].mode()[0])
|
481 |
+
elif method == "Fill with Custom Value" and custom_val:
|
482 |
+
new_df[cols] = new_df[cols].fillna(custom_val)
|
483 |
+
elif method == "Forward Fill":
|
484 |
+
new_df[cols] = new_df[cols].ffill()
|
485 |
+
elif method == "Backward Fill":
|
486 |
+
new_df[cols] = new_df[cols].bfill()
|
487 |
+
update_cleaned_data(new_df)
|
488 |
+
else:
|
489 |
+
st.success("β¨ No missing values detected!")
|
490 |
+
|
491 |
+
enhance_section_title("π Data Type Conversion")
|
492 |
+
col_to_convert = st.selectbox("Select column to convert", df.columns)
|
493 |
+
new_type = st.selectbox("Select new data type", ["String", "Integer", "Float", "Boolean", "Datetime"])
|
494 |
+
if new_type == "Datetime":
|
495 |
+
date_format = st.text_input("Enter date format (e.g., %Y-%m-%d):", "%Y-%m-%d")
|
496 |
+
if st.button("Convert Data Type"):
|
497 |
+
new_df = df.copy()
|
498 |
+
if new_type == "String":
|
499 |
+
new_df[col_to_convert] = new_df[col_to_convert].astype(str)
|
500 |
+
elif new_type == "Integer":
|
501 |
+
new_df[col_to_convert] = pd.to_numeric(new_df[col_to_convert], errors='coerce').astype('Int64')
|
502 |
+
elif new_type == "Float":
|
503 |
+
new_df[col_to_convert] = pd.to_numeric(new_df[col_to_convert], errors='coerce')
|
504 |
+
elif new_type == "Boolean":
|
505 |
+
new_df[col_to_convert] = new_df[col_to_convert].astype(bool)
|
506 |
+
elif new_type == "Datetime":
|
507 |
+
new_df[col_to_convert] = pd.to_datetime(new_df[col_to_convert], format=date_format, errors='coerce')
|
508 |
+
update_cleaned_data(new_df)
|
509 |
+
|
510 |
+
enhance_section_title("ποΈ Drop Columns")
|
511 |
+
columns_to_drop = st.multiselect("Select columns to remove", df.columns)
|
512 |
+
if columns_to_drop and st.button("Confirm Column Removal"):
|
513 |
+
new_df = df.copy()
|
514 |
+
new_df = new_df.drop(columns=columns_to_drop)
|
515 |
+
update_cleaned_data(new_df)
|
516 |
+
|
517 |
+
enhance_section_title("π’ Encoding Options")
|
518 |
+
encoding_method = st.radio("Choose encoding method", ("Label Encoding", "One-Hot Encoding"))
|
519 |
+
data_to_encode = st.multiselect("Select columns to encode", df.select_dtypes(include='object').columns)
|
520 |
+
if data_to_encode and st.button("Apply Encoding"):
|
521 |
+
new_df = df.copy()
|
522 |
+
if encoding_method == "Label Encoding":
|
523 |
+
for col in data_to_encode:
|
524 |
+
le = LabelEncoder()
|
525 |
+
new_df[col] = le.fit_transform(new_df[col].astype(str))
|
526 |
+
elif encoding_method == "One-Hot Encoding":
|
527 |
+
new_df = pd.get_dummies(new_df, columns=data_to_encode, drop_first=True, dtype=int)
|
528 |
+
update_cleaned_data(new_df)
|
529 |
+
|
530 |
+
enhance_section_title("π StandardScaler")
|
531 |
+
scale_cols = st.multiselect("Select numerical columns to scale", df.select_dtypes(include=np.number).columns)
|
532 |
+
if scale_cols and st.button("Apply StandardScaler"):
|
533 |
+
new_df = df.copy()
|
534 |
+
scaler = StandardScaler()
|
535 |
+
new_df[scale_cols] = scaler.fit_transform(new_df[scale_cols])
|
536 |
+
update_cleaned_data(new_df)
|
537 |
+
|
538 |
+
elif app_mode == "EDA":
|
539 |
+
st.header("π Interactive Data Explorer")
|
540 |
+
if 'cleaned_data' not in st.session_state:
|
541 |
+
st.warning("Please upload and clean data first.")
|
542 |
+
st.stop()
|
543 |
+
df = st.session_state.cleaned_data.copy()
|
544 |
+
|
545 |
+
enhance_section_title("Dataset Overview")
|
546 |
+
with st.container():
|
547 |
+
col1, col2, col3, col4 = st.columns(4)
|
548 |
+
col1.metric("Total Rows", df.shape[0])
|
549 |
+
col2.metric("Total Columns", df.shape[1])
|
550 |
+
missing_percentage = df.isna().sum().sum() / df.size * 100
|
551 |
+
col3.metric("Missing Values", f"{df.isna().sum().sum()} ({missing_percentage:.1f}%)")
|
552 |
+
col4.metric("Duplicates", df.duplicated().sum())
|
553 |
+
|
554 |
+
tab1, tab2, tab3 = st.tabs(["Quick Preview", "Column Types", "Missing Matrix"])
|
555 |
+
with tab1:
|
556 |
+
st.write("First few rows of the dataset:")
|
557 |
+
st.dataframe(df.head(), use_container_width=True)
|
558 |
+
with tab2:
|
559 |
+
st.write("Column Data Types:")
|
560 |
+
type_counts = df.dtypes.value_counts().reset_index()
|
561 |
+
type_counts.columns = ['Type', 'Count']
|
562 |
+
st.dataframe(type_counts, use_container_width=True)
|
563 |
+
with tab3:
|
564 |
+
st.write("Missing Values Matrix:")
|
565 |
+
fig_missing = px.imshow(df.isna(), color_continuous_scale=['#e0e0e0', '#66c2a5'])
|
566 |
+
fig_missing.update_layout(coloraxis_colorscale=[[0, 'lightgrey'], [1, '#FF4B4B']])
|
567 |
+
st.plotly_chart(fig_missing, use_container_width=True)
|
568 |
+
|
569 |
+
enhance_section_title("Interactive Visualization Builder")
|
570 |
+
with st.container():
|
571 |
+
col1, col2 = st.columns([1, 3])
|
572 |
+
with col1:
|
573 |
+
plot_type = st.selectbox("Choose visualization type", [
|
574 |
+
"Scatter Plot", "Histogram", "Box Plot", "Line Chart", "Bar Chart", "Correlation Matrix"
|
575 |
+
])
|
576 |
+
x_axis = st.selectbox("X-axis", df.columns) if plot_type != "Correlation Matrix" else None
|
577 |
+
y_axis = st.selectbox("Y-axis", df.columns) if plot_type in ["Scatter Plot", "Box Plot", "Line Chart"] else None
|
578 |
+
color_by = st.selectbox("Color encoding", ["None"] + df.columns.tolist(), format_func=lambda x: "No color" if x == "None" else x) if plot_type != "Correlation Matrix" else None
|
579 |
+
|
580 |
+
with col2:
|
581 |
+
try:
|
582 |
+
fig = None
|
583 |
+
if plot_type == "Scatter Plot" and x_axis and y_axis:
|
584 |
+
fig = px.scatter(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Scatter Plot of {x_axis} vs {y_axis}')
|
585 |
+
elif plot_type == "Histogram" and x_axis:
|
586 |
+
fig = px.histogram(df, x=x_axis, color=color_by if color_by != "None" else None, nbins=30, title=f'Histogram of {x_axis}')
|
587 |
+
elif plot_type == "Box Plot" and x_axis and y_axis:
|
588 |
+
fig = px.box(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Box Plot of {x_axis} vs {y_axis}')
|
589 |
+
elif plot_type == "Line Chart" and x_axis and y_axis:
|
590 |
+
fig = px.line(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Line Chart of {x_axis} vs {y_axis}')
|
591 |
+
elif plot_type == "Bar Chart" and x_axis:
|
592 |
+
fig = px.bar(df, x=x_axis, color=color_by if color_by != "None" else None, title=f'Bar Chart of {x_axis}')
|
593 |
+
elif plot_type == "Correlation Matrix":
|
594 |
+
numeric_df = df.select_dtypes(include=np.number)
|
595 |
+
if len(numeric_df.columns) > 1:
|
596 |
+
corr = numeric_df.corr()
|
597 |
+
fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu_r', zmin=-1, zmax=1, title='Correlation Matrix')
|
598 |
+
|
599 |
+
if fig:
|
600 |
+
fig.update_layout(template="plotly_white")
|
601 |
+
st.plotly_chart(fig, use_container_width=True)
|
602 |
+
st.session_state.last_plot = {
|
603 |
+
"type": plot_type,
|
604 |
+
"x": x_axis,
|
605 |
+
"y": y_axis,
|
606 |
+
"data": df[[x_axis, y_axis]].to_json() if y_axis else df[[x_axis]].to_json()
|
607 |
+
}
|
608 |
+
plot_text = extract_plot_data(st.session_state.last_plot, df)
|
609 |
+
st.session_state.vector_store = update_vector_store_with_plot(plot_text, st.session_state.vector_store)
|
610 |
+
with st.expander("Extracted Plot Data"):
|
611 |
+
st.text(plot_text)
|
612 |
+
else:
|
613 |
+
st.error("Please provide required inputs for the selected plot type.")
|
614 |
+
except Exception as e:
|
615 |
+
st.error(f"Couldn't create visualization: {str(e)}")
|
616 |
+
|
617 |
+
# Chatbot Section
|
618 |
+
st.markdown("---")
|
619 |
+
st.markdown('<div class="chat-container">', unsafe_allow_html=True)
|
620 |
+
st.subheader("π¬ AI Chatbot Assistant (RAG Enabled)")
|
621 |
+
st.info("Ask about your data or app features! Try: 'drop columns X, Y', 'scatter plot of X vs Y', 'analyze plot'")
|
622 |
+
|
623 |
+
for message in st.session_state.chat_history:
|
624 |
+
with st.chat_message(message["role"]):
|
625 |
+
st.markdown(f'<div class="{message["role"]}-message">{message["content"]}</div>', unsafe_allow_html=True)
|
626 |
+
|
627 |
+
user_input = st.chat_input("Ask me anything...")
|
628 |
+
if user_input:
|
629 |
+
st.session_state.chat_history.append({"role": "user", "content": user_input})
|
630 |
+
with st.chat_message("user"):
|
631 |
+
st.markdown(f'<div class="user-message">{user_input}</div>', unsafe_allow_html=True)
|
632 |
+
with st.spinner("Processing..."):
|
633 |
+
func, param = parse_command(user_input)
|
634 |
+
if func:
|
635 |
+
response = func(param) if param else func(None)
|
636 |
+
else:
|
637 |
+
response = get_chatbot_response(user_input, app_mode, st.session_state.vector_store, model)
|
638 |
+
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
639 |
+
with st.chat_message("assistant"):
|
640 |
+
st.markdown(f'<div class="bot-message">{response}</div>', unsafe_allow_html=True)
|
641 |
+
|
642 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
643 |
+
|
644 |
+
# Footer
|
645 |
+
st.markdown("""
|
646 |
+
<div class="footer">
|
647 |
+
<div>Built with <span class="tech-badge">Streamlit</span> + <span class="tech-badge">Groq</span> + <span class="tech-badge">LangChain</span> + <span class="tech-badge">FAISS</span></div>
|
648 |
+
<div style="margin-top: 8px;">Fast inference for data insights</div>
|
649 |
+
</div>
|
650 |
+
""", unsafe_allow_html=True)
|
651 |
+
|
652 |
+
if __name__ == "__main__":
|
653 |
+
main()
|