CosmickVisions commited on
Commit
26c67fd
·
verified ·
1 Parent(s): c88010e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +246 -228
app.py CHANGED
@@ -5,198 +5,123 @@ import numpy as np
5
  from sklearn.model_selection import train_test_split
6
  from sklearn.neural_network import MLPClassifier, MLPRegressor
7
  from sklearn.cluster import KMeans
8
- from sklearn.metrics import accuracy_score, r2_score, silhouette_score
9
  from sklearn.preprocessing import StandardScaler
10
  from ydata_profiling import ProfileReport
11
  from streamlit_pandas_profiling import st_profile_report
12
  from groq import Groq
13
  from langchain_community.vectorstores import FAISS
14
  from langchain.text_splitter import RecursiveCharacterTextSplitter
15
- from langchain_huggingface import HuggingFaceEmbeddings
16
  from langchain_community.document_loaders import TextLoader
 
17
  import os
 
18
  import tempfile
19
 
20
- # Initialize clients
 
 
 
21
  client = Groq(api_key=os.getenv("GROQ_API_KEY"))
 
 
22
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
23
 
24
  # Set page config
25
  st.set_page_config(page_title="Neural-Vision Enhanced", layout="wide")
26
 
27
- # Custom CSS for Responsive Silver-Blue-Gold Theme with Top Nav
28
  st.markdown("""
29
  <style>
30
  :root {
31
- --silver: #D8D8D8;
32
- --blue: #5C89BC;
33
- --gold: #A87E01;
34
- --text-color: #333333;
 
 
 
35
  }
36
  .stApp {
37
- background-color: var(--silver);
38
  font-family: 'Inter', sans-serif;
39
  max-width: 1200px;
40
  margin: 0 auto;
41
- padding: 10px;
42
  }
43
  .header {
44
- background-color: var(--blue);
45
- color: white;
46
  padding: 15px;
47
- border-radius: 5px;
 
48
  text-align: center;
49
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
50
  }
51
  .header-title {
 
52
  font-size: 1.8rem;
53
  font-weight: 700;
54
  margin: 0;
55
  }
56
  .header-subtitle {
 
57
  font-size: 1rem;
58
  margin-top: 5px;
59
  }
60
- .nav-bar {
61
- background-color: white;
62
- border-radius: 5px;
63
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
64
  padding: 15px;
65
- margin-bottom: 20px;
66
- display: flex;
67
- justify-content: space-around;
68
- align-items: center;
69
- }
70
- .nav-item {
71
- color: var(--blue);
72
- font-weight: 500;
73
- cursor: pointer;
74
- padding: 5px 10px;
75
- border-radius: 5px;
76
- }
77
- .nav-item:hover {
78
- background-color: var(--gold);
79
- color: white;
80
- }
81
- .card {
82
- background-color: white;
83
- border-radius: 5px;
84
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
85
- padding: 20px;
86
- margin-bottom: 20px;
87
  }
88
  .chat-container {
89
- background-color: white;
90
- border-radius: 5px;
 
91
  padding: 15px;
92
  margin-top: 20px;
93
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
94
  }
95
  .user-message {
96
- background-color: var(--blue);
97
- color: white;
98
- border-radius: 15px 15px 5px 15px;
99
- padding: 10px;
100
- max-width: 80%;
101
  margin-left: auto;
 
102
  margin-bottom: 10px;
103
  }
104
  .bot-message {
105
- background-color: #F0F0F0;
106
- color: var(--text-color);
107
- border-radius: 15px 15px 15px 5px;
108
- padding: 10px;
109
- max-width: 80%;
110
  margin-right: auto;
 
111
  margin-bottom: 10px;
112
  }
113
- .stButton > button {
114
- background-color: var(--gold);
115
- color: white;
116
- border-radius: 5px;
117
- padding: 8px 16px;
118
- border: none;
119
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
120
- }
121
- .stButton > button:hover {
122
- background-color: #8C6B01;
123
- }
124
- @media (max-width: 768px) {
125
- .header-title {
126
- font-size: 1.4rem;
127
- }
128
- .header-subtitle {
129
- font-size: 0.9rem;
130
- }
131
- .nav-bar {
132
- flex-direction: column;
133
- padding: 10px;
134
- }
135
- .nav-item {
136
- margin: 5px 0;
137
- width: 100%;
138
- text-align: center;
139
- }
140
- .card, .chat-container {
141
- padding: 10px;
142
- }
143
- .stApp {
144
- padding: 5px;
145
- }
146
- }
147
- </style>
148
- <footer>
149
- <p>Created by Calvin Allen-Crawford</p>
150
- </footer>
151
  """, unsafe_allow_html=True)
