NoaiGPT commited on
Commit
453187f
1 Parent(s): a261ff8
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -3,18 +3,20 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from huggingface_hub import HfApi
4
  import gradio as gr
5
  import os
 
6
 
7
- # Function to merge models
8
  def merge_models():
9
- base_model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
10
  finetuned_model_name = "NoaiGPT/autotrain-14mrs-fc44l"
11
 
12
  # Load the base model
13
  base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
14
  base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
15
 
16
- # Load the fine-tuned model
17
- finetuned_model = AutoModelForCausalLM.from_pretrained(finetuned_model_name)
 
18
 
19
  # Merge the models (simple weight averaging here for demonstration; adjust as needed)
20
  for param_base, param_finetuned in zip(base_model.parameters(), finetuned_model.parameters()):
@@ -57,4 +59,4 @@ with gr.Blocks() as demo:
57
  merge_button.click(merge_button_clicked, outputs=output)
58
 
59
  # Launch the Gradio interface
60
- demo.launch()
 
3
  from huggingface_hub import HfApi
4
  import gradio as gr
5
  import os
6
+ from peft import PeftModel, PeftConfig
7
 
8
+ # Function to merge models using PEFT
9
  def merge_models():
10
+ base_model_name = "meta-llama/Meta-Llama-3-8B"
11
  finetuned_model_name = "NoaiGPT/autotrain-14mrs-fc44l"
12
 
13
  # Load the base model
14
  base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
15
  base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
16
 
17
+ # Load the fine-tuned model using PEFT
18
+ peft_config = PeftConfig.from_pretrained(finetuned_model_name)
19
+ finetuned_model = PeftModel.from_pretrained(base_model, peft_config)
20
 
21
  # Merge the models (simple weight averaging here for demonstration; adjust as needed)
22
  for param_base, param_finetuned in zip(base_model.parameters(), finetuned_model.parameters()):
 
59
  merge_button.click(merge_button_clicked, outputs=output)
60
 
61
  # Launch the Gradio interface
62
+ demo.launch()