leafspark commited on
Commit
c17b4e4
·
verified ·
1 Parent(s): 85e1dcd

docs: add conversion script

Browse files
Files changed (1) hide show
  1. convert_qwen2_to_llama.py +182 -0
convert_qwen2_to_llama.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Converts the 2nd version of the Qwen models in the same format as LLaMA2.
3
+ # Usage: python convert_qwen2_to_llama.py --input_dir magnum-72b-v1 --output_dir magnum-72b-v1-llamaify --save_safetensors --continue_conversion
4
+ # Original script: https://github.com/Minami-su/character_AI_open/blob/main/llamafy_qwen_v2.py
5
+
6
+ import json
7
+ import os
8
+ from collections import OrderedDict
9
+ from typing import Any, Dict, Optional
10
+
11
+ import fire
12
+ import torch
13
+ from safetensors import safe_open
14
+ from safetensors.torch import save_file
15
+ from tqdm import tqdm
16
+ from transformers.modeling_utils import (
17
+ SAFE_WEIGHTS_INDEX_NAME,
18
+ SAFE_WEIGHTS_NAME,
19
+ WEIGHTS_INDEX_NAME,
20
+ WEIGHTS_NAME,
21
+ shard_checkpoint,
22
+ )
23
+ from transformers.utils import check_min_version
24
+
25
+ try:
26
+ check_min_version("4.34.0")
27
+ except Exception:
28
+ raise ValueError("Please upgrade `transformers` to 4.34.0")
29
+
30
+ CONFIG_NAME = "config.json"
31
+
32
+
33
+ def load_existing_shards(
34
+ output_dir: str, save_safetensors: bool
35
+ ) -> Dict[str, torch.Tensor]:
36
+ existing_state_dict = OrderedDict()
37
+ weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
38
+ index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
39
+
40
+ if os.path.exists(os.path.join(output_dir, index_name)):
41
+ with open(os.path.join(output_dir, index_name), "r", encoding="utf-8") as f:
42
+ index = json.load(f)
43
+
44
+ for shard_file in tqdm(
45
+ index["weight_map"].values(), desc="Loading existing shards"
46
+ ):
47
+ if os.path.exists(os.path.join(output_dir, shard_file)):
48
+ if save_safetensors:
49
+ with safe_open(
50
+ os.path.join(output_dir, shard_file),
51
+ framework="pt",
52
+ device="cpu",
53
+ ) as f:
54
+ for key in f.keys():
55
+ existing_state_dict[key] = f.get_tensor(key)
56
+ else:
57
+ shard = torch.load(
58
+ os.path.join(output_dir, shard_file), map_location="cpu"
59
+ )
60
+ existing_state_dict.update(shard)
61
+
62
+ return existing_state_dict
63
+
64
+
65
+ def save_weight(
66
+ input_dir: str,
67
+ output_dir: str,
68
+ shard_size: str,
69
+ save_safetensors: bool,
70
+ continue_conversion: bool,
71
+ ) -> str:
72
+ qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
73
+ for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
74
+ if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(
75
+ ".safetensors"
76
+ ):
77
+ with safe_open(
78
+ os.path.join(input_dir, filepath), framework="pt", device="cpu"
79
+ ) as f:
80
+ for key in f.keys():
81
+ qwen_state_dict[key] = f.get_tensor(key)
82
+
83
+ llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
84
+ if continue_conversion:
85
+ llama2_state_dict = load_existing_shards(output_dir, save_safetensors)
86
+
87
+ torch_dtype = None
88
+ for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"):
89
+ if torch_dtype is None:
90
+ torch_dtype = value.dtype
91
+ if "self_attn.o_proj" in key:
92
+ llama2_state_dict[key] = value
93
+ bias_key = key.replace(".weight", ".bias")
94
+ if bias_key not in llama2_state_dict:
95
+ llama2_state_dict[bias_key] = torch.zeros_like(value[:, 0]).squeeze()
96
+ else:
97
+ llama2_state_dict[key] = value
98
+
99
+ weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
100
+ shards, index = shard_checkpoint(
101
+ llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name
102
+ )
103
+
104
+ for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
105
+ if save_safetensors:
106
+ save_file(
107
+ shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"}
108
+ )
109
+ else:
110
+ torch.save(shard, os.path.join(output_dir, shard_file))
111
+
112
+ if index is None:
113
+ print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
114
+ else:
115
+ index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
116
+ with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
117
+ json.dump(index, f, indent=2, sort_keys=True)
118
+ print(f"Model weights saved in {output_dir}")
119
+
120
+ return str(torch_dtype).replace("torch.", "")
121
+
122
+
123
+ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
124
+ with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
125
+ qwen_config_dict: Dict[str, Any] = json.load(f)
126
+
127
+ llama2_config_dict: Dict[str, Any] = OrderedDict()
128
+ llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
129
+ llama2_config_dict["attention_bias"] = True
130
+ llama2_config_dict["attention_dropout"] = qwen_config_dict["attention_dropout"]
131
+ llama2_config_dict["hidden_act"] = "silu"
132
+ llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"]
133
+ llama2_config_dict["initializer_range"] = qwen_config_dict["initializer_range"]
134
+ llama2_config_dict["intermediate_size"] = qwen_config_dict["intermediate_size"]
135
+ llama2_config_dict["max_position_embeddings"] = 32767 # Qwen2-72B-Instruct
136
+ llama2_config_dict["max_window_layers"] = qwen_config_dict["max_window_layers"]
137
+ llama2_config_dict["model_type"] = "llama"
138
+ llama2_config_dict["num_attention_heads"] = qwen_config_dict["num_attention_heads"]
139
+ llama2_config_dict["num_hidden_layers"] = qwen_config_dict["num_hidden_layers"]
140
+ llama2_config_dict["num_key_value_heads"] = qwen_config_dict["num_key_value_heads"]
141
+ llama2_config_dict["pretraining_tp"] = 1
142
+ llama2_config_dict["rms_norm_eps"] = qwen_config_dict["rms_norm_eps"]
143
+ llama2_config_dict["rope_theta"] = qwen_config_dict["rope_theta"]
144
+ llama2_config_dict["rope_scaling"] = None
145
+ llama2_config_dict["sliding_window"] = qwen_config_dict["sliding_window"]
146
+ llama2_config_dict["tie_word_embeddings"] = qwen_config_dict["tie_word_embeddings"]
147
+ llama2_config_dict["torch_dtype"] = torch_dtype
148
+ llama2_config_dict["transformers_version"] = "4.37.0"
149
+ llama2_config_dict["use_cache"] = True
150
+ llama2_config_dict["use_sliding_window"] = qwen_config_dict["use_sliding_window"]
151
+ llama2_config_dict["vocab_size"] = qwen_config_dict["vocab_size"]
152
+
153
+ with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
154
+ json.dump(llama2_config_dict, f, indent=2)
155
+ print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
156
+
157
+
158
+ def llamafy_qwen_v2(
159
+ input_dir: str,
160
+ output_dir: str,
161
+ shard_size: Optional[str] = "4GB",
162
+ save_safetensors: Optional[bool] = False,
163
+ continue_conversion: Optional[bool] = False,
164
+ ):
165
+ if not continue_conversion:
166
+ try:
167
+ os.makedirs(output_dir, exist_ok=False)
168
+ except Exception as e:
169
+ raise ValueError(
170
+ "Output dir already exists. Use --continue_conversion to resume."
171
+ ) from e
172
+ else:
173
+ os.makedirs(output_dir, exist_ok=True)
174
+
175
+ torch_dtype = save_weight(
176
+ input_dir, output_dir, shard_size, save_safetensors, continue_conversion
177
+ )
178
+ save_config(input_dir, output_dir, torch_dtype)
179
+
180
+
181
+ if __name__ == "__main__":
182
+ fire.Fire(llamafy_qwen_v2)