Update app.py
Browse files
app.py
CHANGED
@@ -124,12 +124,13 @@ class SFTDataset(Dataset):
|
|
124 |
"labels": encoding["input_ids"].squeeze() # For causal LM, labels are the same as input_ids
|
125 |
}
|
126 |
|
127 |
-
# Model Builder Class with SFT
|
128 |
class ModelBuilder:
|
129 |
def __init__(self):
|
130 |
self.config = None
|
131 |
self.model = None
|
132 |
self.tokenizer = None
|
|
|
133 |
|
134 |
def load_base_model(self, model_name: str):
|
135 |
"""Load base model from Hugging Face"""
|
@@ -144,14 +145,14 @@ class ModelBuilder:
|
|
144 |
def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
|
145 |
"""Perform Supervised Fine-Tuning with CSV data"""
|
146 |
# Load CSV data
|
147 |
-
|
148 |
with open(csv_path, "r") as f:
|
149 |
reader = csv.DictReader(f)
|
150 |
for row in reader:
|
151 |
-
|
152 |
|
153 |
# Prepare dataset and dataloader
|
154 |
-
dataset = SFTDataset(
|
155 |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
156 |
|
157 |
# Set up optimizer
|
@@ -188,6 +189,14 @@ class ModelBuilder:
|
|
188 |
self.tokenizer.save_pretrained(path)
|
189 |
st.success("Model saved!")
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
# Utility Functions
|
192 |
def sanitize_label(label):
|
193 |
"""Remove invalid characters for Mermaid labels."""
|
@@ -340,6 +349,7 @@ if st.button("Grow Tree 🌱") and new_node and parent_node:
|
|
340 |
# Also update the temporary current_tree.md for compatibility
|
341 |
with open("current_tree.md", "w") as f:
|
342 |
f.write(st.session_state['current_tree'])
|
|
|
343 |
|
344 |
# Display Mermaid Diagram
|
345 |
st.markdown("### Knowledge Tree Visualization")
|
@@ -429,13 +439,23 @@ if st.button("Predict"):
|
|
429 |
with st.expander("Model Configuration", expanded=True):
|
430 |
base_model = st.selectbox(
|
431 |
"Select Base Model",
|
432 |
-
["distilgpt2", "gpt2", "EleutherAI/pythia-70m"],
|
433 |
help="Choose a small model for fine-tuning"
|
434 |
)
|
435 |
-
model_name = st.text_input("Model Name", "sft-model")
|
436 |
domain = st.text_input("Target Domain", "general")
|
437 |
|
438 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
if st.button("Generate Sample CSV"):
|
440 |
sample_data = [
|
441 |
{"prompt": "What is AI?", "response": "AI is artificial intelligence, simulating human intelligence in machines."},
|
@@ -449,30 +469,24 @@ if st.button("Predict"):
|
|
449 |
st.markdown(get_download_link("sft_data.csv", "text/csv"), unsafe_allow_html=True)
|
450 |
st.success("Sample CSV generated as 'sft_data.csv'!")
|
451 |
|
452 |
-
# Fine-Tune
|
453 |
uploaded_csv = st.file_uploader("Upload CSV for SFT (or use generated sample)", type="csv")
|
454 |
if st.button("Fine-Tune Model") and (uploaded_csv or os.path.exists("sft_data.csv")):
|
455 |
-
|
456 |
-
|
457 |
-
base_model=base_model,
|
458 |
-
size="small",
|
459 |
-
domain=domain
|
460 |
-
)
|
461 |
-
builder = ModelBuilder()
|
462 |
|
463 |
-
# Load CSV
|
464 |
csv_path = "sft_data.csv"
|
465 |
if uploaded_csv:
|
466 |
with open(csv_path, "wb") as f:
|
467 |
f.write(uploaded_csv.read())
|
468 |
|
469 |
with st.status("Fine-tuning model...", expanded=True) as status:
|
470 |
-
builder.
|
471 |
-
builder.
|
472 |
-
builder.save_model(config.model_path)
|
473 |
status.update(label="Model fine-tuning completed!", state="complete")
|
474 |
|
475 |
# Generate deployment files
|
|
|
476 |
app_code = f"""
|
477 |
import streamlit as st
|
478 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
@@ -507,6 +521,23 @@ if st.button("Generate"):
|
|
507 |
st.markdown(get_download_link("sft_requirements.txt", "text/plain"), unsafe_allow_html=True)
|
508 |
st.markdown(get_download_link("sft_README.md", "text/markdown"), unsafe_allow_html=True)
|
509 |
st.write(f"Model saved at: {config.model_path}")
|
510 |
-
|
511 |
-
|
512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
"labels": encoding["input_ids"].squeeze() # For causal LM, labels are the same as input_ids
|
125 |
}
|
126 |
|
127 |
+
# Model Builder Class with SFT and Evaluation
|
128 |
class ModelBuilder:
|
129 |
def __init__(self):
|
130 |
self.config = None
|
131 |
self.model = None
|
132 |
self.tokenizer = None
|
133 |
+
self.sft_data = None
|
134 |
|
135 |
def load_base_model(self, model_name: str):
|
136 |
"""Load base model from Hugging Face"""
|
|
|
145 |
def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
|
146 |
"""Perform Supervised Fine-Tuning with CSV data"""
|
147 |
# Load CSV data
|
148 |
+
self.sft_data = []
|
149 |
with open(csv_path, "r") as f:
|
150 |
reader = csv.DictReader(f)
|
151 |
for row in reader:
|
152 |
+
self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
|
153 |
|
154 |
# Prepare dataset and dataloader
|
155 |
+
dataset = SFTDataset(self.sft_data, self.tokenizer)
|
156 |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
157 |
|
158 |
# Set up optimizer
|
|
|
189 |
self.tokenizer.save_pretrained(path)
|
190 |
st.success("Model saved!")
|
191 |
|
192 |
+
def evaluate(self, prompt: str):
|
193 |
+
"""Evaluate the model with a prompt"""
|
194 |
+
self.model.eval()
|
195 |
+
with torch.no_grad():
|
196 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
197 |
+
outputs = self.model.generate(**inputs, max_new_tokens=50)
|
198 |
+
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
199 |
+
|
200 |
# Utility Functions
|
201 |
def sanitize_label(label):
|
202 |
"""Remove invalid characters for Mermaid labels."""
|
|
|
349 |
# Also update the temporary current_tree.md for compatibility
|
350 |
with open("current_tree.md", "w") as f:
|
351 |
f.write(st.session_state['current_tree'])
|
352 |
+
st.rerun()
|
353 |
|
354 |
# Display Mermaid Diagram
|
355 |
st.markdown("### Knowledge Tree Visualization")
|
|
|
439 |
with st.expander("Model Configuration", expanded=True):
|
440 |
base_model = st.selectbox(
|
441 |
"Select Base Model",
|
442 |
+
["distilgpt2", "gpt2", "EleutherAI/pythia-70m"],
|
443 |
help="Choose a small model for fine-tuning"
|
444 |
)
|
445 |
+
model_name = st.text_input("Model Name", f"sft-model-{int(time.time())}")
|
446 |
domain = st.text_input("Target Domain", "general")
|
447 |
|
448 |
+
# Initialize ModelBuilder
|
449 |
+
if 'builder' not in st.session_state:
|
450 |
+
st.session_state['builder'] = ModelBuilder()
|
451 |
+
|
452 |
+
# Load Sample Model
|
453 |
+
if st.button("Load Sample Model"):
|
454 |
+
st.session_state['builder'].load_base_model(base_model)
|
455 |
+
st.session_state['model_loaded'] = True
|
456 |
+
st.rerun()
|
457 |
+
|
458 |
+
# Generate and Export Sample CSV
|
459 |
if st.button("Generate Sample CSV"):
|
460 |
sample_data = [
|
461 |
{"prompt": "What is AI?", "response": "AI is artificial intelligence, simulating human intelligence in machines."},
|
|
|
469 |
st.markdown(get_download_link("sft_data.csv", "text/csv"), unsafe_allow_html=True)
|
470 |
st.success("Sample CSV generated as 'sft_data.csv'!")
|
471 |
|
472 |
+
# Upload CSV and Fine-Tune
|
473 |
uploaded_csv = st.file_uploader("Upload CSV for SFT (or use generated sample)", type="csv")
|
474 |
if st.button("Fine-Tune Model") and (uploaded_csv or os.path.exists("sft_data.csv")):
|
475 |
+
if not hasattr(st.session_state['builder'], 'model') or st.session_state['builder'].model is None:
|
476 |
+
st.session_state['builder'].load_base_model(base_model)
|
|
|
|
|
|
|
|
|
|
|
477 |
|
|
|
478 |
csv_path = "sft_data.csv"
|
479 |
if uploaded_csv:
|
480 |
with open(csv_path, "wb") as f:
|
481 |
f.write(uploaded_csv.read())
|
482 |
|
483 |
with st.status("Fine-tuning model...", expanded=True) as status:
|
484 |
+
st.session_state['builder'].fine_tune_sft(csv_path)
|
485 |
+
st.session_state['builder'].save_model(st.session_state['builder'].config.model_path)
|
|
|
486 |
status.update(label="Model fine-tuning completed!", state="complete")
|
487 |
|
488 |
# Generate deployment files
|
489 |
+
config = ModelConfig(name=model_name, base_model=base_model, size="small", domain=domain)
|
490 |
app_code = f"""
|
491 |
import streamlit as st
|
492 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
521 |
st.markdown(get_download_link("sft_requirements.txt", "text/plain"), unsafe_allow_html=True)
|
522 |
st.markdown(get_download_link("sft_README.md", "text/markdown"), unsafe_allow_html=True)
|
523 |
st.write(f"Model saved at: {config.model_path}")
|
524 |
+
st.rerun()
|
525 |
+
|
526 |
+
# Test and Evaluate Model
|
527 |
+
if 'model_loaded' in st.session_state and st.session_state['builder'].model is not None:
|
528 |
+
st.write("### Test and Evaluate Fine-Tuned Model")
|
529 |
+
if st.session_state['builder'].sft_data:
|
530 |
+
st.write("Testing with SFT data:")
|
531 |
+
for item in st.session_state['builder'].sft_data[:3]: # Show up to 3 examples
|
532 |
+
prompt = item["prompt"]
|
533 |
+
expected = item["response"]
|
534 |
+
generated = st.session_state['builder'].evaluate(prompt)
|
535 |
+
st.write(f"**Prompt**: {prompt}")
|
536 |
+
st.write(f"**Expected**: {expected}")
|
537 |
+
st.write(f"**Generated**: {generated}")
|
538 |
+
st.write("---")
|
539 |
+
|
540 |
+
test_prompt = st.text_area("Enter a custom prompt to test", "What is AI?")
|
541 |
+
if st.button("Test Model"):
|
542 |
+
result = st.session_state['builder'].evaluate(test_prompt)
|
543 |
+
st.write(f"**Generated Response**: {result}")
|