jatingocodeo commited on
Commit
1a1fb0e
·
verified ·
1 Parent(s): 6fa5efe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -251
app.py CHANGED
@@ -1,231 +1,8 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PretrainedConfig
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- import math
7
-
8
- class SmolLM2Config(PretrainedConfig):
9
- model_type = "smollm2"
10
-
11
- def __init__(
12
- self,
13
- vocab_size=49152,
14
- hidden_size=576,
15
- intermediate_size=1536,
16
- num_hidden_layers=30,
17
- num_attention_heads=9,
18
- num_key_value_heads=3,
19
- hidden_act="silu",
20
- max_position_embeddings=2048,
21
- initializer_range=0.041666666666666664,
22
- rms_norm_eps=1e-5,
23
- use_cache=True,
24
- pad_token_id=None,
25
- bos_token_id=0,
26
- eos_token_id=0,
27
- tie_word_embeddings=True,
28
- rope_theta=10000.0,
29
- **kwargs
30
- ):
31
- self.vocab_size = vocab_size
32
- self.hidden_size = hidden_size
33
- self.intermediate_size = intermediate_size
34
- self.num_hidden_layers = num_hidden_layers
35
- self.num_attention_heads = num_attention_heads
36
- self.num_key_value_heads = num_key_value_heads
37
- self.hidden_act = hidden_act
38
- self.max_position_embeddings = max_position_embeddings
39
- self.initializer_range = initializer_range
40
- self.rms_norm_eps = rms_norm_eps
41
- self.use_cache = use_cache
42
- self.rope_theta = rope_theta
43
- super().__init__(
44
- pad_token_id=pad_token_id,
45
- bos_token_id=bos_token_id,
46
- eos_token_id=eos_token_id,
47
- tie_word_embeddings=tie_word_embeddings,
48
- **kwargs
49
- )
50
-
51
- class RMSNorm(nn.Module):
52
- def __init__(self, hidden_size, eps=1e-5):
53
- super().__init__()
54
- self.weight = nn.Parameter(torch.ones(hidden_size))
55
- self.eps = eps
56
-
57
- def forward(self, x):
58
- variance = x.pow(2).mean(-1, keepdim=True)
59
- x = x * torch.rsqrt(variance + self.eps)
60
- return self.weight * x
61
-
62
- class LlamaAttention(nn.Module):
63
- def __init__(self, config):
64
- super().__init__()
65
- self.hidden_size = config.hidden_size
66
- self.num_heads = config.num_attention_heads
67
- self.num_kv_heads = config.num_key_value_heads
68
- self.head_dim = config.hidden_size // config.num_attention_heads
69
-
70
- self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
71
- self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
72
- self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
73
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
74
-
75
- def forward(self, hidden_states, attention_mask=None):
76
- batch_size, seq_length, _ = hidden_states.size()
77
-
78
- # Project and reshape
79
- q = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim)
80
- k = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
81
- v = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
82
-
83
- # Repeat k/v heads if needed
84
- if self.num_kv_heads < self.num_heads:
85
- k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
86
- v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
87
-
88
- # Transpose for attention
89
- q = q.transpose(1, 2) # (batch, num_heads, seq_len, head_dim)
90
- k = k.transpose(1, 2) # (batch, num_heads, seq_len, head_dim)
91
- v = v.transpose(1, 2) # (batch, num_heads, seq_len, head_dim)
92
-
93
- # Calculate attention scores
94
- scale = 1.0 / math.sqrt(self.head_dim)
95
- scores = torch.matmul(q, k.transpose(-2, -1)) * scale # (batch, num_heads, seq_len, seq_len)
96
-
97
- # Apply attention mask if provided
98
- if attention_mask is not None:
99
- # Ensure mask is broadcastable
100
- if attention_mask.dim() == 2:
101
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # (batch, 1, 1, seq_len)
102
- scores = scores + attention_mask
103
-
104
- # Apply softmax and dropout
105
- attention_weights = F.softmax(scores, dim=-1)
106
-
107
- # Apply attention to values
108
- output = torch.matmul(attention_weights, v) # (batch, num_heads, seq_len, head_dim)
109
-
110
- # Reshape and project back
111
- output = output.transpose(1, 2).contiguous() # (batch, seq_len, num_heads, head_dim)
112
- output = output.view(batch_size, seq_length, -1) # (batch, seq_len, hidden_size)
113
- output = self.o_proj(output)
114
-
115
- return output
116
-
117
- class LlamaMLP(nn.Module):
118
- def __init__(self, config):
119
- super().__init__()
120
- self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
121
- self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
122
- self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
123
- self.act_fn = nn.SiLU()
124
-
125
- def forward(self, x):
126
- gate = self.act_fn(self.gate_proj(x))
127
- up = self.up_proj(x)
128
- return self.down_proj(gate * up)
129
-
130
- class LlamaDecoderLayer(nn.Module):
131
- def __init__(self, config):
132
- super().__init__()
133
- self.self_attn = LlamaAttention(config)
134
- self.mlp = LlamaMLP(config)
135
- self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
136
- self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
137
-
138
- def forward(self, hidden_states, attention_mask=None):
139
- residual = hidden_states
140
- hidden_states = self.input_layernorm(hidden_states)
141
- hidden_states = self.self_attn(hidden_states, attention_mask)
142
- hidden_states = residual + hidden_states
143
-
144
- residual = hidden_states
145
- hidden_states = self.post_attention_layernorm(hidden_states)
146
- hidden_states = self.mlp(hidden_states)
147
- hidden_states = residual + hidden_states
148
-
149
- return hidden_states
150
-
151
- class SmolLM2ForCausalLM(PreTrainedModel):
152
- config_class = SmolLM2Config
153
- _no_split_modules = ["LlamaDecoderLayer"]
154
-
155
- def __init__(self, config):
156
- super().__init__(config)
157
- self.config = config
158
-
159
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
160
- self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
161
- self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
162
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
163
-
164
- if config.tie_word_embeddings:
165
- self.lm_head.weight = self.embed_tokens.weight
166
-
167
- def forward(self, input_ids, attention_mask=None, labels=None, return_dict=None, **kwargs):
168
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
169
-
170
- hidden_states = self.embed_tokens(input_ids)
171
-
172
- # Create causal attention mask
173
- batch_size, seq_length = input_ids.size()
174
- device = input_ids.device
175
-
176
- # Create causal mask
177
- causal_mask = torch.triu(
178
- torch.ones((seq_length, seq_length), dtype=torch.bool, device=device),
179
- diagonal=1
180
- )
181
- causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
182
- causal_mask = causal_mask.expand(batch_size, 1, seq_length, seq_length)
183
- causal_mask = causal_mask.to(dtype=torch.float32) * -1e4
184
-
185
- # Combine with attention mask if provided
186
- if attention_mask is not None:
187
- # Convert attention mask to float and unsqueeze
188
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len]
189
- attention_mask = attention_mask.expand(batch_size, 1, seq_length, seq_length)
190
- attention_mask = (1.0 - attention_mask) * -1e4
191
- # Combine masks
192
- causal_mask = causal_mask + attention_mask
193
-
194
- # Process through layers
195
- for layer in self.layers:
196
- hidden_states = layer(hidden_states, causal_mask)
197
-
198
- hidden_states = self.norm(hidden_states)
199
- logits = self.lm_head(hidden_states)
200
-
201
- loss = None
202
- if labels is not None:
203
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
204
-
205
- if return_dict:
206
- from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
207
- return CausalLMOutputWithCrossAttentions(
208
- loss=loss,
209
- logits=logits,
210
- past_key_values=None,
211
- hidden_states=None,
212
- attentions=None,
213
- cross_attentions=None,
214
- )
215
- return (loss, logits) if loss is not None else logits
216
-
217
- def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
218
- # Only return what we need
219
- inputs = {
220
- "input_ids": input_ids,
221
- "attention_mask": attention_mask
222
- }
223
- return inputs
224
-
225
- # Register the model architecture
226
- from transformers import AutoConfig
227
- AutoConfig.register("smollm2", SmolLM2Config)
228
- AutoModelForCausalLM.register(SmolLM2Config, SmolLM2ForCausalLM)
229
 
