leafspark commited on
Commit
56811f1
·
verified ·
1 Parent(s): 6be7fa2
Files changed (6) hide show
  1. README.md +3 -3
  2. SCRIPT_README.md +22 -0
  3. generate.py +177 -0
  4. modello_italia.py +403 -0
  5. requirements.txt +5 -0
  6. tokenizer.model +3 -0
README.md CHANGED
@@ -1,3 +1,3 @@
1
- ---
2
- license: mit
3
- ---
 
1
+ ### Instructions
2
+
3
+ To run the model `italia.bin` along with its tokenizer `tokenizer.model`, you'll need the inference script. Once you get it, you can either move these two files to the `inference_script` folder or specify the correct path within the script.
SCRIPT_README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```python
2
+ # Modello Italia inference script and model
3
+ # Copyright 2024 iGenius
4
+ #
5
+ # Licensed under the MIT License (see LICENSE-MIT).
6
+ # This code also contains code from the original project licensed under the Apache License 2.0 (see LICENSE-APACHE).
7
+ # This script contains modifications of the original code from Lightning AI.
8
+ ```
9
+
10
+ ### Instructions
11
+
12
+ 1. First, move the model and the tokenizer from `/modello_italia_9b` to the current directory, or ensure that the path is correctly specified.
13
+
14
+ 2. Install dependencies by running the following command in the terminal:
15
+ ```terminal
16
+ pip install -r requirements.txt
17
+ ```
18
+
19
+ 3. To run the generation, use the following command:
20
+ ```terminal
21
+ python generate.py --checkpoint_dir <model_path> --max_new_tokens 500 --temperature 0.2 --prompt "Ciao, chi sei?"
22
+ ```
generate.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+ # Derivated from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/generate/base.py
3
+
4
+ import os
5
+ import sys
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Any, Optional
9
+
10
+ import torch
11
+
12
+ # support running without installing as a package
13
+ wd = Path(__file__).parent.parent.resolve()
14
+ sys.path.append(str(wd))
15
+
16
+ from modello_italia import Italia, ItaliaConfig, Tokenizer
17
+
18
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+
20
+ MI_SYSTEM_PROMPT_SHORT = (
21
+ "Tu sei Modello Italia, un modello di linguaggio naturale addestrato da iGenius."
22
+ )
23
+
24
+
25
+ def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
26
+ if torch._dynamo.is_compiling():
27
+ # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
28
+ distribution = torch.empty_like(probs).exponential_(1)
29
+ return torch.argmax(probs / distribution, dim=-1, keepdim=True)
30
+ return torch.multinomial(probs, num_samples=1)
31
+
32
+
33
+ def sample(
34
+ logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None
35
+ ) -> torch.Tensor:
36
+ logits = logits[0, -1]
37
+ # optionally crop the logits to only the top k options
38
+ if top_k is not None:
39
+ v, i = torch.topk(logits, min(top_k, logits.size(-1)))
40
+ # do not use `torch.where` as in nanogpt because it will repeat top-k collisions
41
+ logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
42
+ # optionally scale the logits and sample from a probability distribution
43
+ if temperature > 0.0:
44
+ probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
45
+ return multinomial_num_samples_1(probs)
46
+ return torch.argmax(logits, dim=-1, keepdim=True)
47
+
48
+
49
+ def next_token(
50
+ model: Italia, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any
51
+ ) -> torch.Tensor:
52
+ logits = model(x, input_pos)
53
+ next = sample(logits, **kwargs)
54
+ return next.to(dtype=x.dtype)
55
+
56
+
57
+ @torch.inference_mode()
58
+ def generate(
59
+ model: Italia,
60
+ prompt: torch.Tensor,
61
+ tokenizer: Tokenizer,
62
+ max_returned_tokens: int,
63
+ *,
64
+ temperature: float = 1.0,
65
+ top_k: Optional[int] = None,
66
+ eos_id: Optional[int] = None,
67
+ ) -> torch.Tensor:
68
+ """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
69
+
70
+ The implementation of this function is modified from A. Karpathy's nanoGPT.
71
+
72
+ Args:
73
+ model: The model to use.
74
+ prompt: Tensor of shape (T) with indices of the prompt sequence.
75
+ max_returned_tokens: The maximum number of tokens to return (given plus generated).
76
+ tokenizer: Tokenizer instance to decode generated tokens
77
+ temperature: Scales the predicted logits by 1 / temperature.
78
+ top_k: If specified, only sample among the tokens with the k highest probabilities.
79
+ """
80
+ T = prompt.size(0)
81
+ assert max_returned_tokens > T
82
+
83
+ device = prompt.device
84
+ tokens = [prompt]
85
+ input_pos = torch.tensor([T], device=device)
86
+ token = next_token(
87
+ model,
88
+ torch.arange(0, T, device=device),
89
+ prompt.view(1, -1),
90
+ temperature=temperature,
91
+ top_k=top_k,
92
+ ).clone()
93
+ tokens.append(token)
94
+ for _ in range(2, max_returned_tokens - T + 1):
95
+ token = next_token(
96
+ model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k
97
+ ).clone()
98
+ tokens.append(token)
99
+
100
+ if token == tokenizer.eos_id:
101
+ break
102
+ os.system('cls' if os.name == 'nt' else 'clear')
103
+ print(tokenizer.decode(torch.cat(tokens)[T:]))
104
+ input_pos = input_pos.add_(1)
105
+ return torch.cat(tokens)
106
+
107
+
108
+ @torch.inference_mode()
109
+ def main(
110
+ prompt: str = "Ciao, chi sei?",
111
+ *,
112
+ num_samples: int = 1,
113
+ max_new_tokens: int = 200,
114
+ top_k: Optional[int] = 200,
115
+ temperature: float = 0.4,
116
+ checkpoint_dir: Path = Path("."),
117
+ ) -> None:
118
+ """Generates text samples based on a pre-trained model and tokenizer.
119
+
120
+ Args:
121
+ prompt: The prompt string to use for generating the samples.
122
+ num_samples: The number of text samples to generate.
123
+ max_new_tokens: The number of generation steps to take.
124
+ top_k: The number of top most probable tokens to consider in the sampling process.
125
+ temperature: A value controlling the randomness of the sampling process. Higher values result in more random
126
+ samples.
127
+ checkpoint_dir: The checkpoint directory to load.
128
+ """
129
+
130
+ config = ItaliaConfig()
131
+ checkpoint_path = checkpoint_dir / "italia.bin"
132
+ tokenizer = Tokenizer(checkpoint_dir)
133
+ prompt = f"<|system|>{MI_SYSTEM_PROMPT_SHORT}\n<|user|>{prompt}\n<|assistant|>"
134
+ encoded = tokenizer.encode(prompt, device=device)
135
+ prompt_length = encoded.size(0)
136
+ max_returned_tokens = prompt_length + max_new_tokens
137
+
138
+ print(f"Loading model {str(checkpoint_path)!r}")
139
+
140
+ t0 = time.perf_counter()
141
+
142
+ model = Italia(config)
143
+ model.load_state_dict(torch.load(checkpoint_path, mmap=True))
144
+ model.to(device)
145
+
146
+ print(
147
+ f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.",
148
+ file=sys.stderr,
149
+ )
150
+ model.max_seq_length = max_returned_tokens
151
+ model.set_kv_cache(batch_size=1, device=device)
152
+ model.eval()
153
+
154
+ for _ in range(num_samples):
155
+ t0 = time.perf_counter()
156
+ y = generate(
157
+ model,
158
+ encoded,
159
+ tokenizer,
160
+ max_returned_tokens,
161
+ temperature=temperature,
162
+ top_k=top_k,
163
+ )
164
+ t = time.perf_counter() - t0
165
+ for block in model.transformer.h:
166
+ block.attn.kv_cache.reset_parameters()
167
+
168
+ #print(tokenizer.decode(y))
169
+ tokens_generated = y.size(0) - prompt_length
170
+ print(f"\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec")
171
+
172
+
173
+ if __name__ == "__main__":
174
+ from jsonargparse import CLI
175
+
176
+ torch.set_float32_matmul_precision("high")
177
+ CLI(main)
modello_italia.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+ # Derivated from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
3
+
4
+ import math
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ from dataclasses import dataclass
12
+
13
+
14
+ from pathlib import Path
15
+ from typing import Optional, Union
16
+ from sentencepiece import SentencePieceProcessor
17
+ import torch
18
+
19
+
20
+ @dataclass
21
+ class ItaliaConfig:
22
+ block_size: int = 4096
23
+ vocab_size: int = 50_000
24
+ padding_multiple: int = 512
25
+ padded_vocab_size: int = 50176
26
+ head_size: int = 160
27
+ n_layer: int = 34
28
+ n_head: int = 32
29
+ n_embd: int = 5120
30
+ rotary_percentage: float = 0.4
31
+ parallel_residual: bool = True
32
+ bias: bool = True
33
+ lm_head_bias: bool = True
34
+ n_query_groups: int = 32
35
+ shared_attention_norm: bool = True
36
+ norm_eps: float = 1e-5
37
+ intermediate_size: int = 12800
38
+ rope_condense_ratio: int = 1
39
+ rope_n_elem: int = 64
40
+ rope_base: int = 10000
41
+
42
+
43
+ class Tokenizer:
44
+ def __init__(self, checkpoint_dir: Union[Path, str]) -> None:
45
+ checkpoint_dir = Path(checkpoint_dir)
46
+ if not checkpoint_dir.exists():
47
+ raise NotADirectoryError(
48
+ f"The checkpoint directory does not exist: {str(checkpoint_dir)}"
49
+ )
50
+
51
+ self.use_bos = True
52
+ self.bos_id = None
53
+ self.eos_id = None
54
+
55
+ if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file():
56
+ self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))
57
+ self.backend = "sentencepiece"
58
+ self.bos_id = self.processor.bos_id()
59
+ self.eos_id = self.processor.eos_id()
60
+ else:
61
+ raise FileNotFoundError(
62
+ f"tokenizer.model not found in {str(checkpoint_dir)}"
63
+ )
64
+
65
+ @property
66
+ def vocab_size(self) -> int:
67
+ return self.processor.vocab_size()
68
+
69
+ def token_to_id(self, token: str) -> int:
70
+ return self.processor.piece_to_id(token)
71
+
72
+ def encode(
73
+ self,
74
+ string: str,
75
+ device: Optional[torch.device] = None,
76
+ max_length: int = -1,
77
+ ) -> torch.Tensor:
78
+
79
+ tokens = self.processor.encode(string)
80
+ tokens = [self.bos_id] + tokens
81
+
82
+ if max_length > 0:
83
+ tokens = tokens[:max_length]
84
+ return torch.tensor(tokens, dtype=torch.int, device=device)
85
+
86
+ def decode(self, tensor: torch.Tensor) -> str:
87
+ tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
88
+ return self.processor.decode(tokens).strip()
89
+
90
+
91
+ class Italia(nn.Module):
92
+ def __init__(self, config: ItaliaConfig) -> None:
93
+ super().__init__()
94
+ assert config.padded_vocab_size is not None
95
+ self.config = config
96
+
97
+ self.lm_head = nn.Linear(
98
+ config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
99
+ )
100
+ self.transformer = nn.ModuleDict(
101
+ dict(
102
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
103
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
104
+ ln_f=nn.LayerNorm(config.n_embd, eps=config.norm_eps),
105
+ )
106
+ )
107
+ self.max_seq_length = self.config.block_size
108
+ self.mask_cache: Optional[torch.Tensor] = None
109
+
110
+ @property
111
+ def max_seq_length(self) -> int:
112
+ return self._max_seq_length
113
+
114
+ @max_seq_length.setter
115
+ def max_seq_length(self, value: int) -> None:
116
+ """
117
+ When doing inference, the sequences used might be shorter than the model's context length.
118
+ This allows setting a smaller number to avoid allocating unused memory
119
+ """
120
+ if value > self.config.block_size:
121
+ raise ValueError(
122
+ f"Cannot attend to {value}, block size is only {self.config.block_size}"
123
+ )
124
+ self._max_seq_length = value
125
+ if not hasattr(self, "cos"):
126
+ cos, sin = self.rope_cache()
127
+ self.register_buffer("cos", cos, persistent=False)
128
+ self.register_buffer("sin", sin, persistent=False)
129
+
130
+ elif value != self.cos.size(0):
131
+ self.cos, self.sin = self.rope_cache(device=self.cos.device)
132
+
133
+ def reset_parameters(self) -> None:
134
+ self.cos, self.sin = self.rope_cache()
135
+
136
+ def forward(
137
+ self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None
138
+ ) -> torch.Tensor:
139
+ T = idx.size(1)
140
+ if self.max_seq_length < T:
141
+ raise ValueError(
142
+ f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}."
143
+ )
144
+
145
+ if input_pos is not None: # use the kv cache
146
+ cos = self.cos.index_select(0, input_pos)
147
+ sin = self.sin.index_select(0, input_pos)
148
+ if self.mask_cache is None:
149
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
150
+ mask = self.mask_cache.index_select(2, input_pos)
151
+ else:
152
+ cos = self.cos[:T]
153
+ sin = self.sin[:T]
154
+ mask = None
155
+
156
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
157
+ for block in self.transformer.h:
158
+ x = block(x, cos, sin, mask, input_pos)
159
+ x = self.transformer.ln_f(x)
160
+ return self.lm_head(x) # (b, t, vocab_size)
161
+
162
+ def rope_cache(
163
+ self, device: Optional[torch.device] = None
164
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
165
+ return build_rope_cache(
166
+ seq_len=self.max_seq_length,
167
+ n_elem=self.config.rope_n_elem,
168
+ device=device,
169
+ condense_ratio=self.config.rope_condense_ratio,
170
+ base=self.config.rope_base,
171
+ )
172
+
173
+ def set_kv_cache(
174
+ self,
175
+ batch_size: int,
176
+ rope_cache_length: Optional[int] = None,
177
+ device: Optional[torch.device] = None,
178
+ dtype: Optional[torch.dtype] = None,
179
+ ) -> None:
180
+ if rope_cache_length is None:
181
+ rope_cache_length = self.cos.size(-1)
182
+ max_seq_length = self.max_seq_length
183
+
184
+ for block in self.transformer.h:
185
+ block.attn.kv_cache = block.attn.build_kv_cache(
186
+ batch_size, max_seq_length, rope_cache_length, device, dtype
187
+ )
188
+
189
+ if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
190
+ self.mask_cache = build_mask_cache(max_seq_length, device)
191
+
192
+ def clear_kv_cache(self) -> None:
193
+ self.mask_cache = None
194
+ for block in self.transformer.h:
195
+ block.attn.kv_cache = None
196
+
197
+
198
+ class Block(nn.Module):
199
+ def __init__(self, config: ItaliaConfig) -> None:
200
+ super().__init__()
201
+ self.norm_1 = nn.LayerNorm(config.n_embd, eps=config.norm_eps)
202
+ self.attn = CausalSelfAttention(config)
203
+ self.mlp = MLP(config)
204
+ self.config = config
205
+
206
+ def forward(
207
+ self,
208
+ x: torch.Tensor,
209
+ cos: torch.Tensor,
210
+ sin: torch.Tensor,
211
+ mask: Optional[torch.Tensor] = None,
212
+ input_pos: Optional[torch.Tensor] = None,
213
+ ) -> torch.Tensor:
214
+ n_1 = self.norm_1(x)
215
+ h = self.attn(n_1, cos, sin, mask, input_pos)
216
+ n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
217
+ x = self.mlp(n_2) + h + x
218
+ return x
219
+
220
+
221
+ class CausalSelfAttention(nn.Module):
222
+ def __init__(self, config: ItaliaConfig) -> None:
223
+ super().__init__()
224
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
225
+ linear_module = nn.Linear
226
+ self.attn = linear_module(config.n_embd, shape, bias=config.bias)
227
+ self.proj = linear_module(config.n_embd, config.n_embd, bias=config.bias)
228
+ self.kv_cache: Optional[KVCache] = None
229
+
230
+ self.config = config
231
+
232
+ def forward(
233
+ self,
234
+ x: torch.Tensor,
235
+ cos: torch.Tensor,
236
+ sin: torch.Tensor,
237
+ mask: Optional[torch.Tensor] = None,
238
+ input_pos: Optional[torch.Tensor] = None,
239
+ ) -> torch.Tensor:
240
+ B, T, _ = (
241
+ x.size()
242
+ ) # batch size, sequence length, embedding dimensionality (n_embd)
243
+
244
+ qkv = self.attn(x)
245
+
246
+ # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
247
+ q_per_kv = self.config.n_head // self.config.n_query_groups
248
+ total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
249
+ qkv = qkv.view(
250
+ B, T, self.config.n_query_groups, total_qkv, self.config.head_size
251
+ )
252
+ qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
253
+
254
+ # split batched computation into three
255
+ q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
256
+
257
+ q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
258
+ k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
259
+ v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)
260
+
261
+ q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
262
+ k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
263
+ q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
264
+ k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
265
+
266
+ if input_pos is not None:
267
+ if not isinstance(self.kv_cache, KVCache):
268
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
269
+ k, v = self.kv_cache(input_pos, k, v)
270
+
271
+ y = self.scaled_dot_product_attention(q, k, v, mask)
272
+
273
+ y = y.reshape(
274
+ B, T, self.config.n_embd
275
+ ) # re-assemble all head outputs side by side
276
+
277
+ # output projection
278
+ return self.proj(y)
279
+
280
+ def scaled_dot_product_attention(
281
+ self,
282
+ q: torch.Tensor,
283
+ k: torch.Tensor,
284
+ v: torch.Tensor,
285
+ mask: Optional[torch.Tensor] = None,
286
+ ) -> torch.Tensor:
287
+ scale = 1.0 / math.sqrt(self.config.head_size)
288
+ y = torch.nn.functional.scaled_dot_product_attention(
289
+ q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
290
+ )
291
+ return y.transpose(1, 2)
292
+
293
+ def build_kv_cache(
294
+ self,
295
+ batch_size: int,
296
+ max_seq_length: int,
297
+ rope_cache_length: Optional[int] = None,
298
+ device: Optional[torch.device] = None,
299
+ dtype: Optional[torch.dtype] = None,
300
+ ) -> "KVCache":
301
+ heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
302
+ v_shape = (batch_size, heads, max_seq_length, self.config.head_size)
303
+ if rope_cache_length is None:
304
+ if self.config.rotary_percentage != 1.0:
305
+ raise TypeError(
306
+ "Please pass the `rope_cache_length=gpt.cos.size(-1)` value"
307
+ )
308
+ k_shape = v_shape
309
+ else:
310
+ k_shape = (
311
+ batch_size,
312
+ heads,
313
+ max_seq_length,
314
+ rope_cache_length + self.config.head_size - self.config.rope_n_elem,
315
+ )
316
+ return KVCache(k_shape, v_shape, device=device, dtype=dtype)
317
+
318
+
319
+ class MLP(nn.Module):
320
+ def __init__(self, config: ItaliaConfig) -> None:
321
+ super().__init__()
322
+ self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
323
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
324
+
325
+ self.config = config
326
+
327
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
328
+ x = self.fc(x)
329
+ x = torch.nn.functional.gelu(x, approximate="tanh")
330
+ return self.proj(x)
331
+
332
+
333
+ def build_rope_cache(
334
+ seq_len: int,
335
+ n_elem: int,
336
+ device: Optional[torch.device] = None,
337
+ base: int = 10000,
338
+ condense_ratio: int = 1,
339
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
340
+ """Enhanced Transformer with Rotary Position Embedding.
341
+
342
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
343
+ transformers/rope/__init__.py. MIT License:
344
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
345
+ """
346
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
347
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
348
+
349
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
350
+ seq_idx = torch.arange(seq_len, device=device) / condense_ratio
351
+
352
+ # Calculate the product of position index and $\theta_i$
353
+ idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
354
+
355
+ return torch.cos(idx_theta), torch.sin(idx_theta)
356
+
357
+
358
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
359
+ head_size = x.size(-1)
360
+ x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
361
+ x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
362
+ rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
363
+ roped = (x * cos) + (rotated * sin)
364
+ return roped.to(dtype=x.dtype)
365
+
366
+
367
+ class KVCache(nn.Module):
368
+ def __init__(
369
+ self,
370
+ k_shape: Tuple[int, int, int, int],
371
+ v_shape: Tuple[int, int, int, int],
372
+ device: Optional[torch.device] = None,
373
+ dtype: Optional[torch.dtype] = None,
374
+ ) -> None:
375
+ super().__init__()
376
+ self.register_buffer(
377
+ "k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False
378
+ )
379
+ self.register_buffer(
380
+ "v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False
381
+ )
382
+
383
+ def forward(
384
+ self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor
385
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
386
+ # move the buffer to the activation dtype for when AMP is used
387
+ self.k = self.k.to(k.dtype)
388
+ self.v = self.v.to(v.dtype)
389
+ # update the cache
390
+ k = self.k.index_copy_(2, input_pos, k)
391
+ v = self.v.index_copy_(2, input_pos, v)
392
+ return k, v
393
+
394
+ def reset_parameters(self) -> None:
395
+ torch.nn.init.zeros_(self.k)
396
+ torch.nn.init.zeros_(self.v)
397
+
398
+
399
+ def build_mask_cache(
400
+ max_seq_length: int, device: Optional[torch.device] = None
401
+ ) -> torch.Tensor:
402
+ ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
403
+ return torch.tril(ones).unsqueeze(0).unsqueeze(0)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ --find-links https://download.pytorch.org/whl/torch_stable.html
2
+
3
+ torch>=2.2.0
4
+ jsonargparse[cli]
5
+ sentencepiece
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd74bea2ba620d87e0a2127d9a21196b862a5cc7942ba4638eb2159bbab3340c
3
+ size 1090536