Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,8 @@
|
|
1 |
-
import
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import plotly.express as px
|
5 |
-
import plotly.graph_objects as go
|
6 |
from ydata_profiling import ProfileReport
|
7 |
-
from streamlit_pandas_profiling import st_profile_report
|
8 |
import os
|
9 |
from dotenv import load_dotenv
|
10 |
from groq import Groq
|
@@ -14,11 +12,9 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
14 |
from langchain.embeddings import HuggingFaceEmbeddings
|
15 |
import re
|
16 |
from scipy import stats
|
17 |
-
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
18 |
import tempfile
|
19 |
-
|
20 |
-
# Set page config as the first Streamlit command
|
21 |
-
st.set_page_config(page_title="Data-Vision Pro", layout="wide")
|
22 |
|
23 |
# Load environment variables
|
24 |
load_dotenv()
|
@@ -26,172 +22,16 @@ load_dotenv()
|
|
26 |
# Initialize Groq client
|
27 |
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
|
28 |
|
29 |
-
# Initialize HuggingFace embeddings
|
30 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
31 |
|
32 |
-
#
|
33 |
-
st.markdown("""
|
34 |
-
<style>
|
35 |
-
:root {
|
36 |
-
--silver-light: #D8D8D8;
|
37 |
-
--silver-dark: #B8B8B8;
|
38 |
-
--blue: #5C89BC;
|
39 |
-
--blue-dark: #4E73A0;
|
40 |
-
--blue-light: #6EA8E0;
|
41 |
-
--gold: #A87E01;
|
42 |
-
--text-color: #333333;
|
43 |
-
--shadow-color: rgba(0,0,0,0.1);
|
44 |
-
--shadow-color-stronger: rgba(0,0,0,0.2);
|
45 |
-
}
|
46 |
-
.stApp {
|
47 |
-
background: linear-gradient(135deg, var(--silver-light) 0%, var(--silver-dark) 100%);
|
48 |
-
font-family: 'Inter', sans-serif;
|
49 |
-
max-width: 900px;
|
50 |
-
margin: 0 auto;
|
51 |
-
padding: 10px;
|
52 |
-
transition: all 0.3s ease;
|
53 |
-
}
|
54 |
-
.header {
|
55 |
-
background: linear-gradient(90deg, var(--blue) 80%, var(--blue-dark) 100%);
|
56 |
-
color: white;
|
57 |
-
padding: 20px;
|
58 |
-
border-radius: 16px 16px 0 0;
|
59 |
-
box-shadow: 0 4px 12px var(--shadow-color);
|
60 |
-
text-align: center;
|
61 |
-
transition: transform 0.2s ease;
|
62 |
-
}
|
63 |
-
.header:hover {
|
64 |
-
transform: translateY(-2px);
|
65 |
-
box-shadow: 0 4px 12px var(--shadow-color-stronger);
|
66 |
-
}
|
67 |
-
.header-title {
|
68 |
-
font-size: 1.5rem;
|
69 |
-
font-weight: 700;
|
70 |
-
margin: 0;
|
71 |
-
}
|
72 |
-
.header-subtitle {
|
73 |
-
font-size: 0.9rem;
|
74 |
-
margin-top: 8px;
|
75 |
-
opacity: 0.9;
|
76 |
-
}
|
77 |
-
.sidebar .sidebar-content {
|
78 |
-
background-color: white;
|
79 |
-
border-radius: 16px;
|
80 |
-
box-shadow: 0 6px 16px var(--shadow-color);
|
81 |
-
padding: 20px;
|
82 |
-
transition: box-shadow 0.3s ease;
|
83 |
-
}
|
84 |
-
.sidebar .sidebar-content:hover {
|
85 |
-
box-shadow: 0 8px 20px var(--shadow-color-stronger);
|
86 |
-
}
|
87 |
-
.chat-container {
|
88 |
-
background-color: white;
|
89 |
-
border-radius: 16px;
|
90 |
-
box-shadow: 0 6px 16px var(--shadow-color);
|
91 |
-
padding: 20px;
|
92 |
-
margin-top: 25px;
|
93 |
-
transition: box-shadow 0.3s ease;
|
94 |
-
}
|
95 |
-
.chat-container:hover {
|
96 |
-
box-shadow: 0 8px 20px var(--shadow-color-stronger);
|
97 |
-
}
|
98 |
-
.user-message {
|
99 |
-
background: linear-gradient(45deg, var(--blue), var(--blue-light));
|
100 |
-
color: white;
|
101 |
-
border-radius: 20px 20px 6px 20px;
|
102 |
-
padding: 14px 18px;
|
103 |
-
margin-left: auto;
|
104 |
-
max-width: 80%;
|
105 |
-
margin-bottom: 12px;
|
106 |
-
box-shadow: 0 2px 8px var(--blue-dark);
|
107 |
-
transition: transform 0.2s ease;
|
108 |
-
}
|
109 |
-
.user-message:hover {
|
110 |
-
transform: scale(1.02);
|
111 |
-
}
|
112 |
-
.bot-message {
|
113 |
-
background-color: #F0F0F0;
|
114 |
-
color: var(--text-color);
|
115 |
-
border-radius: 20px 20px 20px 6px;
|
116 |
-
padding: 14px 18px;
|
117 |
-
margin-right: auto;
|
118 |
-
max-width: 80%;
|
119 |
-
margin-bottom: 12px;
|
120 |
-
box-shadow: 0 2px 8px var(--shadow-color);
|
121 |
-
transition: transform 0.2s ease;
|
122 |
-
}
|
123 |
-
.bot-message:hover {
|
124 |
-
transform: scale(1.02);
|
125 |
-
}
|
126 |
-
.footer {
|
127 |
-
text-align: center;
|
128 |
-
margin-top: 20px;
|
129 |
-
color: var(--text-color);
|
130 |
-
font-size: 0.8rem;
|
131 |
-
}
|
132 |
-
.tech-badge {
|
133 |
-
display: inline-block;
|
134 |
-
background-color: #E6ECEF;
|
135 |
-
color: var(--blue);
|
136 |
-
padding: 4px 8px;
|
137 |
-
border-radius: 12px;
|
138 |
-
font-size: 0.7rem;
|
139 |
-
margin: 0 4px;
|
140 |
-
}
|
141 |
-
h2 {
|
142 |
-
color: var(--blue);
|
143 |
-
border-bottom: 2px solid var(--gold);
|
144 |
-
padding-bottom: 5px;
|
145 |
-
font-size: 1.5rem;
|
146 |
-
font-weight: 700;
|
147 |
-
}
|
148 |
-
.stButton > button {
|
149 |
-
background-color: var(--gold);
|
150 |
-
color: white;
|
151 |
-
border-radius: 12px;
|
152 |
-
padding: 10px 20px;
|
153 |
-
border: none;
|
154 |
-
box-shadow: 0 4px 12px var(--shadow-color);
|
155 |
-
font-weight: 600;
|
156 |
-
transition: all 0.3s ease;
|
157 |
-
}
|
158 |
-
.stButton > button:hover {
|
159 |
-
background-color: #8C6B01;
|
160 |
-
transform: translateY(-2px);
|
161 |
-
box-shadow: 0 6px 16px var(--shadow-color-stronger);
|
162 |
-
}
|
163 |
-
@media (max-width: 768px) {
|
164 |
-
.header-title {
|
165 |
-
font-size: 1.2rem;
|
166 |
-
}
|
167 |
-
.header-subtitle {
|
168 |
-
font-size: 0.8rem;
|
169 |
-
}
|
170 |
-
.chat-container, .sidebar .sidebar-content {
|
171 |
-
padding: 10px;
|
172 |
-
}
|
173 |
-
.stApp {
|
174 |
-
padding: 5px;
|
175 |
-
}
|
176 |
-
h2 {
|
177 |
-
font-size: 1.2rem;
|
178 |
-
}
|
179 |
-
}
|
180 |
-
</style>
|
181 |
-
""", unsafe_allow_html=True)
|
182 |
-
|
183 |
-
# Helper Functions
|
184 |
-
def enhance_section_title(title):
|
185 |
-
st.markdown(f"<h2 style='border-bottom: 2px solid var(--gold); padding-bottom: 5px; color: var(--blue);'>{title}</h2>", unsafe_allow_html=True)
|
186 |
-
|
187 |
def update_cleaned_data(df):
|
188 |
-
|
189 |
-
if 'data_versions' not in
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
st.success("✅ Action completed successfully!")
|
194 |
-
st.rerun()
|
195 |
|
196 |
def convert_df_to_text(df):
|
197 |
text = f"Dataset Summary: {df.shape[0]} rows, {df.shape[1]} columns\n"
|
@@ -238,61 +78,33 @@ def extract_plot_data(plot_info, df):
|
|
238 |
x_col = plot_info["x"]
|
239 |
y_col = plot_info["y"] if "y" in plot_info else None
|
240 |
data = pd.read_json(plot_info["data"])
|
241 |
-
plot_text = f"Plot Type: {plot_type}\n"
|
242 |
-
plot_text += f"X-Axis: {x_col}\n"
|
243 |
if y_col:
|
244 |
plot_text += f"Y-Axis: {y_col}\n"
|
245 |
if plot_type == "Scatter Plot" and y_col:
|
246 |
correlation = data[x_col].corr(data[y_col])
|
247 |
slope, intercept, r_value, p_value, std_err = stats.linregress(data[x_col].dropna(), data[y_col].dropna())
|
248 |
-
plot_text += f"Correlation: {correlation:.2f}\n"
|
249 |
-
plot_text += f"Linear Regression: Slope={slope:.2f}, Intercept={intercept:.2f}, R²={r_value**2:.2f}, p-value={p_value:.4f}\n"
|
250 |
-
plot_text += f"X Stats: Mean={data[x_col].mean():.2f}, Std={data[x_col].std():.2f}, Min={data[x_col].min():.2f}, Max={data[x_col].max():.2f}\n"
|
251 |
-
plot_text += f"Y Stats: Mean={data[y_col].mean():.2f}, Std={data[y_col].std():.2f}, Min={data[y_col].min():.2f}, Max={data[y_col].max():.2f}\n"
|
252 |
-
elif plot_type == "Histogram":
|
253 |
-
plot_text += f"Stats: Mean={data[x_col].mean():.2f}, Median={data[x_col].median():.2f}, Std={data[x_col].std():.2f}\n"
|
254 |
-
plot_text += f"Skewness: {data[x_col].skew():.2f}\n"
|
255 |
-
plot_text += f"Range: [{data[x_col].min():.2f}, {data[x_col].max():.2f}]\n"
|
256 |
-
elif plot_type == "Box Plot" and y_col:
|
257 |
-
q1, q3 = data[y_col].quantile(0.25), data[y_col].quantile(0.75)
|
258 |
-
iqr = q3 - q1
|
259 |
-
plot_text += f"Y Stats: Median={data[y_col].median():.2f}, Q1={q1:.2f}, Q3={q3:.2f}, IQR={iqr:.2f}\n"
|
260 |
-
plot_text += f"Outliers: {len(data[y_col][(data[y_col] < q1 - 1.5 * iqr) | (data[y_col] > q3 + 1.5 * iqr)])} potential outliers\n"
|
261 |
-
elif plot_type == "Line Chart" and y_col:
|
262 |
-
plot_text += f"Y Stats: Mean={data[y_col].mean():.2f}, Std={data[y_col].std():.2f}, Trend={'increasing' if data[y_col].iloc[-1] > data[y_col].iloc[0] else 'decreasing'}\n"
|
263 |
-
elif plot_type == "Bar Chart":
|
264 |
-
plot_text += f"Counts: {data[x_col].value_counts().to_dict()}\n"
|
265 |
-
elif plot_type == "Correlation Matrix":
|
266 |
-
corr = data.corr()
|
267 |
-
plot_text += "Correlation Matrix:\n"
|
268 |
-
for col1 in corr.columns:
|
269 |
-
for col2 in corr.index:
|
270 |
-
if col1 < col2:
|
271 |
-
plot_text += f"{col1} vs {col2}: {corr.loc[col2, col1]:.2f}\n"
|
272 |
return plot_text
|
273 |
|
274 |
def get_chatbot_response(user_input, app_mode, vector_store=None, model="llama3-70b-8192"):
|
275 |
system_prompt = (
|
276 |
-
"You are an AI assistant in Data-Vision Pro,
|
277 |
-
|
278 |
-
"-
|
279 |
-
"-
|
280 |
-
"
|
281 |
-
"When analyzing plots, provide detailed insights based on numerical data extracted from them."
|
282 |
)
|
283 |
context = ""
|
284 |
if vector_store:
|
285 |
docs = vector_store.similarity_search(user_input, k=3)
|
286 |
if docs:
|
287 |
-
context = "\n\
|
288 |
-
system_prompt += f"Use this dataset and plot context to augment your response:\n{context}"
|
289 |
-
else:
|
290 |
-
system_prompt += "No dataset or plot data is loaded. Assist based on app functionality."
|
291 |
try:
|
292 |
response = client.chat.completions.create(
|
293 |
model=model,
|
294 |
messages=[
|
295 |
-
{"role": "system", "content": system_prompt},
|
296 |
{"role": "user", "content": user_input}
|
297 |
],
|
298 |
temperature=0.7,
|
@@ -302,379 +114,291 @@ def get_chatbot_response(user_input, app_mode, vector_store=None, model="llama3-
|
|
302 |
except Exception as e:
|
303 |
return f"Error: {str(e)}"
|
304 |
|
305 |
-
|
306 |
-
def drop_columns(columns):
|
307 |
-
if 'cleaned_data' in st.session_state:
|
308 |
-
df = st.session_state.cleaned_data.copy()
|
309 |
-
columns_to_drop = [col.strip() for col in columns.split(',')]
|
310 |
-
valid_columns = [col for col in columns_to_drop if col in df.columns]
|
311 |
-
if valid_columns:
|
312 |
-
df.drop(valid_columns, axis=1, inplace=True)
|
313 |
-
update_cleaned_data(df)
|
314 |
-
return f"Dropped columns: {', '.join(valid_columns)}"
|
315 |
-
else:
|
316 |
-
return "No valid columns found to drop."
|
317 |
-
return "No dataset loaded."
|
318 |
-
|
319 |
-
def generate_scatter_plot(params):
|
320 |
-
df = st.session_state.cleaned_data
|
321 |
-
match = re.search(r"([\w\s]+)\s+vs\s+([\w\s]+)", params)
|
322 |
-
if match and len(match.groups()) >= 2:
|
323 |
-
x_axis, y_axis = match.group(1).strip(), match.group(2).strip()
|
324 |
-
if x_axis in df.columns and y_axis in df.columns:
|
325 |
-
fig = px.scatter(df, x=x_axis, y=y_axis, title=f'Scatter Plot of {x_axis} vs {y_axis}')
|
326 |
-
st.plotly_chart(fig)
|
327 |
-
st.session_state.last_plot = {"type": "Scatter Plot", "x": x_axis, "y": y_axis, "data": df[[x_axis, y_axis]].to_json()}
|
328 |
-
return f"Generated scatter plot of {x_axis} vs {y_axis}"
|
329 |
-
return "Invalid columns for scatter plot."
|
330 |
-
|
331 |
-
def generate_histogram(params):
|
332 |
-
df = st.session_state.cleaned_data
|
333 |
-
x_axis = params.strip()
|
334 |
-
if x_axis in df.columns:
|
335 |
-
fig = px.histogram(df, x=x_axis, title=f'Histogram of {x_axis}')
|
336 |
-
st.plotly_chart(fig)
|
337 |
-
st.session_state.last_plot = {"type": "Histogram", "x": x_axis, "data": df[[x_axis]].to_json()}
|
338 |
-
return f"Generated histogram of {x_axis}"
|
339 |
-
return "Invalid column for histogram."
|
340 |
-
|
341 |
-
def analyze_plot():
|
342 |
-
if "last_plot" not in st.session_state:
|
343 |
-
return "No plot available to analyze."
|
344 |
-
plot_info = st.session_state.last_plot
|
345 |
-
df = pd.read_json(plot_info["data"])
|
346 |
-
plot_text = extract_plot_data(plot_info, df)
|
347 |
-
return f"Analysis of the last plot:\n{plot_text}"
|
348 |
-
|
349 |
-
def parse_command(command):
|
350 |
command = command.lower().strip()
|
351 |
-
if "drop columns" in command
|
352 |
-
columns = command.replace("drop columns", "").
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
413 |
|
414 |
-
#
|
415 |
-
if
|
416 |
-
|
417 |
-
|
418 |
-
st.session_state.chat_history = []
|
419 |
-
|
420 |
-
# Display Dataset Preview
|
421 |
-
display_dataset_preview()
|
422 |
-
|
423 |
-
# App Pages
|
424 |
-
if app_mode == "Data Upload":
|
425 |
-
st.header("📤 Data Upload & Profiling")
|
426 |
-
uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"], key="file_uploader")
|
427 |
-
if uploaded_file:
|
428 |
-
st.session_state.pop('raw_data', None)
|
429 |
-
st.session_state.pop('cleaned_data', None)
|
430 |
-
st.session_state.pop('data_versions', None)
|
431 |
-
try:
|
432 |
-
if uploaded_file.name.endswith('.csv'):
|
433 |
-
df = pd.read_csv(uploaded_file)
|
434 |
-
else:
|
435 |
-
df = pd.read_excel(uploaded_file)
|
436 |
-
if df.empty:
|
437 |
-
st.error("Uploaded file is empty.")
|
438 |
-
st.stop()
|
439 |
-
st.session_state.raw_data = df
|
440 |
-
st.session_state.cleaned_data = df.copy()
|
441 |
-
st.session_state.dataset_text = convert_df_to_text(df)
|
442 |
-
st.session_state.vector_store = create_vector_store(st.session_state.dataset_text)
|
443 |
-
if 'data_versions' not in st.session_state:
|
444 |
-
st.session_state.data_versions = [df.copy()]
|
445 |
-
col1, col2, col3 = st.columns(3)
|
446 |
-
with col1: st.metric("Rows", df.shape[0])
|
447 |
-
with col2: st.metric("Columns", df.shape[1])
|
448 |
-
with col3: st.metric("Missing Values", df.isna().sum().sum())
|
449 |
-
if st.checkbox("Show Data Preview"):
|
450 |
-
st.dataframe(df.head(10), use_container_width=True)
|
451 |
-
if st.button("Generate Full Profile Report"):
|
452 |
-
with st.spinner("Generating report..."):
|
453 |
-
pr = ProfileReport(df, explorative=True)
|
454 |
-
st_profile_report(pr)
|
455 |
-
st.success("✅ Data loaded successfully!")
|
456 |
-
except Exception as e:
|
457 |
-
st.error(f"An error occurred: {str(e)}")
|
458 |
-
|
459 |
-
elif app_mode == "Data Cleaning":
|
460 |
-
st.header("🧹 Smart Data Cleaning")
|
461 |
-
if 'raw_data' not in st.session_state:
|
462 |
-
st.warning("Please upload data first in the Data Upload section.")
|
463 |
-
st.stop()
|
464 |
-
if 'cleaned_data' in st.session_state:
|
465 |
-
df = st.session_state.cleaned_data.copy()
|
466 |
else:
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
if columns_to_drop and st.button("Confirm Column Removal"):
|
540 |
-
new_df = df.copy()
|
541 |
-
new_df = new_df.drop(columns=columns_to_drop)
|
542 |
-
update_cleaned_data(new_df)
|
543 |
-
|
544 |
-
enhance_section_title("🔢 Encoding Options")
|
545 |
-
encoding_method = st.radio("Choose encoding method", ("Label Encoding", "One-Hot Encoding"))
|
546 |
-
data_to_encode = st.multiselect("Select columns to encode", df.select_dtypes(include='object').columns)
|
547 |
-
if data_to_encode and st.button("Apply Encoding"):
|
548 |
-
new_df = df.copy()
|
549 |
-
if encoding_method == "Label Encoding":
|
550 |
-
for col in data_to_encode:
|
551 |
-
le = LabelEncoder()
|
552 |
-
new_df[col] = le.fit_transform(new_df[col].astype(str))
|
553 |
-
elif encoding_method == "One-Hot Encoding":
|
554 |
-
new_df = pd.get_dummies(new_df, columns=data_to_encode, drop_first=True, dtype=int)
|
555 |
-
update_cleaned_data(new_df)
|
556 |
-
|
557 |
-
enhance_section_title("📏 StandardScaler")
|
558 |
-
scale_cols = st.multiselect("Select numerical columns to scale", df.select_dtypes(include=np.number).columns)
|
559 |
-
if scale_cols and st.button("Apply StandardScaler"):
|
560 |
-
new_df = df.copy()
|
561 |
-
scaler = StandardScaler()
|
562 |
-
new_df[scale_cols] = scaler.fit_transform(new_df[scale_cols])
|
563 |
-
update_cleaned_data(new_df)
|
564 |
-
|
565 |
-
elif app_mode == "EDA":
|
566 |
-
st.header("🔍 Interactive Data Explorer")
|
567 |
-
if 'cleaned_data' not in st.session_state:
|
568 |
-
st.warning("Please upload and clean data first.")
|
569 |
-
st.stop()
|
570 |
-
df = st.session_state.cleaned_data.copy()
|
571 |
-
|
572 |
-
enhance_section_title("Dataset Overview")
|
573 |
-
with st.container():
|
574 |
-
col1, col2, col3, col4 = st.columns(4)
|
575 |
-
col1.metric("Total Rows", df.shape[0])
|
576 |
-
col2.metric("Total Columns", df.shape[1])
|
577 |
-
missing_percentage = df.isna().sum().sum() / df.size * 100
|
578 |
-
col3.metric("Missing Values", f"{df.isna().sum().sum()} ({missing_percentage:.1f}%)")
|
579 |
-
col4.metric("Duplicates", df.duplicated().sum())
|
580 |
-
|
581 |
-
tab1, tab2, tab3 = st.tabs(["Quick Preview", "Column Types", "Missing Matrix"])
|
582 |
-
with tab1:
|
583 |
-
st.write("First few rows of the dataset:")
|
584 |
-
st.dataframe(df.head(), use_container_width=True)
|
585 |
-
with tab2:
|
586 |
-
st.write("Column Data Types:")
|
587 |
-
type_counts = df.dtypes.value_counts().reset_index()
|
588 |
-
type_counts.columns = ['Type', 'Count']
|
589 |
-
st.dataframe(type_counts, use_container_width=True)
|
590 |
-
with tab3:
|
591 |
-
st.write("Missing Values Matrix:")
|
592 |
-
fig_missing = px.imshow(df.isna(), color_continuous_scale=['#e0e0e0', '#66c2a5'])
|
593 |
-
fig_missing.update_layout(coloraxis_colorscale=[[0, 'lightgrey'], [1, '#FF4B4B']])
|
594 |
-
st.plotly_chart(fig_missing, use_container_width=True)
|
595 |
-
|
596 |
-
enhance_section_title("Interactive Visualization Builder")
|
597 |
-
with st.container():
|
598 |
-
col1, col2 = st.columns([1, 3])
|
599 |
-
with col1:
|
600 |
-
plot_type = st.selectbox("Choose visualization type", [
|
601 |
-
"Scatter Plot", "Histogram", "Box Plot", "Line Chart", "Bar Chart", "Correlation Matrix"
|
602 |
-
])
|
603 |
-
x_axis = st.selectbox("X-axis", df.columns) if plot_type != "Correlation Matrix" else None
|
604 |
-
y_axis = st.selectbox("Y-axis", df.columns) if plot_type in ["Scatter Plot", "Box Plot", "Line Chart"] else None
|
605 |
-
color_by = st.selectbox("Color encoding", ["None"] + df.columns.tolist(), format_func=lambda x: "No color" if x == "None" else x) if plot_type != "Correlation Matrix" else None
|
606 |
-
|
607 |
-
with col2:
|
608 |
-
try:
|
609 |
-
fig = None
|
610 |
-
if plot_type == "Scatter Plot" and x_axis and y_axis:
|
611 |
-
fig = px.scatter(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Scatter Plot of {x_axis} vs {y_axis}')
|
612 |
-
elif plot_type == "Histogram" and x_axis:
|
613 |
-
fig = px.histogram(df, x=x_axis, color=color_by if color_by != "None" else None, nbins=30, title=f'Histogram of {x_axis}')
|
614 |
-
elif plot_type == "Box Plot" and x_axis and y_axis:
|
615 |
-
fig = px.box(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Box Plot of {x_axis} vs {y_axis}')
|
616 |
-
elif plot_type == "Line Chart" and x_axis and y_axis:
|
617 |
-
fig = px.line(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Line Chart of {x_axis} vs {y_axis}')
|
618 |
-
elif plot_type == "Bar Chart" and x_axis:
|
619 |
-
fig = px.bar(df, x=x_axis, color=color_by if color_by != "None" else None, title=f'Bar Chart of {x_axis}')
|
620 |
-
elif plot_type == "Correlation Matrix":
|
621 |
-
numeric_df = df.select_dtypes(include=np.number)
|
622 |
-
if len(numeric_df.columns) > 1:
|
623 |
-
corr = numeric_df.corr()
|
624 |
-
fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu_r', zmin=-1, zmax=1, title='Correlation Matrix')
|
625 |
-
|
626 |
-
if fig:
|
627 |
-
fig.update_layout(template="plotly_white")
|
628 |
-
st.plotly_chart(fig, use_container_width=True)
|
629 |
-
st.session_state.last_plot = {
|
630 |
-
"type": plot_type,
|
631 |
-
"x": x_axis,
|
632 |
-
"y": y_axis,
|
633 |
-
"data": df[[x_axis, y_axis]].to_json() if y_axis else df[[x_axis]].to_json()
|
634 |
-
}
|
635 |
-
plot_text = extract_plot_data(st.session_state.last_plot, df)
|
636 |
-
st.session_state.vector_store = update_vector_store_with_plot(plot_text, st.session_state.vector_store)
|
637 |
-
with st.expander("Extracted Plot Data"):
|
638 |
-
st.text(plot_text)
|
639 |
-
else:
|
640 |
-
st.error("Please provide required inputs for the selected plot type.")
|
641 |
-
except Exception as e:
|
642 |
-
st.error(f"Couldn't create visualization: {str(e)}")
|
643 |
-
|
644 |
-
# Chatbot Section
|
645 |
-
st.markdown("---")
|
646 |
-
st.markdown('<div class="chat-container">', unsafe_allow_html=True)
|
647 |
-
st.subheader("💬 AI Chatbot Assistant (RAG Enabled)")
|
648 |
-
st.info("Ask about your data or app features! Try: 'drop columns X, Y', 'scatter plot of X vs Y', 'analyze plot'")
|
649 |
-
|
650 |
-
for message in st.session_state.chat_history:
|
651 |
-
with st.chat_message(message["role"]):
|
652 |
-
st.markdown(f'<div class="{message["role"]}-message">{message["content"]}</div>', unsafe_allow_html=True)
|
653 |
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
with st.spinner("Processing..."):
|
660 |
-
func, param = parse_command(user_input)
|
661 |
-
if func:
|
662 |
-
response = func(param) if param else func(None)
|
663 |
-
else:
|
664 |
-
response = get_chatbot_response(user_input, app_mode, st.session_state.vector_store, model)
|
665 |
-
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
666 |
-
with st.chat_message("assistant"):
|
667 |
-
st.markdown(f'<div class="bot-message">{response}</div>', unsafe_allow_html=True)
|
668 |
|
669 |
-
|
|
|
|
|
|
|
670 |
|
671 |
-
#
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
678 |
|
679 |
-
|
680 |
-
main()
|
|
|
1 |
+
import gradio as gr
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import plotly.express as px
|
|
|
5 |
from ydata_profiling import ProfileReport
|
|
|
6 |
import os
|
7 |
from dotenv import load_dotenv
|
8 |
from groq import Groq
|
|
|
12 |
from langchain.embeddings import HuggingFaceEmbeddings
|
13 |
import re
|
14 |
from scipy import stats
|
15 |
+
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
16 |
import tempfile
|
17 |
+
import json
|
|
|
|
|
18 |
|
19 |
# Load environment variables
|
20 |
load_dotenv()
|
|
|
22 |
# Initialize Groq client
|
23 |
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
|
24 |
|
25 |
+
# Initialize HuggingFace embeddings
|
26 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
27 |
|
28 |
+
# Helper Functions (unchanged from your original)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
def update_cleaned_data(df):
|
30 |
+
gr.State(value=df)
|
31 |
+
if 'data_versions' not in gr.State():
|
32 |
+
gr.State(value=[gr.State(value=df.copy())])
|
33 |
+
gr.State(value=gr.State(value=gr.State(value=df.copy())))
|
34 |
+
return df, "✅ Action completed successfully!"
|
|
|
|
|
35 |
|
36 |
def convert_df_to_text(df):
|
37 |
text = f"Dataset Summary: {df.shape[0]} rows, {df.shape[1]} columns\n"
|
|
|
78 |
x_col = plot_info["x"]
|
79 |
y_col = plot_info["y"] if "y" in plot_info else None
|
80 |
data = pd.read_json(plot_info["data"])
|
81 |
+
plot_text = f"Plot Type: {plot_type}\nX-Axis: {x_col}\n"
|
|
|
82 |
if y_col:
|
83 |
plot_text += f"Y-Axis: {y_col}\n"
|
84 |
if plot_type == "Scatter Plot" and y_col:
|
85 |
correlation = data[x_col].corr(data[y_col])
|
86 |
slope, intercept, r_value, p_value, std_err = stats.linregress(data[x_col].dropna(), data[y_col].dropna())
|
87 |
+
plot_text += f"Correlation: {correlation:.2f}\nLinear Regression: Slope={slope:.2f}, Intercept={intercept:.2f}, R²={r_value**2:.2f}, p-value={p_value:.4f}\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
return plot_text
|
89 |
|
90 |
def get_chatbot_response(user_input, app_mode, vector_store=None, model="llama3-70b-8192"):
|
91 |
system_prompt = (
|
92 |
+
f"You are an AI assistant in Data-Vision Pro, on the '{app_mode}' page:\n"
|
93 |
+
"- Data Upload: Upload CSV/XLSX files, view stats, or generate reports.\n"
|
94 |
+
"- Data Cleaning: Clean data (e.g., handle missing values, encode variables).\n"
|
95 |
+
"- EDA: Visualize data (e.g., scatter plots, histograms) and analyze plots.\n"
|
96 |
+
"Use context if provided."
|
|
|
97 |
)
|
98 |
context = ""
|
99 |
if vector_store:
|
100 |
docs = vector_store.similarity_search(user_input, k=3)
|
101 |
if docs:
|
102 |
+
context = "\n\nContext:\n" + "\n".join([f"- {doc.page_content}" for doc in docs])
|
|
|
|
|
|
|
103 |
try:
|
104 |
response = client.chat.completions.create(
|
105 |
model=model,
|
106 |
messages=[
|
107 |
+
{"role": "system", "content": system_prompt + context},
|
108 |
{"role": "user", "content": user_input}
|
109 |
],
|
110 |
temperature=0.7,
|
|
|
114 |
except Exception as e:
|
115 |
return f"Error: {str(e)}"
|
116 |
|
117 |
+
def parse_command(command, df, vector_store):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
command = command.lower().strip()
|
119 |
+
if "drop columns" in command:
|
120 |
+
columns = command.replace("drop columns", "").strip().split(',')
|
121 |
+
valid_cols = [col.strip() for col in columns if col.strip() in df.columns]
|
122 |
+
if valid_cols:
|
123 |
+
df = df.drop(columns=valid_cols)
|
124 |
+
return update_cleaned_data(df)[0], f"Dropped columns: {', '.join(valid_cols)}"
|
125 |
+
return df, "No valid columns to drop."
|
126 |
+
elif "scatter plot of" in command:
|
127 |
+
match = re.search(r"([\w\s]+)\s+vs\s+([\w\s]+)", command)
|
128 |
+
if match:
|
129 |
+
x, y = match.group(1).strip(), match.group(2).strip()
|
130 |
+
if x in df.columns and y in df.columns:
|
131 |
+
fig = px.scatter(df, x=x, y=y)
|
132 |
+
plot_info = {"type": "Scatter Plot", "x": x, "y": y, "data": df[[x, y]].to_json()}
|
133 |
+
return df, fig, plot_info
|
134 |
+
return df, None, "Invalid scatter plot command."
|
135 |
+
elif "histogram of" in command:
|
136 |
+
col = command.replace("histogram of", "").strip()
|
137 |
+
if col in df.columns:
|
138 |
+
fig = px.histogram(df, x=col)
|
139 |
+
plot_info = {"type": "Histogram", "x": col, "data": df[[col]].to_json()}
|
140 |
+
return df, fig, plot_info
|
141 |
+
return df, None, "Invalid histogram command."
|
142 |
+
elif "analyze plot" in command and "last_plot" in gr.State():
|
143 |
+
plot_info = gr.State(value="last_plot")
|
144 |
+
plot_text = extract_plot_data(plot_info, df)
|
145 |
+
if vector_store:
|
146 |
+
vector_store = update_vector_store_with_plot(plot_text, vector_store)
|
147 |
+
return df, plot_text
|
148 |
+
return df, None, None
|
149 |
+
|
150 |
+
# Custom HTML/JS for Enhanced UI
|
151 |
+
custom_html = """
|
152 |
+
<style>
|
153 |
+
:root {
|
154 |
+
--silver-light: #D8D8D8;
|
155 |
+
--silver-dark: #B8B8B8;
|
156 |
+
--blue: #5C89BC;
|
157 |
+
--blue-dark: #4E73A0;
|
158 |
+
--blue-light: #6EA8E0;
|
159 |
+
--gold: #A87E01;
|
160 |
+
--shadow-color: rgba(0,0,0,0.1);
|
161 |
+
}
|
162 |
+
.header {
|
163 |
+
background: linear-gradient(90deg, var(--blue) 80%, var(--blue-dark) 100%);
|
164 |
+
color: white;
|
165 |
+
padding: 20px;
|
166 |
+
border-radius: 16px 16px 0 0;
|
167 |
+
text-align: center;
|
168 |
+
box-shadow: 0 4px 12px var(--shadow-color);
|
169 |
+
}
|
170 |
+
.nav-tabs {
|
171 |
+
display: flex;
|
172 |
+
justify-content: space-around;
|
173 |
+
padding: 10px 0;
|
174 |
+
background: var(--silver-light);
|
175 |
+
border-bottom: 2px solid var(--gold);
|
176 |
+
}
|
177 |
+
.nav-tab {
|
178 |
+
padding: 10px 20px;
|
179 |
+
cursor: pointer;
|
180 |
+
color: var(--blue);
|
181 |
+
font-weight: 600;
|
182 |
+
transition: all 0.3s ease;
|
183 |
+
}
|
184 |
+
.nav-tab.active {
|
185 |
+
color: var(--gold);
|
186 |
+
border-bottom: 2px solid var(--gold);
|
187 |
+
background: white;
|
188 |
+
border-radius: 8px 8px 0 0;
|
189 |
+
}
|
190 |
+
.tab-content { display: none; padding: 20px; }
|
191 |
+
.tab-content.active { display: block; }
|
192 |
+
.chat-container {
|
193 |
+
background: white;
|
194 |
+
border-radius: 16px;
|
195 |
+
padding: 20px;
|
196 |
+
box-shadow: 0 6px 16px var(--shadow-color);
|
197 |
+
margin-top: 20px;
|
198 |
+
}
|
199 |
+
.message {
|
200 |
+
padding: 10px 15px;
|
201 |
+
margin: 5px 0;
|
202 |
+
border-radius: 12px;
|
203 |
+
max-width: 80%;
|
204 |
+
}
|
205 |
+
.user-message {
|
206 |
+
background: linear-gradient(45deg, var(--blue), var(--blue-light));
|
207 |
+
color: white;
|
208 |
+
margin-left: auto;
|
209 |
+
}
|
210 |
+
.bot-message {
|
211 |
+
background: #F0F0F0;
|
212 |
+
margin-right: auto;
|
213 |
+
}
|
214 |
+
.metrics {
|
215 |
+
display: flex;
|
216 |
+
gap: 20px;
|
217 |
+
margin: 10px 0;
|
218 |
+
}
|
219 |
+
.metric {
|
220 |
+
background: #F0F0F0;
|
221 |
+
padding: 10px;
|
222 |
+
border-radius: 8px;
|
223 |
+
}
|
224 |
+
</style>
|
225 |
+
<div class="header">
|
226 |
+
<h1>Data-Vision Pro</h1>
|
227 |
+
<div>Advanced Data Analysis with Groq</div>
|
228 |
+
</div>
|
229 |
+
<div class="nav-tabs">
|
230 |
+
<div class="nav-tab active" data-tab="upload">Data Upload</div>
|
231 |
+
<div class="nav-tab" data-tab="cleaning">Data Cleaning</div>
|
232 |
+
<div class="nav-tab" data-tab="eda">EDA</div>
|
233 |
+
</div>
|
234 |
+
<div id="upload" class="tab-content active">
|
235 |
+
<h2>📤 Data Upload & Profiling</h2>
|
236 |
+
<!-- Gradio components will be injected here -->
|
237 |
+
</div>
|
238 |
+
<div id="cleaning" class="tab-content">
|
239 |
+
<h2>🧹 Data Cleaning</h2>
|
240 |
+
<!-- Gradio components will be injected here -->
|
241 |
+
</div>
|
242 |
+
<div id="eda" class="tab-content">
|
243 |
+
<h2>🔍 Interactive Data Explorer</h2>
|
244 |
+
<!-- Gradio components will be injected here -->
|
245 |
+
</div>
|
246 |
+
<div class="chat-container">
|
247 |
+
<h2>💬 AI Chatbot Assistant</h2>
|
248 |
+
<div id="chat" style="max-height:300px; overflow-y:auto;"></div>
|
249 |
+
<input id="chat-input" placeholder="Ask me anything..." style="width:80%;">
|
250 |
+
<button onclick="sendChat()">Send</button>
|
251 |
+
</div>
|
252 |
+
<script>
|
253 |
+
// Tab Navigation
|
254 |
+
document.querySelectorAll('.nav-tab').forEach(tab => {
|
255 |
+
tab.addEventListener('click', () => {
|
256 |
+
document.querySelectorAll('.nav-tab').forEach(t => t.classList.remove('active'));
|
257 |
+
document.querySelectorAll('.tab-content').forEach(c => c.classList.remove('active'));
|
258 |
+
tab.classList.add('active');
|
259 |
+
document.getElementById(tab.dataset.tab).classList.add('active');
|
260 |
+
document.getElementById('app-mode').value = tab.dataset.tab.charAt(0).toUpperCase() + tab.dataset.tab.slice(1);
|
261 |
+
});
|
262 |
+
});
|
263 |
+
|
264 |
+
// Chat Functionality
|
265 |
+
function sendChat() {
|
266 |
+
const input = document.getElementById('chat-input');
|
267 |
+
const message = input.value.trim();
|
268 |
+
if (!message) return;
|
269 |
+
input.value = '';
|
270 |
+
const chat = document.getElementById('chat');
|
271 |
+
chat.innerHTML += `<div class="message user-message">${message}</div>`;
|
272 |
+
chat.scrollTop = chat.scrollHeight;
|
273 |
+
|
274 |
+
// Trigger Gradio event
|
275 |
+
const event = new CustomEvent('chat_submit', { detail: message });
|
276 |
+
document.dispatchEvent(event);
|
277 |
+
}
|
278 |
+
|
279 |
+
// Listen for bot responses from Gradio
|
280 |
+
document.addEventListener('bot_response', (e) => {
|
281 |
+
const chat = document.getElementById('chat');
|
282 |
+
chat.innerHTML += `<div class="message bot-message">${e.detail}</div>`;
|
283 |
+
chat.scrollTop = chat.scrollHeight;
|
284 |
+
});
|
285 |
+
</script>
|
286 |
+
"""
|
287 |
+
|
288 |
+
# Gradio Interface
|
289 |
+
def main_interface(file, chat_input, cleaned_data, vector_store, last_plot, app_mode, model):
|
290 |
+
outputs = {}
|
291 |
|
292 |
+
# Data Upload
|
293 |
+
if file and app_mode == "Data Upload":
|
294 |
+
if file.name.endswith('.csv'):
|
295 |
+
df = pd.read_csv(file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
else:
|
297 |
+
df = pd.read_excel(file)
|
298 |
+
cleaned_data, msg = update_cleaned_data(df)
|
299 |
+
vector_store = create_vector_store(convert_df_to_text(df))
|
300 |
+
metrics_html = f"""
|
301 |
+
<div class="metrics">
|
302 |
+
<div class="metric">Rows: {df.shape[0]}</div>
|
303 |
+
<div class="metric">Columns: {df.shape[1]}</div>
|
304 |
+
<div class="metric">Missing: {df.isna().sum().sum()}</div>
|
305 |
+
</div>
|
306 |
+
"""
|
307 |
+
outputs["upload_output"] = gr.HTML(value=metrics_html + f"<pre>{df.head().to_string()}</pre>")
|
308 |
+
outputs["status"] = msg
|
309 |
+
outputs["cleaned_data"] = cleaned_data
|
310 |
+
outputs["vector_store"] = vector_store
|
311 |
+
|
312 |
+
# Data Cleaning
|
313 |
+
elif app_mode == "Data Cleaning" and cleaned_data is not None:
|
314 |
+
df = cleaned_data
|
315 |
+
metrics_html = f"""
|
316 |
+
<div class="metrics">
|
317 |
+
<div class="metric">Rows: {df.shape[0]}</div>
|
318 |
+
<div class="metric">Columns: {df.shape[1]}</div>
|
319 |
+
<div class="metric">Missing: {df.isna().sum().sum()}</div>
|
320 |
+
</div>
|
321 |
+
"""
|
322 |
+
outputs["cleaning_output"] = gr.HTML(value=metrics_html)
|
323 |
+
|
324 |
+
# EDA
|
325 |
+
elif app_mode == "EDA" and cleaned_data is not None:
|
326 |
+
df = cleaned_data
|
327 |
+
metrics_html = f"""
|
328 |
+
<div class="metrics">
|
329 |
+
<div class="metric">Rows: {df.shape[0]}</div>
|
330 |
+
<div class="metric">Columns: {df.shape[1]}</div>
|
331 |
+
<div class="metric">Missing: {df.isna().sum().sum()}</div>
|
332 |
+
</div>
|
333 |
+
"""
|
334 |
+
outputs["eda_output"] = gr.HTML(value=metrics_html)
|
335 |
+
|
336 |
+
# Chatbot
|
337 |
+
if chat_input:
|
338 |
+
df = cleaned_data if cleaned_data is not None else pd.DataFrame()
|
339 |
+
new_df, plot_fig, plot_info_or_msg = parse_command(chat_input, df, vector_store)
|
340 |
+
if plot_fig:
|
341 |
+
outputs["plot"] = plot_fig
|
342 |
+
outputs["last_plot"] = plot_info_or_msg
|
343 |
+
vector_store = update_vector_store_with_plot(extract_plot_data(plot_info_or_msg, df), vector_store)
|
344 |
+
outputs["vector_store"] = vector_store
|
345 |
+
response = f"Generated {plot_info_or_msg['type'].lower()}."
|
346 |
+
elif isinstance(plot_info_or_msg, str):
|
347 |
+
response = plot_info_or_msg
|
348 |
+
if "Dropped columns" in response:
|
349 |
+
outputs["cleaned_data"] = new_df
|
350 |
+
outputs["vector_store"] = create_vector_store(convert_df_to_text(new_df))
|
351 |
+
else:
|
352 |
+
response = get_chatbot_response(chat_input, app_mode, vector_store, model)
|
353 |
+
outputs["status"] = response
|
354 |
+
# Trigger JS event for chatbot
|
355 |
+
outputs["chat_output"] = gr.HTML(value=f"""
|
356 |
+
<script>
|
357 |
+
document.dispatchEvent(new CustomEvent('bot_response', {{ detail: {json.dumps(response)} }}));
|
358 |
+
</script>
|
359 |
+
""")
|
360 |
+
|
361 |
+
return outputs
|
362 |
+
|
363 |
+
# Gradio App
|
364 |
+
with gr.Blocks(title="Data-Vision Pro") as demo:
|
365 |
+
# State Variables
|
366 |
+
cleaned_data = gr.State()
|
367 |
+
vector_store = gr.State()
|
368 |
+
last_plot = gr.State()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
|
370 |
+
# Custom HTML
|
371 |
+
gr.HTML(custom_html)
|
372 |
+
|
373 |
+
# Hidden App Mode Input
|
374 |
+
app_mode = gr.Textbox(value="Data Upload", elem_id="app-mode", visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
|
376 |
+
# Inputs
|
377 |
+
with gr.Row():
|
378 |
+
file_input = gr.File(label="Upload CSV/XLSX")
|
379 |
+
model = gr.Dropdown(choices=["llama3-70b-8192", "llama3-8b-8192", "mixtral-8x7b-32768", "gemma-7b-it"], value="llama3-70b-8192", label="Groq Model")
|
380 |
|
381 |
+
# Outputs
|
382 |
+
upload_output = gr.HTML(label="Upload Results", elem_id="upload-output")
|
383 |
+
cleaning_output = gr.HTML(label="Cleaning Results", elem_id="cleaning-output")
|
384 |
+
eda_output = gr.HTML(label="EDA Results", elem_id="eda-output")
|
385 |
+
plot = gr.Plot(label="Visualization")
|
386 |
+
status = gr.Textbox(label="Status")
|
387 |
+
chat_output = gr.HTML(visible=False) # Hidden output to trigger JS
|
388 |
+
|
389 |
+
# Chat Input
|
390 |
+
chat_input = gr.Textbox(label="Chat with AI", interactive=True, placeholder="Ask me anything...")
|
391 |
+
|
392 |
+
# Event Handling
|
393 |
+
file_input.change(
|
394 |
+
main_interface,
|
395 |
+
inputs=[file_input, chat_input, cleaned_data, vector_store, last_plot, app_mode, model],
|
396 |
+
outputs=[upload_output, cleaning_output, eda_output, plot, status, chat_output, cleaned_data, vector_store, last_plot]
|
397 |
+
)
|
398 |
+
chat_input.submit(
|
399 |
+
main_interface,
|
400 |
+
inputs=[file_input, chat_input, cleaned_data, vector_store, last_plot, app_mode, model],
|
401 |
+
outputs=[upload_output, cleaning_output, eda_output, plot, status, chat_output, cleaned_data, vector_store, last_plot]
|
402 |
+
)
|
403 |
|
404 |
+
demo.launch()
|
|