152
 
153
- # Session State Initialization
154
  if 'metrics' not in st.session_state:
155
  st.session_state.metrics = {}
156
  if 'chat_history' not in st.session_state:
157
  st.session_state.chat_history = []
158
  if 'vector_store' not in st.session_state:
159
  st.session_state.vector_store = None
160
- if 'custom_layers' not in st.session_state:
161
- st.session_state.custom_layers = []
162
- if 'prebuilt_selection' not in st.session_state:
163
- st.session_state.prebuilt_selection = None
164
- if 'model_config' not in st.session_state:
165
- st.session_state.model_config = {}
166
- if 'model_builder_mode' not in st.session_state:
167
- st.session_state.model_builder_mode = "prebuilt"
168
- if 'custom_model_type' not in st.session_state:
169
- st.session_state.custom_model_type = "classification"
170
-
171
- # Prebuilt Models
172
- PREBUILT_MODELS = {
173
- "Legal Document Classifier": {
174
- "description": "Optimized for legal document classification.",
175
- "architecture": {"type": "classification", "hidden_layers": [(128, "relu"), (64, "relu")], "dropout": 0.3, "optimizer": "adam", "learning_rate": 0.001},
176
- "domain": "Legal"
177
- },
178
- "Financial Fraud Detector": {
179
- "description": "Detects anomalies in financial transactions.",
180
- "architecture": {"type": "classification", "hidden_layers": [(256, "relu"), (128, "relu"), (64, "relu")], "dropout": 0.4, "optimizer": "adam", "learning_rate": 0.0005},
181
- "domain": "Financial"
182
- },
183
- "Customer Segmentation Engine": {
184
- "description": "Advanced customer segmentation.",
185
- "architecture": {"type": "clustering", "n_clusters": 5, "algorithm": "kmeans", "init": "k-means++", "n_init": 10},
186
- "domain": "Marketing"
187
- }
188
- }
189
 
190
  # Helper Functions
191
  def convert_df_to_text(df):
192
  text = f"Dataset Summary: {df.shape[0]} rows, {df.shape[1]} columns\n"
193
  text += f"Missing Values: {df.isna().sum().sum()}\n"
 
194
  for col in df.columns:
 
195
  if pd.api.types.is_numeric_dtype(df[col]):
196
- mean_value = f"{df[col].mean():.2f}"
197
  else:
198
- mean_value = "N/A"
199
- text += f"- {col} ({df[col].dtype}): Mean={mean_value}\n"
200
  return text
201
 
202
  def create_vector_store(df_text):
@@ -205,122 +130,215 @@ def create_vector_store(df_text):
205
  temp_path = temp_file.name
206
  loader = TextLoader(temp_path)
207
  documents = loader.load()
208
- texts = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100).split_documents(documents)
 
209
  vector_store = FAISS.from_documents(texts, embeddings)
210
  os.unlink(temp_path)
211
  return vector_store
212
 
213
- def get_groq_response(prompt, mode):
214
  context = ""
215
  if st.session_state.vector_store:
216
  docs = st.session_state.vector_store.similarity_search(prompt, k=3)
217
- context += "\nDataset Context:\n" + "\n".join([f"- {doc.page_content}" for doc in docs])
218
- try:
219
- response = client.chat.completions.create(
220
- model="llama3-70b-8192",
221
- messages=[
222
- {"role": "system", "content": f"You are an expert in {mode} data analysis.\n{context}"},
223
- {"role": "user", "content": prompt}
224
- ]
225
- ).choices[0].message.content
226
- return response
227
- except Exception as e:
228
- return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- def build_model_from_config(config, X, y=None):
231
- problem_type = config.get("type", "classification")
232
- if problem_type == "clustering":
233
- return KMeans(n_clusters=config.get("n_clusters", 3), init=config.get("init", "k-means++"), n_init=config.get("n_init", 10), random_state=42)
234
- hidden_layers = config.get("hidden_layers", [(100, "relu")])
235
- layer_sizes = [size for size, _ in hidden_layers]
236
- activation = hidden_layers[0][1] if hidden_layers else "relu"
237
- if problem_type == "classification":
238
- return MLPClassifier(hidden_layer_sizes=layer_sizes, activation=activation, solver=config.get("optimizer", "adam"), learning_rate_init=config.get("learning_rate", 0.001), random_state=42)
239
- return MLPRegressor(hidden_layer_sizes=layer_sizes, activation=activation, solver=config.get("optimizer", "adam"), learning_rate_init=config.get("learning_rate", 0.001), random_state=42)
240
 
