awacke1 commited on
Commit
0559956
·
verified ·
1 Parent(s): 1a5fcce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -380
app.py CHANGED
@@ -1,88 +1,31 @@
1
  #!/usr/bin/env python3
2
  import os
3
- import re
4
  import streamlit as st
5
- import streamlit.components.v1 as components
6
- from urllib.parse import quote
7
  import pandas as pd
8
  import torch
9
- import torch.nn as nn
10
- import torch.optim as optim
11
- from torch.utils.data import DataLoader, TensorDataset
12
- import base64
13
- import glob
14
- import time
15
  from transformers import AutoModelForCausalLM, AutoTokenizer
16
  from torch.utils.data import Dataset, DataLoader
17
  import csv
 
18
  from dataclasses import dataclass
19
  from typing import Optional
20
 
21
  # Page Configuration
22
  st.set_page_config(
23
- page_title="AI Knowledge Tree Builder 📈🌿",
24
- page_icon="🌳✨",
25
  layout="wide",
26
- initial_sidebar_state="auto",
27
  )
28
 
29
- # Predefined Knowledge Trees
30
- trees = {
31
- "ML Engineering": """
32
- 0. ML Engineering 🌐
33
- 1. Data Preparation
34
- - Load Data 📊
35
- - Preprocess Data 🛠️
36
- 2. Model Building
37
- - Train Model 🤖
38
- - Evaluate Model 📈
39
- 3. Deployment
40
- - Deploy Model 🚀
41
- """,
42
- "Health": """
43
- 0. Health and Wellness 🌿
44
- 1. Physical Health
45
- - Exercise 🏋️
46
- - Nutrition 🍎
47
- 2. Mental Health
48
- - Meditation 🧘
49
- - Therapy 🛋️
50
- """,
51
- }
52
-
53
- # Project Seeds
54
- project_seeds = {
55
- "Code Project": """
56
- 0. Code Project 📂
57
- 1. app.py 🐍
58
- 2. requirements.txt 📦
59
- 3. README.md 📄
60
- """,
61
- "Papers Project": """
62
- 0. Papers Project 📚
63
- 1. markdown 📝
64
- 2. mermaid 🖼️
65
- 3. huggingface.co 🤗
66
- """,
67
- "AI Project": """
68
- 0. AI Project 🤖
69
- 1. Streamlit Torch Transformers
70
- - Streamlit 🌐
71
- - Torch 🔥
72
- - Transformers 🤖
73
- 2. SFT Fine-Tuning
74
- - SFT 🤓
75
- - Small Models 📉
76
- """,
77
- }
78
-
79
  # Meta class for model configuration
80
  class ModelMeta(type):
81
  def __new__(cls, name, bases, attrs):
82
  attrs['registry'] = {}
83
  return super().__new__(cls, name, bases, attrs)
84
 
85
- # Base Model Configuration Class
86
  @dataclass
87
  class ModelConfig(metaclass=ModelMeta):
88
  name: str
@@ -121,10 +64,10 @@ class SFTDataset(Dataset):
121
  return {
122
  "input_ids": encoding["input_ids"].squeeze(),
123
  "attention_mask": encoding["attention_mask"].squeeze(),
124
- "labels": encoding["input_ids"].squeeze() # For causal LM, labels are the same as input_ids
125
  }
126
 
127
- # Model Builder Class with SFT and Evaluation
128
  class ModelBuilder:
129
  def __init__(self):
130
  self.config = None
@@ -132,62 +75,53 @@ class ModelBuilder:
132
  self.tokenizer = None
133
  self.sft_data = None
134
 
135
- def load_base_model(self, model_name: str):
136
- """Load base model from Hugging Face"""
137
- with st.spinner("Loading base model..."):
138
- self.model = AutoModelForCausalLM.from_pretrained(model_name)
139
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
140
  if self.tokenizer.pad_token is None:
141
  self.tokenizer.pad_token = self.tokenizer.eos_token
142
- st.success("Base model loaded!")
143
  return self
