|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
print("Loading checkpoint...") |
|
|
|
|
|
checkpoint_paths = [ |
|
'./llama3-5b/model-00001-of-00003.pt', |
|
'./llama3-5b/model-00002-of-00003.pt', |
|
'./llama3-5b/model-00003-of-00003.pt' |
|
] |
|
|
|
|
|
merged_state_dict = {} |
|
|
|
|
|
for checkpoint_path in checkpoint_paths: |
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
merged_state_dict.update(checkpoint) |
|
|
|
print("Loading original model...") |
|
|
|
|
|
original_model_name = "../../slice_with_mergekit/merged/" |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(original_model_name, state_dict=merged_state_dict) |
|
|
|
print("Converting to fp16...") |
|
|
|
|
|
model.half() |
|
|
|
print("Saving model...") |
|
|
|
|
|
output_dir = './llama3-5b/hf/' |
|
model.save_pretrained(output_dir, safe_serialization=True) |
|
|
|
print("Saving tokenizer...") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(original_model_name) |
|
tokenizer.save_pretrained(output_dir) |
|
|
|
print(f"Merged model saved to {output_dir}") |
|
|
|
|