awacke1 commited on
Commit
1a5fcce
·
verified ·
1 Parent(s): 18972c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -22
app.py CHANGED
@@ -124,12 +124,13 @@ class SFTDataset(Dataset):
124
  "labels": encoding["input_ids"].squeeze() # For causal LM, labels are the same as input_ids
125
  }
126
 
127
- # Model Builder Class with SFT
128
  class ModelBuilder:
129
  def __init__(self):
130
  self.config = None
131
  self.model = None
132
  self.tokenizer = None
 
133
 
134
  def load_base_model(self, model_name: str):
135
  """Load base model from Hugging Face"""
@@ -144,14 +145,14 @@ class ModelBuilder:
144
  def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
145
  """Perform Supervised Fine-Tuning with CSV data"""
146
  # Load CSV data
147
- data = []
148
  with open(csv_path, "r") as f:
149
  reader = csv.DictReader(f)
150
  for row in reader:
151
- data.append({"prompt": row["prompt"], "response": row["response"]})
152
 
153
  # Prepare dataset and dataloader
154
- dataset = SFTDataset(data, self.tokenizer)
155
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
156
 
157
  # Set up optimizer
@@ -188,6 +189,14 @@ class ModelBuilder:
188
  self.tokenizer.save_pretrained(path)
189
  st.success("Model saved!")
190
 
 
 
 
 
 
 
 
 
191
  # Utility Functions
192
  def sanitize_label(label):
193
  """Remove invalid characters for Mermaid labels."""
@@ -340,6 +349,7 @@ if st.button("Grow Tree 🌱") and new_node and parent_node:
340
  # Also update the temporary current_tree.md for compatibility
341
  with open("current_tree.md", "w") as f:
342
  f.write(st.session_state['current_tree'])
 
343
 
344
  # Display Mermaid Diagram
345
  st.markdown("### Knowledge Tree Visualization")
@@ -429,13 +439,23 @@ if st.button("Predict"):
429
  with st.expander("Model Configuration", expanded=True):
430
  base_model = st.selectbox(
431
  "Select Base Model",
432
- ["distilgpt2", "gpt2", "EleutherAI/pythia-70m"], # Small models suitable for SFT
433
  help="Choose a small model for fine-tuning"
434
  )
435
- model_name = st.text_input("Model Name", "sft-model")
436
  domain = st.text_input("Target Domain", "general")
437
 
438
- # Generate Sample CSV
 
 
 
 
 
 
 
 
 
 
439
  if st.button("Generate Sample CSV"):