144
 
145
  def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
146
  """Perform Supervised Fine-Tuning with CSV data"""
147
- # Load CSV data
148
  self.sft_data = []
149
  with open(csv_path, "r") as f:
150
  reader = csv.DictReader(f)
151
  for row in reader:
152
  self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
153
 
154
- # Prepare dataset and dataloader
155
  dataset = SFTDataset(self.sft_data, self.tokenizer)
156
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
 
157
 
158
- # Set up optimizer
159
- optimizer = optim.AdamW(self.model.parameters(), lr=2e-5)
160
-
161
- # Training loop
162
  self.model.train()
163
  for epoch in range(epochs):
164
- with st.spinner(f"Training epoch {epoch + 1}/{epochs}..."):
165
  total_loss = 0
166
  for batch in dataloader:
167
  optimizer.zero_grad()
168
  input_ids = batch["input_ids"].to(self.model.device)
169
  attention_mask = batch["attention_mask"].to(self.model.device)
170
  labels = batch["labels"].to(self.model.device)
171
-
172
- outputs = self.model(
173
- input_ids=input_ids,
174
- attention_mask=attention_mask,
175
- labels=labels
176
- )
177
  loss = outputs.loss
178
  loss.backward()
179
  optimizer.step()
180
  total_loss += loss.item()
181
  st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
182
- st.success("SFT Fine-tuning completed!")
183
  return self
184
 
185
  def save_model(self, path: str):
186
  """Save the fine-tuned model"""
187
- with st.spinner("Saving model..."):
 
188
  self.model.save_pretrained(path)
189
  self.tokenizer.save_pretrained(path)
190
- st.success("Model saved!")
191
 
192
  def evaluate(self, prompt: str):
193
  """Evaluate the model with a prompt"""
@@ -198,295 +132,115 @@ class ModelBuilder:
198
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
199
 
200
  # Utility Functions
201
- def sanitize_label(label):
202
- """Remove invalid characters for Mermaid labels."""
203
- return re.sub(r'[^\w\s-]', '', label).replace(' ', '_')
204
-
205
- def sanitize_filename(label):
206
- """Make a valid filename from a label."""
207
- return re.sub(r'[^\w\s-]', '', label).replace(' ', '_')
208
-
209
- def parse_outline_to_mermaid(outline_text, search_agent):
210
- """Convert tree outline to Mermaid syntax with clickable nodes."""
211
- lines = outline_text.strip().split('\n')
212
- nodes, edges, clicks, stack = [], [], [], []
213
- for line in lines:
214
- indent = len(line) - len(line.lstrip())
215
- level = indent // 4
216
- label = re.sub(r'^[#*\->\d\.\s]+', '', line.strip())
217
- if label:
218
- node_id = f"N{len(nodes)}"
219
- sanitized_label = sanitize_label(label)
220
- nodes.append(f'{node_id}["{label}"]')
221
- search_url = search_urls[search_agent](label)
222
- clicks.append(f'click {node_id} "{search_url}" _blank')
223
- if stack:
224
- parent_level = stack[-1][0]
225
- if level > parent_level:
226
- edges.append(f"{stack[-1][1]} --> {node_id}")
227
- stack.append((level, node_id))
228
- else:
229
- while stack and stack[-1][0] >= level:
230
- stack.pop()
231
- if stack:
232
- edges.append(f"{stack[-1][1]} --> {node_id}")
233
- stack.append((level, node_id))
234
- else:
235
- stack.append((level, node_id))
236
- return "%%{init: {'themeVariables': {'fontSize': '18px'}}}%%\nflowchart LR\n" + "\n".join(nodes + edges + clicks)
237
-
238
- def generate_mermaid_html(mermaid_code):
239
- """Generate HTML to display Mermaid diagram."""
240
- return f"""
241
- <html><head><script src="https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"></script>
242
- <style>.centered-mermaid{{display:flex;justify-content:center;margin:20px auto;}}</style></head>
243
- <body><div class="mermaid centered-mermaid">{mermaid_code}</div>
244
- <script>mermaid.initialize({{startOnLoad:true}});</script></body></html>
245
- """
246
-
247
- def grow_tree(base_tree, new_node_name, parent_node):
248
- """Add a new node to the tree under a specified parent."""
249
- lines = base_tree.strip().split('\n')
250
- new_lines = []
251
- added = False
252
- for line in lines:
253
- new_lines.append(line)
254
- if parent_node in line and not added:
255
- indent = len(line) - len(line.lstrip())
256
- new_lines.append(f"{' ' * (indent + 4)}- {new_node_name} 🌱")
257
- added = True
258
- return "\n".join(new_lines)
259
-
260
- def get_download_link(file_path, mime_type="text/plain"):
261
  """Generate a download link for a file."""
