Commit
·
10aca20
1
Parent(s):
78f6f3b
Fixed weight loading from original Phi2 model
Browse files- config.json +1 -1
- phi2_model.py +4 -7
- streaming_inference.py +14 -13
config.json
CHANGED
@@ -13,7 +13,7 @@
|
|
13 |
"torch_dtype": "float16",
|
14 |
"transformers_version": "4.29.0",
|
15 |
|
16 |
-
"vocab_size":
|
17 |
"vocab_chunk_for_gpu_efficiency": 64,
|
18 |
"initial_cos_sin_cache_len": 2048,
|
19 |
"d_embedding": 2560,
|
|
|
13 |
"torch_dtype": "float16",
|
14 |
"transformers_version": "4.29.0",
|
15 |
|
16 |
+
"vocab_size": 51200,
|
17 |
"vocab_chunk_for_gpu_efficiency": 64,
|
18 |
"initial_cos_sin_cache_len": 2048,
|
19 |
"d_embedding": 2560,
|
phi2_model.py
CHANGED
@@ -13,11 +13,6 @@ class Phi2PreTrainedModel(PreTrainedModel):
|
|
13 |
supports_gradient_checkpointing = False
|
14 |
# _no_split_modules = ["ParallelAttentionBlock"]
|
15 |
|
16 |
-
# weight loading
|
17 |
-
# base_model_prefix = "transformer"
|
18 |
-
# _keys_to_ignore_on_load_missing = [""]
|
19 |
-
# _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
|
20 |
-
|
21 |
def __init__(self, config: Phi2Config):
|
22 |
super().__init__(config)
|
23 |
self.config = config
|
@@ -42,6 +37,7 @@ class Phi2PreTrainedModel(PreTrainedModel):
|
|
42 |
input_ids: torch.LongTensor, # dim: (batch_size, seq_len)
|
43 |
kv_cache: KVCache | None = None,
|
44 |
key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None,
|
|
|
45 |
) -> dict[str, Any]:
|
46 |
if not kv_cache:
|
47 |
kv_cache = KVCache(
|
@@ -142,7 +138,7 @@ class Phi2Model(Phi2PreTrainedModel):
|
|
142 |
class Phi2ModelForCausalLM(Phi2PreTrainedModel):
|
143 |
def __init__(self, config: Phi2Config) -> None:
|
144 |
super().__init__(config)
|
145 |
-
self.
|
146 |
self.lm_head_layer_norm = nn.LayerNorm(config.d_embedding, eps=config.layer_norm_epsilon)
|
147 |
self.lm_head_linear = nn.Linear(config.d_embedding, config.vocab_size)
|
148 |
self.loss_fn = nn.CrossEntropyLoss()
|
@@ -154,8 +150,9 @@ class Phi2ModelForCausalLM(Phi2PreTrainedModel):
|
|
154 |
kv_cache: KVCache | None = None,
|
155 |
key_padding_mask: torch.BoolTensor | None = None,
|
156 |
labels: torch.LongTensor | None = None,
|
|
|
157 |
) -> CausalLMOutputWithPast:
|
158 |
-
x = self.
|
159 |
x = self.lm_head_layer_norm(x)
|
160 |
logits = self.lm_head_linear(x).to(torch.float32)
|
161 |
loss = (
|
|
|
13 |
supports_gradient_checkpointing = False
|
14 |
# _no_split_modules = ["ParallelAttentionBlock"]
|
15 |
|
|
|
|
|
|
|
|
|
|
|
16 |
def __init__(self, config: Phi2Config):
|
17 |
super().__init__(config)
|
18 |
self.config = config
|
|
|
37 |
input_ids: torch.LongTensor, # dim: (batch_size, seq_len)
|
38 |
kv_cache: KVCache | None = None,
|
39 |
key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None,
|
40 |
+
**kwargs,
|
41 |
) -> dict[str, Any]:
|
42 |
if not kv_cache:
|
43 |
kv_cache = KVCache(
|
|
|
138 |
class Phi2ModelForCausalLM(Phi2PreTrainedModel):
|
139 |
def __init__(self, config: Phi2Config) -> None:
|
140 |
super().__init__(config)
|
141 |
+
self.model = Phi2Model(config)
|
142 |
self.lm_head_layer_norm = nn.LayerNorm(config.d_embedding, eps=config.layer_norm_epsilon)
|
143 |
self.lm_head_linear = nn.Linear(config.d_embedding, config.vocab_size)
|
144 |
self.loss_fn = nn.CrossEntropyLoss()
|
|
|
150 |
kv_cache: KVCache | None = None,
|
151 |
key_padding_mask: torch.BoolTensor | None = None,
|
152 |
labels: torch.LongTensor | None = None,
|
153 |
+
**kwargs,
|
154 |
) -> CausalLMOutputWithPast:
|
155 |
+
x = self.model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask)
|
156 |
x = self.lm_head_layer_norm(x)
|
157 |
logits = self.lm_head_linear(x).to(torch.float32)
|
158 |
loss = (
|
streaming_inference.py
CHANGED
@@ -20,22 +20,23 @@ if __name__ == "__main__":
|
|
20 |
phi_model_state_dict = phi_model.state_dict()
|
21 |
model_state_dict = {}
|
22 |
for key, value in phi_model_state_dict.items():
|
23 |
-
# transformer.embd.wte.weight -> model.rotary_embedding.embeddings.weight
|
24 |
-
# transformer.h.0.mlp.fc1.weight -> pretrained_model.parallel_blocks.0.mlp.fc1.weight
|
25 |
-
# transformer.h.0.ln.weight -> pretrained_model.parallel_blocks.0.layer_norm.weight
|
26 |
-
# transformer.h.0.mixer.Wqkv.weight -> pretrained_model.parallel_blocks.0.multi_head_attention.Wqkv.weight
|
27 |
-
# transformer.h.0.mixer.out_proj.weight -> pretrained_model.parallel_blocks.0.multi_head_attention.fc_out.weight
|
28 |
# lm_head.ln.weight -> lm_head_layer_norm.weight
|
29 |
# lm_head.linear.weight -> lm_head_linear.weight
|
|
|
|
|
|
|
|
|
|
|
30 |
if key.startswith("transformer"):
|
31 |
-
key.replace("transformer.", "model.")
|
32 |
-
key.replace(".embd.wte.", ".rotary_embedding.embeddings.")
|
33 |
-
key.replace(".h.", ".parallel_blocks")
|
34 |
-
key.replace(".ln.", ".layer_norm.")
|
35 |
-
key.replace(".mixer.Wqkv.", ".multi_head_attention.Wqkv.")
|
36 |
-
key.replace(".mixer.out_proj.", ".multi_head_attention.fc_out.")
|
37 |
-
|
38 |
-
key.replace("
|
|
|
39 |
model_state_dict[key] = value
|
40 |
model.load_state_dict(model_state_dict)
|
41 |
|
|
|
20 |
phi_model_state_dict = phi_model.state_dict()
|
21 |
model_state_dict = {}
|
22 |
for key, value in phi_model_state_dict.items():
|
|
|
|
|
|
|
|
|
|
|
23 |
# lm_head.ln.weight -> lm_head_layer_norm.weight
|
24 |
# lm_head.linear.weight -> lm_head_linear.weight
|
25 |
+
# transformer.embd.wte.weight -> model.rotary_embedding.embeddings.weight
|
26 |
+
# transformer.h.0.mlp.fc1.weight -> model.parallel_blocks.0.mlp.fc1.weight
|
27 |
+
# transformer.h.0.ln.weight -> model.parallel_blocks.0.layer_norm.weight
|
28 |
+
# transformer.h.0.mixer.Wqkv.weight -> model.parallel_blocks.0.multi_head_attention.Wqkv.weight
|
29 |
+
# transformer.h.0.mixer.out_proj.weight -> model.parallel_blocks.0.multi_head_attention.fc_out.weight
|
30 |
if key.startswith("transformer"):
|
31 |
+
key = key.replace("transformer.", "model.")
|
32 |
+
key = key.replace(".embd.wte.", ".rotary_embedding.embeddings.")
|
33 |
+
key = key.replace(".h.", ".parallel_blocks.")
|
34 |
+
key = key.replace(".ln.", ".layer_norm.")
|
35 |
+
key = key.replace(".mixer.Wqkv.", ".multi_head_attention.Wqkv.")
|
36 |
+
key = key.replace(".mixer.out_proj.", ".multi_head_attention.fc_out.")
|
37 |
+
else:
|
38 |
+
key = key.replace("lm_head.ln.", "lm_head_layer_norm.")
|
39 |
+
key = key.replace("lm_head.linear.", "lm_head_linear.")
|
40 |
model_state_dict[key] = value
|
41 |
model.load_state_dict(model_state_dict)
|
42 |
|