Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
-
import
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import plotly.express as px
|
5 |
from ydata_profiling import ProfileReport
|
|
|
6 |
import os
|
7 |
from dotenv import load_dotenv
|
8 |
from groq import Groq
|
@@ -16,6 +17,9 @@ from sklearn.preprocessing import StandardScaler, LabelEncoder
|
|
16 |
import tempfile
|
17 |
import json
|
18 |
|
|
|
|
|
|
|
19 |
# Load environment variables
|
20 |
load_dotenv()
|
21 |
|
@@ -25,128 +29,6 @@ client = Groq(api_key=os.getenv("GROQ_API_KEY"))
|
|
25 |
# Initialize HuggingFace embeddings
|
26 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
27 |
|
28 |
-
# Helper Functions (unchanged from your original)
|
29 |
-
def update_cleaned_data(df):
|
30 |
-
gr.State(value=df)
|
31 |
-
if 'data_versions' not in gr.State():
|
32 |
-
gr.State(value=[gr.State(value=df.copy())])
|
33 |
-
gr.State(value=gr.State(value=gr.State(value=df.copy())))
|
34 |
-
return df, "✅ Action completed successfully!"
|
35 |
-
|
36 |
-
def convert_df_to_text(df):
|
37 |
-
text = f"Dataset Summary: {df.shape[0]} rows, {df.shape[1]} columns\n"
|
38 |
-
text += f"Missing Values: {df.isna().sum().sum()}\n"
|
39 |
-
text += "Columns:\n"
|
40 |
-
for col in df.columns:
|
41 |
-
text += f"- {col} ({df[col].dtype}): "
|
42 |
-
if pd.api.types.is_numeric_dtype(df[col]):
|
43 |
-
text += f"Mean={df[col].mean():.2f}, Min={df[col].min()}, Max={df[col].max()}"
|
44 |
-
else:
|
45 |
-
text += f"Unique={df[col].nunique()}, Top={df[col].mode()[0] if not df[col].mode().empty else 'N/A'}"
|
46 |
-
text += f", Missing={df[col].isna().sum()}\n"
|
47 |
-
return text
|
48 |
-
|
49 |
-
def create_vector_store(df_text):
|
50 |
-
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as temp_file:
|
51 |
-
temp_file.write(df_text)
|
52 |
-
temp_path = temp_file.name
|
53 |
-
loader = TextLoader(temp_path)
|
54 |
-
documents = loader.load()
|
55 |
-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
|
56 |
-
texts = text_splitter.split_documents(documents)
|
57 |
-
vector_store = FAISS.from_documents(texts, embeddings)
|
58 |
-
os.unlink(temp_path)
|
59 |
-
return vector_store
|
60 |
-
|
61 |
-
def update_vector_store_with_plot(plot_text, existing_vector_store):
|
62 |
-
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as temp_file:
|
63 |
-
temp_file.write(plot_text)
|
64 |
-
temp_path = temp_file.name
|
65 |
-
loader = TextLoader(temp_path)
|
66 |
-
documents = loader.load()
|
67 |
-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
|
68 |
-
texts = text_splitter.split_documents(documents)
|
69 |
-
if existing_vector_store:
|
70 |
-
existing_vector_store.add_documents(texts)
|
71 |
-
else:
|
72 |
-
existing_vector_store = FAISS.from_documents(texts, embeddings)
|
73 |
-
os.unlink(temp_path)
|
74 |
-
return existing_vector_store
|
75 |
-
|
76 |
-
def extract_plot_data(plot_info, df):
|
77 |
-
plot_type = plot_info["type"]
|
78 |
-
x_col = plot_info["x"]
|
79 |
-
y_col = plot_info["y"] if "y" in plot_info else None
|
80 |
-
data = pd.read_json(plot_info["data"])
|
81 |
-
plot_text = f"Plot Type: {plot_type}\nX-Axis: {x_col}\n"
|
82 |
-
if y_col:
|
83 |
-
plot_text += f"Y-Axis: {y_col}\n"
|
84 |
-
if plot_type == "Scatter Plot" and y_col:
|
85 |
-
correlation = data[x_col].corr(data[y_col])
|
86 |
-
slope, intercept, r_value, p_value, std_err = stats.linregress(data[x_col].dropna(), data[y_col].dropna())
|
87 |
-
plot_text += f"Correlation: {correlation:.2f}\nLinear Regression: Slope={slope:.2f}, Intercept={intercept:.2f}, R²={r_value**2:.2f}, p-value={p_value:.4f}\n"
|
88 |
-
return plot_text
|
89 |
-
|
90 |
-
def get_chatbot_response(user_input, app_mode, vector_store=None, model="llama3-70b-8192"):
|
91 |
-
system_prompt = (
|
92 |
-
f"You are an AI assistant in Data-Vision Pro, on the '{app_mode}' page:\n"
|
93 |
-
"- Data Upload: Upload CSV/XLSX files, view stats, or generate reports.\n"
|
94 |
-
"- Data Cleaning: Clean data (e.g., handle missing values, encode variables).\n"
|
95 |
-
"- EDA: Visualize data (e.g., scatter plots, histograms) and analyze plots.\n"
|
96 |
-
"Use context if provided."
|
97 |
-
)
|
98 |
-
context = ""
|
99 |
-
if vector_store:
|
100 |
-
docs = vector_store.similarity_search(user_input, k=3)
|
101 |
-
if docs:
|
102 |
-
context = "\n\nContext:\n" + "\n".join([f"- {doc.page_content}" for doc in docs])
|
103 |
-
try:
|
104 |
-
response = client.chat.completions.create(
|
105 |
-
model=model,
|
106 |
-
messages=[
|
107 |
-
{"role": "system", "content": system_prompt + context},
|
108 |
-
{"role": "user", "content": user_input}
|
109 |
-
],
|
110 |
-
temperature=0.7,
|
111 |
-
max_tokens=1024
|
112 |
-
)
|
113 |
-
return response.choices[0].message.content
|
114 |
-
except Exception as e:
|
115 |
-
return f"Error: {str(e)}"
|
116 |
-
|
117 |
-
def parse_command(command, df, vector_store):
|
118 |
-
command = command.lower().strip()
|
119 |
-
if "drop columns" in command:
|
120 |
-
columns = command.replace("drop columns", "").strip().split(',')
|
121 |
-
valid_cols = [col.strip() for col in columns if col.strip() in df.columns]
|
122 |
-
if valid_cols:
|
123 |
-
df = df.drop(columns=valid_cols)
|
124 |
-
return update_cleaned_data(df)[0], f"Dropped columns: {', '.join(valid_cols)}"
|
125 |
-
return df, "No valid columns to drop."
|
126 |
-
elif "scatter plot of" in command:
|
127 |
-
match = re.search(r"([\w\s]+)\s+vs\s+([\w\s]+)", command)
|
128 |
-
if match:
|
129 |
-
x, y = match.group(1).strip(), match.group(2).strip()
|
130 |
-
if x in df.columns and y in df.columns:
|
131 |
-
fig = px.scatter(df, x=x, y=y)
|
132 |
-
plot_info = {"type": "Scatter Plot", "x": x, "y": y, "data": df[[x, y]].to_json()}
|
133 |
-
return df, fig, plot_info
|
134 |
-
return df, None, "Invalid scatter plot command."
|
135 |
-
elif "histogram of" in command:
|
136 |
-
col = command.replace("histogram of", "").strip()
|
137 |
-
if col in df.columns:
|
138 |
-
fig = px.histogram(df, x=col)
|
139 |
-
plot_info = {"type": "Histogram", "x": col, "data": df[[col]].to_json()}
|
140 |
-
return df, fig, plot_info
|
141 |
-
return df, None, "Invalid histogram command."
|
142 |
-
elif "analyze plot" in command and "last_plot" in gr.State():
|
143 |
-
plot_info = gr.State(value="last_plot")
|
144 |
-
plot_text = extract_plot_data(plot_info, df)
|
145 |
-
if vector_store:
|
146 |
-
vector_store = update_vector_store_with_plot(plot_text, vector_store)
|
147 |
-
return df, plot_text
|
148 |
-
return df, None, None
|
149 |
-
|
150 |
# Custom HTML/JS for Enhanced UI
|
151 |
custom_html = """
|
152 |
<style>
|
@@ -159,6 +41,11 @@ custom_html = """
|
|
159 |
--gold: #A87E01;
|
160 |
--shadow-color: rgba(0,0,0,0.1);
|
161 |
}
|
|
|
|
|
|
|
|
|
|
|
162 |
.header {
|
163 |
background: linear-gradient(90deg, var(--blue) 80%, var(--blue-dark) 100%);
|
164 |
color: white;
|
@@ -224,7 +111,7 @@ custom_html = """
|
|
224 |
</style>
|
225 |
<div class="header">
|
226 |
<h1>Data-Vision Pro</h1>
|
227 |
-
<div>Advanced Data Analysis with Groq</div>
|
228 |
</div>
|
229 |
<div class="nav-tabs">
|
230 |
<div class="nav-tab active" data-tab="upload">Data Upload</div>
|
@@ -233,18 +120,18 @@ custom_html = """
|
|
233 |
</div>
|
234 |
<div id="upload" class="tab-content active">
|
235 |
<h2>📤 Data Upload & Profiling</h2>
|
236 |
-
|
237 |
</div>
|
238 |
<div id="cleaning" class="tab-content">
|
239 |
<h2>🧹 Data Cleaning</h2>
|
240 |
-
|
241 |
</div>
|
242 |
<div id="eda" class="tab-content">
|
243 |
<h2>🔍 Interactive Data Explorer</h2>
|
244 |
-
|
245 |
</div>
|
246 |
<div class="chat-container">
|
247 |
-
<h2>💬 AI Chatbot Assistant</h2>
|
248 |
<div id="chat" style="max-height:300px; overflow-y:auto;"></div>
|
249 |
<input id="chat-input" placeholder="Ask me anything..." style="width:80%;">
|
250 |
<button onclick="sendChat()">Send</button>
|
@@ -271,12 +158,12 @@ custom_html = """
|
|
271 |
chat.innerHTML += `<div class="message user-message">${message}</div>`;
|
272 |
chat.scrollTop = chat.scrollHeight;
|
273 |
|
274 |
-
// Trigger
|
275 |
-
|
276 |
-
document.dispatchEvent(
|
277 |
}
|
278 |
|
279 |
-
// Listen for bot responses from
|
280 |
document.addEventListener('bot_response', (e) => {
|
281 |
const chat = document.getElementById('chat');
|
282 |
chat.innerHTML += `<div class="message bot-message">${e.detail}</div>`;
|
@@ -285,120 +172,252 @@ custom_html = """
|
|
285 |
</script>
|
286 |
"""
|
287 |
|
288 |
-
#
|
289 |
-
def
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
else:
|
297 |
-
df =
|
298 |
-
|
299 |
-
|
300 |
-
metrics_html = f"""
|
301 |
-
<div class="metrics">
|
302 |
-
<div class="metric">Rows: {df.shape[0]}</div>
|
303 |
-
<div class="metric">Columns: {df.shape[1]}</div>
|
304 |
-
<div class="metric">Missing: {df.isna().sum().sum()}</div>
|
305 |
-
</div>
|
306 |
-
"""
|
307 |
-
outputs["upload_output"] = gr.HTML(value=metrics_html + f"<pre>{df.head().to_string()}</pre>")
|
308 |
-
outputs["status"] = msg
|
309 |
-
outputs["cleaned_data"] = cleaned_data
|
310 |
-
outputs["vector_store"] = vector_store
|
311 |
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
|
|
|
|
|
|
335 |
|
336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
if chat_input:
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
response = f"Generated {
|
346 |
-
elif isinstance(
|
347 |
-
response =
|
348 |
if "Dropped columns" in response:
|
349 |
-
|
350 |
-
|
|
|
351 |
else:
|
352 |
-
response = get_chatbot_response(chat_input, app_mode, vector_store, model)
|
353 |
-
|
354 |
-
|
355 |
-
outputs["chat_output"] = gr.HTML(value=f"""
|
356 |
<script>
|
357 |
document.dispatchEvent(new CustomEvent('bot_response', {{ detail: {json.dumps(response)} }}));
|
358 |
</script>
|
359 |
-
""")
|
360 |
-
|
361 |
-
return outputs
|
362 |
-
|
363 |
-
# Gradio App
|
364 |
-
with gr.Blocks(title="Data-Vision Pro") as demo:
|
365 |
-
# State Variables
|
366 |
-
cleaned_data = gr.State()
|
367 |
-
vector_store = gr.State()
|
368 |
-
last_plot = gr.State()
|
369 |
-
|
370 |
-
# Custom HTML
|
371 |
-
gr.HTML(custom_html)
|
372 |
-
|
373 |
-
# Hidden App Mode Input
|
374 |
-
app_mode = gr.Textbox(value="Data Upload", elem_id="app-mode", visible=False)
|
375 |
-
|
376 |
-
# Inputs
|
377 |
-
with gr.Row():
|
378 |
-
file_input = gr.File(label="Upload CSV/XLSX")
|
379 |
-
model = gr.Dropdown(choices=["llama3-70b-8192", "llama3-8b-8192", "mixtral-8x7b-32768", "gemma-7b-it"], value="llama3-70b-8192", label="Groq Model")
|
380 |
-
|
381 |
-
# Outputs
|
382 |
-
upload_output = gr.HTML(label="Upload Results", elem_id="upload-output")
|
383 |
-
cleaning_output = gr.HTML(label="Cleaning Results", elem_id="cleaning-output")
|
384 |
-
eda_output = gr.HTML(label="EDA Results", elem_id="eda-output")
|
385 |
-
plot = gr.Plot(label="Visualization")
|
386 |
-
status = gr.Textbox(label="Status")
|
387 |
-
chat_output = gr.HTML(visible=False) # Hidden output to trigger JS
|
388 |
-
|
389 |
-
# Chat Input
|
390 |
-
chat_input = gr.Textbox(label="Chat with AI", interactive=True, placeholder="Ask me anything...")
|
391 |
-
|
392 |
-
# Event Handling
|
393 |
-
file_input.change(
|
394 |
-
main_interface,
|
395 |
-
inputs=[file_input, chat_input, cleaned_data, vector_store, last_plot, app_mode, model],
|
396 |
-
outputs=[upload_output, cleaning_output, eda_output, plot, status, chat_output, cleaned_data, vector_store, last_plot]
|
397 |
-
)
|
398 |
-
chat_input.submit(
|
399 |
-
main_interface,
|
400 |
-
inputs=[file_input, chat_input, cleaned_data, vector_store, last_plot, app_mode, model],
|
401 |
-
outputs=[upload_output, cleaning_output, eda_output, plot, status, chat_output, cleaned_data, vector_store, last_plot]
|
402 |
-
)
|
403 |
|
404 |
-
|
|
|
|
1 |
+
import streamlit as st
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import plotly.express as px
|
5 |
from ydata_profiling import ProfileReport
|
6 |
+
from streamlit_pandas_profiling import st_profile_report
|
7 |
import os
|
8 |
from dotenv import load_dotenv
|
9 |
from groq import Groq
|
|
|
17 |
import tempfile
|
18 |
import json
|
19 |
|
20 |
+
# Set page config
|
21 |
+
st.set_page_config(page_title="Data-Vision Pro", layout="wide")
|
22 |
+
|
23 |
# Load environment variables
|
24 |
load_dotenv()
|
25 |
|
|
|
29 |
# Initialize HuggingFace embeddings
|
30 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
# Custom HTML/JS for Enhanced UI
|
33 |
custom_html = """
|
34 |
<style>
|
|
|
41 |
--gold: #A87E01;
|
42 |
--shadow-color: rgba(0,0,0,0.1);
|
43 |
}
|
44 |
+
.stApp {
|
45 |
+
background: linear-gradient(135deg, var(--silver-light) 0%, var(--silver-dark) 100%);
|
46 |
+
font-family: 'Inter', sans-serif;
|
47 |
+
transition: all 0.3s ease;
|
48 |
+
}
|
49 |
.header {
|
50 |
background: linear-gradient(90deg, var(--blue) 80%, var(--blue-dark) 100%);
|
51 |
color: white;
|
|
|
111 |
</style>
|
112 |
<div class="header">
|
113 |
<h1>Data-Vision Pro</h1>
|
114 |
+
<div>Advanced Data Analysis with Groq Inference</div>
|
115 |
</div>
|
116 |
<div class="nav-tabs">
|
117 |
<div class="nav-tab active" data-tab="upload">Data Upload</div>
|
|
|
120 |
</div>
|
121 |
<div id="upload" class="tab-content active">
|
122 |
<h2>📤 Data Upload & Profiling</h2>
|
123 |
+
<div id="upload-output"></div>
|
124 |
</div>
|
125 |
<div id="cleaning" class="tab-content">
|
126 |
<h2>🧹 Data Cleaning</h2>
|
127 |
+
<div id="cleaning-output"></div>
|
128 |
</div>
|
129 |
<div id="eda" class="tab-content">
|
130 |
<h2>🔍 Interactive Data Explorer</h2>
|
131 |
+
<div id="eda-output"></div>
|
132 |
</div>
|
133 |
<div class="chat-container">
|
134 |
+
<h2>💬 AI Chatbot Assistant (RAG Enabled)</h2>
|
135 |
<div id="chat" style="max-height:300px; overflow-y:auto;"></div>
|
136 |
<input id="chat-input" placeholder="Ask me anything..." style="width:80%;">
|
137 |
<button onclick="sendChat()">Send</button>
|
|
|
158 |
chat.innerHTML += `<div class="message user-message">${message}</div>`;
|
159 |
chat.scrollTop = chat.scrollHeight;
|
160 |
|
161 |
+
// Trigger Streamlit event via hidden input
|
162 |
+
document.getElementById('chat-trigger').value = message;
|
163 |
+
document.getElementById('chat-trigger').dispatchEvent(new Event('change'));
|
164 |
}
|
165 |
|
166 |
+
// Listen for bot responses from Streamlit
|
167 |
document.addEventListener('bot_response', (e) => {
|
168 |
const chat = document.getElementById('chat');
|
169 |
chat.innerHTML += `<div class="message bot-message">${e.detail}</div>`;
|
|
|
172 |
</script>
|
173 |
"""
|
174 |
|
175 |
+
# Helper Functions (mostly unchanged)
|
176 |
+
def update_cleaned_data(df):
|
177 |
+
st.session_state.cleaned_data = df
|
178 |
+
if 'data_versions' not in st.session_state:
|
179 |
+
st.session_state.data_versions = [st.session_state.raw_data.copy()]
|
180 |
+
st.session_state.data_versions.append(df.copy())
|
181 |
+
st.session_state.dataset_text = convert_df_to_text(df)
|
182 |
+
return "✅ Action completed successfully!"
|
183 |
+
|
184 |
+
def convert_df_to_text(df):
|
185 |
+
text = f"Dataset Summary: {df.shape[0]} rows, {df.shape[1]} columns\n"
|
186 |
+
text += f"Missing Values: {df.isna().sum().sum()}\n"
|
187 |
+
text += "Columns:\n"
|
188 |
+
for col in df.columns:
|
189 |
+
text += f"- {col} ({df[col].dtype}): "
|
190 |
+
if pd.api.types.is_numeric_dtype(df[col]):
|
191 |
+
text += f"Mean={df[col].mean():.2f}, Min={df[col].min()}, Max={df[col].max()}"
|
192 |
else:
|
193 |
+
text += f"Unique={df[col].nunique()}, Top={df[col].mode()[0] if not df[col].mode().empty else 'N/A'}"
|
194 |
+
text += f", Missing={df[col].isna().sum()}\n"
|
195 |
+
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
+
def create_vector_store(df_text):
|
198 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as temp_file:
|
199 |
+
temp_file.write(df_text)
|
200 |
+
temp_path = temp_file.name
|
201 |
+
loader = TextLoader(temp_path)
|
202 |
+
documents = loader.load()
|
203 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
|
204 |
+
texts = text_splitter.split_documents(documents)
|
205 |
+
vector_store = FAISS.from_documents(texts, embeddings)
|
206 |
+
os.unlink(temp_path)
|
207 |
+
return vector_store
|
208 |
|
209 |
+
def update_vector_store_with_plot(plot_text, existing_vector_store):
|
210 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as temp_file:
|
211 |
+
temp_file.write(plot_text)
|
212 |
+
temp_path = temp_file.name
|
213 |
+
loader = TextLoader(temp_path)
|
214 |
+
documents = loader.load()
|
215 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
|
216 |
+
texts = text_splitter.split_documents(documents)
|
217 |
+
if existing_vector_store:
|
218 |
+
existing_vector_store.add_documents(texts)
|
219 |
+
else:
|
220 |
+
existing_vector_store = FAISS.from_documents(texts, embeddings)
|
221 |
+
os.unlink(temp_path)
|
222 |
+
return existing_vector_store
|
223 |
|
224 |
+
def extract_plot_data(plot_info, df):
|
225 |
+
plot_type = plot_info["type"]
|
226 |
+
x_col = plot_info["x"]
|
227 |
+
y_col = plot_info["y"] if "y" in plot_info else None
|
228 |
+
data = pd.read_json(plot_info["data"])
|
229 |
+
plot_text = f"Plot Type: {plot_type}\nX-Axis: {x_col}\n"
|
230 |
+
if y_col:
|
231 |
+
plot_text += f"Y-Axis: {y_col}\n"
|
232 |
+
if plot_type == "Scatter Plot" and y_col:
|
233 |
+
correlation = data[x_col].corr(data[y_col])
|
234 |
+
slope, intercept, r_value, p_value, std_err = stats.linregress(data[x_col].dropna(), data[y_col].dropna())
|
235 |
+
plot_text += f"Correlation: {correlation:.2f}\nLinear Regression: Slope={slope:.2f}, Intercept={intercept:.2f}, R²={r_value**2:.2f}, p-value={p_value:.4f}\n"
|
236 |
+
return plot_text
|
237 |
+
|
238 |
+
def get_chatbot_response(user_input, app_mode, vector_store=None, model="llama3-70b-8192"):
|
239 |
+
system_prompt = (
|
240 |
+
f"You are an AI assistant in Data-Vision Pro, on the '{app_mode}' page:\n"
|
241 |
+
"- Data Upload: Upload CSV/XLSX files, view stats, or generate reports.\n"
|
242 |
+
"- Data Cleaning: Clean data (e.g., handle missing values, encode variables).\n"
|
243 |
+
"- EDA: Visualize data (e.g., scatter plots, histograms) and analyze plots.\n"
|
244 |
+
"Use context if provided."
|
245 |
+
)
|
246 |
+
context = ""
|
247 |
+
if vector_store:
|
248 |
+
docs = vector_store.similarity_search(user_input, k=3)
|
249 |
+
if docs:
|
250 |
+
context = "\n\nContext:\n" + "\n".join([f"- {doc.page_content}" for doc in docs])
|
251 |
+
try:
|
252 |
+
response = client.chat.completions.create(
|
253 |
+
model=model,
|
254 |
+
messages=[
|
255 |
+
{"role": "system", "content": system_prompt + context},
|
256 |
+
{"role": "user", "content": user_input}
|
257 |
+
],
|
258 |
+
temperature=0.7,
|
259 |
+
max_tokens=1024
|
260 |
+
)
|
261 |
+
return response.choices[0].message.content
|
262 |
+
except Exception as e:
|
263 |
+
return f"Error: {str(e)}"
|
264 |
+
|
265 |
+
def parse_command(command, df):
|
266 |
+
command = command.lower().strip()
|
267 |
+
if "drop columns" in command:
|
268 |
+
columns = command.replace("drop columns", "").strip().split(',')
|
269 |
+
valid_cols = [col.strip() for col in columns if col.strip() in df.columns]
|
270 |
+
if valid_cols:
|
271 |
+
df = df.drop(columns=valid_cols)
|
272 |
+
update_cleaned_data(df)
|
273 |
+
return df, f"Dropped columns: {', '.join(valid_cols)}"
|
274 |
+
return df, "No valid columns to drop."
|
275 |
+
elif "scatter plot of" in command:
|
276 |
+
match = re.search(r"([\w\s]+)\s+vs\s+([\w\s]+)", command)
|
277 |
+
if match:
|
278 |
+
x, y = match.group(1).strip(), match.group(2).strip()
|
279 |
+
if x in df.columns and y in df.columns:
|
280 |
+
fig = px.scatter(df, x=x, y=y)
|
281 |
+
plot_info = {"type": "Scatter Plot", "x": x, "y": y, "data": df[[x, y]].to_json()}
|
282 |
+
return df, fig, plot_info
|
283 |
+
return df, None, "Invalid scatter plot command."
|
284 |
+
elif "histogram of" in command:
|
285 |
+
col = command.replace("histogram of", "").strip()
|
286 |
+
if col in df.columns:
|
287 |
+
fig = px.histogram(df, x=col)
|
288 |
+
plot_info = {"type": "Histogram", "x": col, "data": df[[col]].to_json()}
|
289 |
+
return df, fig, plot_info
|
290 |
+
return df, None, "Invalid histogram command."
|
291 |
+
elif "analyze plot" in command and "last_plot" in st.session_state:
|
292 |
+
plot_info = st.session_state.last_plot
|
293 |
+
plot_text = extract_plot_data(plot_info, df)
|
294 |
+
st.session_state.vector_store = update_vector_store_with_plot(plot_text, st.session_state.vector_store)
|
295 |
+
return df, plot_text
|
296 |
+
return df, None, None
|
297 |
+
|
298 |
+
# Main App
|
299 |
+
def main():
|
300 |
+
# Render Custom HTML
|
301 |
+
st.markdown(custom_html, unsafe_allow_html=True)
|
302 |
+
|
303 |
+
# Hidden Inputs for JS Interaction
|
304 |
+
if 'app_mode' not in st.session_state:
|
305 |
+
st.session_state.app_mode = "Data Upload"
|
306 |
+
app_mode = st.markdown('<input id="app-mode" type="hidden" value="Data Upload">', unsafe_allow_html=True)
|
307 |
+
chat_trigger = st.markdown('<input id="chat-trigger" type="hidden">', unsafe_allow_html=True)
|
308 |
+
|
309 |
+
# Sidebar
|
310 |
+
with st.sidebar:
|
311 |
+
st.markdown("### 🔮 Data-Vision Pro")
|
312 |
+
model = st.selectbox("Select Groq Model", ["llama3-70b-8192", "llama3-8b-8192", "mixtral-8x7b-32768", "gemma-7b-it"], index=0)
|
313 |
+
if 'cleaned_data' in st.session_state:
|
314 |
+
csv = st.session_state.cleaned_data.to_csv(index=False)
|
315 |
+
st.download_button(label="Download Cleaned Data", data=csv, file_name='cleaned_data.csv', mime='text/csv')
|
316 |
+
|
317 |
+
# Initialize Session State
|
318 |
+
if 'vector_store' not in st.session_state:
|
319 |
+
st.session_state.vector_store = None
|
320 |
+
if 'chat_history' not in st.session_state:
|
321 |
+
st.session_state.chat_history = []
|
322 |
+
|
323 |
+
# App Logic
|
324 |
+
if st.session_state.app_mode == "Data Upload":
|
325 |
+
uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"])
|
326 |
+
if uploaded_file:
|
327 |
+
if uploaded_file.name.endswith('.csv'):
|
328 |
+
df = pd.read_csv(uploaded_file)
|
329 |
+
else:
|
330 |
+
df = pd.read_excel(uploaded_file)
|
331 |
+
st.session_state.raw_data = df
|
332 |
+
st.session_state.cleaned_data = df.copy()
|
333 |
+
st.session_state.dataset_text = convert_df_to_text(df)
|
334 |
+
st.session_state.vector_store = create_vector_store(st.session_state.dataset_text)
|
335 |
+
if 'data_versions' not in st.session_state:
|
336 |
+
st.session_state.data_versions = [df.copy()]
|
337 |
+
metrics_html = f"""
|
338 |
+
<div class="metrics">
|
339 |
+
<div class="metric">Rows: {df.shape[0]}</div>
|
340 |
+
<div class="metric">Columns: {df.shape[1]}</div>
|
341 |
+
<div class="metric">Missing: {df.isna().sum().sum()}</div>
|
342 |
+
</div>
|
343 |
+
<pre>{df.head().to_string()}</pre>
|
344 |
+
"""
|
345 |
+
st.markdown(f'<div id="upload-output">{metrics_html}</div>', unsafe_allow_html=True)
|
346 |
+
if st.button("Generate Full Profile Report"):
|
347 |
+
with st.spinner("Generating report..."):
|
348 |
+
pr = ProfileReport(df, explorative=True)
|
349 |
+
st_profile_report(pr)
|
350 |
+
|
351 |
+
elif st.session_state.app_mode == "Data Cleaning":
|
352 |
+
if 'cleaned_data' not in st.session_state:
|
353 |
+
st.warning("Please upload data first.")
|
354 |
+
else:
|
355 |
+
df = st.session_state.cleaned_data
|
356 |
+
metrics_html = f"""
|
357 |
+
<div class="metrics">
|
358 |
+
<div class="metric">Rows: {df.shape[0]}</div>
|
359 |
+
<div class="metric">Columns: {df.shape[1]}</div>
|
360 |
+
<div class="metric">Missing: {df.isna().sum().sum()}</div>
|
361 |
+
</div>
|
362 |
+
"""
|
363 |
+
st.markdown(f'<div id="cleaning-output">{metrics_html}</div>', unsafe_allow_html=True)
|
364 |
+
cols_to_drop = st.multiselect("Select columns to drop", df.columns)
|
365 |
+
if cols_to_drop and st.button("Drop Columns"):
|
366 |
+
df = df.drop(columns=cols_to_drop)
|
367 |
+
update_cleaned_data(df)
|
368 |
+
st.rerun()
|
369 |
+
|
370 |
+
elif st.session_state.app_mode == "EDA":
|
371 |
+
if 'cleaned_data' not in st.session_state:
|
372 |
+
st.warning("Please upload data first.")
|
373 |
+
else:
|
374 |
+
df = st.session_state.cleaned_data
|
375 |
+
metrics_html = f"""
|
376 |
+
<div class="metrics">
|
377 |
+
<div class="metric">Rows: {df.shape[0]}</div>
|
378 |
+
<div class="metric">Columns: {df.shape[1]}</div>
|
379 |
+
<div class="metric">Missing: {df.isna().sum().sum()}</div>
|
380 |
+
</div>
|
381 |
+
"""
|
382 |
+
st.markdown(f'<div id="eda-output">{metrics_html}</div>', unsafe_allow_html=True)
|
383 |
+
plot_type = st.selectbox("Choose visualization type", ["Scatter Plot", "Histogram"])
|
384 |
+
x_axis = st.selectbox("X-axis", df.columns)
|
385 |
+
y_axis = st.selectbox("Y-axis", df.columns) if plot_type == "Scatter Plot" else None
|
386 |
+
if st.button("Generate Plot"):
|
387 |
+
if plot_type == "Scatter Plot" and x_axis and y_axis:
|
388 |
+
fig = px.scatter(df, x=x_axis, y=y_axis)
|
389 |
+
st.session_state.last_plot = {"type": "Scatter Plot", "x": x_axis, "y": y_axis, "data": df[[x_axis, y_axis]].to_json()}
|
390 |
+
elif plot_type == "Histogram" and x_axis:
|
391 |
+
fig = px.histogram(df, x=x_axis)
|
392 |
+
st.session_state.last_plot = {"type": "Histogram", "x": x_axis, "data": df[[x_axis]].to_json()}
|
393 |
+
st.plotly_chart(fig)
|
394 |
+
|
395 |
+
# Chatbot Logic
|
396 |
+
chat_input = st.session_state.get('chat_input', '')
|
397 |
if chat_input:
|
398 |
+
st.session_state.chat_history.append({"role": "user", "content": chat_input})
|
399 |
+
df = st.session_state.cleaned_data if 'cleaned_data' in st.session_state else pd.DataFrame()
|
400 |
+
new_df, result, plot_info = parse_command(chat_input, df)
|
401 |
+
if isinstance(result, px.scatter._chart_types.Scatter) or isinstance(result, px.histogram._chart_types.Histogram):
|
402 |
+
st.plotly_chart(result)
|
403 |
+
st.session_state.last_plot = plot_info
|
404 |
+
st.session_state.vector_store = update_vector_store_with_plot(extract_plot_data(plot_info, new_df), st.session_state.vector_store)
|
405 |
+
response = f"Generated {plot_info['type'].lower()}."
|
406 |
+
elif isinstance(result, str):
|
407 |
+
response = result
|
408 |
if "Dropped columns" in response:
|
409 |
+
st.session_state.cleaned_data = new_df
|
410 |
+
st.session_state.vector_store = create_vector_store(convert_df_to_text(new_df))
|
411 |
+
st.rerun()
|
412 |
else:
|
413 |
+
response = get_chatbot_response(chat_input, st.session_state.app_mode, st.session_state.vector_store, model)
|
414 |
+
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
415 |
+
st.markdown(f"""
|
|
|
416 |
<script>
|
417 |
document.dispatchEvent(new CustomEvent('bot_response', {{ detail: {json.dumps(response)} }}));
|
418 |
</script>
|
419 |
+
""", unsafe_allow_html=True)
|
420 |
+
st.session_state.chat_input = st.text_input("Chat with AI", key="chat_input", on_change=lambda: None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
|
422 |
+
if __name__ == "__main__":
|
423 |
+
main()
|