241
- # Main Application
242
- def main():
243
- st.markdown('<div class="header"><h1 class="header-title">Neural-Vision Enhanced</h1><p class="header-subtitle">Build & Train Neural Networks</p></div>', unsafe_allow_html=True)
 
 
 
 
244
 
245
- # Top Navigation Bar
246
- st.markdown('<div class="nav-bar">', unsafe_allow_html=True)
247
- col1, col2, col3 = st.columns([1, 2, 1])
248
- with col1:
249
- st.markdown('<div class="nav-item">Data Input</div>', unsafe_allow_html=True)
250
- uploaded_file = st.file_uploader("Upload CSV Dataset", type=["csv"])
251
- if uploaded_file:
252
- df = pd.read_csv(uploaded_file)
253
- st.session_state.vector_store = create_vector_store(convert_df_to_text(df))
254
- st.success("Dataset uploaded!")
255
- with col2:
256
- st.markdown('<div class="nav-item">Navigation</div>', unsafe_allow_html=True)
257
- nav_option = st.selectbox("Navigate", ["Model Builder", "Chat", "Train Model"], label_visibility="collapsed")
258
- with col3:
259
- st.markdown('<div class="nav-item">Info</div>', unsafe_allow_html=True)
260
- st.write("Built with Streamlit & Groq")
261
- st.markdown('</div>', unsafe_allow_html=True)
262
 
263
- # Main Content
264
- if nav_option == "Model Builder":
265
- st.markdown('<div class="card"><h2>Model Builder</h2></div>', unsafe_allow_html=True)
266
- mode = st.selectbox("Domain", ["Legal", "Financial", "Marketing"])
267
- model_builder_mode = st.radio("Mode", ["Prebuilt", "Custom"])
268
- st.session_state.model_builder_mode = "prebuilt" if model_builder_mode == "Prebuilt" else "custom"
269
 
270
- if st.session_state.model_builder_mode == "prebuilt":
271
- for name, details in PREBUILT_MODELS.items():
272
- if st.button(f"{name}: {details['description']}", key=name):
273
- st.session_state.prebuilt_selection = name
274
- st.session_state.model_config = details["architecture"]
275
- if st.session_state.prebuilt_selection:
276
- st.json(st.session_state.model_config)
277
- else:
278
- st.session_state.custom_model_type = st.selectbox("Type", ["classification", "regression", "clustering"])
279
- if st.session_state.custom_model_type != "clustering":
280
- layer_count = st.number_input("Layers", min_value=1, value=1)
281
- st.session_state.custom_layers = []
282
- for i in range(int(layer_count)):
283
- size = st.number_input(f"Layer {i+1} Size", min_value=1, value=100, key=f"size_{i}")
284
- activation = st.selectbox(f"Layer {i+1} Activation", ["relu", "tanh"], key=f"act_{i}")
285
- st.session_state.custom_layers.append((size, activation))
286
- optimizer = st.selectbox("Optimizer", ["adam", "sgd"])
287
- st.session_state.model_config = {"type": st.session_state.custom_model_type, "hidden_layers": st.session_state.custom_layers, "optimizer": optimizer, "learning_rate": 0.001}
288
- else:
289
- st.session_state.model_config = {"type": "clustering", "n_clusters": st.number_input("Clusters", min_value=2, value=3)}
290
- if st.button("Finalize"): st.json(st.session_state.model_config)
291
 
