asd
Browse files
app.py
CHANGED
@@ -3,18 +3,20 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
3 |
from huggingface_hub import HfApi
|
4 |
import gradio as gr
|
5 |
import os
|
|
|
6 |
|
7 |
-
# Function to merge models
|
8 |
def merge_models():
|
9 |
-
base_model_name = "meta-llama/Meta-Llama-3-8B
|
10 |
finetuned_model_name = "NoaiGPT/autotrain-14mrs-fc44l"
|
11 |
|
12 |
# Load the base model
|
13 |
base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
|
14 |
base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
15 |
|
16 |
-
# Load the fine-tuned model
|
17 |
-
|
|
|
18 |
|
19 |
# Merge the models (simple weight averaging here for demonstration; adjust as needed)
|
20 |
for param_base, param_finetuned in zip(base_model.parameters(), finetuned_model.parameters()):
|
@@ -57,4 +59,4 @@ with gr.Blocks() as demo:
|
|
57 |
merge_button.click(merge_button_clicked, outputs=output)
|
58 |
|
59 |
# Launch the Gradio interface
|
60 |
-
demo.launch()
|
|
|
3 |
from huggingface_hub import HfApi
|
4 |
import gradio as gr
|
5 |
import os
|
6 |
+
from peft import PeftModel, PeftConfig
|
7 |
|
8 |
+
# Function to merge models using PEFT
|
9 |
def merge_models():
|
10 |
+
base_model_name = "meta-llama/Meta-Llama-3-8B"
|
11 |
finetuned_model_name = "NoaiGPT/autotrain-14mrs-fc44l"
|
12 |
|
13 |
# Load the base model
|
14 |
base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
|
15 |
base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
16 |
|
17 |
+
# Load the fine-tuned model using PEFT
|
18 |
+
peft_config = PeftConfig.from_pretrained(finetuned_model_name)
|
19 |
+
finetuned_model = PeftModel.from_pretrained(base_model, peft_config)
|
20 |
|
21 |
# Merge the models (simple weight averaging here for demonstration; adjust as needed)
|
22 |
for param_base, param_finetuned in zip(base_model.parameters(), finetuned_model.parameters()):
|
|
|
59 |
merge_button.click(merge_button_clicked, outputs=output)
|
60 |
|
61 |
# Launch the Gradio interface
|
62 |
+
demo.launch()
|