Update app.py
Browse files
app.py
CHANGED
@@ -16,7 +16,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
16 |
from mergekit.config import MergeConfiguration
|
17 |
from mergekit.merge import Mergekit
|
18 |
from spectrum import SpectrumAnalyzer
|
19 |
-
import distilkit
|
20 |
import yaml
|
21 |
from dataclasses import dataclass
|
22 |
from typing import Optional, List
|
@@ -78,8 +77,7 @@ project_seeds = {
|
|
78 |
- Streamlit π
|
79 |
- Torch π₯
|
80 |
- Transformers π€
|
81 |
-
2.
|
82 |
-
- DistillKit π§ͺ
|
83 |
- MergeKit π
|
84 |
- Spectrum π
|
85 |
3. Transformers Diffusers Datasets
|
@@ -133,18 +131,6 @@ class ModelBuilder:
|
|
133 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
134 |
return self
|
135 |
|
136 |
-
@pipeline_stage
|
137 |
-
def apply_distillation(self, teacher_model: str, output_dir: str):
|
138 |
-
"""Apply DistilKit for model distillation"""
|
139 |
-
distiller = distilkit.Distiller(
|
140 |
-
teacher_model=teacher_model,
|
141 |
-
student_model=self.model,
|
142 |
-
tokenizer=self.tokenizer
|
143 |
-
)
|
144 |
-
distiller.distill(output_dir=output_dir)
|
145 |
-
self.model = distiller.student_model
|
146 |
-
return self
|
147 |
-
|
148 |
@pipeline_stage
|
149 |
def apply_merge(self, models_to_merge: List[str], output_dir: str):
|
150 |
"""Apply Mergekit for model merging"""
|
@@ -416,7 +402,6 @@ if st.button("Predict"):
|
|
416 |
)
|
417 |
model_name = st.text_input("Model Name", "custom-model")
|
418 |
domain = st.text_input("Target Domain", "general")
|
419 |
-
use_distillation = st.checkbox("Apply Distillation", True)
|
420 |
use_merging = st.checkbox("Apply Model Merging", False)
|
421 |
use_spectrum = st.checkbox("Apply Spectrum Specialization", True)
|
422 |
|
@@ -433,13 +418,6 @@ if st.button("Predict"):
|
|
433 |
with st.status("Building advanced model...", expanded=True) as status:
|
434 |
builder.load_base_model(config.base_model)
|
435 |
|
436 |
-
if use_distillation:
|
437 |
-
teacher_model = st.selectbox(
|
438 |
-
"Select Teacher Model",
|
439 |
-
["mistral-13b", "llama-2-13b"]
|
440 |
-
)
|
441 |
-
builder.apply_distillation(teacher_model, f"distilled_{config.name}")
|
442 |
-
|
443 |
if use_merging:
|
444 |
models_to_merge = st.multiselect(
|
445 |
"Select Models to Merge",
|
@@ -471,7 +449,7 @@ if st.button("Generate"):
|
|
471 |
"""
|
472 |
with open("advanced_app.py", "w") as f:
|
473 |
f.write(app_code)
|
474 |
-
reqs = "streamlit\ntorch\ntransformers\n"
|
475 |
with open("advanced_requirements.txt", "w") as f:
|
476 |
f.write(reqs)
|
477 |
readme = f"""
|
|
|
16 |
from mergekit.config import MergeConfiguration
|
17 |
from mergekit.merge import Mergekit
|
18 |
from spectrum import SpectrumAnalyzer
|
|
|
19 |
import yaml
|
20 |
from dataclasses import dataclass
|
21 |
from typing import Optional, List
|
|
|
77 |
- Streamlit π
|
78 |
- Torch π₯
|
79 |
- Transformers π€
|
80 |
+
2. MergeKit Spectrum
|
|
|
81 |
- MergeKit π
|
82 |
- Spectrum π
|
83 |
3. Transformers Diffusers Datasets
|
|
|
131 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
132 |
return self
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
@pipeline_stage
|
135 |
def apply_merge(self, models_to_merge: List[str], output_dir: str):
|
136 |
"""Apply Mergekit for model merging"""
|
|
|
402 |
)
|
403 |
model_name = st.text_input("Model Name", "custom-model")
|
404 |
domain = st.text_input("Target Domain", "general")
|
|
|
405 |
use_merging = st.checkbox("Apply Model Merging", False)
|
406 |
use_spectrum = st.checkbox("Apply Spectrum Specialization", True)
|
407 |
|
|
|
418 |
with st.status("Building advanced model...", expanded=True) as status:
|
419 |
builder.load_base_model(config.base_model)
|
420 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
if use_merging:
|
422 |
models_to_merge = st.multiselect(
|
423 |
"Select Models to Merge",
|
|
|
449 |
"""
|
450 |
with open("advanced_app.py", "w") as f:
|
451 |
f.write(app_code)
|
452 |
+
reqs = "streamlit\ntorch\ntransformers\nmergekit\nspectrum\n"
|
453 |
with open("advanced_requirements.txt", "w") as f:
|
454 |
f.write(reqs)
|
455 |
readme = f"""
|