292
- elif nav_option == "Chat":
293
- st.markdown('<div class="chat-container"><h3>Chat with Grok</h3></div>', unsafe_allow_html=True)
294
- mode = st.selectbox("Domain", ["Legal", "Financial", "Marketing"])
295
- prompt = st.text_input("Ask a question:")
296
- if prompt:
297
- response = get_groq_response(prompt, mode)
298
- st.session_state.chat_history.append({"role": "user", "content": prompt})
299
- st.session_state.chat_history.append({"role": "bot", "content": response})
300
- for msg in st.session_state.chat_history:
301
- st.markdown(f'<div class={"user-message" if msg["role"] == "user" else "bot-message"}>{msg["content"]}</div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
- elif nav_option == "Train Model":
304
- if uploaded_file and st.session_state.model_config:
305
- st.markdown('<div class="card"><h2>Train Model</h2></div>', unsafe_allow_html=True)
306
- df = pd.read_csv(uploaded_file)
307
- X = df.drop(columns=[df.columns[-1]]) if st.session_state.model_config["type"] != "clustering" else df
308
- y = df[df.columns[-1]] if st.session_state.model_config["type"] != "clustering" else None
309
- if st.button("Train"):
310
- scaler = StandardScaler()
311
- X_scaled = scaler.fit_transform(X)
312
- model = build_model_from_config(st.session_state.model_config, X_scaled, y)
313
- if st.session_state.model_config["type"] != "clustering":
314
- X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
315
- model.fit(X_train, y_train)
316
- y_pred = model.predict(X_test)
317
- st.session_state.metrics = {"accuracy" if st.session_state.model_config["type"] == "classification" else "r2_score": accuracy_score(y_test, y_pred) if st.session_state.model_config["type"] == "classification" else r2_score(y_test, y_pred)}
318
- else:
319
- model.fit(X_scaled)
320
- st.session_state.metrics = {"silhouette_score": silhouette_score(X_scaled, model.labels_)}
321
- st.json(st.session_state.metrics)
322
- else:
323
- st.warning("Upload a dataset and configure a model first!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
- if __name__ == "__main__":
326
- main()
 
5
  from sklearn.model_selection import train_test_split
6
  from sklearn.neural_network import MLPClassifier, MLPRegressor
7
  from sklearn.cluster import KMeans
8
+ from sklearn.metrics import accuracy_score, r2_score, silhouette_score, confusion_matrix, classification_report, mean_squared_error
9
  from sklearn.preprocessing import StandardScaler
10
  from ydata_profiling import ProfileReport
11
  from streamlit_pandas_profiling import st_profile_report
12
  from groq import Groq
13
  from langchain_community.vectorstores import FAISS
14
  from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+ from langchain.embeddings import HuggingFaceEmbeddings
16
  from langchain_community.document_loaders import TextLoader
17
+ from langchain_community.tools.tavily_search import TavilySearchResults
18
  import os
19
+ from dotenv import load_dotenv
20
  import tempfile
21
 
22
+ # Load environment variables
23
+ load_dotenv()
24
+
25
+ # Initialize Groq client
26
  client = Groq(api_key=os.getenv("GROQ_API_KEY"))
27
+
28
+ # Initialize embeddings for FAISS
29
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
30
 
31
  # Set page config
32
  st.set_page_config(page_title="Neural-Vision Enhanced", layout="wide")
33
 
34
+ # Custom CSS matching previous theme
35
  st.markdown("""
36
  <style>
37
  :root {
38
+ --primary-blue: #3B82F6;
39
+ --dark-blue: #1E40AF;
40
+ --light-blue: #DBEAFE;
41
+ --medium-grey: #6B7280;
42
+ --light-grey: #F3F4F6;
43
+ --white: #FFFFFF;
44
+ --border-grey: #E5E7EB;
45
  }
46
  .stApp {
47
+ background-color: var(--light-grey);
48
  font-family: 'Inter', sans-serif;
49
  max-width: 1200px;
50
  margin: 0 auto;
 
51
  }
52
  .header {
53
+ background-color: var(--white);
54
+ border-bottom: 2px solid var(--border-grey);
55
  padding: 15px;
56
+ border-radius: 12px 12px 0 0;
57
+ box-shadow: 0 2px 4px rgba(0,0,0,0.05);
58
  text-align: center;
 
59
  }
60
  .header-title {
61
+ color: var(--dark-blue);
62
  font-size: 1.8rem;
63
  font-weight: 700;
64
  margin: 0;
65
  }
66
  .header-subtitle {
67
+ color: var(--medium-grey);
68
  font-size: 1rem;
69
  margin-top: 5px;
70
  }
71
+ .sidebar .sidebar-content {
72
+ background-color: var(--white);
73
+ border-radius: 12px;
74
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
75
  padding: 15px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  }
77
  .chat-container {
78
+ background-color: var(--white);
79
+ border-radius: 12px;
80
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
81
  padding: 15px;
82
  margin-top: 20px;
 
83
  }
84
  .user-message {
85
+ background-color: var(--primary-blue);
86
+ color: var(--white);
87
+ border-radius: 18px 18px 4px 18px;
88
+ padding: 12px 16px;
 
89
  margin-left: auto;
90
+ max-width: 80%;
91
  margin-bottom: 10px;
92
  }
93
  .bot-message {
94
+ background-color: var(--light-grey);
95
+ color: var(--medium-grey);
96
+ border-radius: 18px 18px 18px 4px;
97
+ padding: 12px 16px;
 
98
  margin-right: auto;
99
+ max-width: 80%;
100
  margin-bottom: 10px;
101
  }
102
+ </style>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  """, unsafe_allow_html=True)
104
 
105
+ # Initialize session state
106
  if 'metrics' not in st.session_state:
107
  st.session_state.metrics = {}
108
  if 'chat_history' not in st.session_state:
109
  st.session_state.chat_history = []
110
  if 'vector_store' not in st.session_state:
111
  st.session_state.vector_store = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  # Helper Functions
114
  def convert_df_to_text(df):
115
  text = f"Dataset Summary: {df.shape[0]} rows, {df.shape[1]} columns\n"
116
  text += f"Missing Values: {df.isna().sum().sum()}\n"
117
+ text += "Columns:\n"
118
  for col in df.columns:
119
+ text += f"- {col} ({df[col].dtype}): "
120
  if pd.api.types.is_numeric_dtype(df[col]):
121
+ text += f"Mean={df[col].mean():.2f}, Min={df[col].min()}, Max={df[col].max()}"
122
  else:
123
+ text += f"Unique={df[col].nunique()}, Top={df[col].mode()[0] if not df[col].mode().empty else 'N/A'}"
124
+ text += f", Missing={df[col].isna().sum()}\n"
125
  return text
126
 
127
  def create_vector_store(df_text):
 
130
  temp_path = temp_file.name
131
  loader = TextLoader(temp_path)
132
  documents = loader.load()
133
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
134
+ texts = text_splitter.split_documents(documents)
135
  vector_store = FAISS.from_documents(texts, embeddings)
136
  os.unlink(temp_path)
137
  return vector_store
138
 
139
+ def get_groq_response(prompt, mode, use_web_search=False):
140
  context = ""
141
  if st.session_state.vector_store:
142
  docs = st.session_state.vector_store.similarity_search(prompt, k=3)
143
+ context = "\n\nDataset Context:\n" + "\n".join([f"- {doc.page_content}" for doc in docs])
144
+
145
+ if use_web_search:
146
+ tavily = TavilySearchResults(max_results=3)
147
+ web_results = tavily.invoke(prompt)
148
+ context += "\n\nWeb Search Results:\n" + "\n".join([f"- {res['content'][:200]}..." for res in web_results])
149
+
150
+ prompts = {
151
+ "Legal": "You are a neural network expert specializing in legal data analysis.",
152
+ "Financial": "You are a neural network expert specializing in financial data analysis.",
153
+ "Academic": "You are a neural network expert specializing in academic data analysis.",
154
+ "Technical": "You are a neural network expert specializing in technical data analysis."
155
+ }
156
+ system_prompt = prompts.get(mode, "You are a neural network development assistant.") + "\n" + context
157
+
158
+ response = client.chat.completions.create(
159
+ model="llama3-70b-8192",
160
+ messages=[
161
+ {"role": "system", "content": system_prompt},
162
+ {"role": "user", "content": prompt}
163
+ ],
164
+ temperature=0.7,
165
+ max_tokens=1024
166
+ )
167
+ return response.choices[0].message.content
168
 
169
+ # Visualization Functions
170
+ def plot_confusion_matrix(y_true, y_pred):
171
+ cm = confusion_matrix(y_true, y_pred)
172
+ fig = px.imshow(cm, text_auto=True, color_continuous_scale='Blues', title="Confusion Matrix")
173
+ return fig
 
 
 
 
 
174
 
175
+ def plot_feature_importance(model, X):
176
+ if hasattr(model, 'feature_importances_'):
177
+ importance = model.feature_importances_
178
+ else:
179
+ importance = np.abs(model.coef_) if hasattr(model, 'coef_') else np.ones(X.shape[1])
180
+ fig = px.bar(x=X.columns, y=importance, title="Feature Importance")
181
+ return fig
182
 
183
+ def plot_residuals(y_true, y_pred):
184
+ residuals = y_true - y_pred
185
+ fig = px.scatter(x=y_pred, y=residuals, title="Residual Plot", labels={"x": "Predicted", "y": "Residuals"})
186
+ return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
+ def plot_clusters(X, labels):
189
+ fig = px.scatter(X, x=X.columns[0], y=X.columns[1], color=labels, title="Cluster Visualization")
190
+ return fig
 
 
 
191
 
192
+ # Pages
193
+ def data_upload_page():
194
+ st.header("📤 Data Upload & Analysis")
195
+ uploaded_file = st.file_uploader("Upload Dataset", type=["csv"])
196
+
197
+ if uploaded_file:
198
+ df = pd.read_csv(uploaded_file)
199
+ st.session_state.df = df
200
+ st.session_state.vector_store = create_vector_store(convert_df_to_text(df))
201
+ st.session_state.metrics = {}
202
+
203
+ st.subheader("Dataset Health Check")
204
+ col1, col2, col3 = st.columns(3)
205
+ col1.metric("Total Samples", df.shape[0])
206
+ col2.metric("Features", df.shape[1])
207
+ col3.metric("Missing Values", df.isna().sum().sum())
208
+
209
+ if st.button("Generate Full EDA Report"):
210
+ with st.spinner("Generating comprehensive analysis..."):
211
+ profile = ProfileReport(df, explorative=True)
212
+ st_profile_report(profile)
213
 
214
+ def model_training_page():
215
+ st.header("🧠 Neural Network Training Studio")
216
+
217
+ if 'df' not in st.session_state:
218
+ st.warning("Upload data first!")
219
+ return
220
+
221
+ df = st.session_state.df
222
+ problem_type = st.selectbox("Select Problem Type", ["Classification", "Regression", "Clustering"])
223
+ mode = st.selectbox("Domain Specialization", ["Legal", "Financial", "Academic", "Technical"])
224
+
225
+ if problem_type != "Clustering":
226
+ target = st.selectbox("Select Target Variable", df.columns)
227
+ X = df.drop(columns=[target])
228
+ y = df[target]
229
+ else:
230
+ X = df
231
+ y = None
232
+
233
+ if st.button("Train Neural Network"):
234
+ with st.spinner("Training in progress..."):
235
+ X_scaled = StandardScaler().fit_transform(X)
236
+ X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42) if y is not None else (X_scaled, None, None, None)
237
+
238
+ if problem_type == "Classification":
239
+ model = MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=500, random_state=42)
240
+ model.fit(X_train, y_train)
241
+ y_pred = model.predict(X_test)
242
+ st.session_state.metrics = {
243
+ "Accuracy": accuracy_score(y_test, y_pred),
244
+ "Classification Report": classification_report(y_test, y_pred, output_dict=True)
245
+ }
246
+ elif problem_type == "Regression":
247
+ model = MLPRegressor(hidden_layer_sizes=(100, 50), max_iter=500, random_state=42)
248
+ model.fit(X_train, y_train)
249
+ y_pred = model.predict(X_test)
250
+ st.session_state.metrics = {
251
+ "R2 Score": r2_score(y_test, y_pred),
252
+ "Mean Squared Error": mean_squared_error(y_test, y_pred)
253
+ }
254
+ else: # Clustering
255
+ model = KMeans(n_clusters=3, random_state=42)
256
+ labels = model.fit_predict(X_scaled)
257
+ st.session_state.metrics = {
258
+ "Silhouette Score": silhouette_score(X_scaled, labels)
259
+ }
260
+
261
+ st.session_state.best_model = model
262
+ st.session_state.X_test = X_test
263
+ st.session_state.y_test = y_test
264
+ st.session_state.y_pred = y_pred if y is not None else labels
265
+ st.session_state.problem_type = problem_type
266
+ st.success(f"Model trained successfully in {mode} mode!")
267
 
