Update app.py
Browse files
app.py
CHANGED
@@ -1,88 +1,31 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
import os
|
3 |
-
import
|
4 |
import streamlit as st
|
5 |
-
import streamlit.components.v1 as components
|
6 |
-
from urllib.parse import quote
|
7 |
import pandas as pd
|
8 |
import torch
|
9 |
-
import torch.nn as nn
|
10 |
-
import torch.optim as optim
|
11 |
-
from torch.utils.data import DataLoader, TensorDataset
|
12 |
-
import base64
|
13 |
-
import glob
|
14 |
-
import time
|
15 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
16 |
from torch.utils.data import Dataset, DataLoader
|
17 |
import csv
|
|
|
18 |
from dataclasses import dataclass
|
19 |
from typing import Optional
|
20 |
|
21 |
# Page Configuration
|
22 |
st.set_page_config(
|
23 |
-
page_title="
|
24 |
-
page_icon="
|
25 |
layout="wide",
|
26 |
-
initial_sidebar_state="
|
27 |
)
|
28 |
|
29 |
-
# Predefined Knowledge Trees
|
30 |
-
trees = {
|
31 |
-
"ML Engineering": """
|
32 |
-
0. ML Engineering 🌐
|
33 |
-
1. Data Preparation
|
34 |
-
- Load Data 📊
|
35 |
-
- Preprocess Data 🛠️
|
36 |
-
2. Model Building
|
37 |
-
- Train Model 🤖
|
38 |
-
- Evaluate Model 📈
|
39 |
-
3. Deployment
|
40 |
-
- Deploy Model 🚀
|
41 |
-
""",
|
42 |
-
"Health": """
|
43 |
-
0. Health and Wellness 🌿
|
44 |
-
1. Physical Health
|
45 |
-
- Exercise 🏋️
|
46 |
-
- Nutrition 🍎
|
47 |
-
2. Mental Health
|
48 |
-
- Meditation 🧘
|
49 |
-
- Therapy 🛋️
|
50 |
-
""",
|
51 |
-
}
|
52 |
-
|
53 |
-
# Project Seeds
|
54 |
-
project_seeds = {
|
55 |
-
"Code Project": """
|
56 |
-
0. Code Project 📂
|
57 |
-
1. app.py 🐍
|
58 |
-
2. requirements.txt 📦
|
59 |
-
3. README.md 📄
|
60 |
-
""",
|
61 |
-
"Papers Project": """
|
62 |
-
0. Papers Project 📚
|
63 |
-
1. markdown 📝
|
64 |
-
2. mermaid 🖼️
|
65 |
-
3. huggingface.co 🤗
|
66 |
-
""",
|
67 |
-
"AI Project": """
|
68 |
-
0. AI Project 🤖
|
69 |
-
1. Streamlit Torch Transformers
|
70 |
-
- Streamlit 🌐
|
71 |
-
- Torch 🔥
|
72 |
-
- Transformers 🤖
|
73 |
-
2. SFT Fine-Tuning
|
74 |
-
- SFT 🤓
|
75 |
-
- Small Models 📉
|
76 |
-
""",
|
77 |
-
}
|
78 |
-
|
79 |
# Meta class for model configuration
|
80 |
class ModelMeta(type):
|
81 |
def __new__(cls, name, bases, attrs):
|
82 |
attrs['registry'] = {}
|
83 |
return super().__new__(cls, name, bases, attrs)
|
84 |
|
85 |
-
#
|
86 |
@dataclass
|
87 |
class ModelConfig(metaclass=ModelMeta):
|
88 |
name: str
|
@@ -121,10 +64,10 @@ class SFTDataset(Dataset):
|
|
121 |
return {
|
122 |
"input_ids": encoding["input_ids"].squeeze(),
|
123 |
"attention_mask": encoding["attention_mask"].squeeze(),
|
124 |
-
"labels": encoding["input_ids"].squeeze()
|
125 |
}
|
126 |
|
127 |
-
# Model Builder Class
|
128 |
class ModelBuilder:
|
129 |
def __init__(self):
|
130 |
self.config = None
|
@@ -132,62 +75,53 @@ class ModelBuilder:
|
|
132 |
self.tokenizer = None
|
133 |
self.sft_data = None
|
134 |
|
135 |
-
def
|
136 |
-
"""Load
|
137 |
-
with st.spinner("Loading
|
138 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
139 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
140 |
if self.tokenizer.pad_token is None:
|
141 |
self.tokenizer.pad_token = self.tokenizer.eos_token
|
142 |
-
st.success("
|
143 |
return self
|
144 |
|
145 |
def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
|
146 |
"""Perform Supervised Fine-Tuning with CSV data"""
|
147 |
-
# Load CSV data
|
148 |
self.sft_data = []
|
149 |
with open(csv_path, "r") as f:
|
150 |
reader = csv.DictReader(f)
|
151 |
for row in reader:
|
152 |
self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
|
153 |
|
154 |
-
# Prepare dataset and dataloader
|
155 |
dataset = SFTDataset(self.sft_data, self.tokenizer)
|
156 |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
|
157 |
|
158 |
-
# Set up optimizer
|
159 |
-
optimizer = optim.AdamW(self.model.parameters(), lr=2e-5)
|
160 |
-
|
161 |
-
# Training loop
|
162 |
self.model.train()
|
163 |
for epoch in range(epochs):
|
164 |
-
with st.spinner(f"Training epoch {epoch + 1}/{epochs}..."):
|
165 |
total_loss = 0
|
166 |
for batch in dataloader:
|
167 |
optimizer.zero_grad()
|
168 |
input_ids = batch["input_ids"].to(self.model.device)
|
169 |
attention_mask = batch["attention_mask"].to(self.model.device)
|
170 |
labels = batch["labels"].to(self.model.device)
|
171 |
-
|
172 |
-
outputs = self.model(
|
173 |
-
input_ids=input_ids,
|
174 |
-
attention_mask=attention_mask,
|
175 |
-
labels=labels
|
176 |
-
)
|
177 |
loss = outputs.loss
|
178 |
loss.backward()
|
179 |
optimizer.step()
|
180 |
total_loss += loss.item()
|
181 |
st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
|
182 |
-
st.success("SFT Fine-tuning completed!")
|
183 |
return self
|
184 |
|
185 |
def save_model(self, path: str):
|
186 |
"""Save the fine-tuned model"""
|
187 |
-
with st.spinner("Saving model..."):
|
|
|
188 |
self.model.save_pretrained(path)
|
189 |
self.tokenizer.save_pretrained(path)
|
190 |
-
st.success("Model saved!")
|
191 |
|
192 |
def evaluate(self, prompt: str):
|
193 |
"""Evaluate the model with a prompt"""
|
@@ -198,295 +132,115 @@ class ModelBuilder:
|
|
198 |
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
199 |
|
200 |
# Utility Functions
|
201 |
-
def
|
202 |
-
"""Remove invalid characters for Mermaid labels."""
|
203 |
-
return re.sub(r'[^\w\s-]', '', label).replace(' ', '_')
|
204 |
-
|
205 |
-
def sanitize_filename(label):
|
206 |
-
"""Make a valid filename from a label."""
|
207 |
-
return re.sub(r'[^\w\s-]', '', label).replace(' ', '_')
|
208 |
-
|
209 |
-
def parse_outline_to_mermaid(outline_text, search_agent):
|
210 |
-
"""Convert tree outline to Mermaid syntax with clickable nodes."""
|
211 |
-
lines = outline_text.strip().split('\n')
|
212 |
-
nodes, edges, clicks, stack = [], [], [], []
|
213 |
-
for line in lines:
|
214 |
-
indent = len(line) - len(line.lstrip())
|
215 |
-
level = indent // 4
|
216 |
-
label = re.sub(r'^[#*\->\d\.\s]+', '', line.strip())
|
217 |
-
if label:
|
218 |
-
node_id = f"N{len(nodes)}"
|
219 |
-
sanitized_label = sanitize_label(label)
|
220 |
-
nodes.append(f'{node_id}["{label}"]')
|
221 |
-
search_url = search_urls[search_agent](label)
|
222 |
-
clicks.append(f'click {node_id} "{search_url}" _blank')
|
223 |
-
if stack:
|
224 |
-
parent_level = stack[-1][0]
|
225 |
-
if level > parent_level:
|
226 |
-
edges.append(f"{stack[-1][1]} --> {node_id}")
|
227 |
-
stack.append((level, node_id))
|
228 |
-
else:
|
229 |
-
while stack and stack[-1][0] >= level:
|
230 |
-
stack.pop()
|
231 |
-
if stack:
|
232 |
-
edges.append(f"{stack[-1][1]} --> {node_id}")
|
233 |
-
stack.append((level, node_id))
|
234 |
-
else:
|
235 |
-
stack.append((level, node_id))
|
236 |
-
return "%%{init: {'themeVariables': {'fontSize': '18px'}}}%%\nflowchart LR\n" + "\n".join(nodes + edges + clicks)
|
237 |
-
|
238 |
-
def generate_mermaid_html(mermaid_code):
|
239 |
-
"""Generate HTML to display Mermaid diagram."""
|
240 |
-
return f"""
|
241 |
-
<html><head><script src="https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"></script>
|
242 |
-
<style>.centered-mermaid{{display:flex;justify-content:center;margin:20px auto;}}</style></head>
|
243 |
-
<body><div class="mermaid centered-mermaid">{mermaid_code}</div>
|
244 |
-
<script>mermaid.initialize({{startOnLoad:true}});</script></body></html>
|
245 |
-
"""
|
246 |
-
|
247 |
-
def grow_tree(base_tree, new_node_name, parent_node):
|
248 |
-
"""Add a new node to the tree under a specified parent."""
|
249 |
-
lines = base_tree.strip().split('\n')
|
250 |
-
new_lines = []
|
251 |
-
added = False
|
252 |
-
for line in lines:
|
253 |
-
new_lines.append(line)
|
254 |
-
if parent_node in line and not added:
|
255 |
-
indent = len(line) - len(line.lstrip())
|
256 |
-
new_lines.append(f"{' ' * (indent + 4)}- {new_node_name} 🌱")
|
257 |
-
added = True
|
258 |
-
return "\n".join(new_lines)
|
259 |
-
|
260 |
-
def get_download_link(file_path, mime_type="text/plain"):
|
261 |
"""Generate a download link for a file."""
|
262 |
with open(file_path, 'rb') as f:
|
263 |
data = f.read()
|
264 |
b64 = base64.b64encode(data).decode()
|
265 |
-
return f'<a href="data:{mime_type};base64,{b64}" download="{file_path}">
|
266 |
-
|
267 |
-
def save_tree_to_file(tree_text, parent_node, new_node):
|
268 |
-
"""Save tree to a markdown file with name based on nodes."""
|
269 |
-
root_node = tree_text.strip().split('\n')[0].split('.')[1].strip() if tree_text.strip() else "Knowledge_Tree"
|
270 |
-
filename = f"{sanitize_filename(root_node)}_{sanitize_filename(parent_node)}_{sanitize_filename(new_node)}_{int(time.time())}.md"
|
271 |
-
|
272 |
-
mermaid_code = parse_outline_to_mermaid(tree_text, "🔮Google") # Default search engine for saved trees
|
273 |
-
export_md = f"# Knowledge Tree: {root_node}\n\n## Outline\n{tree_text}\n\n## Mermaid Diagram\n```mermaid\n{mermaid_code}\n```"
|
274 |
-
|
275 |
-
with open(filename, "w") as f:
|
276 |
-
f.write(export_md)
|
277 |
-
return filename
|
278 |
-
|
279 |
-
def load_trees_from_files():
|
280 |
-
"""Load all saved tree markdown files."""
|
281 |
-
tree_files = glob.glob("*.md")
|
282 |
-
trees_dict = {}
|
283 |
-
|
284 |
-
for file in tree_files:
|
285 |
-
if file != "README.md" and file != "knowledge_tree.md": # Skip project README and temp export
|
286 |
-
try:
|
287 |
-
with open(file, 'r') as f:
|
288 |
-
content = f.read()
|
289 |
-
# Extract the tree name from the first line
|
290 |
-
match = re.search(r'# Knowledge Tree: (.*)', content)
|
291 |
-
if match:
|
292 |
-
tree_name = match.group(1)
|
293 |
-
else:
|
294 |
-
tree_name = os.path.splitext(file)[0]
|
295 |
-
|
296 |
-
# Extract the outline section
|
297 |
-
outline_match = re.search(r'## Outline\n(.*?)(?=\n## |$)', content, re.DOTALL)
|
298 |
-
if outline_match:
|
299 |
-
tree_outline = outline_match.group(1).strip()
|
300 |
-
trees_dict[f"{tree_name} ({file})"] = tree_outline
|
301 |
-
except Exception as e:
|
302 |
-
print(f"Error loading {file}: {e}")
|
303 |
-
|
304 |
-
return trees_dict
|
305 |
|
306 |
-
|
307 |
-
|
308 |
-
"
|
309 |
-
"🔮Google": lambda k: f"https://www.google.com/search?q={quote(k)}",
|
310 |
-
"📺Youtube": lambda k: f"https://www.youtube.com/results?search_query={quote(k)}",
|
311 |
-
"🔭Bing": lambda k: f"https://www.bing.com/search?q={quote(k)}",
|
312 |
-
"💡Truth": lambda k: f"https://truthsocial.com/search?q={quote(k)}",
|
313 |
-
"📱X": lambda k: f"https://twitter.com/search?q={quote(k)}",
|
314 |
-
}
|
315 |
|
316 |
# Main App
|
317 |
-
st.title("
|
318 |
-
|
319 |
-
# Sidebar
|
320 |
-
st.sidebar.
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
if selected_saved_tree != "None" and selected_saved_tree in saved_trees:
|
330 |
-
st.session_state['current_tree'] = saved_trees[selected_saved_tree]
|
331 |
-
else:
|
332 |
-
st.session_state['current_tree'] = trees.get("ML Engineering", project_seeds[project_type])
|
333 |
-
elif selected_saved_tree != "None" and selected_saved_tree in saved_trees:
|
334 |
-
st.session_state['current_tree'] = saved_trees[selected_saved_tree]
|
335 |
-
|
336 |
-
# Select Search Agent for Node Links
|
337 |
-
search_agent = st.selectbox("Select Search Agent for Node Links", list(search_urls.keys()), index=5) # Default to X
|
338 |
-
|
339 |
-
# Tree Growth
|
340 |
-
new_node = st.text_input("Add New Node")
|
341 |
-
parent_node = st.text_input("Parent Node")
|
342 |
-
if st.button("Grow Tree 🌱") and new_node and parent_node:
|
343 |
-
st.session_state['current_tree'] = grow_tree(st.session_state['current_tree'], new_node, parent_node)
|
344 |
-
|
345 |
-
# Save to a new file with the node names
|
346 |
-
saved_file = save_tree_to_file(st.session_state['current_tree'], parent_node, new_node)
|
347 |
-
st.success(f"Added '{new_node}' under '{parent_node}' and saved to {saved_file}!")
|
348 |
-
|
349 |
-
# Also update the temporary current_tree.md for compatibility
|
350 |
-
with open("current_tree.md", "w") as f:
|
351 |
-
f.write(st.session_state['current_tree'])
|
352 |
st.rerun()
|
353 |
|
354 |
-
#
|
355 |
-
st.
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
st.
|
365 |
-
|
366 |
-
|
367 |
-
if
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
st.
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
dataset = TensorDataset(X_tensor, y_tensor)
|
385 |
-
loader = DataLoader(dataset, batch_size=32, shuffle=True)
|
386 |
-
model = nn.Linear(X.shape[1], 1)
|
387 |
-
criterion = nn.MSELoss()
|
388 |
-
optimizer = optim.Adam(model.parameters(), lr=0.01)
|
389 |
-
for epoch in range(10):
|
390 |
-
for batch_X, batch_y in loader:
|
391 |
-
optimizer.zero_grad()
|
392 |
-
outputs = model(batch_X)
|
393 |
-
loss = criterion(outputs, batch_y)
|
394 |
-
loss.backward()
|
395 |
-
optimizer.step()
|
396 |
-
torch.save(model.state_dict(), "model.pth")
|
397 |
-
app_code = f"""
|
398 |
-
import streamlit as st
|
399 |
-
import torch
|
400 |
-
import torch.nn as nn
|
401 |
-
|
402 |
-
model = nn.Linear({len(feature_cols)}, 1)
|
403 |
-
model.load_state_dict(torch.load("model.pth"))
|
404 |
-
model.eval()
|
405 |
-
|
406 |
-
st.title("ML Model Demo")
|
407 |
-
inputs = []
|
408 |
-
for col in {feature_cols}:
|
409 |
-
inputs.append(st.number_input(col))
|
410 |
-
if st.button("Predict"):
|
411 |
-
input_tensor = torch.tensor([inputs], dtype=torch.float32)
|
412 |
-
prediction = model(input_tensor).item()
|
413 |
-
st.write(f"Predicted {target_col}: {{prediction}}")
|
414 |
-
"""
|
415 |
-
with open("app.py", "w") as f:
|
416 |
-
f.write(app_code)
|
417 |
-
reqs = "streamlit\ntorch\npandas\n"
|
418 |
-
with open("requirements.txt", "w") as f:
|
419 |
-
f.write(reqs)
|
420 |
-
readme = """
|
421 |
-
# ML Model Demo
|
422 |
-
|
423 |
-
## How to run
|
424 |
-
1. Install requirements: `pip install -r requirements.txt`
|
425 |
-
2. Run the app: `streamlit run app.py`
|
426 |
-
3. Input feature values and click "Predict".
|
427 |
-
"""
|
428 |
-
with open("README.md", "w") as f:
|
429 |
-
f.write(readme)
|
430 |
-
st.markdown(get_download_link("model.pth", "application/octet-stream"), unsafe_allow_html=True)
|
431 |
-
st.markdown(get_download_link("app.py", "text/plain"), unsafe_allow_html=True)
|
432 |
-
st.markdown(get_download_link("requirements.txt", "text/plain"), unsafe_allow_html=True)
|
433 |
-
st.markdown(get_download_link("README.md", "text/markdown"), unsafe_allow_html=True)
|
434 |
-
|
435 |
-
elif model_option == "SFT Fine-Tuning":
|
436 |
-
st.write("### SFT Fine-Tuning with Small Models")
|
437 |
-
|
438 |
-
# Model Configuration
|
439 |
-
with st.expander("Model Configuration", expanded=True):
|
440 |
-
base_model = st.selectbox(
|
441 |
-
"Select Base Model",
|
442 |
-
["distilgpt2", "gpt2", "EleutherAI/pythia-70m"],
|
443 |
-
help="Choose a small model for fine-tuning"
|
444 |
-
)
|
445 |
-
model_name = st.text_input("Model Name", f"sft-model-{int(time.time())}")
|
446 |
-
domain = st.text_input("Target Domain", "general")
|
447 |
-
|
448 |
-
# Initialize ModelBuilder
|
449 |
-
if 'builder' not in st.session_state:
|
450 |
-
st.session_state['builder'] = ModelBuilder()
|
451 |
-
|
452 |
-
# Load Sample Model
|
453 |
-
if st.button("Load Sample Model"):
|
454 |
-
st.session_state['builder'].load_base_model(base_model)
|
455 |
-
st.session_state['model_loaded'] = True
|
456 |
-
st.rerun()
|
457 |
-
|
458 |
-
# Generate and Export Sample CSV
|
459 |
-
if st.button("Generate Sample CSV"):
|
460 |
sample_data = [
|
461 |
{"prompt": "What is AI?", "response": "AI is artificial intelligence, simulating human intelligence in machines."},
|
462 |
{"prompt": "Explain machine learning", "response": "Machine learning is a subset of AI where models learn from data."},
|
463 |
{"prompt": "What is a neural network?", "response": "A neural network is a model inspired by the human brain."},
|
464 |
]
|
465 |
-
|
|
|
466 |
writer = csv.DictWriter(f, fieldnames=["prompt", "response"])
|
467 |
writer.writeheader()
|
468 |
writer.writerows(sample_data)
|
469 |
-
st.markdown(get_download_link(
|
470 |
-
st.success("Sample CSV generated as
|
471 |
|
472 |
# Upload CSV and Fine-Tune
|
473 |
-
uploaded_csv = st.file_uploader("Upload CSV for SFT
|
474 |
-
if st.button("Fine-Tune
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
with st.status("Fine-tuning model...", expanded=True) as status:
|
484 |
st.session_state['builder'].fine_tune_sft(csv_path)
|
485 |
-
st.session_state['builder'].save_model(
|
486 |
-
status.update(label="
|
487 |
-
|
488 |
-
|
489 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
490 |
app_code = f"""
|
491 |
import streamlit as st
|
492 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
@@ -517,27 +271,7 @@ if st.button("Generate"):
|
|
517 |
with open("sft_README.md", "w") as f:
|
518 |
f.write(readme)
|
519 |
|
520 |
-
st.markdown(get_download_link("sft_app.py", "text/plain"), unsafe_allow_html=True)
|
521 |
-
st.markdown(get_download_link("sft_requirements.txt", "text/plain"), unsafe_allow_html=True)
|
522 |
-
st.markdown(get_download_link("sft_README.md", "text/markdown"), unsafe_allow_html=True)
|
523 |
-
st.
|
524 |
-
st.rerun()
|
525 |
-
|
526 |
-
# Test and Evaluate Model
|
527 |
-
if 'model_loaded' in st.session_state and st.session_state['builder'].model is not None:
|
528 |
-
st.write("### Test and Evaluate Fine-Tuned Model")
|
529 |
-
if st.session_state['builder'].sft_data:
|
530 |
-
st.write("Testing with SFT data:")
|
531 |
-
for item in st.session_state['builder'].sft_data[:3]: # Show up to 3 examples
|
532 |
-
prompt = item["prompt"]
|
533 |
-
expected = item["response"]
|
534 |
-
generated = st.session_state['builder'].evaluate(prompt)
|
535 |
-
st.write(f"**Prompt**: {prompt}")
|
536 |
-
st.write(f"**Expected**: {expected}")
|
537 |
-
st.write(f"**Generated**: {generated}")
|
538 |
-
st.write("---")
|
539 |
-
|
540 |
-
test_prompt = st.text_area("Enter a custom prompt to test", "What is AI?")
|
541 |
-
if st.button("Test Model"):
|
542 |
-
result = st.session_state['builder'].evaluate(test_prompt)
|
543 |
-
st.write(f"**Generated Response**: {result}")
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
import os
|
3 |
+
import shutil
|
4 |
import streamlit as st
|
|
|
|
|
5 |
import pandas as pd
|
6 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
from torch.utils.data import Dataset, DataLoader
|
9 |
import csv
|
10 |
+
import time
|
11 |
from dataclasses import dataclass
|
12 |
from typing import Optional
|
13 |
|
14 |
# Page Configuration
|
15 |
st.set_page_config(
|
16 |
+
page_title="SFT Model Builder 🚀",
|
17 |
+
page_icon="🤖",
|
18 |
layout="wide",
|
19 |
+
initial_sidebar_state="expanded",
|
20 |
)
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
# Meta class for model configuration
|
23 |
class ModelMeta(type):
|
24 |
def __new__(cls, name, bases, attrs):
|
25 |
attrs['registry'] = {}
|
26 |
return super().__new__(cls, name, bases, attrs)
|
27 |
|
28 |
+
# Model Configuration Class
|
29 |
@dataclass
|
30 |
class ModelConfig(metaclass=ModelMeta):
|
31 |
name: str
|
|
|
64 |
return {
|
65 |
"input_ids": encoding["input_ids"].squeeze(),
|
66 |
"attention_mask": encoding["attention_mask"].squeeze(),
|
67 |
+
"labels": encoding["input_ids"].squeeze()
|
68 |
}
|
69 |
|
70 |
+
# Model Builder Class
|
71 |
class ModelBuilder:
|
72 |
def __init__(self):
|
73 |
self.config = None
|
|
|
75 |
self.tokenizer = None
|
76 |
self.sft_data = None
|
77 |
|
78 |
+
def load_model(self, model_path: str):
|
79 |
+
"""Load a model from a path"""
|
80 |
+
with st.spinner("Loading model... ⏳"):
|
81 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_path)
|
82 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
83 |
if self.tokenizer.pad_token is None:
|
84 |
self.tokenizer.pad_token = self.tokenizer.eos_token
|
85 |
+
st.success("Model loaded! ✅")
|
86 |
return self
|
87 |
|
88 |
def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
|
89 |
"""Perform Supervised Fine-Tuning with CSV data"""
|
|
|
90 |
self.sft_data = []
|
91 |
with open(csv_path, "r") as f:
|
92 |
reader = csv.DictReader(f)
|
93 |
for row in reader:
|
94 |
self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
|
95 |
|
|
|
96 |
dataset = SFTDataset(self.sft_data, self.tokenizer)
|
97 |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
98 |
+
optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
|
99 |
|
|
|
|
|
|
|
|
|
100 |
self.model.train()
|
101 |
for epoch in range(epochs):
|
102 |
+
with st.spinner(f"Training epoch {epoch + 1}/{epochs}... ⚙️"):
|
103 |
total_loss = 0
|
104 |
for batch in dataloader:
|
105 |
optimizer.zero_grad()
|
106 |
input_ids = batch["input_ids"].to(self.model.device)
|
107 |
attention_mask = batch["attention_mask"].to(self.model.device)
|
108 |
labels = batch["labels"].to(self.model.device)
|
109 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
|
|
|
|
|
|
|
|
|
|
110 |
loss = outputs.loss
|
111 |
loss.backward()
|
112 |
optimizer.step()
|
113 |
total_loss += loss.item()
|
114 |
st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
|
115 |
+
st.success("SFT Fine-tuning completed! 🎉")
|
116 |
return self
|
117 |
|
118 |
def save_model(self, path: str):
|
119 |
"""Save the fine-tuned model"""
|
120 |
+
with st.spinner("Saving model... 💾"):
|
121 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
122 |
self.model.save_pretrained(path)
|
123 |
self.tokenizer.save_pretrained(path)
|
124 |
+
st.success(f"Model saved at {path}! ✅")
|
125 |
|
126 |
def evaluate(self, prompt: str):
|
127 |
"""Evaluate the model with a prompt"""
|
|
|
132 |
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
133 |
|
134 |
# Utility Functions
|
135 |
+
def get_download_link(file_path, mime_type="text/plain", label="Download"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
"""Generate a download link for a file."""
|
137 |
with open(file_path, 'rb') as f:
|
138 |
data = f.read()
|
139 |
b64 = base64.b64encode(data).decode()
|
140 |
+
return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} 📥</a>'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
+
def get_model_files():
|
143 |
+
"""List all saved model directories."""
|
144 |
+
return [d for d in glob.glob("models/*") if os.path.isdir(d)]
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
# Main App
|
147 |
+
st.title("SFT Model Builder 🤖🚀")
|
148 |
+
|
149 |
+
# Sidebar for Model Management
|
150 |
+
st.sidebar.header("Model Management 🗂️")
|
151 |
+
model_dirs = get_model_files()
|
152 |
+
selected_model = st.sidebar.selectbox("Select Saved Model", ["None"] + model_dirs)
|
153 |
+
|
154 |
+
if selected_model != "None" and st.sidebar.button("Load Model 📂"):
|
155 |
+
if 'builder' not in st.session_state:
|
156 |
+
st.session_state['builder'] = ModelBuilder()
|
157 |
+
st.session_state['builder'].load_model(selected_model)
|
158 |
+
st.session_state['model_loaded'] = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
st.rerun()
|
160 |
|
161 |
+
# Main UI with Tabs
|
162 |
+
tab1, tab2, tab3 = st.tabs(["Build New Model 🌱", "Fine-Tune Model 🔧", "Test Model 🧪"])
|
163 |
+
|
164 |
+
with tab1:
|
165 |
+
st.header("Build New Model 🌱")
|
166 |
+
base_model = st.selectbox(
|
167 |
+
"Select Base Model",
|
168 |
+
["distilgpt2", "gpt2", "EleutherAI/pythia-70m"],
|
169 |
+
help="Choose a small model to start with"
|
170 |
+
)
|
171 |
+
model_name = st.text_input("Model Name", f"new-model-{int(time.time())}")
|
172 |
+
domain = st.text_input("Target Domain", "general")
|
173 |
+
|
174 |
+
if st.button("Download Model ⬇️"):
|
175 |
+
config = ModelConfig(name=model_name, base_model=base_model, size="small", domain=domain)
|
176 |
+
builder = ModelBuilder()
|
177 |
+
builder.load_model(base_model)
|
178 |
+
builder.save_model(config.model_path)
|
179 |
+
st.session_state['builder'] = builder
|
180 |
+
st.session_state['model_loaded'] = True
|
181 |
+
st.success(f"Model downloaded and saved to {config.model_path}! 🎉")
|
182 |
+
st.rerun()
|
183 |
+
|
184 |
+
with tab2:
|
185 |
+
st.header("Fine-Tune Model 🔧")
|
186 |
+
if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
|
187 |
+
st.warning("Please download or load a model first! ⚠️")
|
188 |
+
else:
|
189 |
+
# Generate Sample CSV
|
190 |
+
if st.button("Generate Sample CSV 📝"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
sample_data = [
|
192 |
{"prompt": "What is AI?", "response": "AI is artificial intelligence, simulating human intelligence in machines."},
|
193 |
{"prompt": "Explain machine learning", "response": "Machine learning is a subset of AI where models learn from data."},
|
194 |
{"prompt": "What is a neural network?", "response": "A neural network is a model inspired by the human brain."},
|
195 |
]
|
196 |
+
csv_path = f"sft_data_{int(time.time())}.csv"
|
197 |
+
with open(csv_path, "w", newline="") as f:
|
198 |
writer = csv.DictWriter(f, fieldnames=["prompt", "response"])
|
199 |
writer.writeheader()
|
200 |
writer.writerows(sample_data)
|
201 |
+
st.markdown(get_download_link(csv_path, "text/csv", "Download Sample CSV"), unsafe_allow_html=True)
|
202 |
+
st.success(f"Sample CSV generated as {csv_path}! ✅")
|
203 |
|
204 |
# Upload CSV and Fine-Tune
|
205 |
+
uploaded_csv = st.file_uploader("Upload CSV for SFT", type="csv")
|
206 |
+
if uploaded_csv and st.button("Fine-Tune with Uploaded CSV 🔄"):
|
207 |
+
csv_path = f"uploaded_sft_data_{int(time.time())}.csv"
|
208 |
+
with open(csv_path, "wb") as f:
|
209 |
+
f.write(uploaded_csv.read())
|
210 |
+
new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
|
211 |
+
new_config = ModelConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small", domain=st.session_state['builder'].config.domain)
|
212 |
+
st.session_state['builder'].config = new_config
|
213 |
+
with st.status("Fine-tuning model... ⏳", expanded=True) as status:
|
|
|
|
|
214 |
st.session_state['builder'].fine_tune_sft(csv_path)
|
215 |
+
st.session_state['builder'].save_model(new_config.model_path)
|
216 |
+
status.update(label="Fine-tuning completed! 🎉", state="complete")
|
217 |
+
st.markdown(get_download_link(f"{new_config.model_path}/pytorch_model.bin", "application/octet-stream", "Download Fine-Tuned Model"), unsafe_allow_html=True)
|
218 |
+
st.rerun()
|
219 |
+
|
220 |
+
with tab3:
|
221 |
+
st.header("Test Model 🧪")
|
222 |
+
if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
|
223 |
+
st.warning("Please download or load a model first! ⚠️")
|
224 |
+
else:
|
225 |
+
if st.session_state['builder'].sft_data:
|
226 |
+
st.write("Testing with SFT Data:")
|
227 |
+
for item in st.session_state['builder'].sft_data[:3]:
|
228 |
+
prompt = item["prompt"]
|
229 |
+
expected = item["response"]
|
230 |
+
generated = st.session_state['builder'].evaluate(prompt)
|
231 |
+
st.write(f"**Prompt**: {prompt}")
|
232 |
+
st.write(f"**Expected**: {expected}")
|
233 |
+
st.write(f"**Generated**: {generated}")
|
234 |
+
st.write("---")
|
235 |
+
|
236 |
+
test_prompt = st.text_area("Enter Test Prompt", "What is AI?")
|
237 |
+
if st.button("Run Test ▶️"):
|
238 |
+
result = st.session_state['builder'].evaluate(test_prompt)
|
239 |
+
st.write(f"**Generated Response**: {result}")
|
240 |
+
|
241 |
+
# Export Model Files
|
242 |
+
if st.button("Export Model Files 📦"):
|
243 |
+
config = st.session_state['builder'].config
|
244 |
app_code = f"""
|
245 |
import streamlit as st
|
246 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
271 |
with open("sft_README.md", "w") as f:
|
272 |
f.write(readme)
|
273 |
|
274 |
+
st.markdown(get_download_link("sft_app.py", "text/plain", "Download App"), unsafe_allow_html=True)
|
275 |
+
st.markdown(get_download_link("sft_requirements.txt", "text/plain", "Download Requirements"), unsafe_allow_html=True)
|
276 |
+
st.markdown(get_download_link("sft_README.md", "text/markdown", "Download README"), unsafe_allow_html=True)
|
277 |
+
st.success("Model files exported! ✅")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|