YiDuo1999 commited on
Commit
b051745
·
verified ·
1 Parent(s): c63c010

Create model_soups_utils

Browse files
Files changed (1) hide show
  1. model_soups_utils +25 -0
model_soups_utils ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM,AutoTokenizer
3
+ from transformers import LlamaTokenizer
4
+ from vllm import LLM, SamplingParams
5
+
6
+ def average_two_model(model_path_1,model_path_2,update_num,base_path='/dccstor/obsidian_llm/yiduo/h100_data/llama-3-8b'):
7
+
8
+ # Path to save the averaged model and tokenizer
9
+ averaged_model_path = "{0}".format(model_path_1+model_path_2.split('/')[-1]).replace('00','').replace('random','').replace('naive_3k','').replace('shuffle','').replace('average','')
10
+ # Load and average the state dicts for each model
11
+ models=[]
12
+ model_paths=[model_path_1,model_path_2]
13
+ for model_path in model_paths:
14
+ models.append(AutoModelForCausalLM.from_pretrained(model_path))
15
+ avg_state_dict = {}
16
+ for key in models[0].state_dict().keys():
17
+ avg_state_dict[key] = (update_num/(update_num+1))*models[0].state_dict()[key]+(1.0/(update_num+1))*models[1].state_dict()[key] #sum([model.state_dict()[key] for model in models]) / len(models)
18
+ base_model = AutoModelForCausalLM.from_pretrained(base_path) # Load the base model configuration
19
+ base_model.load_state_dict(avg_state_dict)
20
+ base_model.save_pretrained(averaged_model_path) # Save the averaged model
21
+ # Load the tokenizer (assuming all models used the same tokenizer)
22
+ # If needed, adjust the tokenizer path to match the base LLaMA tokenizer used
23
+ tokenizer = AutoTokenizer.from_pretrained(model_path_1) #tokenizer = LlamaTokenizer.from_pretrained(model_path+'_{0}'.format(seeds[0]))
24
+ tokenizer.save_pretrained(averaged_model_path)
25
+ return averaged_model_path