Upload export_llama1.py with huggingface_hub
Browse files- export_llama1.py +273 -0
export_llama1.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import warnings
|
4 |
+
from typing import List, Optional, Tuple
|
5 |
+
|
6 |
+
import coremltools as ct
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from transformers.cache_utils import Cache
|
10 |
+
from transformers.models.llama.modeling_llama import (
|
11 |
+
LLAMA_ATTENTION_CLASSES,
|
12 |
+
LlamaAttention,
|
13 |
+
LlamaConfig,
|
14 |
+
LlamaForCausalLM,
|
15 |
+
apply_rotary_pos_emb,
|
16 |
+
repeat_kv,
|
17 |
+
)
|
18 |
+
|
19 |
+
warnings.filterwarnings("ignore")
|
20 |
+
logging.getLogger("coremltools").setLevel(logging.ERROR)
|
21 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
22 |
+
|
23 |
+
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3
|
24 |
+
MODEL_ID: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
25 |
+
METADATA_TOKENIZER: str = "co.huggingface.exporters.name"
|
26 |
+
|
27 |
+
class SliceUpdateKeyValueCache(Cache):
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
shape: Tuple[int, ...],
|
31 |
+
device="cpu",
|
32 |
+
dtype=torch.float32,
|
33 |
+
) -> None:
|
34 |
+
"""KV cache of shape (#layers, batch_size, #kv_heads, context_size, head_dim)."""
|
35 |
+
super().__init__()
|
36 |
+
self.past_seen_tokens: int = 0
|
37 |
+
self.k_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)
|
38 |
+
self.v_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)
|
39 |
+
self.max_length: int = shape[3] # context_size dimension
|
40 |
+
|
41 |
+
def update(
|
42 |
+
self,
|
43 |
+
k_state: torch.Tensor,
|
44 |
+
v_state: torch.Tensor,
|
45 |
+
layer_idx: int,
|
46 |
+
slice_indices: torch.LongTensor,
|
47 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
48 |
+
"""
|
49 |
+
Update key/value cache tensors for slice [slice_indices[0], slice_indices[1]).
|
50 |
+
Return slice of key/value cache tensors from [0, slice_indices[1]).
|
51 |
+
"""
|
52 |
+
if len(slice_indices) != 2:
|
53 |
+
raise ValueError(f"Expect tuple of integers [start, end), got {slice_indices=}.")
|
54 |
+
begin, end = slice_indices
|
55 |
+
self.k_cache[layer_idx, :, : k_state.shape[1], begin:end, :] = k_state
|
56 |
+
self.v_cache[layer_idx, :, : v_state.shape[1], begin:end, :] = v_state
|
57 |
+
k_cache: torch.Tensor = self.k_cache[layer_idx, :, :, :end, :]
|
58 |
+
v_cache: torch.Tensor = self.v_cache[layer_idx, :, :, :end, :]
|
59 |
+
return k_cache, v_cache
|
60 |
+
|
61 |
+
def get_seq_length(self, _: int | None = 0) -> int:
|
62 |
+
"""Get the sequence length of the cache."""
|
63 |
+
return self.past_seen_tokens
|
64 |
+
|
65 |
+
def get_max_length(self) -> Optional[int]:
|
66 |
+
"""Returns the maximum sequence length of the cached states, if there is any."""
|
67 |
+
return None
|
68 |
+
|
69 |
+
|
70 |
+
class SliceUpdateLlamaAttention(LlamaAttention):
|
71 |
+
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
|
72 |
+
super().__init__(config=config, layer_idx=layer_idx)
|
73 |
+
|
74 |
+
@torch.no_grad()
|
75 |
+
def forward(
|
76 |
+
self,
|
77 |
+
hidden_states: torch.Tensor,
|
78 |
+
attention_mask: torch.Tensor,
|
79 |
+
position_ids: Optional[torch.LongTensor] = None,
|
80 |
+
past_key_value: Optional[Cache] = None,
|
81 |
+
**kwargs,
|
82 |
+
) -> Tuple[torch.Tensor | None, ...]:
|
83 |
+
bsz, q_len, _ = hidden_states.size()
|
84 |
+
|
85 |
+
query_states = self.q_proj(hidden_states)
|
86 |
+
key_states = self.k_proj(hidden_states)
|
87 |
+
value_states = self.v_proj(hidden_states)
|
88 |
+
|
89 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
90 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(
|
91 |
+
1, 2
|
92 |
+
)
|
93 |
+
value_states = value_states.view(
|
94 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
95 |
+
).transpose(1, 2)
|
96 |
+
|
97 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
98 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
99 |
+
|
100 |
+
|
101 |
+
# Slice update key/value cache
|
102 |
+
end_step = attention_mask.shape[-1]
|
103 |
+
key_states, value_states = past_key_value.update(
|
104 |
+
key_states,
|
105 |
+
value_states,
|
106 |
+
self.layer_idx,
|
107 |
+
slice_indices=(end_step - q_len, end_step),
|
108 |
+
)
|
109 |
+
|
110 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
111 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
112 |
+
|
113 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
114 |
+
query_states,
|
115 |
+
key_states,
|
116 |
+
value_states,
|
117 |
+
attn_mask=attention_mask,
|
118 |
+
)
|
119 |
+
|
120 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
121 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
122 |
+
attn_output = self.o_proj(attn_output)
|
123 |
+
|
124 |
+
return attn_output, None, None
|
125 |
+
|
126 |
+
|
127 |
+
class StatefulLlamaForCausalLM(torch.nn.Module):
|
128 |
+
def __init__(self, model_path: str, max_context_size: int = 2048, batch_size: int = 1) -> None:
|
129 |
+
super().__init__()
|
130 |
+
|
131 |
+
# Custom attention implementation for stateful slice update key/value cache, override
|
132 |
+
# "sdpa" to compliance with transformers.modeling_utils._autoset_attn_implementation
|
133 |
+
LLAMA_ATTENTION_CLASSES["sdpa"] = SliceUpdateLlamaAttention
|
134 |
+
self.model = LlamaForCausalLM.from_pretrained(model_path)
|
135 |
+
|
136 |
+
# Register KV cache buffers to be recognized as Core ML states
|
137 |
+
config: LlamaConfig = self.model.config
|
138 |
+
self.kv_cache_shape: Tuple[int, ...] = (
|
139 |
+
config.num_hidden_layers,
|
140 |
+
batch_size,
|
141 |
+
config.num_key_value_heads,
|
142 |
+
max_context_size,
|
143 |
+
config.hidden_size // config.num_attention_heads,
|
144 |
+
)
|
145 |
+
self.kv_cache = SliceUpdateKeyValueCache(shape=self.kv_cache_shape)
|
146 |
+
self.register_buffer("keyCache", self.kv_cache.k_cache)
|
147 |
+
self.register_buffer("valueCache", self.kv_cache.v_cache)
|
148 |
+
|
149 |
+
@torch.no_grad()
|
150 |
+
def forward(
|
151 |
+
self,
|
152 |
+
input_ids: torch.LongTensor,
|
153 |
+
causal_mask: torch.Tensor,
|
154 |
+
) -> torch.Tensor:
|
155 |
+
# Compute past seen tokens used for updating key/value cache slices
|
156 |
+
self.kv_cache.past_seen_tokens = causal_mask.shape[-1] - input_ids.shape[-1]
|
157 |
+
return self.model(
|
158 |
+
input_ids,
|
159 |
+
attention_mask=causal_mask,
|
160 |
+
past_key_values=self.kv_cache,
|
161 |
+
use_cache=True,
|
162 |
+
).logits
|
163 |
+
|
164 |
+
def generate() -> None:
|
165 |
+
# Construct model from transformers and trace to TorchScript
|
166 |
+
max_context_size: int = 2048
|
167 |
+
torch_model = StatefulLlamaForCausalLM(MODEL_ID, max_context_size=max_context_size)
|
168 |
+
torch_model.eval()
|
169 |
+
|
170 |
+
# Decode output tokens using the tokenizer
|
171 |
+
from transformers import AutoTokenizer
|
172 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
173 |
+
#initial_prompt = "Write a christmas Carol"
|
174 |
+
initial_prompt = "Write a poem on Apple "
|
175 |
+
|
176 |
+
input_ids = tokenizer(initial_prompt, return_tensors='pt').input_ids
|
177 |
+
causal_mask: torch.Tensor = torch.ones((1, 1, 1, input_ids.shape[-1] + 1), dtype=torch.float32)
|
178 |
+
|
179 |
+
# Set the output length
|
180 |
+
output_length = 20
|
181 |
+
|
182 |
+
is_first_run = True
|
183 |
+
|
184 |
+
# Initialize the output tensor
|
185 |
+
output_tokens = input_ids
|
186 |
+
|
187 |
+
# Loop until the desired output length is reached
|
188 |
+
while output_tokens.shape[-1] < output_length + input_ids.shape[-1]:
|
189 |
+
# Compute the past seen tokens used for updating key/value cache slices
|
190 |
+
#torch_model.kv_cache.past_seen_tokens = causal_mask.shape[-1] - output_tokens.shape[-1]
|
191 |
+
|
192 |
+
# Get the model output
|
193 |
+
model_inp = output_tokens[:, -1:]
|
194 |
+
if is_first_run:
|
195 |
+
model_inp = input_ids
|
196 |
+
is_first_run = False
|
197 |
+
#print(f"KEVINDEBUG model_inp: {model_inp} causal_mask: {causal_mask}")
|
198 |
+
output = torch_model(model_inp, causal_mask) # Start with a sub-squence that long so need multiple previous when size only very lwo larger later same past arg a so try keeping right padded!
|
199 |
+
|
200 |
+
# Get the most likely token IDs
|
201 |
+
output_ids = torch.argmax(output, dim=-1)
|
202 |
+
|
203 |
+
# Append the generated token IDs to the output tensor
|
204 |
+
output_tokens = torch.cat((output_tokens, output_ids[:, -1, None]), dim=-1)
|
205 |
+
#print(f"KEVINDEBUG output_tokens: {output_tokens}")
|
206 |
+
|
207 |
+
# Update the causal mask
|
208 |
+
causal_mask = torch.ones((1, 1, 1, output_tokens.shape[-1] + 1), dtype=torch.float32)
|
209 |
+
|
210 |
+
decoded_output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
|
211 |
+
print(f"input : {tokenizer.decode(input_ids[0])} output: {decoded_output}")
|
212 |
+
|
213 |
+
def export() -> None:
|
214 |
+
# Construct model from transformers and trace to TorchScript
|
215 |
+
max_context_size: int = 2048
|
216 |
+
torch_model = StatefulLlamaForCausalLM(MODEL_ID, max_context_size=max_context_size)
|
217 |
+
torch_model.eval()
|
218 |
+
input_ids: torch.Tensor = torch.tensor([[19161, 253, 8216, 335, 10910, 216]], dtype=torch.int32)
|
219 |
+
#input_ids: torch.Tensor = torch.tensor([[ 11 ]], dtype=torch.int32)
|
220 |
+
causal_mask: torch.Tensor = torch.ones((1, 1, 1, 7), dtype=torch.float32)
|
221 |
+
traced_model = torch.jit.trace(torch_model, [input_ids, causal_mask])
|
222 |
+
|
223 |
+
# Convert traced TorchScript to Core ML format
|
224 |
+
query_length = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1)
|
225 |
+
end_step_dim = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1)
|
226 |
+
inputs: List[ct.TensorType] = [
|
227 |
+
ct.TensorType(shape=(1, query_length), dtype=np.int32, name="inputIds"),
|
228 |
+
ct.TensorType(
|
229 |
+
shape=(1, 1, query_length, end_step_dim),
|
230 |
+
dtype=np.float16,
|
231 |
+
name="causalMask",
|
232 |
+
),
|
233 |
+
]
|
234 |
+
outputs: List[ct.TensorType] = [ct.TensorType(dtype=np.float16, name="logits")]
|
235 |
+
states: List[ct.StateType] = [
|
236 |
+
ct.StateType(
|
237 |
+
wrapped_type=ct.TensorType(shape=torch_model.kv_cache_shape, dtype=np.float16),
|
238 |
+
name="keyCache",
|
239 |
+
),
|
240 |
+
ct.StateType(
|
241 |
+
wrapped_type=ct.TensorType(shape=torch_model.kv_cache_shape, dtype=np.float16),
|
242 |
+
name="valueCache",
|
243 |
+
),
|
244 |
+
]
|
245 |
+
|
246 |
+
# Convert model with FP16 precision
|
247 |
+
mlmodel_fp16: ct.MLModel = ct.convert(
|
248 |
+
traced_model,
|
249 |
+
inputs=inputs,
|
250 |
+
outputs=outputs,
|
251 |
+
states=states,
|
252 |
+
minimum_deployment_target=ct.target.iOS18,
|
253 |
+
skip_model_load=True,
|
254 |
+
)
|
255 |
+
#mlmodel_fp16._spec.description.metadata.userDefined.update({METADATA_TOKENIZER: MODEL_ID})
|
256 |
+
#mlmodel_fp16.save("Stateful_Llama_3_1_8B_InstructFP16.mlpackage")
|
257 |
+
|
258 |
+
# Block-wise quantize model weights to int4
|
259 |
+
op_config = ct.optimize.coreml.OpLinearQuantizerConfig(
|
260 |
+
mode="linear_symmetric",
|
261 |
+
dtype="int4",
|
262 |
+
granularity="per_block",
|
263 |
+
block_size=32,
|
264 |
+
)
|
265 |
+
config = ct.optimize.coreml.OptimizationConfig(global_config=op_config)
|
266 |
+
mlmodel_int4 = ct.optimize.coreml.linear_quantize_weights(mlmodel_fp16, config=config)
|
267 |
+
mlmodel_int4._spec.description.metadata.userDefined.update({METADATA_TOKENIZER: MODEL_ID})
|
268 |
+
mlmodel_int4.save("Stateful_Llama_3_1_8B_InstructInt4.mlpackage")
|
269 |
+
|
270 |
+
|
271 |
+
if __name__ == "__main__":
|
272 |
+
export()
|
273 |
+
|