230
  # Cache for model and tokenizer
231
  MODEL = None
@@ -239,10 +16,12 @@ def initialize():
239
  model_id = "jatingocodeo/SmolLM2"
240
 
241
  try:
 
 
 
242
  # Load tokenizer
243
- print("\n1. Loading tokenizer...")
244
  TOKENIZER = AutoTokenizer.from_pretrained(model_id)
245
- print("✓ Tokenizer loaded successfully")
246
 
247
  # Add special tokens if needed
248
  special_tokens = {
@@ -250,25 +29,25 @@ def initialize():
250
  'eos_token': '</s>',
251
  'bos_token': '<s>'
252
  }
253
- num_added = TOKENIZER.add_special_tokens(special_tokens)
254
- print(f"✓ Added {num_added} special tokens")
255
 
256
  # Load model
257
- print("\n2. Loading model...")
258
  MODEL = AutoModelForCausalLM.from_pretrained(
259
  model_id,
260
- trust_remote_code=True,
261
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
262
  low_cpu_mem_usage=True
263
  )
264
 
265
- # Move model to appropriate device
266
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
267
- MODEL = MODEL.to(device)
268
- print(f"✓ Model loaded successfully and moved to {device}")
 
269
 
270
  except Exception as e:
271
- print(f"Error initializing model: {str(e)}")
272
  raise
