Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,35 +1,30 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
-
import numpy as np
|
4 |
import plotly.express as px
|
5 |
-
import
|
|
|
|
|
|
|
|
|
|
|
6 |
from ydata_profiling import ProfileReport
|
7 |
from streamlit_pandas_profiling import st_profile_report
|
8 |
-
import os
|
9 |
-
from dotenv import load_dotenv
|
10 |
from groq import Groq
|
11 |
from langchain_community.vectorstores import FAISS
|
12 |
-
from langchain_community.document_loaders import TextLoader
|
13 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
14 |
-
from
|
15 |
-
import
|
16 |
-
|
17 |
-
from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
|
18 |
import tempfile
|
19 |
|
20 |
-
#
|
21 |
-
st.set_page_config(page_title="Data-Vision Pro", layout="wide")
|
22 |
-
|
23 |
-
# Load environment variables
|
24 |
-
load_dotenv()
|
25 |
-
|
26 |
-
# Initialize Groq client
|
27 |
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
|
28 |
-
|
29 |
-
# Initialize HuggingFace embeddings for FAISS
|
30 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
31 |
|
32 |
-
#
|
|
|
|
|
|
|
33 |
st.markdown("""
|
34 |
<style>
|
35 |
:root {
|
@@ -41,7 +36,7 @@ st.markdown("""
|
|
41 |
.stApp {
|
42 |
background-color: var(--silver);
|
43 |
font-family: 'Inter', sans-serif;
|
44 |
-
max-width:
|
45 |
margin: 0 auto;
|
46 |
padding: 10px;
|
47 |
}
|
@@ -50,69 +45,71 @@ st.markdown("""
|
|
50 |
color: white;
|
51 |
padding: 15px;
|
52 |
border-radius: 5px;
|
53 |
-
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
54 |
text-align: center;
|
|
|
55 |
}
|
56 |
.header-title {
|
57 |
-
font-size: 1.
|
58 |
font-weight: 700;
|
59 |
margin: 0;
|
60 |
}
|
61 |
.header-subtitle {
|
62 |
-
font-size:
|
63 |
margin-top: 5px;
|
64 |
}
|
65 |
-
.
|
66 |
background-color: white;
|
67 |
border-radius: 5px;
|
68 |
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
69 |
padding: 15px;
|
|
|
|
|
|
|
|
|
70 |
}
|
71 |
-
.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
background-color: white;
|
73 |
border-radius: 5px;
|
74 |
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
padding: 15px;
|
76 |
margin-top: 20px;
|
|
|
77 |
}
|
78 |
.user-message {
|
79 |
background-color: var(--blue);
|
80 |
color: white;
|
81 |
-
border-radius:
|
82 |
-
padding:
|
83 |
-
margin-left: auto;
|
84 |
max-width: 80%;
|
|
|
85 |
margin-bottom: 10px;
|
86 |
}
|
87 |
.bot-message {
|
88 |
background-color: #F0F0F0;
|
89 |
color: var(--text-color);
|
90 |
-
border-radius:
|
91 |
-
padding:
|
92 |
-
margin-right: auto;
|
93 |
max-width: 80%;
|
|
|
94 |
margin-bottom: 10px;
|
95 |
}
|
96 |
-
.footer {
|
97 |
-
text-align: center;
|
98 |
-
margin-top: 20px;
|
99 |
-
color: var(--text-color);
|
100 |
-
font-size: 0.8rem;
|
101 |
-
}
|
102 |
-
.tech-badge {
|
103 |
-
display: inline-block;
|
104 |
-
background-color: #E6ECEF;
|
105 |
-
color: var(--blue);
|
106 |
-
padding: 4px 8px;
|
107 |
-
border-radius: 12px;
|
108 |
-
font-size: 0.7rem;
|
109 |
-
margin: 0 4px;
|
110 |
-
}
|
111 |
-
h2 {
|
112 |
-
color: var(--blue);
|
113 |
-
border-bottom: 2px solid var(--gold);
|
114 |
-
padding-bottom: 5px;
|
115 |
-
}
|
116 |
.stButton > button {
|
117 |
background-color: var(--gold);
|
118 |
color: white;
|
@@ -126,48 +123,76 @@ st.markdown("""
|
|
126 |
}
|
127 |
@media (max-width: 768px) {
|
128 |
.header-title {
|
129 |
-
font-size: 1.
|
130 |
}
|
131 |
.header-subtitle {
|
132 |
-
font-size: 0.
|
|
|
|
|
|
|
|
|
133 |
}
|
134 |
-
.
|
|
|
|
|
|
|
|
|
|
|
135 |
padding: 10px;
|
136 |
}
|
137 |
.stApp {
|
138 |
padding: 5px;
|
139 |
}
|
140 |
-
h2 {
|
141 |
-
font-size: 1.2rem;
|
142 |
-
}
|
143 |
}
|
144 |
-
|
|
|
|
|
|
|
145 |
""", unsafe_allow_html=True)
|
146 |
|
147 |
-
#
|
148 |
-
|
149 |
-
st.
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
st.session_state.
|
156 |
-
|
157 |
-
st.
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
|
|
160 |
def convert_df_to_text(df):
|
161 |
text = f"Dataset Summary: {df.shape[0]} rows, {df.shape[1]} columns\n"
|
162 |
text += f"Missing Values: {df.isna().sum().sum()}\n"
|
163 |
-
text += "Columns:\n"
|
164 |
for col in df.columns:
|
165 |
-
text += f"- {col} ({df[col].dtype}): "
|
166 |
-
if pd.api.types.is_numeric_dtype(df[col]):
|
167 |
-
text += f"Mean={df[col].mean():.2f}, Min={df[col].min()}, Max={df[col].max()}"
|
168 |
-
else:
|
169 |
-
text += f"Unique={df[col].nunique()}, Top={df[col].mode()[0] if not df[col].mode().empty else 'N/A'}"
|
170 |
-
text += f", Missing={df[col].isna().sum()}\n"
|
171 |
return text
|
172 |
|
173 |
def create_vector_store(df_text):
|
@@ -176,469 +201,122 @@ def create_vector_store(df_text):
|
|
176 |
temp_path = temp_file.name
|
177 |
loader = TextLoader(temp_path)
|
178 |
documents = loader.load()
|
179 |
-
|
180 |
-
texts = text_splitter.split_documents(documents)
|
181 |
vector_store = FAISS.from_documents(texts, embeddings)
|
182 |
os.unlink(temp_path)
|
183 |
return vector_store
|
184 |
|
185 |
-
def
|
186 |
-
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as temp_file:
|
187 |
-
temp_file.write(plot_text)
|
188 |
-
temp_path = temp_file.name
|
189 |
-
loader = TextLoader(temp_path)
|
190 |
-
documents = loader.load()
|
191 |
-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
|
192 |
-
texts = text_splitter.split_documents(documents)
|
193 |
-
if existing_vector_store:
|
194 |
-
existing_vector_store.add_documents(texts)
|
195 |
-
else:
|
196 |
-
existing_vector_store = FAISS.from_documents(texts, embeddings)
|
197 |
-
os.unlink(temp_path)
|
198 |
-
return existing_vector_store
|
199 |
-
|
200 |
-
def extract_plot_data(plot_info, df):
|
201 |
-
plot_type = plot_info["type"]
|
202 |
-
x_col = plot_info["x"]
|
203 |
-
y_col = plot_info["y"] if "y" in plot_info else None
|
204 |
-
data = pd.read_json(plot_info["data"])
|
205 |
-
plot_text = f"Plot Type: {plot_type}\n"
|
206 |
-
plot_text += f"X-Axis: {x_col}\n"
|
207 |
-
if y_col:
|
208 |
-
plot_text += f"Y-Axis: {y_col}\n"
|
209 |
-
if plot_type == "Scatter Plot" and y_col:
|
210 |
-
correlation = data[x_col].corr(data[y_col])
|
211 |
-
slope, intercept, r_value, p_value, std_err = stats.linregress(data[x_col].dropna(), data[y_col].dropna())
|
212 |
-
plot_text += f"Correlation: {correlation:.2f}\n"
|
213 |
-
plot_text += f"Linear Regression: Slope={slope:.2f}, Intercept={intercept:.2f}, R²={r_value**2:.2f}, p-value={p_value:.4f}\n"
|
214 |
-
plot_text += f"X Stats: Mean={data[x_col].mean():.2f}, Std={data[x_col].std():.2f}, Min={data[x_col].min():.2f}, Max={data[x_col].max():.2f}\n"
|
215 |
-
plot_text += f"Y Stats: Mean={data[y_col].mean():.2f}, Std={data[y_col].std():.2f}, Min={data[y_col].min():.2f}, Max={data[y_col].max():.2f}\n"
|
216 |
-
elif plot_type == "Histogram":
|
217 |
-
plot_text += f"Stats: Mean={data[x_col].mean():.2f}, Median={data[x_col].median():.2f}, Std={data[x_col].std():.2f}\n"
|
218 |
-
plot_text += f"Skewness: {data[x_col].skew():.2f}\n"
|
219 |
-
plot_text += f"Range: [{data[x_col].min():.2f}, {data[x_col].max():.2f}]\n"
|
220 |
-
elif plot_type == "Box Plot" and y_col:
|
221 |
-
q1, q3 = data[y_col].quantile(0.25), data[y_col].quantile(0.75)
|
222 |
-
iqr = q3 - q1
|
223 |
-
plot_text += f"Y Stats: Median={data[y_col].median():.2f}, Q1={q1:.2f}, Q3={q3:.2f}, IQR={iqr:.2f}\n"
|
224 |
-
plot_text += f"Outliers: {len(data[y_col][(data[y_col] < q1 - 1.5 * iqr) | (data[y_col] > q3 + 1.5 * iqr)])} potential outliers\n"
|
225 |
-
elif plot_type == "Line Chart" and y_col:
|
226 |
-
plot_text += f"Y Stats: Mean={data[y_col].mean():.2f}, Std={data[y_col].std():.2f}, Trend={'increasing' if data[y_col].iloc[-1] > data[y_col].iloc[0] else 'decreasing'}\n"
|
227 |
-
elif plot_type == "Bar Chart":
|
228 |
-
plot_text += f"Counts: {data[x_col].value_counts().to_dict()}\n"
|
229 |
-
elif plot_type == "Correlation Matrix":
|
230 |
-
corr = data.corr()
|
231 |
-
plot_text += "Correlation Matrix:\n"
|
232 |
-
for col1 in corr.columns:
|
233 |
-
for col2 in corr.index:
|
234 |
-
if col1 < col2:
|
235 |
-
plot_text += f"{col1} vs {col2}: {corr.loc[col2, col1]:.2f}\n"
|
236 |
-
return plot_text
|
237 |
-
|
238 |
-
def get_chatbot_response(user_input, app_mode, vector_store=None, model="llama3-70b-8192"):
|
239 |
-
system_prompt = (
|
240 |
-
"You are an AI assistant in Data-Vision Pro, a data analysis app with RAG capabilities. "
|
241 |
-
f"The user is on the '{app_mode}' page:\n"
|
242 |
-
"- **Data Upload**: Upload CSV/XLSX files, view stats, or generate reports.\n"
|
243 |
-
"- **Data Cleaning**: Clean data (e.g., handle missing values, encode variables).\n"
|
244 |
-
"- **EDA**: Visualize data (e.g., scatter plots, histograms) and analyze plots.\n"
|
245 |
-
"When analyzing plots, provide detailed insights based on numerical data extracted from them."
|
246 |
-
)
|
247 |
context = ""
|
248 |
-
if vector_store:
|
249 |
-
docs = vector_store.similarity_search(
|
250 |
-
|
251 |
-
context = "\n\nDataset and Plot Context:\n" + "\n".join([f"- {doc.page_content}" for doc in docs])
|
252 |
-
system_prompt += f"Use this dataset and plot context to augment your response:\n{context}"
|
253 |
-
else:
|
254 |
-
system_prompt += "No dataset or plot data is loaded. Assist based on app functionality."
|
255 |
try:
|
256 |
response = client.chat.completions.create(
|
257 |
-
model=
|
258 |
messages=[
|
259 |
-
{"role": "system", "content":
|
260 |
-
{"role": "user", "content":
|
261 |
-
]
|
262 |
-
|
263 |
-
|
264 |
-
)
|
265 |
-
return response.choices[0].message.content
|
266 |
except Exception as e:
|
267 |
return f"Error: {str(e)}"
|
268 |
|
269 |
-
|
270 |
-
|
271 |
-
if
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
return "No dataset loaded."
|
282 |
-
|
283 |
-
def generate_scatter_plot(params):
|
284 |
-
df = st.session_state.cleaned_data
|
285 |
-
match = re.search(r"([\w\s]+)\s+vs\s+([\w\s]+)", params)
|
286 |
-
if match and len(match.groups()) >= 2:
|
287 |
-
x_axis, y_axis = match.group(1).strip(), match.group(2).strip()
|
288 |
-
if x_axis in df.columns and y_axis in df.columns:
|
289 |
-
fig = px.scatter(df, x=x_axis, y=y_axis, title=f'Scatter Plot of {x_axis} vs {y_axis}')
|
290 |
-
st.plotly_chart(fig)
|
291 |
-
st.session_state.last_plot = {"type": "Scatter Plot", "x": x_axis, "y": y_axis, "data": df[[x_axis, y_axis]].to_json()}
|
292 |
-
return f"Generated scatter plot of {x_axis} vs {y_axis}"
|
293 |
-
return "Invalid columns for scatter plot."
|
294 |
-
|
295 |
-
def generate_histogram(params):
|
296 |
-
df = st.session_state.cleaned_data
|
297 |
-
x_axis = params.strip()
|
298 |
-
if x_axis in df.columns:
|
299 |
-
fig = px.histogram(df, x=x_axis, title=f'Histogram of {x_axis}')
|
300 |
-
st.plotly_chart(fig)
|
301 |
-
st.session_state.last_plot = {"type": "Histogram", "x": x_axis, "data": df[[x_axis]].to_json()}
|
302 |
-
return f"Generated histogram of {x_axis}"
|
303 |
-
return "Invalid column for histogram."
|
304 |
-
|
305 |
-
def analyze_plot():
|
306 |
-
if "last_plot" not in st.session_state:
|
307 |
-
return "No plot available to analyze."
|
308 |
-
plot_info = st.session_state.last_plot
|
309 |
-
df = pd.read_json(plot_info["data"])
|
310 |
-
plot_text = extract_plot_data(plot_info, df)
|
311 |
-
return f"Analysis of the last plot:\n{plot_text}"
|
312 |
-
|
313 |
-
def parse_command(command):
|
314 |
-
command = command.lower().strip()
|
315 |
-
if "drop columns" in command or "drop column" in command:
|
316 |
-
columns = command.replace("drop columns", "").replace("drop column", "").strip()
|
317 |
-
return drop_columns, columns
|
318 |
-
elif "show a scatter plot" in command or "scatter plot of" in command:
|
319 |
-
params = command.replace("show a scatter plot of", "").replace("scatter plot of", "").strip()
|
320 |
-
return generate_scatter_plot, params
|
321 |
-
elif "show a histogram" in command or "histogram of" in command:
|
322 |
-
params = command.replace("show a histogram of", "").replace("histogram of", "").strip()
|
323 |
-
return generate_histogram, params
|
324 |
-
elif "analyze plot" in command:
|
325 |
-
return lambda x: analyze_plot(), None
|
326 |
-
return None, command
|
327 |
-
|
328 |
-
# Dataset Preview Function
|
329 |
-
def display_dataset_preview():
|
330 |
-
if 'cleaned_data' in st.session_state:
|
331 |
-
st.subheader("Current Dataset Preview")
|
332 |
-
st.dataframe(st.session_state.cleaned_data.head(10), use_container_width=True)
|
333 |
-
st.markdown("---")
|
334 |
-
|
335 |
-
# Main App
|
336 |
def main():
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
# Sidebar Navigation
|
346 |
-
with st.sidebar:
|
347 |
-
st.markdown("### 🔮 Data-Vision Pro")
|
348 |
-
st.markdown("Your AI-powered data analysis suite with RAG.")
|
349 |
-
st.markdown("---")
|
350 |
-
app_mode = st.selectbox(
|
351 |
-
"Navigation",
|
352 |
-
["Data Upload", "Data Cleaning", "EDA"],
|
353 |
-
format_func=lambda x: f"📌 {x}"
|
354 |
-
)
|
355 |
-
model = st.selectbox(
|
356 |
-
"Select Groq Model",
|
357 |
-
["llama3-70b-8192", "llama3-8b-8192", "mixtral-8x7b-32768", "gemma-7b-it"],
|
358 |
-
index=0
|
359 |
-
)
|
360 |
-
if app_mode == "Data Upload":
|
361 |
-
st.info("⬆️ Upload your CSV or XLSX dataset to begin.")
|
362 |
-
elif app_mode == "Data Cleaning":
|
363 |
-
st.info("🧹 Clean and preprocess your data.")
|
364 |
-
elif app_mode == "EDA":
|
365 |
-
st.info("🔍 Explore your data visually.")
|
366 |
-
|
367 |
-
if 'cleaned_data' in st.session_state:
|
368 |
-
csv = st.session_state.cleaned_data.to_csv(index=False)
|
369 |
-
st.download_button(
|
370 |
-
label="Download Cleaned Data",
|
371 |
-
data=csv,
|
372 |
-
file_name='cleaned_data.csv',
|
373 |
-
mime='text/csv',
|
374 |
-
)
|
375 |
-
st.markdown("---")
|
376 |
-
st.markdown("Built with <span class='tech-badge'>Streamlit</span> + <span class='tech-badge'>Groq</span>", unsafe_allow_html=True)
|
377 |
-
|
378 |
-
# Initialize Session State
|
379 |
-
if 'vector_store' not in st.session_state:
|
380 |
-
st.session_state.vector_store = None
|
381 |
-
if 'chat_history' not in st.session_state:
|
382 |
-
st.session_state.chat_history = []
|
383 |
-
|
384 |
-
# Display Dataset Preview
|
385 |
-
display_dataset_preview()
|
386 |
-
|
387 |
-
# App Pages
|
388 |
-
if app_mode == "Data Upload":
|
389 |
-
st.header("📤 Data Upload & Profiling")
|
390 |
-
uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"], key="file_uploader")
|
391 |
if uploaded_file:
|
392 |
-
|
393 |
-
st.session_state.
|
394 |
-
st.
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
st.stop()
|
403 |
-
st.session_state.raw_data = df
|
404 |
-
st.session_state.cleaned_data = df.copy()
|
405 |
-
st.session_state.dataset_text = convert_df_to_text(df)
|
406 |
-
st.session_state.vector_store = create_vector_store(st.session_state.dataset_text)
|
407 |
-
if 'data_versions' not in st.session_state:
|
408 |
-
st.session_state.data_versions = [df.copy()]
|
409 |
-
col1, col2, col3 = st.columns(3)
|
410 |
-
with col1: st.metric("Rows", df.shape[0])
|
411 |
-
with col2: st.metric("Columns", df.shape[1])
|
412 |
-
with col3: st.metric("Missing Values", df.isna().sum().sum())
|
413 |
-
if st.checkbox("Show Data Preview"):
|
414 |
-
st.dataframe(df.head(10), use_container_width=True)
|
415 |
-
if st.button("Generate Full Profile Report"):
|
416 |
-
with st.spinner("Generating report..."):
|
417 |
-
pr = ProfileReport(df, explorative=True)
|
418 |
-
st_profile_report(pr)
|
419 |
-
st.success("✅ Data loaded successfully!")
|
420 |
-
except Exception as e:
|
421 |
-
st.error(f"An error occurred: {str(e)}")
|
422 |
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
if
|
429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
else:
|
431 |
-
st.session_state.
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
with st.spinner("Generating report..."):
|
442 |
-
profile = ProfileReport(df, minimal=True)
|
443 |
-
st_profile_report(profile)
|
444 |
-
if 'data_versions' in st.session_state and len(st.session_state.data_versions) > 1:
|
445 |
-
if st.button("Undo Last Action"):
|
446 |
-
st.session_state.data_versions.pop()
|
447 |
-
st.session_state.cleaned_data = st.session_state.data_versions[-1].copy()
|
448 |
-
st.session_state.dataset_text = convert_df_to_text(st.session_state.cleaned_data)
|
449 |
-
st.session_state.vector_store = create_vector_store(st.session_state.dataset_text)
|
450 |
-
st.rerun()
|
451 |
-
|
452 |
-
with st.expander("🛠️ Data Cleaning Operations", expanded=True):
|
453 |
-
enhance_section_title("🔍 Missing Values Treatment")
|
454 |
-
missing_cols = df.columns[df.isna().any()].tolist()
|
455 |
-
if missing_cols:
|
456 |
-
cols = st.multiselect("Select columns with missing values", missing_cols)
|
457 |
-
method = st.selectbox("Choose imputation method", [
|
458 |
-
"Drop Missing Values", "Fill with Mean/Median", "Fill with Custom Value", "Forward Fill", "Backward Fill"
|
459 |
-
])
|
460 |
-
if method == "Fill with Custom Value":
|
461 |
-
custom_val = st.text_input("Enter custom value:")
|
462 |
-
if st.button("Apply Missing Value Treatment"):
|
463 |
-
new_df = df.copy()
|
464 |
-
if method == "Drop Missing Values":
|
465 |
-
new_df = new_df.dropna(subset=cols)
|
466 |
-
elif method == "Fill with Mean/Median":
|
467 |
-
for col in cols:
|
468 |
-
if pd.api.types.is_numeric_dtype(new_df[col]):
|
469 |
-
new_df[col] = new_df[col].fillna(new_df[col].median())
|
470 |
-
else:
|
471 |
-
new_df[col] = new_df[col].fillna(new_df[col].mode()[0])
|
472 |
-
elif method == "Fill with Custom Value" and custom_val:
|
473 |
-
new_df[cols] = new_df[cols].fillna(custom_val)
|
474 |
-
elif method == "Forward Fill":
|
475 |
-
new_df[cols] = new_df[cols].ffill()
|
476 |
-
elif method == "Backward Fill":
|
477 |
-
new_df[cols] = new_df[cols].bfill()
|
478 |
-
update_cleaned_data(new_df)
|
479 |
else:
|
480 |
-
st.
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
enhance_section_title("🗑️ Drop Columns")
|
502 |
-
columns_to_drop = st.multiselect("Select columns to remove", df.columns)
|
503 |
-
if columns_to_drop and st.button("Confirm Column Removal"):
|
504 |
-
new_df = df.copy()
|
505 |
-
new_df = new_df.drop(columns=columns_to_drop)
|
506 |
-
update_cleaned_data(new_df)
|
507 |
-
|
508 |
-
enhance_section_title("🔢 Encoding Options")
|
509 |
-
encoding_method = st.radio("Choose encoding method", ("Label Encoding", "One-Hot Encoding"))
|
510 |
-
data_to_encode = st.multiselect("Select columns to encode", df.select_dtypes(include='object').columns)
|
511 |
-
if data_to_encode and st.button("Apply Encoding"):
|
512 |
-
new_df = df.copy()
|
513 |
-
if encoding_method == "Label Encoding":
|
514 |
-
for col in data_to_encode:
|
515 |
-
le = LabelEncoder()
|
516 |
-
new_df[col] = le.fit_transform(new_df[col].astype(str))
|
517 |
-
elif encoding_method == "One-Hot Encoding":
|
518 |
-
new_df = pd.get_dummies(new_df, columns=data_to_encode, drop_first=True, dtype=int)
|
519 |
-
update_cleaned_data(new_df)
|
520 |
-
|
521 |
-
enhance_section_title("📏 StandardScaler")
|
522 |
-
scale_cols = st.multiselect("Select numerical columns to scale", df.select_dtypes(include=np.number).columns)
|
523 |
-
if scale_cols and st.button("Apply StandardScaler"):
|
524 |
-
new_df = df.copy()
|
525 |
scaler = StandardScaler()
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
col1.metric("Total Rows", df.shape[0])
|
540 |
-
col2.metric("Total Columns", df.shape[1])
|
541 |
-
missing_percentage = df.isna().sum().sum() / df.size * 100
|
542 |
-
col3.metric("Missing Values", f"{df.isna().sum().sum()} ({missing_percentage:.1f}%)")
|
543 |
-
col4.metric("Duplicates", df.duplicated().sum())
|
544 |
-
|
545 |
-
tab1, tab2, tab3 = st.tabs(["Quick Preview", "Column Types", "Missing Matrix"])
|
546 |
-
with tab1:
|
547 |
-
st.write("First few rows of the dataset:")
|
548 |
-
st.dataframe(df.head(), use_container_width=True)
|
549 |
-
with tab2:
|
550 |
-
st.write("Column Data Types:")
|
551 |
-
type_counts = df.dtypes.value_counts().reset_index()
|
552 |
-
type_counts.columns = ['Type', 'Count']
|
553 |
-
st.dataframe(type_counts, use_container_width=True)
|
554 |
-
with tab3:
|
555 |
-
st.write("Missing Values Matrix:")
|
556 |
-
fig_missing = px.imshow(df.isna(), color_continuous_scale=['#e0e0e0', '#66c2a5'])
|
557 |
-
fig_missing.update_layout(coloraxis_colorscale=[[0, 'lightgrey'], [1, '#FF4B4B']])
|
558 |
-
st.plotly_chart(fig_missing, use_container_width=True)
|
559 |
-
|
560 |
-
enhance_section_title("Interactive Visualization Builder")
|
561 |
-
with st.container():
|
562 |
-
col1, col2 = st.columns([1, 3])
|
563 |
-
with col1:
|
564 |
-
plot_type = st.selectbox("Choose visualization type", [
|
565 |
-
"Scatter Plot", "Histogram", "Box Plot", "Line Chart", "Bar Chart", "Correlation Matrix"
|
566 |
-
])
|
567 |
-
x_axis = st.selectbox("X-axis", df.columns) if plot_type != "Correlation Matrix" else None
|
568 |
-
y_axis = st.selectbox("Y-axis", df.columns) if plot_type in ["Scatter Plot", "Box Plot", "Line Chart"] else None
|
569 |
-
color_by = st.selectbox("Color encoding", ["None"] + df.columns.tolist(), format_func=lambda x: "No color" if x == "None" else x) if plot_type != "Correlation Matrix" else None
|
570 |
-
|
571 |
-
with col2:
|
572 |
-
try:
|
573 |
-
fig = None
|
574 |
-
if plot_type == "Scatter Plot" and x_axis and y_axis:
|
575 |
-
fig = px.scatter(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Scatter Plot of {x_axis} vs {y_axis}')
|
576 |
-
elif plot_type == "Histogram" and x_axis:
|
577 |
-
fig = px.histogram(df, x=x_axis, color=color_by if color_by != "None" else None, nbins=30, title=f'Histogram of {x_axis}')
|
578 |
-
elif plot_type == "Box Plot" and x_axis and y_axis:
|
579 |
-
fig = px.box(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Box Plot of {x_axis} vs {y_axis}')
|
580 |
-
elif plot_type == "Line Chart" and x_axis and y_axis:
|
581 |
-
fig = px.line(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Line Chart of {x_axis} vs {y_axis}')
|
582 |
-
elif plot_type == "Bar Chart" and x_axis:
|
583 |
-
fig = px.bar(df, x=x_axis, color=color_by if color_by != "None" else None, title=f'Bar Chart of {x_axis}')
|
584 |
-
elif plot_type == "Correlation Matrix":
|
585 |
-
numeric_df = df.select_dtypes(include=np.number)
|
586 |
-
if len(numeric_df.columns) > 1:
|
587 |
-
corr = numeric_df.corr()
|
588 |
-
fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu_r', zmin=-1, zmax=1, title='Correlation Matrix')
|
589 |
-
|
590 |
-
if fig:
|
591 |
-
fig.update_layout(template="plotly_white")
|
592 |
-
st.plotly_chart(fig, use_container_width=True)
|
593 |
-
st.session_state.last_plot = {
|
594 |
-
"type": plot_type,
|
595 |
-
"x": x_axis,
|
596 |
-
"y": y_axis,
|
597 |
-
"data": df[[x_axis, y_axis]].to_json() if y_axis else df[[x_axis]].to_json()
|
598 |
-
}
|
599 |
-
plot_text = extract_plot_data(st.session_state.last_plot, df)
|
600 |
-
st.session_state.vector_store = update_vector_store_with_plot(plot_text, st.session_state.vector_store)
|
601 |
-
with st.expander("Extracted Plot Data"):
|
602 |
-
st.text(plot_text)
|
603 |
-
else:
|
604 |
-
st.error("Please provide required inputs for the selected plot type.")
|
605 |
-
except Exception as e:
|
606 |
-
st.error(f"Couldn't create visualization: {str(e)}")
|
607 |
-
|
608 |
-
# Chatbot Section
|
609 |
-
st.markdown("---")
|
610 |
-
st.markdown('<div class="chat-container">', unsafe_allow_html=True)
|
611 |
-
st.subheader("💬 AI Chatbot Assistant (RAG Enabled)")
|
612 |
-
st.info("Ask about your data or app features! Try: 'drop columns X, Y', 'scatter plot of X vs Y', 'analyze plot'")
|
613 |
-
|
614 |
-
for message in st.session_state.chat_history:
|
615 |
-
with st.chat_message(message["role"]):
|
616 |
-
st.markdown(f'<div class="{message["role"]}-message">{message["content"]}</div>', unsafe_allow_html=True)
|
617 |
-
|
618 |
-
user_input = st.chat_input("Ask me anything...")
|
619 |
-
if user_input:
|
620 |
-
st.session_state.chat_history.append({"role": "user", "content": user_input})
|
621 |
-
with st.chat_message("user"):
|
622 |
-
st.markdown(f'<div class="user-message">{user_input}</div>', unsafe_allow_html=True)
|
623 |
-
with st.spinner("Processing..."):
|
624 |
-
func, param = parse_command(user_input)
|
625 |
-
if func:
|
626 |
-
response = func(param) if param else func(None)
|
627 |
-
else:
|
628 |
-
response = get_chatbot_response(user_input, app_mode, st.session_state.vector_store, model)
|
629 |
-
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
630 |
-
with st.chat_message("assistant"):
|
631 |
-
st.markdown(f'<div class="bot-message">{response}</div>', unsafe_allow_html=True)
|
632 |
-
|
633 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
634 |
-
|
635 |
-
# Footer
|
636 |
-
st.markdown("""
|
637 |
-
<div class="footer">
|
638 |
-
<div>Built with <span class="tech-badge">Streamlit</span> + <span class="tech-badge">Groq</span> + <span class="tech-badge">LangChain</span> + <span class="tech-badge">FAISS</span></div>
|
639 |
-
<div style="margin-top: 8px;">Fast inference for data insights</div>
|
640 |
-
</div>
|
641 |
-
""", unsafe_allow_html=True)
|
642 |
|
643 |
if __name__ == "__main__":
|
644 |
main()
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
|
|
3 |
import plotly.express as px
|
4 |
+
import numpy as np
|
5 |
+
from sklearn.model_selection import train_test_split
|
6 |
+
from sklearn.neural_network import MLPClassifier, MLPRegressor
|
7 |
+
from sklearn.cluster import KMeans
|
8 |
+
from sklearn.metrics import accuracy_score, r2_score, silhouette_score
|
9 |
+
from sklearn.preprocessing import StandardScaler
|
10 |
from ydata_profiling import ProfileReport
|
11 |
from streamlit_pandas_profiling import st_profile_report
|
|
|
|
|
12 |
from groq import Groq
|
13 |
from langchain_community.vectorstores import FAISS
|
|
|
14 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
15 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
16 |
+
from langchain_community.document_loaders import TextLoader
|
17 |
+
import os
|
|
|
18 |
import tempfile
|
19 |
|
20 |
+
# Initialize clients
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
|
|
|
|
|
22 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
23 |
|
24 |
+
# Set page config
|
25 |
+
st.set_page_config(page_title="Neural-Vision Enhanced", layout="wide")
|
26 |
+
|
27 |
+
# Custom CSS for Responsive Silver-Blue-Gold Theme with Top Nav
|
28 |
st.markdown("""
|
29 |
<style>
|
30 |
:root {
|
|
|
36 |
.stApp {
|
37 |
background-color: var(--silver);
|
38 |
font-family: 'Inter', sans-serif;
|
39 |
+
max-width: 1200px;
|
40 |
margin: 0 auto;
|
41 |
padding: 10px;
|
42 |
}
|
|
|
45 |
color: white;
|
46 |
padding: 15px;
|
47 |
border-radius: 5px;
|
|
|
48 |
text-align: center;
|
49 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
50 |
}
|
51 |
.header-title {
|
52 |
+
font-size: 1.8rem;
|
53 |
font-weight: 700;
|
54 |
margin: 0;
|
55 |
}
|
56 |
.header-subtitle {
|
57 |
+
font-size: 1rem;
|
58 |
margin-top: 5px;
|
59 |
}
|
60 |
+
.nav-bar {
|
61 |
background-color: white;
|
62 |
border-radius: 5px;
|
63 |
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
64 |
padding: 15px;
|
65 |
+
margin-bottom: 20px;
|
66 |
+
display: flex;
|
67 |
+
justify-content: space-around;
|
68 |
+
align-items: center;
|
69 |
}
|
70 |
+
.nav-item {
|
71 |
+
color: var(--blue);
|
72 |
+
font-weight: 500;
|
73 |
+
cursor: pointer;
|
74 |
+
padding: 5px 10px;
|
75 |
+
border-radius: 5px;
|
76 |
+
}
|
77 |
+
.nav-item:hover {
|
78 |
+
background-color: var(--gold);
|
79 |
+
color: white;
|
80 |
+
}
|
81 |
+
.card {
|
82 |
background-color: white;
|
83 |
border-radius: 5px;
|
84 |
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
85 |
+
padding: 20px;
|
86 |
+
margin-bottom: 20px;
|
87 |
+
}
|
88 |
+
.chat-container {
|
89 |
+
background-color: white;
|
90 |
+
border-radius: 5px;
|
91 |
padding: 15px;
|
92 |
margin-top: 20px;
|
93 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
94 |
}
|
95 |
.user-message {
|
96 |
background-color: var(--blue);
|
97 |
color: white;
|
98 |
+
border-radius: 15px 15px 5px 15px;
|
99 |
+
padding: 10px;
|
|
|
100 |
max-width: 80%;
|
101 |
+
margin-left: auto;
|
102 |
margin-bottom: 10px;
|
103 |
}
|
104 |
.bot-message {
|
105 |
background-color: #F0F0F0;
|
106 |
color: var(--text-color);
|
107 |
+
border-radius: 15px 15px 15px 5px;
|
108 |
+
padding: 10px;
|
|
|
109 |
max-width: 80%;
|
110 |
+
margin-right: auto;
|
111 |
margin-bottom: 10px;
|
112 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
.stButton > button {
|
114 |
background-color: var(--gold);
|
115 |
color: white;
|
|
|
123 |
}
|
124 |
@media (max-width: 768px) {
|
125 |
.header-title {
|
126 |
+
font-size: 1.4rem;
|
127 |
}
|
128 |
.header-subtitle {
|
129 |
+
font-size: 0.9rem;
|
130 |
+
}
|
131 |
+
.nav-bar {
|
132 |
+
flex-direction: column;
|
133 |
+
padding: 10px;
|
134 |
}
|
135 |
+
.nav-item {
|
136 |
+
margin: 5px 0;
|
137 |
+
width: 100%;
|
138 |
+
text-align: center;
|
139 |
+
}
|
140 |
+
.card, .chat-container {
|
141 |
padding: 10px;
|
142 |
}
|
143 |
.stApp {
|
144 |
padding: 5px;
|
145 |
}
|
|
|
|
|
|
|
146 |
}
|
147 |
+
# Footer
|
148 |
+
<footer style='text-align: center; padding: 10px; background-color: var(--blue); color: white; border-radius: 5px; margin-top: 20px;'>
|
149 |
+
<p>Created by Calvin Allen-Crawford</p>
|
150 |
+
</footer>
|
151 |
""", unsafe_allow_html=True)
|
152 |
|
153 |
+
# Session State Initialization
|
154 |
+
if 'metrics' not in st.session_state:
|
155 |
+
st.session_state.metrics = {}
|
156 |
+
if 'chat_history' not in st.session_state:
|
157 |
+
st.session_state.chat_history = []
|
158 |
+
if 'vector_store' not in st.session_state:
|
159 |
+
st.session_state.vector_store = None
|
160 |
+
if 'custom_layers' not in st.session_state:
|
161 |
+
st.session_state.custom_layers = []
|
162 |
+
if 'prebuilt_selection' not in st.session_state:
|
163 |
+
st.session_state.prebuilt_selection = None
|
164 |
+
if 'model_config' not in st.session_state:
|
165 |
+
st.session_state.model_config = {}
|
166 |
+
if 'model_builder_mode' not in st.session_state:
|
167 |
+
st.session_state.model_builder_mode = "prebuilt"
|
168 |
+
if 'custom_model_type' not in st.session_state:
|
169 |
+
st.session_state.custom_model_type = "classification"
|
170 |
+
|
171 |
+
# Prebuilt Models
|
172 |
+
PREBUILT_MODELS = {
|
173 |
+
"Legal Document Classifier": {
|
174 |
+
"description": "Optimized for legal document classification.",
|
175 |
+
"architecture": {"type": "classification", "hidden_layers": [(128, "relu"), (64, "relu")], "dropout": 0.3, "optimizer": "adam", "learning_rate": 0.001},
|
176 |
+
"domain": "Legal"
|
177 |
+
},
|
178 |
+
"Financial Fraud Detector": {
|
179 |
+
"description": "Detects anomalies in financial transactions.",
|
180 |
+
"architecture": {"type": "classification", "hidden_layers": [(256, "relu"), (128, "relu"), (64, "relu")], "dropout": 0.4, "optimizer": "adam", "learning_rate": 0.0005},
|
181 |
+
"domain": "Financial"
|
182 |
+
},
|
183 |
+
"Customer Segmentation Engine": {
|
184 |
+
"description": "Advanced customer segmentation.",
|
185 |
+
"architecture": {"type": "clustering", "n_clusters": 5, "algorithm": "kmeans", "init": "k-means++", "n_init": 10},
|
186 |
+
"domain": "Marketing"
|
187 |
+
}
|
188 |
+
}
|
189 |
|
190 |
+
# Helper Functions (unchanged)
|
191 |
def convert_df_to_text(df):
|
192 |
text = f"Dataset Summary: {df.shape[0]} rows, {df.shape[1]} columns\n"
|
193 |
text += f"Missing Values: {df.isna().sum().sum()}\n"
|
|
|
194 |
for col in df.columns:
|
195 |
+
text += f"- {col} ({df[col].dtype}): Mean={df[col].mean():.2f if pd.api.types.is_numeric_dtype(df[col]) else 'N/A'}\n"
|
|
|
|
|
|
|
|
|
|
|
196 |
return text
|
197 |
|
198 |
def create_vector_store(df_text):
|
|
|
201 |
temp_path = temp_file.name
|
202 |
loader = TextLoader(temp_path)
|
203 |
documents = loader.load()
|
204 |
+
texts = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100).split_documents(documents)
|
|
|
205 |
vector_store = FAISS.from_documents(texts, embeddings)
|
206 |
os.unlink(temp_path)
|
207 |
return vector_store
|
208 |
|
209 |
+
def get_groq_response(prompt, mode):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
context = ""
|
211 |
+
if st.session_state.vector_store:
|
212 |
+
docs = st.session_state.vector_store.similarity_search(prompt, k=3)
|
213 |
+
context += "\nDataset Context:\n" + "\n".join([f"- {doc.page_content}" for doc in docs])
|
|
|
|
|
|
|
|
|
214 |
try:
|
215 |
response = client.chat.completions.create(
|
216 |
+
model="llama3-70b-8192",
|
217 |
messages=[
|
218 |
+
{"role": "system", "content": f"You are an expert in {mode} data analysis.\n{context}"},
|
219 |
+
{"role": "user", "content": prompt}
|
220 |
+
]
|
221 |
+
).choices[0].message.content
|
222 |
+
return response
|
|
|
|
|
223 |
except Exception as e:
|
224 |
return f"Error: {str(e)}"
|
225 |
|
226 |
+
def build_model_from_config(config, X, y=None):
|
227 |
+
problem_type = config.get("type", "classification")
|
228 |
+
if problem_type == "clustering":
|
229 |
+
return KMeans(n_clusters=config.get("n_clusters", 3), init=config.get("init", "k-means++"), n_init=config.get("n_init", 10), random_state=42)
|
230 |
+
hidden_layers = config.get("hidden_layers", [(100, "relu")])
|
231 |
+
layer_sizes = [size for size, _ in hidden_layers]
|
232 |
+
activation = hidden_layers[0][1] if hidden_layers else "relu"
|
233 |
+
if problem_type == "classification":
|
234 |
+
return MLPClassifier(hidden_layer_sizes=layer_sizes, activation=activation, solver=config.get("optimizer", "adam"), learning_rate_init=config.get("learning_rate", 0.001), random_state=42)
|
235 |
+
return MLPRegressor(hidden_layer_sizes=layer_sizes, activation=activation, solver=config.get("optimizer", "adam"), learning_rate_init=config.get("learning_rate", 0.001), random_state=42)
|
236 |
+
|
237 |
+
# Main Application
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
def main():
|
239 |
+
st.markdown('<div class="header"><h1 class="header-title">Neural-Vision Enhanced</h1><p class="header-subtitle">Build & Train Neural Networks</p></div>', unsafe_allow_html=True)
|
240 |
+
|
241 |
+
# Top Navigation Bar
|
242 |
+
st.markdown('<div class="nav-bar">', unsafe_allow_html=True)
|
243 |
+
col1, col2, col3 = st.columns([1, 2, 1])
|
244 |
+
with col1:
|
245 |
+
st.markdown('<div class="nav-item">Data Input</div>', unsafe_allow_html=True)
|
246 |
+
uploaded_file = st.file_uploader("Upload CSV Dataset", type=["csv"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
if uploaded_file:
|
248 |
+
df = pd.read_csv(uploaded_file)
|
249 |
+
st.session_state.vector_store = create_vector_store(convert_df_to_text(df))
|
250 |
+
st.success("Dataset uploaded!")
|
251 |
+
with col2:
|
252 |
+
st.markdown('<div class="nav-item">Navigation</div>', unsafe_allow_html=True)
|
253 |
+
nav_option = st.selectbox("Navigate", ["Model Builder", "Chat", "Train Model"], label_visibility="collapsed")
|
254 |
+
with col3:
|
255 |
+
st.markdown('<div class="nav-item">Info</div>', unsafe_allow_html=True)
|
256 |
+
st.write("Built with Streamlit & Groq")
|
257 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
|
259 |
+
# Main Content
|
260 |
+
if nav_option == "Model Builder":
|
261 |
+
st.markdown('<div class="card"><h2>Model Builder</h2></div>', unsafe_allow_html=True)
|
262 |
+
mode = st.selectbox("Domain", ["Legal", "Financial", "Marketing"])
|
263 |
+
model_builder_mode = st.radio("Mode", ["Prebuilt", "Custom"])
|
264 |
+
st.session_state.model_builder_mode = "prebuilt" if model_builder_mode == "Prebuilt" else "custom"
|
265 |
+
|
266 |
+
if st.session_state.model_builder_mode == "prebuilt":
|
267 |
+
for name, details in PREBUILT_MODELS.items():
|
268 |
+
if st.button(f"{name}: {details['description']}", key=name):
|
269 |
+
st.session_state.prebuilt_selection = name
|
270 |
+
st.session_state.model_config = details["architecture"]
|
271 |
+
if st.session_state.prebuilt_selection:
|
272 |
+
st.json(st.session_state.model_config)
|
273 |
else:
|
274 |
+
st.session_state.custom_model_type = st.selectbox("Type", ["classification", "regression", "clustering"])
|
275 |
+
if st.session_state.custom_model_type != "clustering":
|
276 |
+
layer_count = st.number_input("Layers", min_value=1, value=1)
|
277 |
+
st.session_state.custom_layers = []
|
278 |
+
for i in range(int(layer_count)):
|
279 |
+
size = st.number_input(f"Layer {i+1} Size", min_value=1, value=100, key=f"size_{i}")
|
280 |
+
activation = st.selectbox(f"Layer {i+1} Activation", ["relu", "tanh"], key=f"act_{i}")
|
281 |
+
st.session_state.custom_layers.append((size, activation))
|
282 |
+
optimizer = st.selectbox("Optimizer", ["adam", "sgd"])
|
283 |
+
st.session_state.model_config = {"type": st.session_state.custom_model_type, "hidden_layers": st.session_state.custom_layers, "optimizer": optimizer, "learning_rate": 0.001}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
else:
|
285 |
+
st.session_state.model_config = {"type": "clustering", "n_clusters": st.number_input("Clusters", min_value=2, value=3)}
|
286 |
+
if st.button("Finalize"): st.json(st.session_state.model_config)
|
287 |
+
|
288 |
+
elif nav_option == "Chat":
|
289 |
+
st.markdown('<div class="chat-container"><h3>Chat with Grok</h3></div>', unsafe_allow_html=True)
|
290 |
+
mode = st.selectbox("Domain", ["Legal", "Financial", "Marketing"])
|
291 |
+
prompt = st.text_input("Ask a question:")
|
292 |
+
if prompt:
|
293 |
+
response = get_groq_response(prompt, mode)
|
294 |
+
st.session_state.chat_history.append({"role": "user", "content": prompt})
|
295 |
+
st.session_state.chat_history.append({"role": "bot", "content": response})
|
296 |
+
for msg in st.session_state.chat_history:
|
297 |
+
st.markdown(f'<div class={"user-message" if msg["role"] == "user" else "bot-message"}>{msg["content"]}</div>', unsafe_allow_html=True)
|
298 |
+
|
299 |
+
elif nav_option == "Train Model":
|
300 |
+
if uploaded_file and st.session_state.model_config:
|
301 |
+
st.markdown('<div class="card"><h2>Train Model</h2></div>', unsafe_allow_html=True)
|
302 |
+
df = pd.read_csv(uploaded_file)
|
303 |
+
X = df.drop(columns=[df.columns[-1]]) if st.session_state.model_config["type"] != "clustering" else df
|
304 |
+
y = df[df.columns[-1]] if st.session_state.model_config["type"] != "clustering" else None
|
305 |
+
if st.button("Train"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
scaler = StandardScaler()
|
307 |
+
X_scaled = scaler.fit_transform(X)
|
308 |
+
model = build_model_from_config(st.session_state.model_config, X_scaled, y)
|
309 |
+
if st.session_state.model_config["type"] != "clustering":
|
310 |
+
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
|
311 |
+
model.fit(X_train, y_train)
|
312 |
+
y_pred = model.predict(X_test)
|
313 |
+
st.session_state.metrics = {"accuracy" if st.session_state.model_config["type"] == "classification" else "r2_score": accuracy_score(y_test, y_pred) if st.session_state.model_config["type"] == "classification" else r2_score(y_test, y_pred)}
|
314 |
+
else:
|
315 |
+
model.fit(X_scaled)
|
316 |
+
st.session_state.metrics = {"silhouette_score": silhouette_score(X_scaled, model.labels_)}
|
317 |
+
st.json(st.session_state.metrics)
|
318 |
+
else:
|
319 |
+
st.warning("Upload a dataset and configure a model first!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
if __name__ == "__main__":
|
322 |
main()
|