kevin36524 commited on
Commit
c5afa42
·
verified ·
1 Parent(s): ae5bac8

Upload export_llama1.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. export_llama1.py +273 -0
export_llama1.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import warnings
4
+ from typing import List, Optional, Tuple
5
+
6
+ import coremltools as ct
7
+ import numpy as np
8
+ import torch
9
+ from transformers.cache_utils import Cache
10
+ from transformers.models.llama.modeling_llama import (
11
+ LLAMA_ATTENTION_CLASSES,
12
+ LlamaAttention,
13
+ LlamaConfig,
14
+ LlamaForCausalLM,
15
+ apply_rotary_pos_emb,
16
+ repeat_kv,
17
+ )
18
+
19
+ warnings.filterwarnings("ignore")
20
+ logging.getLogger("coremltools").setLevel(logging.ERROR)
21
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
+
23
+ # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3
24
+ MODEL_ID: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"
25
+ METADATA_TOKENIZER: str = "co.huggingface.exporters.name"
26
+
27
+ class SliceUpdateKeyValueCache(Cache):
28
+ def __init__(
29
+ self,
30
+ shape: Tuple[int, ...],
31
+ device="cpu",
32
+ dtype=torch.float32,
33
+ ) -> None:
34
+ """KV cache of shape (#layers, batch_size, #kv_heads, context_size, head_dim)."""
35
+ super().__init__()
36
+ self.past_seen_tokens: int = 0
37
+ self.k_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)
38
+ self.v_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)
39
+ self.max_length: int = shape[3] # context_size dimension
40
+
41
+ def update(
42
+ self,
43
+ k_state: torch.Tensor,
44
+ v_state: torch.Tensor,
45
+ layer_idx: int,
46
+ slice_indices: torch.LongTensor,
47
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
48
+ """
49
+ Update key/value cache tensors for slice [slice_indices[0], slice_indices[1]).
50
+ Return slice of key/value cache tensors from [0, slice_indices[1]).
51
+ """
52
+ if len(slice_indices) != 2:
53
+ raise ValueError(f"Expect tuple of integers [start, end), got {slice_indices=}.")
54
+ begin, end = slice_indices
55
+ self.k_cache[layer_idx, :, : k_state.shape[1], begin:end, :] = k_state
56
+ self.v_cache[layer_idx, :, : v_state.shape[1], begin:end, :] = v_state
57
+ k_cache: torch.Tensor = self.k_cache[layer_idx, :, :, :end, :]
58
+ v_cache: torch.Tensor = self.v_cache[layer_idx, :, :, :end, :]
59
+ return k_cache, v_cache
60
+
61
+ def get_seq_length(self, _: int | None = 0) -> int:
62
+ """Get the sequence length of the cache."""
63
+ return self.past_seen_tokens
64
+
65
+ def get_max_length(self) -> Optional[int]:
66
+ """Returns the maximum sequence length of the cached states, if there is any."""
67
+ return None
68
+
69
+
70
+ class SliceUpdateLlamaAttention(LlamaAttention):
71
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
72
+ super().__init__(config=config, layer_idx=layer_idx)
73
+
74
+ @torch.no_grad()
75
+ def forward(
76
+ self,
77
+ hidden_states: torch.Tensor,
78
+ attention_mask: torch.Tensor,
79
+ position_ids: Optional[torch.LongTensor] = None,
80
+ past_key_value: Optional[Cache] = None,
81
+ **kwargs,
82
+ ) -> Tuple[torch.Tensor | None, ...]:
83
+ bsz, q_len, _ = hidden_states.size()
84
+
85
+ query_states = self.q_proj(hidden_states)
86
+ key_states = self.k_proj(hidden_states)
87
+ value_states = self.v_proj(hidden_states)
88
+
89
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
90
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(
91
+ 1, 2
92
+ )
93
+ value_states = value_states.view(
94
+ bsz, q_len, self.num_key_value_heads, self.head_dim
95
+ ).transpose(1, 2)
96
+
97
+ cos, sin = self.rotary_emb(value_states, position_ids)
98
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
99
+
100
+
101
+ # Slice update key/value cache
102
+ end_step = attention_mask.shape[-1]
103
+ key_states, value_states = past_key_value.update(
104
+ key_states,
105
+ value_states,
106
+ self.layer_idx,
107
+ slice_indices=(end_step - q_len, end_step),
108
+ )
109
+
110
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
111
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
112
+
113
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
114
+ query_states,
115
+ key_states,
116
+ value_states,
117
+ attn_mask=attention_mask,
118
+ )
119
+
120
+ attn_output = attn_output.transpose(1, 2).contiguous()
121
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
122
+ attn_output = self.o_proj(attn_output)
123
+
124
+ return attn_output, None, None
125
+
126
+
127
+ class StatefulLlamaForCausalLM(torch.nn.Module):
128
+ def __init__(self, model_path: str, max_context_size: int = 2048, batch_size: int = 1) -> None:
129
+ super().__init__()
130
+
131
+ # Custom attention implementation for stateful slice update key/value cache, override
132
+ # "sdpa" to compliance with transformers.modeling_utils._autoset_attn_implementation
133
+ LLAMA_ATTENTION_CLASSES["sdpa"] = SliceUpdateLlamaAttention
134
+ self.model = LlamaForCausalLM.from_pretrained(model_path)
135
+
136
+ # Register KV cache buffers to be recognized as Core ML states
137
+ config: LlamaConfig = self.model.config
138
+ self.kv_cache_shape: Tuple[int, ...] = (
139
+ config.num_hidden_layers,
140
+ batch_size,
141
+ config.num_key_value_heads,
142
+ max_context_size,
143
+ config.hidden_size // config.num_attention_heads,
144
+ )
145
+ self.kv_cache = SliceUpdateKeyValueCache(shape=self.kv_cache_shape)
146
+ self.register_buffer("keyCache", self.kv_cache.k_cache)
147
+ self.register_buffer("valueCache", self.kv_cache.v_cache)
148
+
149
+ @torch.no_grad()
150
+ def forward(
151
+ self,
152
+ input_ids: torch.LongTensor,
153
+ causal_mask: torch.Tensor,
154
+ ) -> torch.Tensor:
155
+ # Compute past seen tokens used for updating key/value cache slices
156
+ self.kv_cache.past_seen_tokens = causal_mask.shape[-1] - input_ids.shape[-1]
157
+ return self.model(
158
+ input_ids,
159
+ attention_mask=causal_mask,
160
+ past_key_values=self.kv_cache,
161
+ use_cache=True,
162
+ ).logits
163
+
164
+ def generate() -> None:
165
+ # Construct model from transformers and trace to TorchScript
166
+ max_context_size: int = 2048
167
+ torch_model = StatefulLlamaForCausalLM(MODEL_ID, max_context_size=max_context_size)
168
+ torch_model.eval()
169
+
170
+ # Decode output tokens using the tokenizer
171
+ from transformers import AutoTokenizer
172
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
173
+ #initial_prompt = "Write a christmas Carol"
174
+ initial_prompt = "Write a poem on Apple "
175
+
176
+ input_ids = tokenizer(initial_prompt, return_tensors='pt').input_ids
177
+ causal_mask: torch.Tensor = torch.ones((1, 1, 1, input_ids.shape[-1] + 1), dtype=torch.float32)
178
+
179
+ # Set the output length
180
+ output_length = 20
181
+
182
+ is_first_run = True
183
+
184
+ # Initialize the output tensor
185
+ output_tokens = input_ids
186
+
187
+ # Loop until the desired output length is reached
188
+ while output_tokens.shape[-1] < output_length + input_ids.shape[-1]:
189
+ # Compute the past seen tokens used for updating key/value cache slices
190
+ #torch_model.kv_cache.past_seen_tokens = causal_mask.shape[-1] - output_tokens.shape[-1]
191
+
192
+ # Get the model output
193
+ model_inp = output_tokens[:, -1:]
194
+ if is_first_run:
195
+ model_inp = input_ids
196
+ is_first_run = False
197
+ #print(f"KEVINDEBUG model_inp: {model_inp} causal_mask: {causal_mask}")
198
+ output = torch_model(model_inp, causal_mask) # Start with a sub-squence that long so need multiple previous when size only very lwo larger later same past arg a so try keeping right padded!
199
+
200
+ # Get the most likely token IDs
201
+ output_ids = torch.argmax(output, dim=-1)
202
+
203
+ # Append the generated token IDs to the output tensor
204
+ output_tokens = torch.cat((output_tokens, output_ids[:, -1, None]), dim=-1)
205
+ #print(f"KEVINDEBUG output_tokens: {output_tokens}")
206
+
207
+ # Update the causal mask
208
+ causal_mask = torch.ones((1, 1, 1, output_tokens.shape[-1] + 1), dtype=torch.float32)
209
+
210
+ decoded_output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
211
+ print(f"input : {tokenizer.decode(input_ids[0])} output: {decoded_output}")
212
+
213
+ def export() -> None:
214
+ # Construct model from transformers and trace to TorchScript
215
+ max_context_size: int = 2048
216
+ torch_model = StatefulLlamaForCausalLM(MODEL_ID, max_context_size=max_context_size)
217
+ torch_model.eval()
218
+ input_ids: torch.Tensor = torch.tensor([[19161, 253, 8216, 335, 10910, 216]], dtype=torch.int32)
219
+ #input_ids: torch.Tensor = torch.tensor([[ 11 ]], dtype=torch.int32)
220
+ causal_mask: torch.Tensor = torch.ones((1, 1, 1, 7), dtype=torch.float32)
221
+ traced_model = torch.jit.trace(torch_model, [input_ids, causal_mask])
222
+
223
+ # Convert traced TorchScript to Core ML format
224
+ query_length = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1)
225
+ end_step_dim = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1)
226
+ inputs: List[ct.TensorType] = [
227
+ ct.TensorType(shape=(1, query_length), dtype=np.int32, name="inputIds"),
228
+ ct.TensorType(
229
+ shape=(1, 1, query_length, end_step_dim),
230
+ dtype=np.float16,
231
+ name="causalMask",
232
+ ),
233
+ ]
234
+ outputs: List[ct.TensorType] = [ct.TensorType(dtype=np.float16, name="logits")]
235
+ states: List[ct.StateType] = [
236
+ ct.StateType(
237
+ wrapped_type=ct.TensorType(shape=torch_model.kv_cache_shape, dtype=np.float16),
238
+ name="keyCache",
239
+ ),
240
+ ct.StateType(
241
+ wrapped_type=ct.TensorType(shape=torch_model.kv_cache_shape, dtype=np.float16),
242
+ name="valueCache",
243
+ ),
244
+ ]
245
+
246
+ # Convert model with FP16 precision
247
+ mlmodel_fp16: ct.MLModel = ct.convert(
248
+ traced_model,
249
+ inputs=inputs,
250
+ outputs=outputs,
251
+ states=states,
252
+ minimum_deployment_target=ct.target.iOS18,
253
+ skip_model_load=True,
254
+ )
255
+ #mlmodel_fp16._spec.description.metadata.userDefined.update({METADATA_TOKENIZER: MODEL_ID})
256
+ #mlmodel_fp16.save("Stateful_Llama_3_1_8B_InstructFP16.mlpackage")
257
+
258
+ # Block-wise quantize model weights to int4
259
+ op_config = ct.optimize.coreml.OpLinearQuantizerConfig(
260
+ mode="linear_symmetric",
261
+ dtype="int4",
262
+ granularity="per_block",
263
+ block_size=32,
264
+ )
265
+ config = ct.optimize.coreml.OptimizationConfig(global_config=op_config)
266
+ mlmodel_int4 = ct.optimize.coreml.linear_quantize_weights(mlmodel_fp16, config=config)
267
+ mlmodel_int4._spec.description.metadata.userDefined.update({METADATA_TOKENIZER: MODEL_ID})
268
+ mlmodel_int4.save("Stateful_Llama_3_1_8B_InstructInt4.mlpackage")
269
+
270
+
271
+ if __name__ == "__main__":
272
+ export()
273
+