awacke1 commited on
Commit
24a2dd3
Β·
verified Β·
1 Parent(s): 4aaca6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -89
app.py CHANGED
@@ -13,17 +13,10 @@ import base64
13
  import glob
14
  import time
15
  from transformers import AutoModelForCausalLM, AutoTokenizer
16
- from mergekit.config import MergeConfiguration
17
- from mergekit.merge import Mergekit
18
- from spectrum import SpectrumAnalyzer
19
- import yaml
20
  from dataclasses import dataclass
21
- from typing import Optional, List
22
- import logging
23
-
24
- # Configure logging
25
- logging.basicConfig(level=logging.INFO)
26
- logger = logging.getLogger(__name__)
27
 
28
  # Page Configuration
29
  st.set_page_config(
@@ -77,13 +70,9 @@ project_seeds = {
77
  - Streamlit 🌐
78
  - Torch πŸ”₯
79
  - Transformers πŸ€–
80
- 2. MergeKit Spectrum
81
- - MergeKit πŸ”„
82
- - Spectrum πŸ“Š
83
- 3. Transformers Diffusers Datasets
84
- - Transformers πŸ€–
85
- - Diffusers 🎨
86
- - Datasets πŸ“Š
87
  """,
88
  }
89
 
@@ -108,54 +97,96 @@ class ModelConfig(metaclass=ModelMeta):
108
  def model_path(self):
109
  return f"models/{self.name}"
110
 
111
- # Decorator for pipeline stages
112
- def pipeline_stage(func):
113
- def wrapper(*args, **kwargs):
114
- st.spinner(f"Running {func.__name__}...")
115
- result = func(*args, **kwargs)
116
- st.success(f"Completed {func.__name__}!")
117
- return result
118
- return wrapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- # Model Builder Class
121
  class ModelBuilder:
122
  def __init__(self):
123
  self.config = None
124
  self.model = None
125
  self.tokenizer = None
126
 
127
- @pipeline_stage
128
  def load_base_model(self, model_name: str):
129
  """Load base model from Hugging Face"""
130
- self.model = AutoModelForCausalLM.from_pretrained(model_name)
131
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
132
  return self
133
 
134
- @pipeline_stage
135
- def apply_merge(self, models_to_merge: List[str], output_dir: str):
136
- """Apply Mergekit for model merging"""
137
- merge_config = MergeConfiguration(
138
- models=models_to_merge,
139
- merge_method="linear",
140
- output_dir=output_dir
141
- )
142
- merger = Mergekit(merge_config)
143
- merger.run()
144
- self.model = AutoModelForCausalLM.from_pretrained(output_dir)
145
- return self
146
-
147
- @pipeline_stage
148
- def apply_spectrum(self, domain_data: str):
149
- """Apply Spectrum for domain specialization"""
150
- analyzer = SpectrumAnalyzer(self.model)
151
- analyzer.fit(domain_data)
152
- self.model = analyzer.specialized_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  return self
154
 
155
  def save_model(self, path: str):
156
- """Save the final model"""
157
- self.model.save_pretrained(path)
158
- self.tokenizer.save_pretrained(path)
 
 
159
 
160
  # Utility Functions
161
  def sanitize_label(label):
@@ -325,7 +356,7 @@ if st.button("Export Tree as Markdown"):
325
  # AI Project: Model Building Options
326
  if project_type == "AI Project":
327
  st.subheader("AI Model Building Options")
328
- model_option = st.radio("Choose Model Building Method", ["Minimal ML Model from CSV", "Advanced Model Pipeline"])
329
 
330
  if model_option == "Minimal ML Model from CSV":
331
  st.write("### Build Minimal ML Model from CSV")
@@ -391,46 +422,55 @@ if st.button("Predict"):
391
  st.markdown(get_download_link("requirements.txt", "text/plain"), unsafe_allow_html=True)
392
  st.markdown(get_download_link("README.md", "text/markdown"), unsafe_allow_html=True)
393
 
394
- elif model_option == "Advanced Model Pipeline":
395
- st.write("### Advanced Model Building Pipeline")
396
 
397
  # Model Configuration
398
  with st.expander("Model Configuration", expanded=True):
399
  base_model = st.selectbox(
400
  "Select Base Model",
401
- ["mistral-7b", "llama-2-7b", "gpt2-medium"]
 
402
  )
403
- model_name = st.text_input("Model Name", "custom-model")
404
  domain = st.text_input("Target Domain", "general")
405
- use_merging = st.checkbox("Apply Model Merging", False)
406
- use_spectrum = st.checkbox("Apply Spectrum Specialization", True)
407
 
408
- # Build Model
409
- if st.button("Build Advanced Model"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  config = ModelConfig(
411
  name=model_name,
412
  base_model=base_model,
413
- size="7B",
414
  domain=domain
415
  )
416
  builder = ModelBuilder()
417
 
418
- with st.status("Building advanced model...", expanded=True) as status:
 
 
 
 
 
 
419
  builder.load_base_model(config.base_model)
420
-
421
- if use_merging:
422
- models_to_merge = st.multiselect(
423
- "Select Models to Merge",
424
- ["mistral-7b", "llama-2-7b", "gpt2-medium"]
425
- )
426
- builder.apply_merge(models_to_merge, f"merged_{config.name}")
427
-
428
- if use_spectrum:
429
- domain_data = st.text_area("Enter domain-specific data", "Sample domain data")
430
- builder.apply_spectrum(domain_data)
431
-
432
  builder.save_model(config.model_path)
433
- status.update(label="Advanced model built successfully!", state="complete")
434
 
435
  # Generate deployment files
436
  app_code = f"""
@@ -440,32 +480,32 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
440
  model = AutoModelForCausalLM.from_pretrained("{config.model_path}")
441
  tokenizer = AutoTokenizer.from_pretrained("{config.model_path}")
442
 
443
- st.title("Advanced Model Demo")
444
- input_text = st.text_area("Enter text")
445
  if st.button("Generate"):
446
  inputs = tokenizer(input_text, return_tensors="pt")
447
- outputs = model.generate(**inputs)
448
  st.write(tokenizer.decode(outputs[0], skip_special_tokens=True))
449
  """
450
- with open("advanced_app.py", "w") as f:
451
  f.write(app_code)
452
- reqs = "streamlit\ntorch\ntransformers\nmergekit\nspectrum\n"
453
- with open("advanced_requirements.txt", "w") as f:
454
  f.write(reqs)
455
  readme = f"""
456
- # Advanced Model Demo
457
 
458
  ## How to run
459
- 1. Install requirements: `pip install -r advanced_requirements.txt`
460
- 2. Run the app: `streamlit run advanced_app.py`
461
- 3. Input text and click "Generate".
462
  """
463
- with open("advanced_README.md", "w") as f:
464
  f.write(readme)
465
 
466
- st.markdown(get_download_link("advanced_app.py", "text/plain"), unsafe_allow_html=True)
467
- st.markdown(get_download_link("advanced_requirements.txt", "text/plain"), unsafe_allow_html=True)
468
- st.markdown(get_download_link("advanced_README.md", "text/markdown"), unsafe_allow_html=True)
469
  st.write(f"Model saved at: {config.model_path}")
470
 
471
  if __name__ == "__main__":
 
13
  import glob
14
  import time
15
  from transformers import AutoModelForCausalLM, AutoTokenizer
16
+ from torch.utils.data import Dataset, DataLoader
17
+ import csv
 
 
18
  from dataclasses import dataclass
19
+ from typing import Optional
 
 
 
 
 
20
 
21
  # Page Configuration
22
  st.set_page_config(
 
70
  - Streamlit 🌐
71
  - Torch πŸ”₯
72
  - Transformers πŸ€–
73
+ 2. SFT Fine-Tuning
74
+ - SFT πŸ€“
75
+ - Small Models πŸ“‰
 
 
 
 
76
  """,
77
  }
78
 
 
97
  def model_path(self):
98
  return f"models/{self.name}"
99
 
100
+ # Custom Dataset for SFT
101
+ class SFTDataset(Dataset):
102
+ def __init__(self, data, tokenizer, max_length=128):
103
+ self.data = data
104
+ self.tokenizer = tokenizer
105
+ self.max_length = max_length
106
+
107
+ def __len__(self):
108
+ return len(self.data)
109
+
110
+ def __getitem__(self, idx):
111
+ prompt = self.data[idx]["prompt"]
112
+ response = self.data[idx]["response"]
113
+ input_text = f"{prompt} {response}"
114
+ encoding = self.tokenizer(
115
+ input_text,
116
+ max_length=self.max_length,
117
+ padding="max_length",
118
+ truncation=True,
119
+ return_tensors="pt"
120
+ )
121
+ return {
122
+ "input_ids": encoding["input_ids"].squeeze(),
123
+ "attention_mask": encoding["attention_mask"].squeeze(),
124
+ "labels": encoding["input_ids"].squeeze() # For causal LM, labels are the same as input_ids
125
+ }
126
 
127
+ # Model Builder Class with SFT
128
  class ModelBuilder:
129
  def __init__(self):
130
  self.config = None
131
  self.model = None
132
  self.tokenizer = None
133
 
 
134
  def load_base_model(self, model_name: str):
135
  """Load base model from Hugging Face"""
136
+ with st.spinner("Loading base model..."):
137
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
138
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
139
+ if self.tokenizer.pad_token is None:
140
+ self.tokenizer.pad_token = self.tokenizer.eos_token
141
+ st.success("Base model loaded!")
142
  return self
143
 
144
+ def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
145
+ """Perform Supervised Fine-Tuning with CSV data"""
146
+ # Load CSV data
147
+ data = []
148
+ with open(csv_path, "r") as f:
149
+ reader = csv.DictReader(f)
150
+ for row in reader:
151
+ data.append({"prompt": row["prompt"], "response": row["response"]})
152
+
153
+ # Prepare dataset and dataloader
154
+ dataset = SFTDataset(data, self.tokenizer)
155
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
156
+
157
+ # Set up optimizer
158
+ optimizer = optim.AdamW(self.model.parameters(), lr=2e-5)
159
+
160
+ # Training loop
161
+ self.model.train()
162
+ for epoch in range(epochs):
163
+ with st.spinner(f"Training epoch {epoch + 1}/{epochs}..."):
164
+ total_loss = 0
165
+ for batch in dataloader:
166
+ optimizer.zero_grad()
167
+ input_ids = batch["input_ids"].to(self.model.device)
168
+ attention_mask = batch["attention_mask"].to(self.model.device)
169
+ labels = batch["labels"].to(self.model.device)
170
+
171
+ outputs = self.model(
172
+ input_ids=input_ids,
173
+ attention_mask=attention_mask,
174
+ labels=labels
175
+ )
176
+ loss = outputs.loss
177
+ loss.backward()
178
+ optimizer.step()
179
+ total_loss += loss.item()
180
+ st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
181
+ st.success("SFT Fine-tuning completed!")
182
  return self
183
 
184
  def save_model(self, path: str):
185
+ """Save the fine-tuned model"""
186
+ with st.spinner("Saving model..."):
187
+ self.model.save_pretrained(path)
188
+ self.tokenizer.save_pretrained(path)
189
+ st.success("Model saved!")
190
 
191
  # Utility Functions
192
  def sanitize_label(label):
 
356
  # AI Project: Model Building Options
357
  if project_type == "AI Project":
358
  st.subheader("AI Model Building Options")
359
+ model_option = st.radio("Choose Model Building Method", ["Minimal ML Model from CSV", "SFT Fine-Tuning"])
360
 
361
  if model_option == "Minimal ML Model from CSV":
362
  st.write("### Build Minimal ML Model from CSV")
 
422
  st.markdown(get_download_link("requirements.txt", "text/plain"), unsafe_allow_html=True)
423
  st.markdown(get_download_link("README.md", "text/markdown"), unsafe_allow_html=True)
424
 
425
+ elif model_option == "SFT Fine-Tuning":
426
+ st.write("### SFT Fine-Tuning with Small Models")
427
 
428
  # Model Configuration
429
  with st.expander("Model Configuration", expanded=True):
430
  base_model = st.selectbox(
431
  "Select Base Model",
432
+ ["distilgpt2", "gpt2", "EleutherAI/pythia-70m"], # Small models suitable for SFT
433
+ help="Choose a small model for fine-tuning"
434
  )
435
+ model_name = st.text_input("Model Name", "sft-model")
436
  domain = st.text_input("Target Domain", "general")
 
 
437
 
438
+ # Generate Sample CSV
439
+ if st.button("Generate Sample CSV"):
440
+ sample_data = [
441
+ {"prompt": "What is AI?", "response": "AI is artificial intelligence, simulating human intelligence in machines."},
442
+ {"prompt": "Explain machine learning", "response": "Machine learning is a subset of AI where models learn from data."},
443
+ {"prompt": "What is a neural network?", "response": "A neural network is a model inspired by the human brain."},
444
+ ]
445
+ with open("sft_data.csv", "w", newline="") as f:
446
+ writer = csv.DictWriter(f, fieldnames=["prompt", "response"])
447
+ writer.writeheader()
448
+ writer.writerows(sample_data)
449
+ st.markdown(get_download_link("sft_data.csv", "text/csv"), unsafe_allow_html=True)
450
+ st.success("Sample CSV generated as 'sft_data.csv'!")
451
+
452
+ # Fine-Tune with SFT
453
+ uploaded_csv = st.file_uploader("Upload CSV for SFT (or use generated sample)", type="csv")
454
+ if st.button("Fine-Tune Model") and (uploaded_csv or os.path.exists("sft_data.csv")):
455
  config = ModelConfig(
456
  name=model_name,
457
  base_model=base_model,
458
+ size="small",
459
  domain=domain
460
  )
461
  builder = ModelBuilder()
462
 
463
+ # Load CSV
464
+ csv_path = "sft_data.csv"
465
+ if uploaded_csv:
466
+ with open(csv_path, "wb") as f:
467
+ f.write(uploaded_csv.read())
468
+
469
+ with st.status("Fine-tuning model...", expanded=True) as status:
470
  builder.load_base_model(config.base_model)
471
+ builder.fine_tune_sft(csv_path)
 
 
 
 
 
 
 
 
 
 
 
472
  builder.save_model(config.model_path)
473
+ status.update(label="Model fine-tuning completed!", state="complete")
474
 
475
  # Generate deployment files
476
  app_code = f"""
 
480
  model = AutoModelForCausalLM.from_pretrained("{config.model_path}")
481
  tokenizer = AutoTokenizer.from_pretrained("{config.model_path}")
482
 
483
+ st.title("SFT Model Demo")
484
+ input_text = st.text_area("Enter prompt")
485
  if st.button("Generate"):
486
  inputs = tokenizer(input_text, return_tensors="pt")
487
+ outputs = model.generate(**inputs, max_new_tokens=50)
488
  st.write(tokenizer.decode(outputs[0], skip_special_tokens=True))
489
  """
490
+ with open("sft_app.py", "w") as f:
491
  f.write(app_code)
492
+ reqs = "streamlit\ntorch\ntransformers\n"
493
+ with open("sft_requirements.txt", "w") as f:
494
  f.write(reqs)
495
  readme = f"""
496
+ # SFT Model Demo
497
 
498
  ## How to run
499
+ 1. Install requirements: `pip install -r sft_requirements.txt`
500
+ 2. Run the app: `streamlit run sft_app.py`
501
+ 3. Input a prompt and click "Generate".
502
  """
503
+ with open("sft_README.md", "w") as f:
504
  f.write(readme)
505
 
506
+ st.markdown(get_download_link("sft_app.py", "text/plain"), unsafe_allow_html=True)
507
+ st.markdown(get_download_link("sft_requirements.txt", "text/plain"), unsafe_allow_html=True)
508
+ st.markdown(get_download_link("sft_README.md", "text/markdown"), unsafe_allow_html=True)
509
  st.write(f"Model saved at: {config.model_path}")
510
 
511
  if __name__ == "__main__":