262
  with open(file_path, 'rb') as f:
263
  data = f.read()
264
  b64 = base64.b64encode(data).decode()
265
- return f'<a href="data:{mime_type};base64,{b64}" download="{file_path}">Download {file_path}</a>'
266
-
267
- def save_tree_to_file(tree_text, parent_node, new_node):
268
- """Save tree to a markdown file with name based on nodes."""
269
- root_node = tree_text.strip().split('\n')[0].split('.')[1].strip() if tree_text.strip() else "Knowledge_Tree"
270
- filename = f"{sanitize_filename(root_node)}_{sanitize_filename(parent_node)}_{sanitize_filename(new_node)}_{int(time.time())}.md"
271
-
272
- mermaid_code = parse_outline_to_mermaid(tree_text, "🔮Google") # Default search engine for saved trees
273
- export_md = f"# Knowledge Tree: {root_node}\n\n## Outline\n{tree_text}\n\n## Mermaid Diagram\n```mermaid\n{mermaid_code}\n```"
274
-
275
- with open(filename, "w") as f:
276
- f.write(export_md)
277
- return filename
278
-
279
- def load_trees_from_files():
280
- """Load all saved tree markdown files."""
281
- tree_files = glob.glob("*.md")
282
- trees_dict = {}
283
-
284
- for file in tree_files:
285
- if file != "README.md" and file != "knowledge_tree.md": # Skip project README and temp export
286
- try:
287
- with open(file, 'r') as f:
288
- content = f.read()
289
- # Extract the tree name from the first line
290
- match = re.search(r'# Knowledge Tree: (.*)', content)
291
- if match:
292
- tree_name = match.group(1)
293
- else:
294
- tree_name = os.path.splitext(file)[0]
295
-
296
- # Extract the outline section
297
- outline_match = re.search(r'## Outline\n(.*?)(?=\n## |$)', content, re.DOTALL)
298
- if outline_match:
299
- tree_outline = outline_match.group(1).strip()
300
- trees_dict[f"{tree_name} ({file})"] = tree_outline
301
- except Exception as e:
302
- print(f"Error loading {file}: {e}")
303
-
304
- return trees_dict
305
 
306
- # Search Agents (Highest resolution social network default: X)
307
- search_urls = {
308
- "📚📖ArXiv": lambda k: f"/?q={quote(k)}",
309
- "🔮Google": lambda k: f"https://www.google.com/search?q={quote(k)}",
310
- "📺Youtube": lambda k: f"https://www.youtube.com/results?search_query={quote(k)}",
311
- "🔭Bing": lambda k: f"https://www.bing.com/search?q={quote(k)}",
312
- "💡Truth": lambda k: f"https://truthsocial.com/search?q={quote(k)}",
313
- "📱X": lambda k: f"https://twitter.com/search?q={quote(k)}",
314
- }
315
 
316
  # Main App
