File size: 2,165 Bytes
a261ff8 6cc4234 a261ff8 453187f 6cc4234 453187f a261ff8 453187f a261ff8 453187f a261ff8 6cc4234 a261ff8 453187f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import HfApi
import gradio as gr
import os
from peft import PeftModel, PeftConfig
# Function to merge models using PEFT
def merge_models():
base_model_name = "meta-llama/Meta-Llama-3-8B"
finetuned_model_name = "NoaiGPT/autotrain-14mrs-fc44l"
# Load the base model
base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# Load the fine-tuned model using PEFT
peft_config = PeftConfig.from_pretrained(finetuned_model_name)
finetuned_model = PeftModel.from_pretrained(base_model, peft_config)
# Merge the models (simple weight averaging here for demonstration; adjust as needed)
for param_base, param_finetuned in zip(base_model.parameters(), finetuned_model.parameters()):
param_base.data = (param_base.data + param_finetuned.data) / 2
# Save the merged model
merged_model_name = "./merged_model"
base_model.save_pretrained(merged_model_name)
base_tokenizer.save_pretrained(merged_model_name)
return merged_model_name
# Function to upload the merged model to Hugging Face Hub
def upload_to_hf(repo_id, merged_model_name):
api = HfApi()
model_files = [os.path.join(merged_model_name, f) for f in os.listdir(merged_model_name)]
for file in model_files:
api.upload_file(
path_or_fileobj=file,
path_in_repo=os.path.basename(file),
repo_id=repo_id,
repo_type="model"
)
return f"Model uploaded to Hugging Face Hub at {repo_id}."
# Gradio function to handle the merge button click
def merge_button_clicked():
merged_model_name = merge_models()
repo_id = "NoaiGPT/autotrain-14mrs-fc44l"
return upload_to_hf(repo_id, merged_model_name)
# Create the Gradio interface
with gr.Blocks() as demo:
with gr.Row():
merge_button = gr.Button("Merge Models")
output = gr.Textbox(label="Output")
merge_button.click(merge_button_clicked, outputs=output)
# Launch the Gradio interface
demo.launch() |