File size: 1,600 Bytes
b051745
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
from transformers import AutoModelForCausalLM,AutoTokenizer
from transformers import LlamaTokenizer
from vllm import LLM, SamplingParams

def average_two_model(model_path_1,model_path_2,update_num,base_path='/dccstor/obsidian_llm/yiduo/h100_data/llama-3-8b'):

# Path to save the averaged model and tokenizer
    averaged_model_path = "{0}".format(model_path_1+model_path_2.split('/')[-1]).replace('00','').replace('random','').replace('naive_3k','').replace('shuffle','').replace('average','')
    # Load and average the state dicts for each model
    models=[]
    model_paths=[model_path_1,model_path_2]
    for model_path in model_paths:
        models.append(AutoModelForCausalLM.from_pretrained(model_path))
    avg_state_dict = {}
    for key in models[0].state_dict().keys():
        avg_state_dict[key] = (update_num/(update_num+1))*models[0].state_dict()[key]+(1.0/(update_num+1))*models[1].state_dict()[key]  #sum([model.state_dict()[key] for model in models]) / len(models)
    base_model = AutoModelForCausalLM.from_pretrained(base_path)  # Load the base model configuration
    base_model.load_state_dict(avg_state_dict)
    base_model.save_pretrained(averaged_model_path) # Save the averaged model
    # Load the tokenizer (assuming all models used the same tokenizer)
    # If needed, adjust the tokenizer path to match the base LLaMA tokenizer used
    tokenizer = AutoTokenizer.from_pretrained(model_path_1) #tokenizer = LlamaTokenizer.from_pretrained(model_path+'_{0}'.format(seeds[0]))
    tokenizer.save_pretrained(averaged_model_path)
    return averaged_model_path