317
- st.title("🌳 AI Knowledge Tree Builder 🌱")
318
-
319
- # Sidebar with saved trees
320
- st.sidebar.title("Saved Trees")
321
- saved_trees = load_trees_from_files()
322
- selected_saved_tree = st.sidebar.selectbox("Select a saved tree", ["None"] + list(saved_trees.keys()))
323
-
324
- # Select Project Type
325
- project_type = st.selectbox("Select Project Type", ["Code Project", "Papers Project", "AI Project"])
326
-
327
- # Initialize or load tree
328
- if 'current_tree' not in st.session_state:
329
- if selected_saved_tree != "None" and selected_saved_tree in saved_trees:
330
- st.session_state['current_tree'] = saved_trees[selected_saved_tree]
331
- else:
332
- st.session_state['current_tree'] = trees.get("ML Engineering", project_seeds[project_type])
333
- elif selected_saved_tree != "None" and selected_saved_tree in saved_trees:
334
- st.session_state['current_tree'] = saved_trees[selected_saved_tree]
335
-
336
- # Select Search Agent for Node Links
337
- search_agent = st.selectbox("Select Search Agent for Node Links", list(search_urls.keys()), index=5) # Default to X
338
-
339
- # Tree Growth
340
- new_node = st.text_input("Add New Node")
341
- parent_node = st.text_input("Parent Node")
342
- if st.button("Grow Tree 🌱") and new_node and parent_node:
343
- st.session_state['current_tree'] = grow_tree(st.session_state['current_tree'], new_node, parent_node)
344
-
345
- # Save to a new file with the node names
346
- saved_file = save_tree_to_file(st.session_state['current_tree'], parent_node, new_node)
347
- st.success(f"Added '{new_node}' under '{parent_node}' and saved to {saved_file}!")
348
-
349
- # Also update the temporary current_tree.md for compatibility
350
- with open("current_tree.md", "w") as f:
351
- f.write(st.session_state['current_tree'])
352
  st.rerun()
353
 