273
 
274
  def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
@@ -281,6 +60,7 @@ def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
281
  if not prompt.strip():
282
  return "Please enter a prompt."
283
 
 
284
  if not prompt.startswith(TOKENIZER.bos_token):
285
  prompt = TOKENIZER.bos_token + prompt
286
 
@@ -288,36 +68,26 @@ def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
288
  input_ids = TOKENIZER.encode(prompt, return_tensors="pt", truncation=True, max_length=2048)
289
  input_ids = input_ids.to(MODEL.device)
290
 
291
- # Create attention mask
292
- attention_mask = torch.ones_like(input_ids)
293
-
294
  # Generate
295
  with torch.no_grad():
296
- output_ids = MODEL.generate(
297
  input_ids,
298
- attention_mask=attention_mask,
299
  max_length=min(max_length + len(input_ids[0]), 2048),
300
  temperature=temperature,
301
  top_k=top_k,
302
  do_sample=True,
303
  pad_token_id=TOKENIZER.pad_token_id,
304
  eos_token_id=TOKENIZER.eos_token_id,
305
- num_return_sequences=1,
306
- use_cache=True
307
  )
308
 
309
  # Decode and return
310
- generated_text = TOKENIZER.decode(output_ids[0], skip_special_tokens=True)
311
  return generated_text.strip()
312
 
313
  except Exception as e:
314
- print(f"Generation error details: {str(e)}")
315
- import traceback
316
- traceback.print_exc()
317
- return f"Error generating text: {str(e)}"
318
-
319
- # Initialize on startup
320
- initialize()
321
 
322
  # Create Gradio interface
323
  iface = gr.Interface(
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from huggingface_hub import hf_hub_download
5
+ import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Cache for model and tokenizer
8
  MODEL = None
 
16
  model_id = "jatingocodeo/SmolLM2"
17
 
18
  try:
19
+ # Download model files from HF Hub
20
+ config_path = hf_hub_download(repo_id=model_id, filename="config.json")
21
+
22
  # Load tokenizer
23
+ print("Loading tokenizer...")
24
  TOKENIZER = AutoTokenizer.from_pretrained(model_id)
 
25
 
26
  # Add special tokens if needed
27
  special_tokens = {
 
29
  'eos_token': '</s>',
30
  'bos_token': '<s>'
31
  }
32
+ TOKENIZER.add_special_tokens(special_tokens)
 
33
 
34
  # Load model
35
+ print("Loading model...")
36
  MODEL = AutoModelForCausalLM.from_pretrained(
37
  model_id,
 
38
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
39
+ trust_remote_code=True,
40
  low_cpu_mem_usage=True
41
  )
42
 
43
+ # Move model to device
44
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ MODEL.to(device)
46
+
47
+ print(f"Model loaded successfully on {device}")
48
 
49
  except Exception as e:
50
+ print(f"Error initializing: {str(e)}")
51
  raise
52
 
53
  def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
 
60
  if not prompt.strip():
61
  return "Please enter a prompt."
62
 
63
+ # Add BOS token if needed
64
  if not prompt.startswith(TOKENIZER.bos_token):
65
  prompt = TOKENIZER.bos_token + prompt
66
 
 
68
  input_ids = TOKENIZER.encode(prompt, return_tensors="pt", truncation=True, max_length=2048)
69
  input_ids = input_ids.to(MODEL.device)
70
 
 
 
 
71
  # Generate
72
  with torch.no_grad():
73
+ outputs = MODEL.generate(
74
  input_ids,
 
75
  max_length=min(max_length + len(input_ids[0]), 2048),
76
  temperature=temperature,
77
  top_k=top_k,
78
  do_sample=True,
79
  pad_token_id=TOKENIZER.pad_token_id,
80
  eos_token_id=TOKENIZER.eos_token_id,
81
+ num_return_sequences=1
 
82
  )
83
 
84
  # Decode and return
85
+ generated_text = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
86
  return generated_text.strip()
87
 
88
  except Exception as e:
89
+ print(f"Error generating text: {str(e)}")
90
+ return f"An error occurred: {str(e)}"
 
 
 
 
 
91
 
92
  # Create Gradio interface
93
  iface = gr.Interface(