268
+ def visualization_page():
269
+ st.header("🔍 Neural Network Evaluation Center")
270
+
271
+ if 'best_model' not in st.session_state:
272
+ st.warning("Train a model first!")
273
+ return
274
+
275
+ st.subheader("Performance Analysis")
276
+ if st.session_state.problem_type == "Classification":
277
+ st.plotly_chart(plot_confusion_matrix(st.session_state.y_test, st.session_state.y_pred))
278
+ st.plotly_chart(plot_feature_importance(st.session_state.best_model, pd.DataFrame(st.session_state.X_test, columns=st.session_state.df.columns[:-1])))
279
+ elif st.session_state.problem_type == "Regression":
280
+ st.plotly_chart(plot_residuals(st.session_state.y_test, st.session_state.y_pred))
281
+ st.plotly_chart(plot_feature_importance(st.session_state.best_model, pd.DataFrame(st.session_state.X_test, columns=st.session_state.df.columns[:-1])))
282
+ else: # Clustering
283
+ st.plotly_chart(plot_clusters(pd.DataFrame(st.session_state.X_test, columns=st.session_state.df.columns), st.session_state.y_pred))
284
+
285
+ st.subheader("Metrics")
286
+ st.write(st.session_state.metrics)
287
+
288
+ # Chatbot Interface
289
+ def ai_assistant():
290
+ st.markdown('<div class="chat-container">', unsafe_allow_html=True)
291
+ st.subheader("🧠 Neural Insight Assistant (RAG + Web Search)")
292
+
293
+ use_web_search = st.checkbox("Enable Tavily Web Search", value=False)
294
+ mode = st.selectbox("Domain Mode", ["Legal", "Financial", "Academic", "Technical"], key="chat_mode")
295
+
296
+ for msg in st.session_state.chat_history:
297
+ with st.chat_message(msg["role"]):
298
+ st.markdown(f'<div class="{msg["role"]}-message">{msg["content"]}</div>', unsafe_allow_html=True)
299
+
300
+ if prompt := st.chat_input("Ask about data, models, or web insights..."):
301
+ st.session_state.chat_history.append({"role": "user", "content": prompt})
302
+ with st.chat_message("user"):
303
+ st.markdown(f'<div class="user-message">{prompt}</div>', unsafe_allow_html=True)
304
+
305
+ with st.spinner("Processing..."):
306
+ response = get_groq_response(prompt, mode, use_web_search)
307
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
308
+
309
+ with st.chat_message("assistant"):
310
+ st.markdown(f'<div class="bot-message">{response}</div>', unsafe_allow_html=True)
311
+
312
+ st.markdown('</div>', unsafe_allow_html=True)
313
+
314
+ # Main App Layout
315
+ st.markdown("""
316
+ <div class="header">
317
+ <h1 class="header-title">Neural-Vision Enhanced</h1>
318
+ <div class="header-subtitle">Neural Network Development for Domain-Specialized Analysis</div>
319
+ </div>
320
+ """, unsafe_allow_html=True)
321
+
322
+ with st.sidebar:
323
+ st.title("🔮 Neural-Vision Enhanced")
324
+ page = st.selectbox("Navigation", [
325
+ "Data Upload & Analysis",
326
+ "Neural Network Training Studio",
327
+ "Neural Network Evaluation Center"
328
+ ])
329
+ st.session_state.active_page = page
330
+ st.markdown("---")
331
+ st.markdown("**Environment Setup**")
332
+ os.environ["TAVILY_API_KEY"] = st.text_input("Tavily API Key", type="password", help="For web search functionality")
333
+ st.markdown("---")
334
+ st.markdown("v5.0 | © 2025 Neural-Vision")
335
+
336
+ # Page Routing
337
+ if "Data Upload & Analysis" in page:
338
+ data_upload_page()
339
+ elif "Neural Network Training Studio" in page:
340
+ model_training_page()
341
+ else:
342
+ visualization_page()
343
 
344
+ ai_assistant()