354
- # Display Mermaid Diagram
355
- st.markdown("### Knowledge Tree Visualization")
356
- mermaid_code = parse_outline_to_mermaid(st.session_state['current_tree'], search_agent)
357
- components.html(generate_mermaid_html(mermaid_code), height=600)
358
-
359
- # Export Tree
360
- if st.button("Export Tree as Markdown"):
361
- export_md = f"# Knowledge Tree\n\n## Outline\n{st.session_state['current_tree']}\n\n## Mermaid Diagram\n```mermaid\n{mermaid_code}\n```"
362
- with open("knowledge_tree.md", "w") as f:
363
- f.write(export_md)
364
- st.markdown(get_download_link("knowledge_tree.md", "text/markdown"), unsafe_allow_html=True)
365
-
366
- # AI Project: Model Building Options
367
- if project_type == "AI Project":
368
- st.subheader("AI Model Building Options")
369
- model_option = st.radio("Choose Model Building Method", ["Minimal ML Model from CSV", "SFT Fine-Tuning"])
370
-
371
- if model_option == "Minimal ML Model from CSV":
372
- st.write("### Build Minimal ML Model from CSV")
373
- uploaded_file = st.file_uploader("Upload CSV", type="csv")
374
- if uploaded_file:
375
- df = pd.read_csv(uploaded_file)
376
- st.write("Columns:", df.columns.tolist())
377
- feature_cols = st.multiselect("Select feature columns", df.columns)
378
- target_col = st.selectbox("Select target column", df.columns)
379
- if st.button("Train Model"):
380
- X = df[feature_cols].values
381
- y = df[target_col].values
382
- X_tensor = torch.tensor(X, dtype=torch.float32)
383
- y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1)
384
- dataset = TensorDataset(X_tensor, y_tensor)
385
- loader = DataLoader(dataset, batch_size=32, shuffle=True)
386
- model = nn.Linear(X.shape[1], 1)
387
- criterion = nn.MSELoss()
388
- optimizer = optim.Adam(model.parameters(), lr=0.01)
389
- for epoch in range(10):
390
- for batch_X, batch_y in loader:
391
- optimizer.zero_grad()
392
- outputs = model(batch_X)
393
- loss = criterion(outputs, batch_y)
394
- loss.backward()
395
- optimizer.step()
396
- torch.save(model.state_dict(), "model.pth")
397
- app_code = f"""
398
- import streamlit as st
399
- import torch
400
- import torch.nn as nn
401
-
402
- model = nn.Linear({len(feature_cols)}, 1)
403
- model.load_state_dict(torch.load("model.pth"))
404
- model.eval()
405
-
406
- st.title("ML Model Demo")
407
- inputs = []
408
- for col in {feature_cols}:
409
- inputs.append(st.number_input(col))
410
- if st.button("Predict"):
411
- input_tensor = torch.tensor([inputs], dtype=torch.float32)
412
- prediction = model(input_tensor).item()
413
- st.write(f"Predicted {target_col}: {{prediction}}")
414
- """
415
- with open("app.py", "w") as f:
416
- f.write(app_code)
417
- reqs = "streamlit\ntorch\npandas\n"
418
- with open("requirements.txt", "w") as f:
419
- f.write(reqs)
420
- readme = """
421
- # ML Model Demo
422
-
423
- ## How to run
424
- 1. Install requirements: `pip install -r requirements.txt`
425
- 2. Run the app: `streamlit run app.py`
426
- 3. Input feature values and click "Predict".
427
- """
428
- with open("README.md", "w") as f:
429
- f.write(readme)
430
- st.markdown(get_download_link("model.pth", "application/octet-stream"), unsafe_allow_html=True)
431
- st.markdown(get_download_link("app.py", "text/plain"), unsafe_allow_html=True)
432
- st.markdown(get_download_link("requirements.txt", "text/plain"), unsafe_allow_html=True)
433
- st.markdown(get_download_link("README.md", "text/markdown"), unsafe_allow_html=True)
434
-
435
- elif model_option == "SFT Fine-Tuning":
436
- st.write("### SFT Fine-Tuning with Small Models")
437
-
438
- # Model Configuration
439
- with st.expander("Model Configuration", expanded=True):
440
- base_model = st.selectbox(
441
- "Select Base Model",
442
- ["distilgpt2", "gpt2", "EleutherAI/pythia-70m"],
443
- help="Choose a small model for fine-tuning"
444
- )
445
- model_name = st.text_input("Model Name", f"sft-model-{int(time.time())}")
446
- domain = st.text_input("Target Domain", "general")
447
-
448
- # Initialize ModelBuilder
449
- if 'builder' not in st.session_state:
450
- st.session_state['builder'] = ModelBuilder()
451
-
452
- # Load Sample Model
453
- if st.button("Load Sample Model"):
454
- st.session_state['builder'].load_base_model(base_model)
455
- st.session_state['model_loaded'] = True
456
- st.rerun()
457
-
458
- # Generate and Export Sample CSV
459
- if st.button("Generate Sample CSV"):
460
  sample_data = [
461
  {"prompt": "What is AI?", "response": "AI is artificial intelligence, simulating human intelligence in machines."},
462
  {"prompt": "Explain machine learning", "response": "Machine learning is a subset of AI where models learn from data."},
463
  {"prompt": "What is a neural network?", "response": "A neural network is a model inspired by the human brain."},
464
  ]
465
- with open("sft_data.csv", "w", newline="") as f:
 
466
  writer = csv.DictWriter(f, fieldnames=["prompt", "response"])
467
  writer.writeheader()
468
  writer.writerows(sample_data)
469
- st.markdown(get_download_link("sft_data.csv", "text/csv"), unsafe_allow_html=True)
470
- st.success("Sample CSV generated as 'sft_data.csv'!")
471
 
472
  # Upload CSV and Fine-Tune
