CosmickVisions commited on
Commit
e5613af
·
verified ·
1 Parent(s): 8db18fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -523
app.py CHANGED
@@ -1,35 +1,30 @@
1
  import streamlit as st
2
  import pandas as pd
3
- import numpy as np
4
  import plotly.express as px
5
- import plotly.graph_objects as go
 
 
 
 
 
6
  from ydata_profiling import ProfileReport
7
  from streamlit_pandas_profiling import st_profile_report
8
- import os
9
- from dotenv import load_dotenv
10
  from groq import Groq
11
  from langchain_community.vectorstores import FAISS
12
- from langchain_community.document_loaders import TextLoader
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
- from langchain.embeddings import HuggingFaceEmbeddings
15
- import re
16
- from scipy import stats
17
- from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
18
  import tempfile
19
 
20
- # Set page config as the first Streamlit command
21
- st.set_page_config(page_title="Data-Vision Pro", layout="wide")
22
-
23
- # Load environment variables
24
- load_dotenv()
25
-
26
- # Initialize Groq client
27
  client = Groq(api_key=os.getenv("GROQ_API_KEY"))
28
-
29
- # Initialize HuggingFace embeddings for FAISS
30
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
31
 
32
- # Custom CSS with Silver, Blue, and Gold Theme + Responsiveness
 
 
 