440
  sample_data = [
441
  {"prompt": "What is AI?", "response": "AI is artificial intelligence, simulating human intelligence in machines."},
@@ -449,30 +469,24 @@ if st.button("Predict"):
449
  st.markdown(get_download_link("sft_data.csv", "text/csv"), unsafe_allow_html=True)
450
  st.success("Sample CSV generated as 'sft_data.csv'!")
451
 
452
- # Fine-Tune with SFT
453
  uploaded_csv = st.file_uploader("Upload CSV for SFT (or use generated sample)", type="csv")
454
  if st.button("Fine-Tune Model") and (uploaded_csv or os.path.exists("sft_data.csv")):
455
- config = ModelConfig(
456
- name=model_name,
457
- base_model=base_model,
458
- size="small",
459
- domain=domain
460
- )
461
- builder = ModelBuilder()
462
 
463
- # Load CSV
464
  csv_path = "sft_data.csv"
465
  if uploaded_csv:
466
  with open(csv_path, "wb") as f:
467
  f.write(uploaded_csv.read())
468
 
469
  with st.status("Fine-tuning model...", expanded=True) as status:
470
- builder.load_base_model(config.base_model)
471
- builder.fine_tune_sft(csv_path)
472
- builder.save_model(config.model_path)
473
  status.update(label="Model fine-tuning completed!", state="complete")
474
 
475
  # Generate deployment files
 
476
  app_code = f"""
477
  import streamlit as st
478
  from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -507,6 +521,23 @@ if st.button("Generate"):
507
  st.markdown(get_download_link("sft_requirements.txt", "text/plain"), unsafe_allow_html=True)
508
  st.markdown(get_download_link("sft_README.md", "text/markdown"), unsafe_allow_html=True)
509
  st.write(f"Model saved at: {config.model_path}")
510
-
511
- if __name__ == "__main__":
512
- st.run()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  "labels": encoding["input_ids"].squeeze() # For causal LM, labels are the same as input_ids
125
  }
126
 
127
+ # Model Builder Class with SFT and Evaluation
128
  class ModelBuilder:
129
  def __init__(self):
130
  self.config = None
131
  self.model = None
132
  self.tokenizer = None
133
+ self.sft_data = None
134
 
135
  def load_base_model(self, model_name: str):
136
  """Load base model from Hugging Face"""
 
145
  def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
146
  """Perform Supervised Fine-Tuning with CSV data"""
147
  # Load CSV data
148
+ self.sft_data = []
149
  with open(csv_path, "r") as f:
150
  reader = csv.DictReader(f)
151
  for row in reader:
152
+ self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
153
 
154
  # Prepare dataset and dataloader
155
+ dataset = SFTDataset(self.sft_data, self.tokenizer)
156
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
157
 
158
  # Set up optimizer
 
189
  self.tokenizer.save_pretrained(path)
190
  st.success("Model saved!")
191
 
192
+ def evaluate(self, prompt: str):
193
+ """Evaluate the model with a prompt"""
194
+ self.model.eval()
195
+ with torch.no_grad():
196
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
197
+ outputs = self.model.generate(**inputs, max_new_tokens=50)
198
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
199
+
200
  # Utility Functions
201
  def sanitize_label(label):
202
  """Remove invalid characters for Mermaid labels."""
 
349
  # Also update the temporary current_tree.md for compatibility
350
  with open("current_tree.md", "w") as f:
351
  f.write(st.session_state['current_tree'])
352
+ st.rerun()
353
 
354
  # Display Mermaid Diagram
355
  st.markdown("### Knowledge Tree Visualization")
 
439
  with st.expander("Model Configuration", expanded=True):
440
  base_model = st.selectbox(
441
  "Select Base Model",
442
+ ["distilgpt2", "gpt2", "EleutherAI/pythia-70m"],
443
  help="Choose a small model for fine-tuning"
444
  )
445
+ model_name = st.text_input("Model Name", f"sft-model-{int(time.time())}")
446
  domain = st.text_input("Target Domain", "general")
447
 
448
+ # Initialize ModelBuilder
449
+ if 'builder' not in st.session_state:
450
+ st.session_state['builder'] = ModelBuilder()
451
+
452
+ # Load Sample Model
453
+ if st.button("Load Sample Model"):
454
+ st.session_state['builder'].load_base_model(base_model)
455
+ st.session_state['model_loaded'] = True
456
+ st.rerun()
457
+
458
+ # Generate and Export Sample CSV
459
  if st.button("Generate Sample CSV"):
460
  sample_data = [
461
  {"prompt": "What is AI?", "response": "AI is artificial intelligence, simulating human intelligence in machines."},
 
469
  st.markdown(get_download_link("sft_data.csv", "text/csv"), unsafe_allow_html=True)
470
  st.success("Sample CSV generated as 'sft_data.csv'!")
471
 
472
+ # Upload CSV and Fine-Tune
473
  uploaded_csv = st.file_uploader("Upload CSV for SFT (or use generated sample)", type="csv")
474
  if st.button("Fine-Tune Model") and (uploaded_csv or os.path.exists("sft_data.csv")):
475
+ if not hasattr(st.session_state['builder'], 'model') or st.session_state['builder'].model is None:
476
+ st.session_state['builder'].load_base_model(base_model)
 
 
 
 
 
477
 
 
478
  csv_path = "sft_data.csv"
479
  if uploaded_csv:
480
  with open(csv_path, "wb") as f:
481
  f.write(uploaded_csv.read())
482
 
483
  with st.status("Fine-tuning model...", expanded=True) as status:
484
+ st.session_state['builder'].fine_tune_sft(csv_path)
485
+ st.session_state['builder'].save_model(st.session_state['builder'].config.model_path)
 
486
  status.update(label="Model fine-tuning completed!", state="complete")
487
 
488
  # Generate deployment files
489
+ config = ModelConfig(name=model_name, base_model=base_model, size="small", domain=domain)
490
  app_code = f"""
491
  import streamlit as st
492
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
521
  st.markdown(get_download_link("sft_requirements.txt", "text/plain"), unsafe_allow_html=True)
522
  st.markdown(get_download_link("sft_README.md", "text/markdown"), unsafe_allow_html=True)
523
  st.write(f"Model saved at: {config.model_path}")
524
+ st.rerun()
525
+
526
+ # Test and Evaluate Model
527
+ if 'model_loaded' in st.session_state and st.session_state['builder'].model is not None:
528
+ st.write("### Test and Evaluate Fine-Tuned Model")
529
+ if st.session_state['builder'].sft_data:
530
+ st.write("Testing with SFT data:")
531
+ for item in st.session_state['builder'].sft_data[:3]: # Show up to 3 examples
532
+ prompt = item["prompt"]
533
+ expected = item["response"]
534
+ generated = st.session_state['builder'].evaluate(prompt)
535
+ st.write(f"**Prompt**: {prompt}")
536
+ st.write(f"**Expected**: {expected}")
537
+ st.write(f"**Generated**: {generated}")
538
+ st.write("---")
539
+
540
+ test_prompt = st.text_area("Enter a custom prompt to test", "What is AI?")
541
+ if st.button("Test Model"):
542
+ result = st.session_state['builder'].evaluate(test_prompt)
543
+ st.write(f"**Generated Response**: {result}")