473
- uploaded_csv = st.file_uploader("Upload CSV for SFT (or use generated sample)", type="csv")
474
- if st.button("Fine-Tune Model") and (uploaded_csv or os.path.exists("sft_data.csv")):
475
- if not hasattr(st.session_state['builder'], 'model') or st.session_state['builder'].model is None:
476
- st.session_state['builder'].load_base_model(base_model)
477
-
478
- csv_path = "sft_data.csv"
479
- if uploaded_csv:
480
- with open(csv_path, "wb") as f:
481
- f.write(uploaded_csv.read())
482
-
483
- with st.status("Fine-tuning model...", expanded=True) as status:
484
  st.session_state['builder'].fine_tune_sft(csv_path)
485
- st.session_state['builder'].save_model(st.session_state['builder'].config.model_path)
486
- status.update(label="Model fine-tuning completed!", state="complete")
487
-
488
- # Generate deployment files
489
- config = ModelConfig(name=model_name, base_model=base_model, size="small", domain=domain)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  app_code = f"""
491
  import streamlit as st
492
  from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -517,27 +271,7 @@ if st.button("Generate"):
517
  with open("sft_README.md", "w") as f:
518
  f.write(readme)
519
 
520
- st.markdown(get_download_link("sft_app.py", "text/plain"), unsafe_allow_html=True)
521
- st.markdown(get_download_link("sft_requirements.txt", "text/plain"), unsafe_allow_html=True)
522
- st.markdown(get_download_link("sft_README.md", "text/markdown"), unsafe_allow_html=True)
523
- st.write(f"Model saved at: {config.model_path}")
524
- st.rerun()
525
-
526
- # Test and Evaluate Model
527
- if 'model_loaded' in st.session_state and st.session_state['builder'].model is not None:
528
- st.write("### Test and Evaluate Fine-Tuned Model")
529
- if st.session_state['builder'].sft_data:
530
- st.write("Testing with SFT data:")
531
- for item in st.session_state['builder'].sft_data[:3]: # Show up to 3 examples
532
- prompt = item["prompt"]
533
- expected = item["response"]
534
- generated = st.session_state['builder'].evaluate(prompt)
535
- st.write(f"**Prompt**: {prompt}")
536
- st.write(f"**Expected**: {expected}")
537
- st.write(f"**Generated**: {generated}")
538
- st.write("---")
539
-
540
- test_prompt = st.text_area("Enter a custom prompt to test", "What is AI?")
541
- if st.button("Test Model"):
542
- result = st.session_state['builder'].evaluate(test_prompt)
543
- st.write(f"**Generated Response**: {result}")
 
1
  #!/usr/bin/env python3
2
  import os
3
+ import shutil
4
  import streamlit as st
 
 
5
  import pandas as pd
6
  import torch
 
 
 
 
 
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from torch.utils.data import Dataset, DataLoader
9
  import csv
10
+ import time
11
  from dataclasses import dataclass
12
  from typing import Optional
13
 
14
  # Page Configuration
15
  st.set_page_config(
16
+ page_title="SFT Model Builder 🚀",
17
+ page_icon="🤖",
18
  layout="wide",
19
+ initial_sidebar_state="expanded",
20
  )
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # Meta class for model configuration
23
  class ModelMeta(type):
24
  def __new__(cls, name, bases, attrs):
25
  attrs['registry'] = {}
26
  return super().__new__(cls, name, bases, attrs)
27
 
28
+ # Model Configuration Class
29
  @dataclass
30
  class ModelConfig(metaclass=ModelMeta):
31
  name: str
 
64
  return {
65
  "input_ids": encoding["input_ids"].squeeze(),
66
  "attention_mask": encoding["attention_mask"].squeeze(),
67
+ "labels": encoding["input_ids"].squeeze()
68
  }
69
 
70
+ # Model Builder Class
71
  class ModelBuilder:
72
  def __init__(self):
73
  self.config = None
 
75
  self.tokenizer = None
76
  self.sft_data = None
77
 
78
+ def load_model(self, model_path: str):
79
+ """Load a model from a path"""
80
+ with st.spinner("Loading model..."):
81
+ self.model = AutoModelForCausalLM.from_pretrained(model_path)
82
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
83
  if self.tokenizer.pad_token is None:
84
  self.tokenizer.pad_token = self.tokenizer.eos_token
85
+ st.success("Model loaded!")
86
  return self
87
 
88
  def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
89
  """Perform Supervised Fine-Tuning with CSV data"""
 
90
  self.sft_data = []
91
  with open(csv_path, "r") as f:
92
  reader = csv.DictReader(f)
93
  for row in reader:
94
  self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
95
 
 
96
  dataset = SFTDataset(self.sft_data, self.tokenizer)
97
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
98
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
99
 
 
 
 
 
100
  self.model.train()
101
  for epoch in range(epochs):
102
+ with st.spinner(f"Training epoch {epoch + 1}/{epochs}... ⚙️"):
103
  total_loss = 0
104
  for batch in dataloader:
105
  optimizer.zero_grad()
106
  input_ids = batch["input_ids"].to(self.model.device)
107
  attention_mask = batch["attention_mask"].to(self.model.device)
108
  labels = batch["labels"].to(self.model.device)
109
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
 
 
 
 
 
110
  loss = outputs.loss
111
  loss.backward()
112
  optimizer.step()
113
  total_loss += loss.item()
114
  st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
115
+ st.success("SFT Fine-tuning completed! 🎉")
116
  return self
117
 
118
  def save_model(self, path: str):
119
  """Save the fine-tuned model"""
120
+ with st.spinner("Saving model... 💾"):
121
+ os.makedirs(os.path.dirname(path), exist_ok=True)
122
  self.model.save_pretrained(path)
123
  self.tokenizer.save_pretrained(path)
124
+ st.success(f"Model saved at {path}!")
125
 
126
  def evaluate(self, prompt: str):
127
  """Evaluate the model with a prompt"""
 
132
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
133
 
134
  # Utility Functions
135
+ def get_download_link(file_path, mime_type="text/plain", label="Download"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  """Generate a download link for a file."""
137
  with open(file_path, 'rb') as f:
138
  data = f.read()
139
  b64 = base64.b64encode(data).decode()
140
+ return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} 📥</a>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
+ def get_model_files():
143
+ """List all saved model directories."""
144
+ return [d for d in glob.glob("models/*") if os.path.isdir(d)]
 
 
 
 
 
 
145
 
146
  # Main App
147
+ st.title("SFT Model Builder 🤖🚀")
148
+
149
+ # Sidebar for Model Management
150
+ st.sidebar.header("Model Management 🗂️")
151
+ model_dirs = get_model_files()
152
+ selected_model = st.sidebar.selectbox("Select Saved Model", ["None"] + model_dirs)
153
+
154
+ if selected_model != "None" and st.sidebar.button("Load Model 📂"):
155
+ if 'builder' not in st.session_state:
156
+ st.session_state['builder'] = ModelBuilder()
157
+ st.session_state['builder'].load_model(selected_model)
158
+ st.session_state['model_loaded'] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  st.rerun()
160
 
161
+ # Main UI with Tabs
162
+ tab1, tab2, tab3 = st.tabs(["Build New Model 🌱", "Fine-Tune Model 🔧", "Test Model 🧪"])
163
+
164
+ with tab1:
165
+ st.header("Build New Model 🌱")
166
+ base_model = st.selectbox(
167
+ "Select Base Model",
168
+ ["distilgpt2", "gpt2", "EleutherAI/pythia-70m"],
169
+ help="Choose a small model to start with"
170
+ )
171
+ model_name = st.text_input("Model Name", f"new-model-{int(time.time())}")
172
+ domain = st.text_input("Target Domain", "general")
173
+
174
+ if st.button("Download Model ⬇️"):
175
+ config = ModelConfig(name=model_name, base_model=base_model, size="small", domain=domain)
176
+ builder = ModelBuilder()
177
+ builder.load_model(base_model)
178
+ builder.save_model(config.model_path)
179
+ st.session_state['builder'] = builder
180
+ st.session_state['model_loaded'] = True
181
+ st.success(f"Model downloaded and saved to {config.model_path}! 🎉")
182
+ st.rerun()
183
+
184
+ with tab2:
185
+ st.header("Fine-Tune Model 🔧")
186
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
187
+ st.warning("Please download or load a model first! ⚠️")
188
+ else:
189
+ # Generate Sample CSV
190
+ if st.button("Generate Sample CSV 📝"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  sample_data = [
192
  {"prompt": "What is AI?", "response": "AI is artificial intelligence, simulating human intelligence in machines."},
193
  {"prompt": "Explain machine learning", "response": "Machine learning is a subset of AI where models learn from data."},
194
  {"prompt": "What is a neural network?", "response": "A neural network is a model inspired by the human brain."},
195
  ]
196
+ csv_path = f"sft_data_{int(time.time())}.csv"
197
+ with open(csv_path, "w", newline="") as f:
198
  writer = csv.DictWriter(f, fieldnames=["prompt", "response"])
199
  writer.writeheader()
200
  writer.writerows(sample_data)
201
+ st.markdown(get_download_link(csv_path, "text/csv", "Download Sample CSV"), unsafe_allow_html=True)
202
+ st.success(f"Sample CSV generated as {csv_path}!")
203
 
204
  # Upload CSV and Fine-Tune
205
+ uploaded_csv = st.file_uploader("Upload CSV for SFT", type="csv")
206
+ if uploaded_csv and st.button("Fine-Tune with Uploaded CSV 🔄"):
207
+ csv_path = f"uploaded_sft_data_{int(time.time())}.csv"
208
+ with open(csv_path, "wb") as f:
209
+ f.write(uploaded_csv.read())
210
+ new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
211
+ new_config = ModelConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small", domain=st.session_state['builder'].config.domain)
212
+ st.session_state['builder'].config = new_config
213
+ with st.status("Fine-tuning model... ⏳", expanded=True) as status:
 
 
214
  st.session_state['builder'].fine_tune_sft(csv_path)
215
+ st.session_state['builder'].save_model(new_config.model_path)
216
+ status.update(label="Fine-tuning completed! 🎉", state="complete")
217
+ st.markdown(get_download_link(f"{new_config.model_path}/pytorch_model.bin", "application/octet-stream", "Download Fine-Tuned Model"), unsafe_allow_html=True)
218
+ st.rerun()
219
+
220
+ with tab3:
221
+ st.header("Test Model 🧪")
222
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
223
+ st.warning("Please download or load a model first! ⚠️")
224
+ else:
225
+ if st.session_state['builder'].sft_data:
226
+ st.write("Testing with SFT Data:")
227
+ for item in st.session_state['builder'].sft_data[:3]:
228
+ prompt = item["prompt"]
229
+ expected = item["response"]
230
+ generated = st.session_state['builder'].evaluate(prompt)
231
+ st.write(f"**Prompt**: {prompt}")
232
+ st.write(f"**Expected**: {expected}")
233
+ st.write(f"**Generated**: {generated}")
234
+ st.write("---")
235
+
236
+ test_prompt = st.text_area("Enter Test Prompt", "What is AI?")
237
+ if st.button("Run Test ▶️"):
238
+ result = st.session_state['builder'].evaluate(test_prompt)
239
+ st.write(f"**Generated Response**: {result}")
240
+
241
+ # Export Model Files
242
+ if st.button("Export Model Files 📦"):
243
+ config = st.session_state['builder'].config
244
  app_code = f"""
245
  import streamlit as st
246
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
271
  with open("sft_README.md", "w") as f:
272
  f.write(readme)
273
 
274
+ st.markdown(get_download_link("sft_app.py", "text/plain", "Download App"), unsafe_allow_html=True)
275
+ st.markdown(get_download_link("sft_requirements.txt", "text/plain", "Download Requirements"), unsafe_allow_html=True)
276
+ st.markdown(get_download_link("sft_README.md", "text/markdown", "Download README"), unsafe_allow_html=True)
277
+ st.success("Model files exported! ")