HaileyStorm commited on
Commit
7b54808
·
verified ·
1 Parent(s): 3000a46

Upload 2 files

Browse files
Files changed (2) hide show
  1. scripts/ckpt2hf.py +47 -0
  2. scripts/full.yaml +2 -2
scripts/ckpt2hf.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+
4
+ print("Loading checkpoint...")
5
+
6
+ # Define the paths to your checkpoint files
7
+ checkpoint_paths = [
8
+ './llama3-5b/model-00001-of-00003.pt',
9
+ './llama3-5b/model-00002-of-00003.pt',
10
+ './llama3-5b/model-00003-of-00003.pt'
11
+ ]
12
+
13
+ # Initialize an empty state dictionary
14
+ merged_state_dict = {}
15
+
16
+ # Load each checkpoint and merge them
17
+ for checkpoint_path in checkpoint_paths:
18
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
19
+ merged_state_dict.update(checkpoint)
20
+
21
+ print("Loading original model...")
22
+
23
+ # Define the original model name or path
24
+ original_model_name = "../../slice_with_mergekit/merged/"
25
+
26
+ # Load the model configuration and create a new model instance
27
+ model = AutoModelForCausalLM.from_pretrained(original_model_name, state_dict=merged_state_dict)
28
+
29
+ print("Converting to fp16...")
30
+
31
+ # Convert model parameters to float16
32
+ model.half()
33
+
34
+ print("Saving model...")
35
+
36
+ # Save the model in the safetensors format
37
+ output_dir = './llama3-5b/hf/'
38
+ model.save_pretrained(output_dir, safe_serialization=True)
39
+
40
+ print("Saving tokenizer...")
41
+
42
+ # Save the tokenizer as well
43
+ tokenizer = AutoTokenizer.from_pretrained(original_model_name)
44
+ tokenizer.save_pretrained(output_dir)
45
+
46
+ print(f"Merged model saved to {output_dir}")
47
+
scripts/full.yaml CHANGED
@@ -1,7 +1,7 @@
1
  # Tokenizer
2
  tokenizer:
3
  _component_: torchtune.models.llama3.llama3_tokenizer
4
- path: ../original/tokenizer.model
5
 
6
  # Dataset and Sampler
7
  dataset:
@@ -29,7 +29,7 @@ model:
29
 
30
  checkpointer:
31
  _component_: torchtune.utils.FullModelHFCheckpointer
32
- checkpoint_dir: ../merged/
33
  checkpoint_files: [
34
  model-00001-of-00003.safetensors,
35
  model-00002-of-00003.safetensors,
 
1
  # Tokenizer
2
  tokenizer:
3
  _component_: torchtune.models.llama3.llama3_tokenizer
4
+ path: ../tokenizer.model
5
 
6
  # Dataset and Sampler
7
  dataset:
 
29
 
30
  checkpointer:
31
  _component_: torchtune.utils.FullModelHFCheckpointer
32
+ checkpoint_dir: ../
33
  checkpoint_files: [
34
  model-00001-of-00003.safetensors,
35
  model-00002-of-00003.safetensors,