Commit
·
78f6f3b
1
Parent(s):
16cc769
Renaming state dict keys from Phi2
Browse files- phi2_model.py +8 -8
- streaming_inference.py +23 -33
phi2_model.py
CHANGED
@@ -91,7 +91,7 @@ class Embedding(nn.Module):
|
|
91 |
class Phi2Model(Phi2PreTrainedModel):
|
92 |
def __init__(self, config: Phi2Config) -> None:
|
93 |
super().__init__(config)
|
94 |
-
self.
|
95 |
vocab_size=config.vocab_size,
|
96 |
d_embedding=config.d_embedding,
|
97 |
embd_pdrop=config.embd_pdrop,
|
@@ -117,10 +117,10 @@ class Phi2Model(Phi2PreTrainedModel):
|
|
117 |
|
118 |
"""
|
119 |
def get_input_embeddings(self) -> nn.Embedding:
|
120 |
-
return self.
|
121 |
|
122 |
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
123 |
-
self.
|
124 |
"""
|
125 |
|
126 |
def forward(
|
@@ -129,7 +129,7 @@ class Phi2Model(Phi2PreTrainedModel):
|
|
129 |
kv_cache: KVCache | None = None,
|
130 |
key_padding_mask: torch.BoolTensor | None = None,
|
131 |
) -> torch.FloatTensor:
|
132 |
-
x = self.
|
133 |
for block in self.parallel_blocks:
|
134 |
x = block(
|
135 |
x,
|
@@ -143,8 +143,8 @@ class Phi2ModelForCausalLM(Phi2PreTrainedModel):
|
|
143 |
def __init__(self, config: Phi2Config) -> None:
|
144 |
super().__init__(config)
|
145 |
self.pretrained_model = Phi2Model(config)
|
146 |
-
self.
|
147 |
-
self.
|
148 |
self.loss_fn = nn.CrossEntropyLoss()
|
149 |
self.post_init() # calls self._init_weights() for all modules
|
150 |
|
@@ -156,8 +156,8 @@ class Phi2ModelForCausalLM(Phi2PreTrainedModel):
|
|
156 |
labels: torch.LongTensor | None = None,
|
157 |
) -> CausalLMOutputWithPast:
|
158 |
x = self.pretrained_model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask)
|
159 |
-
x = self.
|
160 |
-
logits = self.
|
161 |
loss = (
|
162 |
self.loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
|
163 |
if labels is not None
|
|
|
91 |
class Phi2Model(Phi2PreTrainedModel):
|
92 |
def __init__(self, config: Phi2Config) -> None:
|
93 |
super().__init__(config)
|
94 |
+
self.rotary_embedding = Embedding(
|
95 |
vocab_size=config.vocab_size,
|
96 |
d_embedding=config.d_embedding,
|
97 |
embd_pdrop=config.embd_pdrop,
|
|
|
117 |
|
118 |
"""
|
119 |
def get_input_embeddings(self) -> nn.Embedding:
|
120 |
+
return self.rotary_embedding.embeddings
|
121 |
|
122 |
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
123 |
+
self.rotary_embedding.embeddings = new_embeddings
|
124 |
"""
|
125 |
|
126 |
def forward(
|
|
|
129 |
kv_cache: KVCache | None = None,
|
130 |
key_padding_mask: torch.BoolTensor | None = None,
|
131 |
) -> torch.FloatTensor:
|
132 |
+
x = self.rotary_embedding(input_ids)
|
133 |
for block in self.parallel_blocks:
|
134 |
x = block(
|
135 |
x,
|
|
|
143 |
def __init__(self, config: Phi2Config) -> None:
|
144 |
super().__init__(config)
|
145 |
self.pretrained_model = Phi2Model(config)
|
146 |
+
self.lm_head_layer_norm = nn.LayerNorm(config.d_embedding, eps=config.layer_norm_epsilon)
|
147 |
+
self.lm_head_linear = nn.Linear(config.d_embedding, config.vocab_size)
|
148 |
self.loss_fn = nn.CrossEntropyLoss()
|
149 |
self.post_init() # calls self._init_weights() for all modules
|
150 |
|
|
|
156 |
labels: torch.LongTensor | None = None,
|
157 |
) -> CausalLMOutputWithPast:
|
158 |
x = self.pretrained_model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask)
|
159 |
+
x = self.lm_head_layer_norm(x)
|
160 |
+
logits = self.lm_head_linear(x).to(torch.float32)
|
161 |
loss = (
|
162 |
self.loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
|
163 |
if labels is not None
|
streaming_inference.py
CHANGED
@@ -1,43 +1,11 @@
|
|
1 |
import json
|
2 |
from threading import Thread
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
4 |
-
import torch
|
5 |
|
6 |
from .phi2_configuration import Phi2Config
|
7 |
from .phi2_model import Phi2ModelForCausalLM
|
8 |
|
9 |
|
10 |
-
# This works, but is not streaming
|
11 |
-
"""
|
12 |
-
if __name__ == "__main__":
|
13 |
-
device = "cuda"
|
14 |
-
|
15 |
-
model_config = Phi2Config(**json.load(open("simplified_phi2/config.json")))
|
16 |
-
model = Phi2ModelForCausalLM(model_config).to(device)
|
17 |
-
phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
18 |
-
model.load_state_dict(phi_model.state_dict())
|
19 |
-
|
20 |
-
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
21 |
-
|
22 |
-
text = "Write an essay on sea monkeys: "
|
23 |
-
tokens = tokenizer(text, return_tensors="pt", return_attention_mask=False).to(device)
|
24 |
-
outputs = model.generate(**tokens, max_length=200)
|
25 |
-
text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
26 |
-
print(text)
|
27 |
-
"""
|
28 |
-
|
29 |
-
|
30 |
-
# This is streaming, but does not work because you can't set trust_remote_code=True
|
31 |
-
"""
|
32 |
-
if __name__ == "__main__":
|
33 |
-
client = InferenceClient(model="microsoft/phi-2")
|
34 |
-
text = "How do you make cheese?"
|
35 |
-
for token in client.text_generation(text, max_new_tokens=500, stream=True):
|
36 |
-
print(token, end="")
|
37 |
-
"""
|
38 |
-
|
39 |
-
|
40 |
-
# This is trying the TextIteratorStreamer class
|
41 |
if __name__ == "__main__":
|
42 |
# make and load tokenizer, use tokenizer to initialize token_streamer
|
43 |
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
@@ -48,7 +16,29 @@ if __name__ == "__main__":
|
|
48 |
model_config = Phi2Config(**json.load(open("simplified_phi2/config.json")))
|
49 |
model = Phi2ModelForCausalLM(model_config).to(device)
|
50 |
phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
thread = Thread(
|
53 |
target=model.generate,
|
54 |
kwargs=dict(
|
|
|
1 |
import json
|
2 |
from threading import Thread
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
|
|
4 |
|
5 |
from .phi2_configuration import Phi2Config
|
6 |
from .phi2_model import Phi2ModelForCausalLM
|
7 |
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
if __name__ == "__main__":
|
10 |
# make and load tokenizer, use tokenizer to initialize token_streamer
|
11 |
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
|
|
16 |
model_config = Phi2Config(**json.load(open("simplified_phi2/config.json")))
|
17 |
model = Phi2ModelForCausalLM(model_config).to(device)
|
18 |
phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
19 |
+
|
20 |
+
phi_model_state_dict = phi_model.state_dict()
|
21 |
+
model_state_dict = {}
|
22 |
+
for key, value in phi_model_state_dict.items():
|
23 |
+
# transformer.embd.wte.weight -> model.rotary_embedding.embeddings.weight
|
24 |
+
# transformer.h.0.mlp.fc1.weight -> pretrained_model.parallel_blocks.0.mlp.fc1.weight
|
25 |
+
# transformer.h.0.ln.weight -> pretrained_model.parallel_blocks.0.layer_norm.weight
|
26 |
+
# transformer.h.0.mixer.Wqkv.weight -> pretrained_model.parallel_blocks.0.multi_head_attention.Wqkv.weight
|
27 |
+
# transformer.h.0.mixer.out_proj.weight -> pretrained_model.parallel_blocks.0.multi_head_attention.fc_out.weight
|
28 |
+
# lm_head.ln.weight -> lm_head_layer_norm.weight
|
29 |
+
# lm_head.linear.weight -> lm_head_linear.weight
|
30 |
+
if key.startswith("transformer"):
|
31 |
+
key.replace("transformer.", "model.")
|
32 |
+
key.replace(".embd.wte.", ".rotary_embedding.embeddings.")
|
33 |
+
key.replace(".h.", ".parallel_blocks")
|
34 |
+
key.replace(".ln.", ".layer_norm.")
|
35 |
+
key.replace(".mixer.Wqkv.", ".multi_head_attention.Wqkv.")
|
36 |
+
key.replace(".mixer.out_proj.", ".multi_head_attention.fc_out.")
|
37 |
+
key.replace(".lm_head.ln.", ".lm_head_layer_norm.")
|
38 |
+
key.replace(".lm_head.linear.", ".lm_head_linear.")
|
39 |
+
model_state_dict[key] = value
|
40 |
+
model.load_state_dict(model_state_dict)
|
41 |
+
|
42 |
thread = Thread(
|
43 |
target=model.generate,
|
44 |
kwargs=dict(
|