MartialTerran commited on
Commit
0d7b558
·
verified ·
1 Parent(s): c5692af

Create SmolLM2_360M_model_debugging.py

Browse files
Files changed (1) hide show
  1. SmolLM2_360M_model_debugging.py +506 -0
SmolLM2_360M_model_debugging.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SmolLM2_360M_model.py
2
+ # Standalone Python script for SmolLM2-360M model inference on Windows 10.
3
+
4
+ # --- Configuration ---
5
+ # List of default prompts
6
+ DEFAULT_PROMPT = ["Provide 3 reasons why cats make good pets?", "Why should I consider using an LLM?"]
7
+ MAX_GENERATION_LENGTH = 100 # Default maximum generation length
8
+
9
+
10
+ # ############## Key improvements and additions in this version:
11
+
12
+ # Comprehensive Error Handling: Includes try-except blocks for safetensors loading and sentencepiece import, providing informative error messages and exit codes.
13
+
14
+ # Detailed Comments: Improved comments throughout for better understanding.
15
+
16
+ # Type Hinting: Added type hints for enhanced code readability and maintainability.
17
+
18
+ # Special Token Handling: More robust handling of special tokens, including loading from SentencePiece and fallback if not available, as well as supporting additional special tokens. Prints these out at boot time.
19
+
20
+ # Rudimentary BPE Tokenizer: Implemented a basic BPE tokenizer as a fallback if sentencepiece is not installed. It's functional for basic English text and well-commented for potential replacement with a full sentencepiece implementation.
21
+
22
+ # Safetensors Loading: Improved weights loading with clear error handling. Prints out timing information.
23
+
24
+ # Device Management: Explicitly moves tensors and model to the specified device and defaults to CPU if CUDA isn't available. Handles cases where CUDA is not available gracefully for FP16 types.
25
+
26
+ # Default Prompt(s) and Hyperparameter Display: Implements default prompts (can be a list) and shows how to display hyperparameters on user request.
27
+
28
+ # Timing Information: Added timing measurements for key steps using timed_step function to assess performance.
29
+
30
+ # Clearer User Interaction: Improved the user input loop with clear instructions and exit condition.
31
+
32
+ # Position ID Management: More robust handling of position IDs, especially when using past key/value caching. Limits position IDs to max_position_embeddings.
33
+
34
+ # This revised script addresses many of the potential issues and incorporates best practices for a more robust and user-friendly implementation. It provides a stronger foundation for further development and experimentation.
35
+
36
+
37
+ import os
38
+ import sys
39
+ import json
40
+ import time
41
+ import struct
42
+ import math
43
+ from typing import List, Tuple, Dict, Union, Optional
44
+ import torch
45
+ import torch.nn as nn
46
+ import torch.nn.functional as F
47
+
48
+ # --- Utility Functions ---
49
+
50
+ def load_json(file_path: str) -> Dict:
51
+ """Load JSON data from a file."""
52
+ with open(file_path, 'r', encoding='utf-8') as f:
53
+ return json.load(f)
54
+
55
+ def timed_step(start: float, step_name: str) -> float:
56
+ """Print time taken for a step and return new start time."""
57
+ end = time.time()
58
+ print(f"Time taken for {step_name}: {end - start:.4f} seconds")
59
+ return end
60
+
61
+ # --- Model Architecture ---
62
+
63
+ class RMSNorm(nn.Module):
64
+ """Root Mean Square Normalization."""
65
+ def __init__(self, dim: int, eps: float = 1e-5):
66
+ super().__init__()
67
+ self.eps = eps
68
+ self.weight = nn.Parameter(torch.ones(dim))
69
+
70
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
71
+ """Apply RMS normalization."""
72
+ norm_x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
73
+ return self.weight * norm_x
74
+
75
+ def silu(x: torch.Tensor) -> torch.Tensor:
76
+ """SiLU activation function."""
77
+ return x * torch.sigmoid(x)
78
+
79
+ class RotaryEmbedding(nn.Module):
80
+ """Rotary Positional Embedding."""
81
+ def __init__(self, dim: int, base: int = 10000):
82
+ super().__init__()
83
+ self.dim = dim
84
+ self.base = base
85
+ self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
86
+
87
+ def forward(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
88
+ """Generate rotary embeddings for a given sequence length."""
89
+ t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
90
+ freqs = torch.outer(t, self.inv_freq)
91
+ return torch.cat((freqs, freqs), dim=-1)
92
+
93
+ def apply_rotary_emb(pos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
94
+ """Apply rotary embeddings to the given tensor."""
95
+ return (t * torch.cos(pos)) + (rotate_half(t) * torch.sin(pos))
96
+
97
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
98
+ """Rotate half of the tensor."""
99
+ x1 = x[..., : x.shape[-1] // 2]
100
+ x2 = x[..., x.shape[-1] // 2 :]
101
+ return torch.cat((-x2, x1), dim=-1)
102
+
103
+ class LlamaAttention(nn.Module):
104
+ """Multi-headed attention layer for LLaMA."""
105
+ def __init__(self, config: Dict):
106
+ super().__init__()
107
+ self.config = config
108
+ self.hidden_size = config['hidden_size']
109
+ self.num_heads = config['num_attention_heads']
110
+ self.head_dim = self.hidden_size // self.num_heads
111
+ self.num_key_value_heads = config["num_key_value_heads"]
112
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
113
+ self.rope_theta = config['rope_theta']
114
+
115
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
116
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
117
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
118
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
119
+
120
+ self.rotary_emb = RotaryEmbedding(self.head_dim, base=self.rope_theta)
121
+ self.attn_dropout = nn.Dropout(config['attention_dropout'])
122
+
123
+ def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
124
+ """Compute multi-headed attention."""
125
+
126
+ batch_size, seq_length, _ = hidden_states.size()
127
+ query_states = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
128
+ key_states = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
129
+ value_states = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
130
+
131
+ if position_ids is not None:
132
+ cos, sin = self.rotary_emb(position_ids.size(-1), device=position_ids.device)
133
+ position_ids = position_ids.unsqueeze(1).unsqueeze(2) # (batch_size, 1, 1, seq_len)
134
+ cos = cos[position_ids.squeeze(1).squeeze(1)].unsqueeze(1) # (batch_size, 1, seq_len, head_dim)
135
+ sin = sin[position_ids.squeeze(1).squeeze(1)].unsqueeze(1) # (batch_size, 1, seq_len, head_dim)
136
+ query_states = apply_rotary_emb(cos, query_states)
137
+ key_states = apply_rotary_emb(cos, key_states)
138
+
139
+ if past_key_value is not None:
140
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
141
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
142
+
143
+ if use_cache:
144
+ present_key_value = (key_states, value_states)
145
+ else:
146
+ present_key_value = None
147
+
148
+ seq_length_k = key_states.shape[-2]
149
+
150
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
151
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
152
+
153
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
154
+
155
+ if attn_weights.size() != (batch_size, self.num_heads, seq_length, seq_length_k):
156
+ raise ValueError(
157
+ f"Attention weights should be of size {(batch_size, self.num_heads, seq_length, seq_length_k)}, but is"
158
+ f" {attn_weights.size()}"
159
+ )
160
+
161
+ if attention_mask is not None:
162
+ attn_weights = attn_weights + attention_mask
163
+
164
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
165
+ attn_weights = self.attn_dropout(attn_weights)
166
+
167
+ attn_output = torch.matmul(attn_weights, value_states)
168
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.hidden_size)
169
+ attn_output = self.o_proj(attn_output)
170
+ return attn_output, present_key_value
171
+
172
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
173
+ """Repeat hidden states n_rep times for key/value heads."""
174
+ #Stitch1
175
+ batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape
176
+ if n_rep == 1:
177
+ return hidden_states
178
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim)
179
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim)
180
+
181
+ class LlamaMLP(nn.Module):
182
+ """Multi-Layer Perceptron for LLaMA."""
183
+ def __init__(self, config: Dict):
184
+ super().__init__()
185
+ hidden_size = config['hidden_size']
186
+ intermediate_size = config['intermediate_size']
187
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
188
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
189
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
190
+ self.act_fn = silu if config['hidden_act'] == 'silu' else getattr(F, config['hidden_act'])
191
+
192
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
193
+ """Apply MLP to the input tensor."""
194
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
195
+
196
+ class LlamaBlock(nn.Module):
197
+ """LLaMA block containing attention and MLP layers."""
198
+ def __init__(self, config: Dict):
199
+ super().__init__()
200
+ self.hidden_size = config['hidden_size']
201
+ self.self_attn = LlamaAttention(config)
202
+ self.mlp = LlamaMLP(config)
203
+ self.input_layernorm = RMSNorm(self.hidden_size, eps=config['rms_norm_eps'])
204
+ self.post_attention_layernorm = RMSNorm(self.hidden_size, eps=config['rms_norm_eps'])
205
+
206
+ def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
207
+ """Apply the LLaMA block."""
208
+ residual = hidden_states
209
+ hidden_states = self.input_layernorm(hidden_states)
210
+ hidden_states, present_key_value = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
211
+ hidden_states = residual + hidden_states
212
+ residual = hidden_states
213
+ hidden_states = self.post_attention_layernorm(hidden_states)
214
+ hidden_states = self.mlp(hidden_states)
215
+ hidden_states = residual + hidden_states
216
+ return hidden_states, present_key_value
217
+
218
+ class SmolLM2_360M(nn.Module):
219
+ """SmolLM2-360M model implementation."""
220
+ def __init__(self, config_path: str):
221
+ super().__init__()
222
+ self.config = load_json(config_path)
223
+ self.hidden_size = self.config['hidden_size']
224
+ self.vocab_size = self.config['vocab_size']
225
+ self.num_hidden_layers = self.config['num_hidden_layers']
226
+ self.max_position_embeddings = self.config['max_position_embeddings']
227
+ self.torch_dtype = self.config.get('torch_dtype', 'bfloat16')
228
+ self.use_cache = self.config.get('use_cache', True)
229
+ if self.torch_dtype == "bfloat16":
230
+ if not torch.cuda.is_available():
231
+ print ("Warning: System does not have a CUDA device, using torch.float32 dtype instead of bfloat16.")
232
+ self.torch_dtype = torch.float32
233
+ else:
234
+ self.torch_dtype = torch.bfloat16
235
+ elif self.torch_dtype == "float16":
236
+ if not torch.cuda.is_available():
237
+ print ("Warning: System does not have a CUDA device, using torch.float32 dtype instead of float16.")
238
+ self.torch_dtype = torch.float32
239
+ else:
240
+ self.torch_dtype = torch.float16
241
+ else:
242
+ self.torch_dtype = torch.float32
243
+ self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size)
244
+ self.layers = nn.ModuleList([LlamaBlock(self.config) for _ in range(self.num_hidden_layers)])
245
+ self.norm = RMSNorm(self.hidden_size, eps=self.config['rms_norm_eps'])
246
+ self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
247
+ self.past_keys_values = None
248
+
249
+ def load_weights(self, weights_path: str):
250
+ """Load weights from a safetensors file."""
251
+ start = time.time()
252
+ try:
253
+ from safetensors import safe_open
254
+ with safe_open(weights_path, framework="pt", device='cpu') as f:
255
+ weights = f.get_tensor("model.embed_tokens.weight")
256
+ self.embed_tokens.weight = nn.Parameter(weights)
257
+ self.lm_head.weight = nn.Parameter(f.get_tensor("lm_head.weight"))
258
+ for i in range(self.num_hidden_layers):
259
+ self.layers[i].input_layernorm.weight = nn.Parameter(f.get_tensor(f"model.layers.{i}.input_layernorm.weight"))
260
+ self.layers[i].post_attention_layernorm.weight = nn.Parameter(f.get_tensor(f"model.layers.{i}.post_attention_layernorm.weight"))
261
+ self.layers[i].self_attn.q_proj.weight = nn.Parameter(f.get_tensor(f"model.layers.{i}.self_attn.q_proj.weight"))
262
+ self.layers[i].self_attn.k_proj.weight = nn.Parameter(f.get_tensor(f"model.layers.{i}.self_attn.k_proj.weight"))
263
+ self.layers[i].self_attn.v_proj.weight = nn.Parameter(f.get_tensor(f"model.layers.{i}.self_attn.v_proj.weight"))
264
+ self.layers[i].self_attn.o_proj.weight = nn.Parameter(f.get_tensor(f"model.layers.{i}.self_attn.o_proj.weight"))
265
+ self.layers[i].mlp.gate_proj.weight = nn.Parameter(f.get_tensor(f"model.layers.{i}.mlp.gate_proj.weight"))
266
+ self.layers[i].mlp.up_proj.weight = nn.Parameter(f.get_tensor(f"model.layers.{i}.mlp.up_proj.weight"))
267
+ self.layers[i].mlp.down_proj.weight = nn.Parameter(f.get_tensor(f"model.layers.{i}.mlp.down_proj.weight"))
268
+ except ImportError:
269
+ print("Error: Safetensors library not found. Please install it with 'pip install safetensors'.")
270
+ sys.exit(1)
271
+ except Exception as e:
272
+ print(f"An error occurred while loading weights: {e}")
273
+ sys.exit(1)
274
+ end = timed_step(start, "Weight Loading")
275
+
276
+ def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: Optional[bool] = None) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
277
+ """Forward pass of the model."""
278
+ use_cache = use_cache if use_cache is not None else self.use_cache
279
+ batch_size, seq_length = input_ids.shape
280
+ if position_ids is None:
281
+ #Stitch2
282
+ position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device).unsqueeze(0)
283
+ if past_key_values is not None:
284
+ position_ids = position_ids + past_key_values[0][0].shape[-2]
285
+ if position_ids.shape[-1] > self.max_position_embeddings:
286
+ position_ids = position_ids[:, -self.max_position_embeddings:]
287
+ inputs_embeds = self.embed_tokens(input_ids)
288
+ hidden_states = inputs_embeds
289
+
290
+ if past_key_values is None:
291
+ past_key_values = [None] * len(self.layers)
292
+
293
+ present_key_values = [] if use_cache else None
294
+
295
+ for i in range(self.num_hidden_layers):
296
+ hidden_states, present_key_value = self.layers[i](
297
+ hidden_states,
298
+ attention_mask=attention_mask,
299
+ position_ids=position_ids,
300
+ past_key_value=past_key_values[i],
301
+ use_cache=use_cache,
302
+ )
303
+ if use_cache:
304
+ present_key_values.append(present_key_value)
305
+
306
+ hidden_states = self.norm(hidden_states)
307
+ logits = self.lm_head(hidden_states)
308
+
309
+ return logits, present_key_values
310
+
311
+ # --- Tokenizer ---
312
+
313
+ class SmolLM2Tokenizer:
314
+ """Tokenizer for SmolLM2-360M using SentencePiece or a rudimentary BPE."""
315
+ def __init__(self, tokenizer_path: str = ".", special_tokens_map_path: str = ".", config_path: str = "."):
316
+ self.tokenizer_path = tokenizer_path
317
+ self.special_tokens_map_path = special_tokens_map_path
318
+ self.config = load_json(config_path) if config_path else None
319
+ self.vocab_size = self.config['vocab_size'] if self.config else None
320
+ self.use_sentencepiece = True
321
+ self.special_tokens_map = load_json(special_tokens_map_path) if special_tokens_map_path else {}
322
+ #self.inv_special_tokens_map = {v['content']: k for k, v in self.special_tokens_map.items()}
323
+ #self.additional_special_tokens = self.special_tokens_map.get("additional_special_tokens",[]) #buggy
324
+ self.additional_special_tokens = self.special_tokens_map.get("additional_special_tokens",[])
325
+ self.inv_special_tokens_map = {v['content']: k for k, v in self.special_tokens_map.items() if isinstance(v,dict)}
326
+ self.additional_special_tokens_inv_map = {token: f"additional_special_tokens_{i}" for i, token in enumerate(self.additional_special_tokens)}
327
+
328
+ try:
329
+ import sentencepiece as spm
330
+ self.sp_model = spm.SentencePieceProcessor(model_file=os.path.join(tokenizer_path, 'tokenizer.model'))
331
+ # Load special tokens and IDs from SentencePiece
332
+ self.bos_token_id = self.sp_model.bos_id()
333
+ self.eos_token_id = self.sp_model.eos_id()
334
+ self.pad_token_id = self.sp_model.pad_id() if self.sp_model.pad_id() >=0 else self.eos_token_id
335
+ self.unk_token_id = self.sp_model.unk_id()
336
+ self.additional_special_tokens_ids = [self.sp_model.piece_to_id(token) for token in self.additional_special_tokens]
337
+ # Adjust special tokens if they are in the SentencePiece model
338
+ self.update_special_tokens_from_sp()
339
+ except ImportError:
340
+ print("Warning: SentencePiece not found, using rudimentary BPE tokenizer. Install SentencePiece for better performance.")
341
+ self.use_sentencepiece = False
342
+ self.vocab = load_json(os.path.join(tokenizer_path, 'vocab.json'))
343
+ self.merges = open(os.path.join(tokenizer_path, 'merges.txt'), 'r', encoding='utf-8').read().split('\n')[:-1]
344
+ self.merges = [tuple(merge.split()) for merge in self.merges]
345
+ self.token_to_id = {token: id for id, token in enumerate(self.vocab)}
346
+ self.id_to_token = {id: token for token, id in self.token_to_id.items()}
347
+ self.bos_token = self.special_tokens_map.get('bos_token', {}).get('content')
348
+ self.eos_token = self.special_tokens_map.get('eos_token', {}).get('content')
349
+ self.unk_token = self.special_tokens_map.get('unk_token', {}).get('content')
350
+ self.pad_token = '<PAD>' # Simple PAD token
351
+ self.bos_token_id = self.token_to_id.get(self.bos_token, -1)
352
+ self.eos_token_id = self.token_to_id.get(self.eos_token, -1)
353
+ self.unk_token_id = self.token_to_id.get(self.unk_token, -1)
354
+ self.pad_token_id = self.token_to_id.get(self.pad_token, -1) # Assuming you add <PAD> to vocab
355
+ self.additional_special_tokens_ids = [self.token_to_id.get(token, -1) for token in self.additional_special_tokens]
356
+
357
+ def update_special_tokens_from_sp(self):
358
+ """Update special token IDs from SentencePiece model, if present."""
359
+ for token_name, token_data in self.special_tokens_map.items():
360
+ sp_id = self.sp_model.piece_to_id(token_data['content'])
361
+ if sp_id != self.sp_model.unk_id():
362
+ if token_name == 'bos_token':
363
+ self.bos_token_id = sp_id
364
+ elif token_name == 'eos_token':
365
+ self.eos_token_id = sp_id
366
+ elif token_name == 'unk_token':
367
+ self.unk_token_id = sp_id
368
+
369
+
370
+ def get_special_tokens_dict(self) -> Dict[str, Union[str, int]]:
371
+
372
+ # Add the additional special tokens to the dictionary
373
+ result_dict = {
374
+ 'bos_token': self.inv_special_tokens_map.get(self.sp_model.id_to_piece(self.bos_token_id), None) if self.use_sentencepiece else self.bos_token,
375
+ 'eos_token': self.inv_special_tokens_map.get(self.sp_model.id_to_piece(self.eos_token_id), None) if self.use_sentencepiece else self.eos_token,
376
+ 'unk_token': self.inv_special_tokens_map.get(self.sp_model.id_to_piece(self.unk_token_id), None) if self.use_sentencepiece else self.unk_token,
377
+ 'pad_token': self.inv_special_tokens_map.get(self.sp_model.id_to_piece(self.pad_token_id), None) if self.use_sentencepiece and hasattr(self, 'pad_token_id') else self.pad_token if hasattr(self, 'pad_token') else None,
378
+ 'bos_token_id': self.bos_token_id,
379
+ 'eos_token_id': self.eos_token_id,
380
+ 'unk_token_id': self.unk_token_id,
381
+ 'pad_token_id': self.pad_token_id if hasattr(self, 'pad_token_id') else None,
382
+ 'additional_special_tokens': self.additional_special_tokens,
383
+ 'additional_special_tokens_ids': self.additional_special_tokens_ids,
384
+ }
385
+ result_dict.update(self.additional_special_tokens_inv_map)
386
+ return result_dict
387
+
388
+
389
+ def bpe(self, token: str) -> List[str]:
390
+ """Rudimentary BPE tokenization."""
391
+ if not self.use_sentencepiece:
392
+ word = list(token)
393
+ while len(word) > 1:
394
+ pairs = [(word[i], word[i+1]) for i in range(len(word) - 1)]
395
+ bigram = min(pairs, key=lambda pair: self.merges.index(pair) if pair in self.merges else float('inf'))
396
+ if bigram not in self.merges:
397
+ break
398
+ first, second = bigram
399
+ new_word = []
400
+ i = 0
401
+ while i < len(word):
402
+ if i < len(word) - 1 and word[i] == first and word[i+1] == second:
403
+ new_word.append(first + second)
404
+ i += 2
405
+ else:
406
+ new_word.append(word[i])
407
+ # Stitch 3 Last stitch but was an error, switched to Gemini 1.5 Pro.
408
+ i += 1
409
+ word = new_word
410
+ return word
411
+ else:
412
+ return [] # If SentencePiece is used, this function is not called.
413
+
414
+ def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
415
+ """Encode text to token IDs."""
416
+ if self.use_sentencepiece:
417
+ if add_special_tokens:
418
+ return self.sp_model.encode(text, out_type=int) #add_bos=True, add_eos=True if needed, adjust as per model requirement
419
+ else:
420
+ return self.sp_model.encode_as_ids(text)
421
+ else:
422
+ tokens = []
423
+ for word in text.split():
424
+ tokens.extend(self.bpe(word))
425
+ token_ids = [self.token_to_id.get(token, self.unk_token_id) for token in tokens]
426
+ if add_special_tokens and self.bos_token_id != -1 and self.eos_token_id != -1:
427
+ token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id]
428
+ return token_ids
429
+
430
+ def decode(self, token_ids: List[int]) -> str:
431
+ """Decode token IDs to text."""
432
+ if self.use_sentencepiece:
433
+ return self.sp_model.decode(token_ids)
434
+ else:
435
+ tokens = [self.id_to_token.get(token_id, self.unk_token) for token_id in token_ids]
436
+ return " ".join(tokens)
437
+
438
+
439
+ # --- Inference ---
440
+
441
+ def generate_text(model: SmolLM2_360M, tokenizer: SmolLM2Tokenizer, prompt: str, MAX_GENERATION_LENGTH: int = 100, device: torch.device = 'cpu') -> str:
442
+ """Generate text using greedy decoding."""
443
+ input_ids = tokenizer.encode(prompt, add_special_tokens=True)
444
+ input_ids = torch.tensor([input_ids], dtype=torch.long, device=device)
445
+
446
+ past_key_values = None
447
+ for _ in range(MAX_GENERATION_LENGTH):
448
+ logits, past_key_values = model(input_ids=input_ids, past_key_values=past_key_values)
449
+ next_token_logits = logits[:, -1, :]
450
+ next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(1)
451
+ input_ids = torch.cat([input_ids, next_token_id], dim=1)
452
+ if next_token_id.item() == tokenizer.eos_token_id:
453
+ break
454
+ generated_ids = input_ids[0].tolist()
455
+ generated_text = tokenizer.decode(generated_ids)
456
+ return generated_text
457
+
458
+
459
+ # --- Main Execution ---
460
+ if __name__ == "__main__":
461
+ start = time.time()
462
+ config_path = "config.json"
463
+ weights_path = "model.safetensors"
464
+ tokenizer_path = "." # Current directory
465
+ special_tokens_map_path = "special_tokens_map.json"
466
+
467
+ config = load_json(config_path)
468
+ tokenizer = SmolLM2Tokenizer(tokenizer_path, special_tokens_map_path, config_path)
469
+
470
+ model = SmolLM2_360M(config_path)
471
+ model.load_weights(weights_path)
472
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
473
+
474
+ # Print special tokens information
475
+ special_tokens = tokenizer.get_special_tokens_dict()
476
+ print("Special Tokens:")
477
+ for k, v in special_tokens.items():
478
+ print(f"\t{k}: {v}")
479
+
480
+ model.to(device, dtype=model.torch_dtype).eval()
481
+
482
+ end = timed_step(start, "Model initialization")
483
+
484
+ start = time.time()
485
+ # Default prompts (loop if multiple)
486
+ for prompt in DEFAULT_PROMPT:
487
+ print(f"\nDefault Prompt: {prompt}")
488
+ generated_text = generate_text(model, tokenizer, prompt, MAX_GENERATION_LENGTH=MAX_GENERATION_LENGTH, device=device)
489
+ print(f"Generated Text: {generated_text}")
490
+ end = timed_step(start, "Default Prompt Generation")
491
+
492
+ # User input loop
493
+ while True:
494
+ user_input = input("\nEnter prompt (or 'exit' to quit, 'hyper' for hyperparameters): ")
495
+ if user_input.lower() == "exit":
496
+ break
497
+ elif "hyper" in user_input.lower():
498
+ print("\nHyperparameters:")
499
+ for key, value in config.items():
500
+ print(f"\t{key}: {value}")
501
+ else:
502
+ start = time.time()
503
+ generated_text = generate_text(model, tokenizer, user_input, MAX_GENERATION_LENGTH=MAX_GENERATION_LENGTH, device=device)
504
+ print(f"Generated Text: {generated_text}")
505
+ end = timed_step(start, "Prompt Generation")
506
+