awacke1 commited on
Commit
d8d14b1
·
verified ·
1 Parent(s): 1e53277

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +494 -0
app.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import re
4
+ import streamlit as st
5
+ import streamlit.components.v1 as components
6
+ from urllib.parse import quote
7
+ import pandas as pd
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ from torch.utils.data import DataLoader, TensorDataset
12
+ import base64
13
+ import glob
14
+ import time
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer
16
+ from mergekit.config import MergeConfiguration
17
+ from mergekit.merge import Mergekit
18
+ from spectrum import SpectrumAnalyzer
19
+ import distilkit
20
+ import yaml
21
+ from dataclasses import dataclass
22
+ from typing import Optional, List
23
+ import logging
24
+
25
+ # Configure logging
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Page Configuration
30
+ st.set_page_config(
31
+ page_title="AI Knowledge Tree Builder 📈🌿",
32
+ page_icon="🌳✨",
33
+ layout="wide",
34
+ initial_sidebar_state="auto",
35
+ )
36
+
37
+ # Predefined Knowledge Trees
38
+ trees = {
39
+ "ML Engineering": """
40
+ 0. ML Engineering 🌐
41
+ 1. Data Preparation
42
+ - Load Data 📊
43
+ - Preprocess Data 🛠️
44
+ 2. Model Building
45
+ - Train Model 🤖
46
+ - Evaluate Model 📈
47
+ 3. Deployment
48
+ - Deploy Model 🚀
49
+ """,
50
+ "Health": """
51
+ 0. Health and Wellness 🌿
52
+ 1. Physical Health
53
+ - Exercise 🏋️
54
+ - Nutrition 🍎
55
+ 2. Mental Health
56
+ - Meditation 🧘
57
+ - Therapy 🛋️
58
+ """,
59
+ }
60
+
61
+ # Project Seeds
62
+ project_seeds = {
63
+ "Code Project": """
64
+ 0. Code Project 📂
65
+ 1. app.py 🐍
66
+ 2. requirements.txt 📦
67
+ 3. README.md 📄
68
+ """,
69
+ "Papers Project": """
70
+ 0. Papers Project 📚
71
+ 1. markdown 📝
72
+ 2. mermaid 🖼️
73
+ 3. huggingface.co 🤗
74
+ """,
75
+ "AI Project": """
76
+ 0. AI Project 🤖
77
+ 1. Streamlit Torch Transformers
78
+ - Streamlit 🌐
79
+ - Torch 🔥
80
+ - Transformers 🤖
81
+ 2. DistillKit MergeKit Spectrum
82
+ - DistillKit 🧪
83
+ - MergeKit 🔄
84
+ - Spectrum 📊
85
+ 3. Transformers Diffusers Datasets
86
+ - Transformers 🤖
87
+ - Diffusers 🎨
88
+ - Datasets 📊
89
+ """,
90
+ }
91
+
92
+ # Meta class for model configuration
93
+ class ModelMeta(type):
94
+ def __new__(cls, name, bases, attrs):
95
+ attrs['registry'] = {}
96
+ return super().__new__(cls, name, bases, attrs)
97
+
98
+ # Base Model Configuration Class
99
+ @dataclass
100
+ class ModelConfig(metaclass=ModelMeta):
101
+ name: str
102
+ base_model: str
103
+ size: str
104
+ domain: Optional[str] = None
105
+
106
+ def __init_subclass__(cls):
107
+ ModelConfig.registry[cls.__name__] = cls
108
+
109
+ @property
110
+ def model_path(self):
111
+ return f"models/{self.name}"
112
+
113
+ # Decorator for pipeline stages
114
+ def pipeline_stage(func):
115
+ def wrapper(*args, **kwargs):
116
+ st.spinner(f"Running {func.__name__}...")
117
+ result = func(*args, **kwargs)
118
+ st.success(f"Completed {func.__name__}!")
119
+ return result
120
+ return wrapper
121
+
122
+ # Model Builder Class
123
+ class ModelBuilder:
124
+ def __init__(self):
125
+ self.config = None
126
+ self.model = None
127
+ self.tokenizer = None
128
+
129
+ @pipeline_stage
130
+ def load_base_model(self, model_name: str):
131
+ """Load base model from Hugging Face"""
132
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
133
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
134
+ return self
135
+
136
+ @pipeline_stage
137
+ def apply_distillation(self, teacher_model: str, output_dir: str):
138
+ """Apply DistilKit for model distillation"""
139
+ distiller = distilkit.Distiller(
140
+ teacher_model=teacher_model,
141
+ student_model=self.model,
142
+ tokenizer=self.tokenizer
143
+ )
144
+ distiller.distill(output_dir=output_dir)
145
+ self.model = distiller.student_model
146
+ return self
147
+
148
+ @pipeline_stage
149
+ def apply_merge(self, models_to_merge: List[str], output_dir: str):
150
+ """Apply Mergekit for model merging"""
151
+ merge_config = MergeConfiguration(
152
+ models=models_to_merge,
153
+ merge_method="linear",
154
+ output_dir=output_dir
155
+ )
156
+ merger = Mergekit(merge_config)
157
+ merger.run()
158
+ self.model = AutoModelForCausalLM.from_pretrained(output_dir)
159
+ return self
160
+
161
+ @pipeline_stage
162
+ def apply_spectrum(self, domain_data: str):
163
+ """Apply Spectrum for domain specialization"""
164
+ analyzer = SpectrumAnalyzer(self.model)
165
+ analyzer.fit(domain_data)
166
+ self.model = analyzer.specialized_model
167
+ return self
168
+
169
+ def save_model(self, path: str):
170
+ """Save the final model"""
171
+ self.model.save_pretrained(path)
172
+ self.tokenizer.save_pretrained(path)
173
+
174
+ # Utility Functions
175
+ def sanitize_label(label):
176
+ """Remove invalid characters for Mermaid labels."""
177
+ return re.sub(r'[^\w\s-]', '', label).replace(' ', '_')
178
+
179
+ def sanitize_filename(label):
180
+ """Make a valid filename from a label."""
181
+ return re.sub(r'[^\w\s-]', '', label).replace(' ', '_')
182
+
183
+ def parse_outline_to_mermaid(outline_text, search_agent):
184
+ """Convert tree outline to Mermaid syntax with clickable nodes."""
185
+ lines = outline_text.strip().split('\n')
186
+ nodes, edges, clicks, stack = [], [], [], []
187
+ for line in lines:
188
+ indent = len(line) - len(line.lstrip())
189
+ level = indent // 4
190
+ label = re.sub(r'^[#*\->\d\.\s]+', '', line.strip())
191
+ if label:
192
+ node_id = f"N{len(nodes)}"
193
+ sanitized_label = sanitize_label(label)
194
+ nodes.append(f'{node_id}["{label}"]')
195
+ search_url = search_urls[search_agent](label)
196
+ clicks.append(f'click {node_id} "{search_url}" _blank')
197
+ if stack:
198
+ parent_level = stack[-1][0]
199
+ if level > parent_level:
200
+ edges.append(f"{stack[-1][1]} --> {node_id}")
201
+ stack.append((level, node_id))
202
+ else:
203
+ while stack and stack[-1][0] >= level:
204
+ stack.pop()
205
+ if stack:
206
+ edges.append(f"{stack[-1][1]} --> {node_id}")
207
+ stack.append((level, node_id))
208
+ else:
209
+ stack.append((level, node_id))
210
+ return "%%{init: {'themeVariables': {'fontSize': '18px'}}}%%\nflowchart LR\n" + "\n".join(nodes + edges + clicks)
211
+
212
+ def generate_mermaid_html(mermaid_code):
213
+ """Generate HTML to display Mermaid diagram."""
214
+ return f"""
215
+ <html><head><script src="https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"></script>
216
+ <style>.centered-mermaid{{display:flex;justify-content:center;margin:20px auto;}}</style></head>
217
+ <body><div class="mermaid centered-mermaid">{mermaid_code}</div>
218
+ <script>mermaid.initialize({{startOnLoad:true}});</script></body></html>
219
+ """
220
+
221
+ def grow_tree(base_tree, new_node_name, parent_node):
222
+ """Add a new node to the tree under a specified parent."""
223
+ lines = base_tree.strip().split('\n')
224
+ new_lines = []
225
+ added = False
226
+ for line in lines:
227
+ new_lines.append(line)
228
+ if parent_node in line and not added:
229
+ indent = len(line) - len(line.lstrip())
230
+ new_lines.append(f"{' ' * (indent + 4)}- {new_node_name} 🌱")
231
+ added = True
232
+ return "\n".join(new_lines)
233
+
234
+ def get_download_link(file_path, mime_type="text/plain"):
235
+ """Generate a download link for a file."""
236
+ with open(file_path, 'rb') as f:
237
+ data = f.read()
238
+ b64 = base64.b64encode(data).decode()
239
+ return f'<a href="data:{mime_type};base64,{b64}" download="{file_path}">Download {file_path}</a>'
240
+
241
+ def save_tree_to_file(tree_text, parent_node, new_node):
242
+ """Save tree to a markdown file with name based on nodes."""
243
+ root_node = tree_text.strip().split('\n')[0].split('.')[1].strip() if tree_text.strip() else "Knowledge_Tree"
244
+ filename = f"{sanitize_filename(root_node)}_{sanitize_filename(parent_node)}_{sanitize_filename(new_node)}_{int(time.time())}.md"
245
+
246
+ mermaid_code = parse_outline_to_mermaid(tree_text, "🔮Google") # Default search engine for saved trees
247
+ export_md = f"# Knowledge Tree: {root_node}\n\n## Outline\n{tree_text}\n\n## Mermaid Diagram\n```mermaid\n{mermaid_code}\n```"
248
+
249
+ with open(filename, "w") as f:
250
+ f.write(export_md)
251
+ return filename
252
+
253
+ def load_trees_from_files():
254
+ """Load all saved tree markdown files."""
255
+ tree_files = glob.glob("*.md")
256
+ trees_dict = {}
257
+
258
+ for file in tree_files:
259
+ if file != "README.md" and file != "knowledge_tree.md": # Skip project README and temp export
260
+ try:
261
+ with open(file, 'r') as f:
262
+ content = f.read()
263
+ # Extract the tree name from the first line
264
+ match = re.search(r'# Knowledge Tree: (.*)', content)
265
+ if match:
266
+ tree_name = match.group(1)
267
+ else:
268
+ tree_name = os.path.splitext(file)[0]
269
+
270
+ # Extract the outline section
271
+ outline_match = re.search(r'## Outline\n(.*?)(?=\n## |$)', content, re.DOTALL)
272
+ if outline_match:
273
+ tree_outline = outline_match.group(1).strip()
274
+ trees_dict[f"{tree_name} ({file})"] = tree_outline
275
+ except Exception as e:
276
+ print(f"Error loading {file}: {e}")
277
+
278
+ return trees_dict
279
+
280
+ # Search Agents (Highest resolution social network default: X)
281
+ search_urls = {
282
+ "📚📖ArXiv": lambda k: f"/?q={quote(k)}",
283
+ "🔮Google": lambda k: f"https://www.google.com/search?q={quote(k)}",
284
+ "📺Youtube": lambda k: f"https://www.youtube.com/results?search_query={quote(k)}",
285
+ "🔭Bing": lambda k: f"https://www.bing.com/search?q={quote(k)}",
286
+ "💡Truth": lambda k: f"https://truthsocial.com/search?q={quote(k)}",
287
+ "📱X": lambda k: f"https://twitter.com/search?q={quote(k)}",
288
+ }
289
+
290
+ # Main App
291
+ st.title("🌳 AI Knowledge Tree Builder 🌱")
292
+
293
+ # Sidebar with saved trees
294
+ st.sidebar.title("Saved Trees")
295
+ saved_trees = load_trees_from_files()
296
+ selected_saved_tree = st.sidebar.selectbox("Select a saved tree", ["None"] + list(saved_trees.keys()))
297
+
298
+ # Select Project Type
299
+ project_type = st.selectbox("Select Project Type", ["Code Project", "Papers Project", "AI Project"])
300
+
301
+ # Initialize or load tree
302
+ if 'current_tree' not in st.session_state:
303
+ if selected_saved_tree != "None" and selected_saved_tree in saved_trees:
304
+ st.session_state['current_tree'] = saved_trees[selected_saved_tree]
305
+ else:
306
+ st.session_state['current_tree'] = trees.get("ML Engineering", project_seeds[project_type])
307
+ elif selected_saved_tree != "None" and selected_saved_tree in saved_trees:
308
+ st.session_state['current_tree'] = saved_trees[selected_saved_tree]
309
+
310
+ # Select Search Agent for Node Links
311
+ search_agent = st.selectbox("Select Search Agent for Node Links", list(search_urls.keys()), index=5) # Default to X
312
+
313
+ # Tree Growth
314
+ new_node = st.text_input("Add New Node")
315
+ parent_node = st.text_input("Parent Node")
316
+ if st.button("Grow Tree 🌱") and new_node and parent_node:
317
+ st.session_state['current_tree'] = grow_tree(st.session_state['current_tree'], new_node, parent_node)
318
+
319
+ # Save to a new file with the node names
320
+ saved_file = save_tree_to_file(st.session_state['current_tree'], parent_node, new_node)
321
+ st.success(f"Added '{new_node}' under '{parent_node}' and saved to {saved_file}!")
322
+
323
+ # Also update the temporary current_tree.md for compatibility
324
+ with open("current_tree.md", "w") as f:
325
+ f.write(st.session_state['current_tree'])
326
+
327
+ # Display Mermaid Diagram
328
+ st.markdown("### Knowledge Tree Visualization")
329
+ mermaid_code = parse_outline_to_mermaid(st.session_state['current_tree'], search_agent)
330
+ components.html(generate_mermaid_html(mermaid_code), height=600)
331
+
332
+ # Export Tree
333
+ if st.button("Export Tree as Markdown"):
334
+ export_md = f"# Knowledge Tree\n\n## Outline\n{st.session_state['current_tree']}\n\n## Mermaid Diagram\n```mermaid\n{mermaid_code}\n```"
335
+ with open("knowledge_tree.md", "w") as f:
336
+ f.write(export_md)
337
+ st.markdown(get_download_link("knowledge_tree.md", "text/markdown"), unsafe_allow_html=True)
338
+
339
+ # AI Project: Model Building Options
340
+ if project_type == "AI Project":
341
+ st.subheader("AI Model Building Options")
342
+ model_option = st.radio("Choose Model Building Method", ["Minimal ML Model from CSV", "Advanced Model Pipeline"])
343
+
344
+ if model_option == "Minimal ML Model from CSV":
345
+ st.write("### Build Minimal ML Model from CSV")
346
+ uploaded_file = st.file_uploader("Upload CSV", type="csv")
347
+ if uploaded_file:
348
+ df = pd.read_csv(uploaded_file)
349
+ st.write("Columns:", df.columns.tolist())
350
+ feature_cols = st.multiselect("Select feature columns", df.columns)
351
+ target_col = st.selectbox("Select target column", df.columns)
352
+ if st.button("Train Model"):
353
+ X = df[feature_cols].values
354
+ y = df[target_col].values
355
+ X_tensor = torch.tensor(X, dtype=torch.float32)
356
+ y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1)
357
+ dataset = TensorDataset(X_tensor, y_tensor)
358
+ loader = DataLoader(dataset, batch_size=32, shuffle=True)
359
+ model = nn.Linear(X.shape[1], 1)
360
+ criterion = nn.MSELoss()
361
+ optimizer = optim.Adam(model.parameters(), lr=0.01)
362
+ for epoch in range(10):
363
+ for batch_X, batch_y in loader:
364
+ optimizer.zero_grad()
365
+ outputs = model(batch_X)
366
+ loss = criterion(outputs, batch_y)
367
+ loss.backward()
368
+ optimizer.step()
369
+ torch.save(model.state_dict(), "model.pth")
370
+ app_code = f"""
371
+ import streamlit as st
372
+ import torch
373
+ import torch.nn as nn
374
+
375
+ model = nn.Linear({len(feature_cols)}, 1)
376
+ model.load_state_dict(torch.load("model.pth"))
377
+ model.eval()
378
+
379
+ st.title("ML Model Demo")
380
+ inputs = []
381
+ for col in {feature_cols}:
382
+ inputs.append(st.number_input(col))
383
+ if st.button("Predict"):
384
+ input_tensor = torch.tensor([inputs], dtype=torch.float32)
385
+ prediction = model(input_tensor).item()
386
+ st.write(f"Predicted {target_col}: {{prediction}}")
387
+ """
388
+ with open("app.py", "w") as f:
389
+ f.write(app_code)
390
+ reqs = "streamlit\ntorch\npandas\n"
391
+ with open("requirements.txt", "w") as f:
392
+ f.write(reqs)
393
+ readme = """
394
+ # ML Model Demo
395
+
396
+ ## How to run
397
+ 1. Install requirements: `pip install -r requirements.txt`
398
+ 2. Run the app: `streamlit run app.py`
399
+ 3. Input feature values and click "Predict".
400
+ """
401
+ with open("README.md", "w") as f:
402
+ f.write(readme)
403
+ st.markdown(get_download_link("model.pth", "application/octet-stream"), unsafe_allow_html=True)
404
+ st.markdown(get_download_link("app.py", "text/plain"), unsafe_allow_html=True)
405
+ st.markdown(get_download_link("requirements.txt", "text/plain"), unsafe_allow_html=True)
406
+ st.markdown(get_download_link("README.md", "text/markdown"), unsafe_allow_html=True)
407
+
408
+ elif model_option == "Advanced Model Pipeline":
409
+ st.write("### Advanced Model Building Pipeline")
410
+
411
+ # Model Configuration
412
+ with st.expander("Model Configuration", expanded=True):
413
+ base_model = st.selectbox(
414
+ "Select Base Model",
415
+ ["mistral-7b", "llama-2-7b", "gpt2-medium"]
416
+ )
417
+ model_name = st.text_input("Model Name", "custom-model")
418
+ domain = st.text_input("Target Domain", "general")
419
+ use_distillation = st.checkbox("Apply Distillation", True)
420
+ use_merging = st.checkbox("Apply Model Merging", False)
421
+ use_spectrum = st.checkbox("Apply Spectrum Specialization", True)
422
+
423
+ # Build Model
424
+ if st.button("Build Advanced Model"):
425
+ config = ModelConfig(
426
+ name=model_name,
427
+ base_model=base_model,
428
+ size="7B",
429
+ domain=domain
430
+ )
431
+ builder = ModelBuilder()
432
+
433
+ with st.status("Building advanced model...", expanded=True) as status:
434
+ builder.load_base_model(config.base_model)
435
+
436
+ if use_distillation:
437
+ teacher_model = st.selectbox(
438
+ "Select Teacher Model",
439
+ ["mistral-13b", "llama-2-13b"]
440
+ )
441
+ builder.apply_distillation(teacher_model, f"distilled_{config.name}")
442
+
443
+ if use_merging:
444
+ models_to_merge = st.multiselect(
445
+ "Select Models to Merge",
446
+ ["mistral-7b", "llama-2-7b", "gpt2-medium"]
447
+ )
448
+ builder.apply_merge(models_to_merge, f"merged_{config.name}")
449
+
450
+ if use_spectrum:
451
+ domain_data = st.text_area("Enter domain-specific data", "Sample domain data")
452
+ builder.apply_spectrum(domain_data)
453
+
454
+ builder.save_model(config.model_path)
455
+ status.update(label="Advanced model built successfully!", state="complete")
456
+
457
+ # Generate deployment files
458
+ app_code = f"""
459
+ import streamlit as st
460
+ from transformers import AutoModelForCausalLM, AutoTokenizer
461
+
462
+ model = AutoModelForCausalLM.from_pretrained("{config.model_path}")
463
+ tokenizer = AutoTokenizer.from_pretrained("{config.model_path}")
464
+
465
+ st.title("Advanced Model Demo")
466
+ input_text = st.text_area("Enter text")
467
+ if st.button("Generate"):
468
+ inputs = tokenizer(input_text, return_tensors="pt")
469
+ outputs = model.generate(**inputs)
470
+ st.write(tokenizer.decode(outputs[0], skip_special_tokens=True))
471
+ """
472
+ with open("advanced_app.py", "w") as f:
473
+ f.write(app_code)
474
+ reqs = "streamlit\ntorch\ntransformers\n"
475
+ with open("advanced_requirements.txt", "w") as f:
476
+ f.write(reqs)
477
+ readme = f"""
478
+ # Advanced Model Demo
479
+
480
+ ## How to run
481
+ 1. Install requirements: `pip install -r advanced_requirements.txt`
482
+ 2. Run the app: `streamlit run advanced_app.py`
483
+ 3. Input text and click "Generate".
484
+ """
485
+ with open("advanced_README.md", "w") as f:
486
+ f.write(readme)
487
+
488
+ st.markdown(get_download_link("advanced_app.py", "text/plain"), unsafe_allow_html=True)
489
+ st.markdown(get_download_link("advanced_requirements.txt", "text/plain"), unsafe_allow_html=True)
490
+ st.markdown(get_download_link("advanced_README.md", "text/markdown"), unsafe_allow_html=True)
491
+ st.write(f"Model saved at: {config.model_path}")
492
+
493
+ if __name__ == "__main__":
494
+ st.run()