33
  st.markdown("""
34
  <style>
35
  :root {
@@ -41,7 +36,7 @@ st.markdown("""
41
  .stApp {
42
  background-color: var(--silver);
43
  font-family: 'Inter', sans-serif;
44
- max-width: 900px;
45
  margin: 0 auto;
46
  padding: 10px;
47
  }
@@ -50,69 +45,71 @@ st.markdown("""
50
  color: white;
51
  padding: 15px;
52
  border-radius: 5px;
53
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
54
  text-align: center;
 
55
  }
56
  .header-title {
57
- font-size: 1.5rem;
58
  font-weight: 700;
59
  margin: 0;
60
  }
61
  .header-subtitle {
62
- font-size: 0.9rem;
63
  margin-top: 5px;
64
  }
65
- .sidebar .sidebar-content {
66
  background-color: white;
67
  border-radius: 5px;
68
  box-shadow: 0 2px 4px rgba(0,0,0,0.1);
69
  padding: 15px;
 
 
 
 
70
  }
71
- .chat-container {
 
 
 
 
 
 
 
 
 
 
 
72
  background-color: white;
73
  border-radius: 5px;
74
  box-shadow: 0 2px 4px rgba(0,0,0,0.1);
 
 
 
 
 
 
75
  padding: 15px;
76
  margin-top: 20px;
 
77
  }
78
  .user-message {
79
  background-color: var(--blue);
80
  color: white;
81
- border-radius: 18px 18px 4px 18px;
82
- padding: 12px 16px;
83
- margin-left: auto;
84
  max-width: 80%;
 
85
  margin-bottom: 10px;
86
  }
87
  .bot-message {
88
  background-color: #F0F0F0;
89
  color: var(--text-color);
90
- border-radius: 18px 18px 18px 4px;
91
- padding: 12px 16px;
92
- margin-right: auto;
93
  max-width: 80%;
 
94
  margin-bottom: 10px;
95
  }
96
- .footer {
97
- text-align: center;
98
- margin-top: 20px;
99
- color: var(--text-color);
100
- font-size: 0.8rem;
101
- }
102
- .tech-badge {
103
- display: inline-block;
104
- background-color: #E6ECEF;
105
- color: var(--blue);
106
- padding: 4px 8px;
107
- border-radius: 12px;
108
- font-size: 0.7rem;
109
- margin: 0 4px;
110
- }
111
- h2 {
112
- color: var(--blue);
113
- border-bottom: 2px solid var(--gold);
114
- padding-bottom: 5px;
115
- }
116
  .stButton > button {
117
  background-color: var(--gold);
118
  color: white;
@@ -126,48 +123,76 @@ st.markdown("""
126
  }
127
  @media (max-width: 768px) {
128
  .header-title {
129
- font-size: 1.2rem;
130
  }
131
  .header-subtitle {
132
- font-size: 0.8rem;
 
 
 
 
133
  }
134
- .chat-container, .sidebar .sidebar-content {
 
 
 
 
 
135
  padding: 10px;
136
  }
137
  .stApp {
138
  padding: 5px;
139
  }
140
- h2 {
141
- font-size: 1.2rem;
142
- }
143
  }
144
- </style>
 
 
 
145
  """, unsafe_allow_html=True)
146
 
147
- # Helper Functions (unchanged)
148
- def enhance_section_title(title):
149
- st.markdown(f"<h2 style='border-bottom: 2px solid var(--gold); padding-bottom: 5px; color: var(--blue);'>{title}</h2>", unsafe_allow_html=True)
150
-
151
- def update_cleaned_data(df):
152
- st.session_state.cleaned_data = df
153
- if 'data_versions' not in st.session_state:
154
- st.session_state.data_versions = [st.session_state.raw_data.copy()]
155
- st.session_state.data_versions.append(df.copy())
156
- st.session_state.dataset_text = convert_df_to_text(df)
157
- st.success("✅ Action completed successfully!")
158
- st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
 
160
  def convert_df_to_text(df):
161
  text = f"Dataset Summary: {df.shape[0]} rows, {df.shape[1]} columns\n"
162
  text += f"Missing Values: {df.isna().sum().sum()}\n"
163
- text += "Columns:\n"
164
  for col in df.columns:
165
- text += f"- {col} ({df[col].dtype}): "
166
- if pd.api.types.is_numeric_dtype(df[col]):
167
- text += f"Mean={df[col].mean():.2f}, Min={df[col].min()}, Max={df[col].max()}"
168
- else:
169
- text += f"Unique={df[col].nunique()}, Top={df[col].mode()[0] if not df[col].mode().empty else 'N/A'}"
170
- text += f", Missing={df[col].isna().sum()}\n"
171
  return text
172
 
173
  def create_vector_store(df_text):
@@ -176,469 +201,122 @@ def create_vector_store(df_text):
176
  temp_path = temp_file.name
177
  loader = TextLoader(temp_path)
178
  documents = loader.load()
179
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
180
- texts = text_splitter.split_documents(documents)
181
  vector_store = FAISS.from_documents(texts, embeddings)
182
  os.unlink(temp_path)
183
  return vector_store
184
 
185
- def update_vector_store_with_plot(plot_text, existing_vector_store):
186
- with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as temp_file:
187
- temp_file.write(plot_text)
188
- temp_path = temp_file.name
189
- loader = TextLoader(temp_path)
190
- documents = loader.load()
191
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
192
- texts = text_splitter.split_documents(documents)
193
- if existing_vector_store:
194
- existing_vector_store.add_documents(texts)
195
- else:
196
- existing_vector_store = FAISS.from_documents(texts, embeddings)
197
- os.unlink(temp_path)
198
- return existing_vector_store
199
-
200
- def extract_plot_data(plot_info, df):
201
- plot_type = plot_info["type"]
202
- x_col = plot_info["x"]
203
- y_col = plot_info["y"] if "y" in plot_info else None
204
- data = pd.read_json(plot_info["data"])
205
- plot_text = f"Plot Type: {plot_type}\n"
206
- plot_text += f"X-Axis: {x_col}\n"
207
- if y_col:
208
- plot_text += f"Y-Axis: {y_col}\n"
209
- if plot_type == "Scatter Plot" and y_col:
210
- correlation = data[x_col].corr(data[y_col])
211
- slope, intercept, r_value, p_value, std_err = stats.linregress(data[x_col].dropna(), data[y_col].dropna())
212
- plot_text += f"Correlation: {correlation:.2f}\n"
213
- plot_text += f"Linear Regression: Slope={slope:.2f}, Intercept={intercept:.2f}, R²={r_value**2:.2f}, p-value={p_value:.4f}\n"
214
- plot_text += f"X Stats: Mean={data[x_col].mean():.2f}, Std={data[x_col].std():.2f}, Min={data[x_col].min():.2f}, Max={data[x_col].max():.2f}\n"
215
- plot_text += f"Y Stats: Mean={data[y_col].mean():.2f}, Std={data[y_col].std():.2f}, Min={data[y_col].min():.2f}, Max={data[y_col].max():.2f}\n"
216
- elif plot_type == "Histogram":
217
- plot_text += f"Stats: Mean={data[x_col].mean():.2f}, Median={data[x_col].median():.2f}, Std={data[x_col].std():.2f}\n"
218
- plot_text += f"Skewness: {data[x_col].skew():.2f}\n"
219
- plot_text += f"Range: [{data[x_col].min():.2f}, {data[x_col].max():.2f}]\n"
220
- elif plot_type == "Box Plot" and y_col:
221
- q1, q3 = data[y_col].quantile(0.25), data[y_col].quantile(0.75)
222
- iqr = q3 - q1
223
- plot_text += f"Y Stats: Median={data[y_col].median():.2f}, Q1={q1:.2f}, Q3={q3:.2f}, IQR={iqr:.2f}\n"
224
- plot_text += f"Outliers: {len(data[y_col][(data[y_col] < q1 - 1.5 * iqr) | (data[y_col] > q3 + 1.5 * iqr)])} potential outliers\n"
225
- elif plot_type == "Line Chart" and y_col:
226
- plot_text += f"Y Stats: Mean={data[y_col].mean():.2f}, Std={data[y_col].std():.2f}, Trend={'increasing' if data[y_col].iloc[-1] > data[y_col].iloc[0] else 'decreasing'}\n"
227
- elif plot_type == "Bar Chart":
228
- plot_text += f"Counts: {data[x_col].value_counts().to_dict()}\n"
229
- elif plot_type == "Correlation Matrix":
230
- corr = data.corr()
231
- plot_text += "Correlation Matrix:\n"
232
- for col1 in corr.columns:
233
- for col2 in corr.index:
234
- if col1 < col2:
235
- plot_text += f"{col1} vs {col2}: {corr.loc[col2, col1]:.2f}\n"
236
- return plot_text
237
-
238
- def get_chatbot_response(user_input, app_mode, vector_store=None, model="llama3-70b-8192"):
239
- system_prompt = (
240
- "You are an AI assistant in Data-Vision Pro, a data analysis app with RAG capabilities. "
241
- f"The user is on the '{app_mode}' page:\n"
242
- "- **Data Upload**: Upload CSV/XLSX files, view stats, or generate reports.\n"
243
- "- **Data Cleaning**: Clean data (e.g., handle missing values, encode variables).\n"
244
- "- **EDA**: Visualize data (e.g., scatter plots, histograms) and analyze plots.\n"
245
- "When analyzing plots, provide detailed insights based on numerical data extracted from them."
246
- )
247
  context = ""
248
- if vector_store:
249
- docs = vector_store.similarity_search(user_input, k=3)
250
- if docs:
251
- context = "\n\nDataset and Plot Context:\n" + "\n".join([f"- {doc.page_content}" for doc in docs])
252
- system_prompt += f"Use this dataset and plot context to augment your response:\n{context}"
253
- else:
254
- system_prompt += "No dataset or plot data is loaded. Assist based on app functionality."
255
  try:
256
  response = client.chat.completions.create(
257
- model=model,
258
  messages=[
259
- {"role": "system", "content": system_prompt},
260
- {"role": "user", "content": user_input}
261
- ],
262
- temperature=0.7,
263
- max_tokens=1024
264
- )
265
- return response.choices[0].message.content
266
  except Exception as e:
267
  return f"Error: {str(e)}"
268
 
269
- # Command Functions
270
- def drop_columns(columns):
271
- if 'cleaned_data' in st.session_state:
272
- df = st.session_state.cleaned_data.copy()
273
- columns_to_drop = [col.strip() for col in columns.split(',')]
274
- valid_columns = [col for col in columns_to_drop if col in df.columns]
275
- if valid_columns:
276
- df.drop(valid_columns, axis=1, inplace=True)
277
- update_cleaned_data(df)
278
- return f"Dropped columns: {', '.join(valid_columns)}"
279
- else:
280
- return "No valid columns found to drop."
281
- return "No dataset loaded."
282
-
283
- def generate_scatter_plot(params):
284
- df = st.session_state.cleaned_data
285
- match = re.search(r"([\w\s]+)\s+vs\s+([\w\s]+)", params)
286
- if match and len(match.groups()) >= 2:
287
- x_axis, y_axis = match.group(1).strip(), match.group(2).strip()
288
- if x_axis in df.columns and y_axis in df.columns:
289
- fig = px.scatter(df, x=x_axis, y=y_axis, title=f'Scatter Plot of {x_axis} vs {y_axis}')
290
- st.plotly_chart(fig)
291
- st.session_state.last_plot = {"type": "Scatter Plot", "x": x_axis, "y": y_axis, "data": df[[x_axis, y_axis]].to_json()}
292
- return f"Generated scatter plot of {x_axis} vs {y_axis}"
293
- return "Invalid columns for scatter plot."
294
-
295
- def generate_histogram(params):
296
- df = st.session_state.cleaned_data
297
- x_axis = params.strip()
298
- if x_axis in df.columns:
299
- fig = px.histogram(df, x=x_axis, title=f'Histogram of {x_axis}')
300
- st.plotly_chart(fig)
301
- st.session_state.last_plot = {"type": "Histogram", "x": x_axis, "data": df[[x_axis]].to_json()}
302
- return f"Generated histogram of {x_axis}"
303
- return "Invalid column for histogram."
304
-
305
- def analyze_plot():
306
- if "last_plot" not in st.session_state:
307
- return "No plot available to analyze."
308
- plot_info = st.session_state.last_plot
309
- df = pd.read_json(plot_info["data"])
310
- plot_text = extract_plot_data(plot_info, df)
311
- return f"Analysis of the last plot:\n{plot_text}"
312
-
313
- def parse_command(command):
314
- command = command.lower().strip()
315
- if "drop columns" in command or "drop column" in command:
316
- columns = command.replace("drop columns", "").replace("drop column", "").strip()
317
- return drop_columns, columns
318
- elif "show a scatter plot" in command or "scatter plot of" in command:
319
- params = command.replace("show a scatter plot of", "").replace("scatter plot of", "").strip()
320
- return generate_scatter_plot, params
321
- elif "show a histogram" in command or "histogram of" in command:
322
- params = command.replace("show a histogram of", "").replace("histogram of", "").strip()
323
- return generate_histogram, params
324
- elif "analyze plot" in command:
325
- return lambda x: analyze_plot(), None
326
- return None, command
327
-
328
- # Dataset Preview Function
329
- def display_dataset_preview():
330
- if 'cleaned_data' in st.session_state:
331
- st.subheader("Current Dataset Preview")
332
- st.dataframe(st.session_state.cleaned_data.head(10), use_container_width=True)
333
- st.markdown("---")
334
-
335
- # Main App
336
  def main():
337
- # Header
338
- st.markdown("""
339
- <div class="header">
340
- <h1 class="header-title">Data-Vision Pro</h1>
341
- <div class="header-subtitle">Advanced Data Analysis with Groq Inference</div>
342
- </div>
343
- """, unsafe_allow_html=True)
344
-
345
- # Sidebar Navigation
346
- with st.sidebar:
347
- st.markdown("### 🔮 Data-Vision Pro")
348
- st.markdown("Your AI-powered data analysis suite with RAG.")
349
- st.markdown("---")
350
- app_mode = st.selectbox(
351
- "Navigation",
352
- ["Data Upload", "Data Cleaning", "EDA"],
353
- format_func=lambda x: f"📌 {x}"
354
- )
355
- model = st.selectbox(
356
- "Select Groq Model",
357
- ["llama3-70b-8192", "llama3-8b-8192", "mixtral-8x7b-32768", "gemma-7b-it"],
358
- index=0
359
- )
360
- if app_mode == "Data Upload":
361
- st.info("⬆️ Upload your CSV or XLSX dataset to begin.")
362
- elif app_mode == "Data Cleaning":
363
- st.info("🧹 Clean and preprocess your data.")
364
- elif app_mode == "EDA":
365
- st.info("🔍 Explore your data visually.")
366
-
367
- if 'cleaned_data' in st.session_state:
368
- csv = st.session_state.cleaned_data.to_csv(index=False)
369
- st.download_button(
370
- label="Download Cleaned Data",
371
- data=csv,
372
- file_name='cleaned_data.csv',
373
- mime='text/csv',
374
- )
375
- st.markdown("---")
376
- st.markdown("Built with <span class='tech-badge'>Streamlit</span> + <span class='tech-badge'>Groq</span>", unsafe_allow_html=True)
377
-
378
- # Initialize Session State
379
- if 'vector_store' not in st.session_state:
380
- st.session_state.vector_store = None
381
- if 'chat_history' not in st.session_state:
382
- st.session_state.chat_history = []
383
-
384
- # Display Dataset Preview
385
- display_dataset_preview()
386
-
387
- # App Pages
388
- if app_mode == "Data Upload":
389
- st.header("📤 Data Upload & Profiling")
390
- uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"], key="file_uploader")
391
  if uploaded_file:
392
- st.session_state.pop('raw_data', None)
393
- st.session_state.pop('cleaned_data', None)
394
- st.session_state.pop('data_versions', None)
395
- try:
396
- if uploaded_file.name.endswith('.csv'):
397
- df = pd.read_csv(uploaded_file)
398
- else:
399
- df = pd.read_excel(uploaded_file)
400
- if df.empty:
401
- st.error("Uploaded file is empty.")
402
- st.stop()
403
- st.session_state.raw_data = df
404
- st.session_state.cleaned_data = df.copy()
405
- st.session_state.dataset_text = convert_df_to_text(df)
406
- st.session_state.vector_store = create_vector_store(st.session_state.dataset_text)
407
- if 'data_versions' not in st.session_state:
408
- st.session_state.data_versions = [df.copy()]
409
- col1, col2, col3 = st.columns(3)
410
- with col1: st.metric("Rows", df.shape[0])
411
- with col2: st.metric("Columns", df.shape[1])
412
- with col3: st.metric("Missing Values", df.isna().sum().sum())
413
- if st.checkbox("Show Data Preview"):
414
- st.dataframe(df.head(10), use_container_width=True)
415
- if st.button("Generate Full Profile Report"):
416
- with st.spinner("Generating report..."):
417
- pr = ProfileReport(df, explorative=True)
418
- st_profile_report(pr)
419
- st.success("✅ Data loaded successfully!")
420
- except Exception as e:
421
- st.error(f"An error occurred: {str(e)}")
422
 
423
- elif app_mode == "Data Cleaning":
424
- st.header("🧹 Smart Data Cleaning")
425
- if 'raw_data' not in st.session_state:
426
- st.warning("Please upload data first in the Data Upload section.")
427
- st.stop()
428
- if 'cleaned_data' in st.session_state:
429
- df = st.session_state.cleaned_data.copy()
 
 
 
 
 
 
 
430
  else:
431
- st.session_state.cleaned_data = st.session_state.raw_data.copy()
432
- df = st.session_state.cleaned_data.copy()
433
-
434
- enhance_section_title("📊 Data Health Dashboard")
435
- with st.expander("Explore Data Health Metrics", expanded=True):
436
- col1, col2, col3 = st.columns(3)
437
- with col1: st.metric("Columns", len(df.columns))
438
- with col2: st.metric("Rows", len(df))
439
- with col3: st.metric("Missing Values", df.isna().sum().sum())
440
- if st.button("Generate Detailed Health Report"):
441
- with st.spinner("Generating report..."):
442
- profile = ProfileReport(df, minimal=True)
443
- st_profile_report(profile)
444
- if 'data_versions' in st.session_state and len(st.session_state.data_versions) > 1:
445
- if st.button("Undo Last Action"):
446
- st.session_state.data_versions.pop()
447
- st.session_state.cleaned_data = st.session_state.data_versions[-1].copy()
448
- st.session_state.dataset_text = convert_df_to_text(st.session_state.cleaned_data)
449
- st.session_state.vector_store = create_vector_store(st.session_state.dataset_text)
450
- st.rerun()
451
-
452
- with st.expander("🛠️ Data Cleaning Operations", expanded=True):
453
- enhance_section_title("🔍 Missing Values Treatment")
454
- missing_cols = df.columns[df.isna().any()].tolist()
455
- if missing_cols:
456
- cols = st.multiselect("Select columns with missing values", missing_cols)
457
- method = st.selectbox("Choose imputation method", [
458
- "Drop Missing Values", "Fill with Mean/Median", "Fill with Custom Value", "Forward Fill", "Backward Fill"
459
- ])
460
- if method == "Fill with Custom Value":
461
- custom_val = st.text_input("Enter custom value:")
462
- if st.button("Apply Missing Value Treatment"):
463
- new_df = df.copy()
464
- if method == "Drop Missing Values":
465
- new_df = new_df.dropna(subset=cols)
466
- elif method == "Fill with Mean/Median":
467
- for col in cols:
468
- if pd.api.types.is_numeric_dtype(new_df[col]):
469
- new_df[col] = new_df[col].fillna(new_df[col].median())
470
- else:
471
- new_df[col] = new_df[col].fillna(new_df[col].mode()[0])
472
- elif method == "Fill with Custom Value" and custom_val:
473
- new_df[cols] = new_df[cols].fillna(custom_val)
474
- elif method == "Forward Fill":
475
- new_df[cols] = new_df[cols].ffill()
476
- elif method == "Backward Fill":
477
- new_df[cols] = new_df[cols].bfill()
478
- update_cleaned_data(new_df)
479
  else:
480
- st.success(" No missing values detected!")
481
-
482
- enhance_section_title("🔄 Data Type Conversion")
483
- col_to_convert = st.selectbox("Select column to convert", df.columns)
484
- new_type = st.selectbox("Select new data type", ["String", "Integer", "Float", "Boolean", "Datetime"])
485
- if new_type == "Datetime":
486
- date_format = st.text_input("Enter date format (e.g., %Y-%m-%d):", "%Y-%m-%d")
487
- if st.button("Convert Data Type"):
488
- new_df = df.copy()
489
- if new_type == "String":
490
- new_df[col_to_convert] = new_df[col_to_convert].astype(str)
491
- elif new_type == "Integer":
492
- new_df[col_to_convert] = pd.to_numeric(new_df[col_to_convert], errors='coerce').astype('Int64')
493
- elif new_type == "Float":
494
- new_df[col_to_convert] = pd.to_numeric(new_df[col_to_convert], errors='coerce')
495
- elif new_type == "Boolean":
496
- new_df[col_to_convert] = new_df[col_to_convert].astype(bool)
497
- elif new_type == "Datetime":
498
- new_df[col_to_convert] = pd.to_datetime(new_df[col_to_convert], format=date_format, errors='coerce')
499
- update_cleaned_data(new_df)
500
-
501
- enhance_section_title("🗑️ Drop Columns")
502
- columns_to_drop = st.multiselect("Select columns to remove", df.columns)
503
- if columns_to_drop and st.button("Confirm Column Removal"):
504
- new_df = df.copy()
505
- new_df = new_df.drop(columns=columns_to_drop)
506
- update_cleaned_data(new_df)
507
-
508
- enhance_section_title("🔢 Encoding Options")
509
- encoding_method = st.radio("Choose encoding method", ("Label Encoding", "One-Hot Encoding"))
510
- data_to_encode = st.multiselect("Select columns to encode", df.select_dtypes(include='object').columns)
511
- if data_to_encode and st.button("Apply Encoding"):
512
- new_df = df.copy()
513
- if encoding_method == "Label Encoding":
514
- for col in data_to_encode:
515
- le = LabelEncoder()
516
- new_df[col] = le.fit_transform(new_df[col].astype(str))
517
- elif encoding_method == "One-Hot Encoding":
518
- new_df = pd.get_dummies(new_df, columns=data_to_encode, drop_first=True, dtype=int)
519
- update_cleaned_data(new_df)
520
-
521
- enhance_section_title("📏 StandardScaler")
522
- scale_cols = st.multiselect("Select numerical columns to scale", df.select_dtypes(include=np.number).columns)
523
- if scale_cols and st.button("Apply StandardScaler"):
524
- new_df = df.copy()
525
  scaler = StandardScaler()
526
- new_df[scale_cols] = scaler.fit_transform(new_df[scale_cols])
527
- update_cleaned_data(new_df)
528
-
529
- elif app_mode == "EDA":
530
- st.header("🔍 Interactive Data Explorer")
531
- if 'cleaned_data' not in st.session_state:
532
- st.warning("Please upload and clean data first.")
533
- st.stop()
534
- df = st.session_state.cleaned_data.copy()
535
-
536
- enhance_section_title("Dataset Overview")
537
- with st.container():
538
- col1, col2, col3, col4 = st.columns(4)
539
- col1.metric("Total Rows", df.shape[0])
540
- col2.metric("Total Columns", df.shape[1])
541
- missing_percentage = df.isna().sum().sum() / df.size * 100
542
- col3.metric("Missing Values", f"{df.isna().sum().sum()} ({missing_percentage:.1f}%)")
543
- col4.metric("Duplicates", df.duplicated().sum())
544
-
545
- tab1, tab2, tab3 = st.tabs(["Quick Preview", "Column Types", "Missing Matrix"])
546
- with tab1:
547
- st.write("First few rows of the dataset:")
548
- st.dataframe(df.head(), use_container_width=True)
549
- with tab2:
550
- st.write("Column Data Types:")
551
- type_counts = df.dtypes.value_counts().reset_index()
552
- type_counts.columns = ['Type', 'Count']
553
- st.dataframe(type_counts, use_container_width=True)
554
- with tab3:
555
- st.write("Missing Values Matrix:")
556
- fig_missing = px.imshow(df.isna(), color_continuous_scale=['#e0e0e0', '#66c2a5'])
557
- fig_missing.update_layout(coloraxis_colorscale=[[0, 'lightgrey'], [1, '#FF4B4B']])
558
- st.plotly_chart(fig_missing, use_container_width=True)
559
-
560
- enhance_section_title("Interactive Visualization Builder")
561
- with st.container():
562
- col1, col2 = st.columns([1, 3])
563
- with col1:
564
- plot_type = st.selectbox("Choose visualization type", [
565
- "Scatter Plot", "Histogram", "Box Plot", "Line Chart", "Bar Chart", "Correlation Matrix"
566
- ])
567
- x_axis = st.selectbox("X-axis", df.columns) if plot_type != "Correlation Matrix" else None
568
- y_axis = st.selectbox("Y-axis", df.columns) if plot_type in ["Scatter Plot", "Box Plot", "Line Chart"] else None
569
- color_by = st.selectbox("Color encoding", ["None"] + df.columns.tolist(), format_func=lambda x: "No color" if x == "None" else x) if plot_type != "Correlation Matrix" else None
570
-
571
- with col2:
572
- try:
573
- fig = None
574
- if plot_type == "Scatter Plot" and x_axis and y_axis:
575
- fig = px.scatter(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Scatter Plot of {x_axis} vs {y_axis}')
576
- elif plot_type == "Histogram" and x_axis:
577
- fig = px.histogram(df, x=x_axis, color=color_by if color_by != "None" else None, nbins=30, title=f'Histogram of {x_axis}')
578
- elif plot_type == "Box Plot" and x_axis and y_axis:
579
- fig = px.box(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Box Plot of {x_axis} vs {y_axis}')
580
- elif plot_type == "Line Chart" and x_axis and y_axis:
581
- fig = px.line(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Line Chart of {x_axis} vs {y_axis}')
582
- elif plot_type == "Bar Chart" and x_axis:
583
- fig = px.bar(df, x=x_axis, color=color_by if color_by != "None" else None, title=f'Bar Chart of {x_axis}')
584
- elif plot_type == "Correlation Matrix":
585
- numeric_df = df.select_dtypes(include=np.number)
586
- if len(numeric_df.columns) > 1:
587
- corr = numeric_df.corr()
588
- fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu_r', zmin=-1, zmax=1, title='Correlation Matrix')
589
-
590
- if fig:
591
- fig.update_layout(template="plotly_white")
592
- st.plotly_chart(fig, use_container_width=True)
593
- st.session_state.last_plot = {
594
- "type": plot_type,
595
- "x": x_axis,
596
- "y": y_axis,
597
- "data": df[[x_axis, y_axis]].to_json() if y_axis else df[[x_axis]].to_json()
598
- }
599
- plot_text = extract_plot_data(st.session_state.last_plot, df)
600
- st.session_state.vector_store = update_vector_store_with_plot(plot_text, st.session_state.vector_store)
601
- with st.expander("Extracted Plot Data"):
602
- st.text(plot_text)
603
- else:
604
- st.error("Please provide required inputs for the selected plot type.")
605
- except Exception as e:
606
- st.error(f"Couldn't create visualization: {str(e)}")
607
-
608
- # Chatbot Section
609
- st.markdown("---")
610
- st.markdown('<div class="chat-container">', unsafe_allow_html=True)
611
- st.subheader("💬 AI Chatbot Assistant (RAG Enabled)")
612
- st.info("Ask about your data or app features! Try: 'drop columns X, Y', 'scatter plot of X vs Y', 'analyze plot'")
613
-
614
- for message in st.session_state.chat_history:
615
- with st.chat_message(message["role"]):
616
- st.markdown(f'<div class="{message["role"]}-message">{message["content"]}</div>', unsafe_allow_html=True)
617
-
618
- user_input = st.chat_input("Ask me anything...")
619
- if user_input:
620
- st.session_state.chat_history.append({"role": "user", "content": user_input})
621
- with st.chat_message("user"):
622
- st.markdown(f'<div class="user-message">{user_input}</div>', unsafe_allow_html=True)
623
- with st.spinner("Processing..."):
624
- func, param = parse_command(user_input)
625
- if func:
626
- response = func(param) if param else func(None)
627
- else:
628
- response = get_chatbot_response(user_input, app_mode, st.session_state.vector_store, model)
629
- st.session_state.chat_history.append({"role": "assistant", "content": response})
630
- with st.chat_message("assistant"):
631
- st.markdown(f'<div class="bot-message">{response}</div>', unsafe_allow_html=True)
632
-
633
- st.markdown('</div>', unsafe_allow_html=True)
634
-
635
- # Footer
636
- st.markdown("""
637
- <div class="footer">
638
- <div>Built with <span class="tech-badge">Streamlit</span> + <span class="tech-badge">Groq</span> + <span class="tech-badge">LangChain</span> + <span class="tech-badge">FAISS</span></div>
639
- <div style="margin-top: 8px;">Fast inference for data insights</div>
640
- </div>
641
- """, unsafe_allow_html=True)
642
 
643
  if __name__ == "__main__":
644
  main()
 
1
  import streamlit as st
2
  import pandas as pd
 
3
  import plotly.express as px
4
+ import numpy as np
5
+ from sklearn.model_selection import train_test_split
6
+ from sklearn.neural_network import MLPClassifier, MLPRegressor
7
+ from sklearn.cluster import KMeans
8
+ from sklearn.metrics import accuracy_score, r2_score, silhouette_score
9
+ from sklearn.preprocessing import StandardScaler
10
  from ydata_profiling import ProfileReport
11
  from streamlit_pandas_profiling import st_profile_report
 
 
12
  from groq import Groq
13
  from langchain_community.vectorstores import FAISS
 
14
  from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+ from langchain_huggingface import HuggingFaceEmbeddings
16
+ from langchain_community.document_loaders import TextLoader
17
+ import os
 
18
  import tempfile
19
 
20
+ # Initialize clients
 
 
 
 
 
 
21
  client = Groq(api_key=os.getenv("GROQ_API_KEY"))
 
 
22
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
23
 
24
+ # Set page config
25
+ st.set_page_config(page_title="Neural-Vision Enhanced", layout="wide")
26
+
27
+ # Custom CSS for Responsive Silver-Blue-Gold Theme with Top Nav
28
  st.markdown("""
29
  <style>
30
  :root {
 
36
  .stApp {
37
  background-color: var(--silver);
38
  font-family: 'Inter', sans-serif;
39
+ max-width: 1200px;
40
  margin: 0 auto;
41
  padding: 10px;
42
  }
 
45
  color: white;
46
  padding: 15px;
47
  border-radius: 5px;
 
48
  text-align: center;
49
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
50
  }
51
  .header-title {
52
+ font-size: 1.8rem;
53
  font-weight: 700;
54
  margin: 0;
55
  }
56
  .header-subtitle {
57
+ font-size: 1rem;
58
  margin-top: 5px;
59
  }
60
+ .nav-bar {
61
  background-color: white;
62
  border-radius: 5px;
63
  box-shadow: 0 2px 4px rgba(0,0,0,0.1);
64
  padding: 15px;
65
+ margin-bottom: 20px;
66
+ display: flex;
67
+ justify-content: space-around;
68
+ align-items: center;
69
  }
70
+ .nav-item {
71
+ color: var(--blue);
72
+ font-weight: 500;
73
+ cursor: pointer;
74
+ padding: 5px 10px;
75
+ border-radius: 5px;
76
+ }
77
+ .nav-item:hover {
78
+ background-color: var(--gold);
79
+ color: white;
80
+ }
81
+ .card {
82
  background-color: white;
83
  border-radius: 5px;
84
  box-shadow: 0 2px 4px rgba(0,0,0,0.1);
85
+ padding: 20px;
86
+ margin-bottom: 20px;
87
+ }
88
+ .chat-container {
89
+ background-color: white;
90
+ border-radius: 5px;
91
  padding: 15px;
92
  margin-top: 20px;
93
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
94
  }
95
  .user-message {
96
  background-color: var(--blue);
97
  color: white;
98
+ border-radius: 15px 15px 5px 15px;
99
+ padding: 10px;
 
100
  max-width: 80%;
101
+ margin-left: auto;
102
  margin-bottom: 10px;
103
  }
104
  .bot-message {
105
  background-color: #F0F0F0;
106
  color: var(--text-color);
107
+ border-radius: 15px 15px 15px 5px;
108
+ padding: 10px;
 
109
  max-width: 80%;
110
+ margin-right: auto;
111
  margin-bottom: 10px;
112
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  .stButton > button {
114
  background-color: var(--gold);
115
  color: white;
 
123
  }
124
  @media (max-width: 768px) {
125
  .header-title {
126
+ font-size: 1.4rem;
127
  }
128
  .header-subtitle {
129
+ font-size: 0.9rem;
130
+ }
131
+ .nav-bar {
132
+ flex-direction: column;
133
+ padding: 10px;
134
  }
135
+ .nav-item {
136
+ margin: 5px 0;
137
+ width: 100%;
138
+ text-align: center;
139
+ }
140
+ .card, .chat-container {
141
  padding: 10px;
142
  }
143
  .stApp {
144
  padding: 5px;
145
  }
 
 
 
146
  }
147
+ # Footer
148
+ <footer style='text-align: center; padding: 10px; background-color: var(--blue); color: white; border-radius: 5px; margin-top: 20px;'>
149
+ <p>Created by Calvin Allen-Crawford</p>
150
+ </footer>
151
  """, unsafe_allow_html=True)
152
 
153
+ # Session State Initialization
154
+ if 'metrics' not in st.session_state:
155
+ st.session_state.metrics = {}
156
+ if 'chat_history' not in st.session_state:
157
+ st.session_state.chat_history = []
158
+ if 'vector_store' not in st.session_state:
159
+ st.session_state.vector_store = None
160
+ if 'custom_layers' not in st.session_state:
161
+ st.session_state.custom_layers = []
162
+ if 'prebuilt_selection' not in st.session_state:
163
+ st.session_state.prebuilt_selection = None
164
+ if 'model_config' not in st.session_state:
165
+ st.session_state.model_config = {}
166
+ if 'model_builder_mode' not in st.session_state:
167
+ st.session_state.model_builder_mode = "prebuilt"
168
+ if 'custom_model_type' not in st.session_state:
169
+ st.session_state.custom_model_type = "classification"
170
+
171
+ # Prebuilt Models
172
+ PREBUILT_MODELS = {
173
+ "Legal Document Classifier": {
174
+ "description": "Optimized for legal document classification.",
175
+ "architecture": {"type": "classification", "hidden_layers": [(128, "relu"), (64, "relu")], "dropout": 0.3, "optimizer": "adam", "learning_rate": 0.001},
176
+ "domain": "Legal"
177
+ },
178
+ "Financial Fraud Detector": {
179
+ "description": "Detects anomalies in financial transactions.",
180
+ "architecture": {"type": "classification", "hidden_layers": [(256, "relu"), (128, "relu"), (64, "relu")], "dropout": 0.4, "optimizer": "adam", "learning_rate": 0.0005},
181
+ "domain": "Financial"
182
+ },
183
+ "Customer Segmentation Engine": {
184
+ "description": "Advanced customer segmentation.",
185
+ "architecture": {"type": "clustering", "n_clusters": 5, "algorithm": "kmeans", "init": "k-means++", "n_init": 10},
186
+ "domain": "Marketing"
187
+ }
188
+ }
189
 
190
+ # Helper Functions (unchanged)
191
  def convert_df_to_text(df):
192
  text = f"Dataset Summary: {df.shape[0]} rows, {df.shape[1]} columns\n"
193
  text += f"Missing Values: {df.isna().sum().sum()}\n"
 
194
  for col in df.columns:
195
+ text += f"- {col} ({df[col].dtype}): Mean={df[col].mean():.2f if pd.api.types.is_numeric_dtype(df[col]) else 'N/A'}\n"
 
 
 
 
 
196
  return text
197
 
198
  def create_vector_store(df_text):
 
201
  temp_path = temp_file.name
202
  loader = TextLoader(temp_path)
203
  documents = loader.load()
204
+ texts = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100).split_documents(documents)
 
205
  vector_store = FAISS.from_documents(texts, embeddings)
206
  os.unlink(temp_path)
207
  return vector_store
208
 
209
+ def get_groq_response(prompt, mode):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  context = ""
211
+ if st.session_state.vector_store:
212
+ docs = st.session_state.vector_store.similarity_search(prompt, k=3)
213
+ context += "\nDataset Context:\n" + "\n".join([f"- {doc.page_content}" for doc in docs])
 
 
 
 
214
  try:
215
  response = client.chat.completions.create(
216
+ model="llama3-70b-8192",
217
  messages=[
218
+ {"role": "system", "content": f"You are an expert in {mode} data analysis.\n{context}"},
219
+ {"role": "user", "content": prompt}
220
+ ]
221
+ ).choices[0].message.content
222
+ return response
 
 
223
  except Exception as e:
224
  return f"Error: {str(e)}"
225
 
226
+ def build_model_from_config(config, X, y=None):
227
+ problem_type = config.get("type", "classification")
228
+ if problem_type == "clustering":
229
+ return KMeans(n_clusters=config.get("n_clusters", 3), init=config.get("init", "k-means++"), n_init=config.get("n_init", 10), random_state=42)
230
+ hidden_layers = config.get("hidden_layers", [(100, "relu")])
231
+ layer_sizes = [size for size, _ in hidden_layers]
232
+ activation = hidden_layers[0][1] if hidden_layers else "relu"
233
+ if problem_type == "classification":
234
+ return MLPClassifier(hidden_layer_sizes=layer_sizes, activation=activation, solver=config.get("optimizer", "adam"), learning_rate_init=config.get("learning_rate", 0.001), random_state=42)
235
+ return MLPRegressor(hidden_layer_sizes=layer_sizes, activation=activation, solver=config.get("optimizer", "adam"), learning_rate_init=config.get("learning_rate", 0.001), random_state=42)
236
+
237
+ # Main Application
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def main():
239
+ st.markdown('<div class="header"><h1 class="header-title">Neural-Vision Enhanced</h1><p class="header-subtitle">Build & Train Neural Networks</p></div>', unsafe_allow_html=True)
240
+
241
+ # Top Navigation Bar
242
+ st.markdown('<div class="nav-bar">', unsafe_allow_html=True)
243
+ col1, col2, col3 = st.columns([1, 2, 1])
244
+ with col1:
245
+ st.markdown('<div class="nav-item">Data Input</div>', unsafe_allow_html=True)
246
+ uploaded_file = st.file_uploader("Upload CSV Dataset", type=["csv"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  if uploaded_file:
248
+ df = pd.read_csv(uploaded_file)
249
+ st.session_state.vector_store = create_vector_store(convert_df_to_text(df))
250
+ st.success("Dataset uploaded!")
251
+ with col2:
252
+ st.markdown('<div class="nav-item">Navigation</div>', unsafe_allow_html=True)
253
+ nav_option = st.selectbox("Navigate", ["Model Builder", "Chat", "Train Model"], label_visibility="collapsed")
254
+ with col3:
255
+ st.markdown('<div class="nav-item">Info</div>', unsafe_allow_html=True)
256
+ st.write("Built with Streamlit & Groq")
257
+ st.markdown('</div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
+ # Main Content
260
+ if nav_option == "Model Builder":
261
+ st.markdown('<div class="card"><h2>Model Builder</h2></div>', unsafe_allow_html=True)
262
+ mode = st.selectbox("Domain", ["Legal", "Financial", "Marketing"])
263
+ model_builder_mode = st.radio("Mode", ["Prebuilt", "Custom"])
264
+ st.session_state.model_builder_mode = "prebuilt" if model_builder_mode == "Prebuilt" else "custom"
265
+
266
+ if st.session_state.model_builder_mode == "prebuilt":
267
+ for name, details in PREBUILT_MODELS.items():
268
+ if st.button(f"{name}: {details['description']}", key=name):
269
+ st.session_state.prebuilt_selection = name
270
+ st.session_state.model_config = details["architecture"]
271
+ if st.session_state.prebuilt_selection:
272
+ st.json(st.session_state.model_config)
273
  else:
274
+ st.session_state.custom_model_type = st.selectbox("Type", ["classification", "regression", "clustering"])
275
+ if st.session_state.custom_model_type != "clustering":
276
+ layer_count = st.number_input("Layers", min_value=1, value=1)
277
+ st.session_state.custom_layers = []
278
+ for i in range(int(layer_count)):
279
+ size = st.number_input(f"Layer {i+1} Size", min_value=1, value=100, key=f"size_{i}")
280
+ activation = st.selectbox(f"Layer {i+1} Activation", ["relu", "tanh"], key=f"act_{i}")
281
+ st.session_state.custom_layers.append((size, activation))
282
+ optimizer = st.selectbox("Optimizer", ["adam", "sgd"])
283
+ st.session_state.model_config = {"type": st.session_state.custom_model_type, "hidden_layers": st.session_state.custom_layers, "optimizer": optimizer, "learning_rate": 0.001}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  else:
285
+ st.session_state.model_config = {"type": "clustering", "n_clusters": st.number_input("Clusters", min_value=2, value=3)}
286
+ if st.button("Finalize"): st.json(st.session_state.model_config)
287
+
288
+ elif nav_option == "Chat":
289
+ st.markdown('<div class="chat-container"><h3>Chat with Grok</h3></div>', unsafe_allow_html=True)
290
+ mode = st.selectbox("Domain", ["Legal", "Financial", "Marketing"])
291
+ prompt = st.text_input("Ask a question:")
292
+ if prompt:
293
+ response = get_groq_response(prompt, mode)
294
+ st.session_state.chat_history.append({"role": "user", "content": prompt})
295
+ st.session_state.chat_history.append({"role": "bot", "content": response})
296
+ for msg in st.session_state.chat_history:
297
+ st.markdown(f'<div class={"user-message" if msg["role"] == "user" else "bot-message"}>{msg["content"]}</div>', unsafe_allow_html=True)
298
+
299
+ elif nav_option == "Train Model":
300
+ if uploaded_file and st.session_state.model_config:
301
+ st.markdown('<div class="card"><h2>Train Model</h2></div>', unsafe_allow_html=True)
302
+ df = pd.read_csv(uploaded_file)
303
+ X = df.drop(columns=[df.columns[-1]]) if st.session_state.model_config["type"] != "clustering" else df
304
+ y = df[df.columns[-1]] if st.session_state.model_config["type"] != "clustering" else None
305
+ if st.button("Train"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  scaler = StandardScaler()
307
+ X_scaled = scaler.fit_transform(X)
308
+ model = build_model_from_config(st.session_state.model_config, X_scaled, y)
309
+ if st.session_state.model_config["type"] != "clustering":
310
+ X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
311
+ model.fit(X_train, y_train)
312
+ y_pred = model.predict(X_test)
313
+ st.session_state.metrics = {"accuracy" if st.session_state.model_config["type"] == "classification" else "r2_score": accuracy_score(y_test, y_pred) if st.session_state.model_config["type"] == "classification" else r2_score(y_test, y_pred)}
314
+ else:
315
+ model.fit(X_scaled)
316
+ st.session_state.metrics = {"silhouette_score": silhouette_score(X_scaled, model.labels_)}
317
+ st.json(st.session_state.metrics)
318
+ else:
319
+ st.warning("Upload a dataset and configure a model first!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
  if __name__ == "__main__":
322
  main()