jatingocodeo commited on
Commit
1ab2f15
·
verified ·
1 Parent(s): bbaff56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -2
app.py CHANGED
@@ -5,7 +5,92 @@ import torch.nn as nn
5
  import torch.nn.functional as F
6
  import math
7
 
8
- # Model architecture definition
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class SmolLM2Config(PretrainedConfig):
10
  model_type = "smollm2"
11
 
@@ -58,6 +143,8 @@ class SmolLM2ForCausalLM(PreTrainedModel):
58
  self.config = config
59
 
60
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
 
 
61
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
62
 
63
  if config.tie_word_embeddings:
@@ -65,6 +152,20 @@ class SmolLM2ForCausalLM(PreTrainedModel):
65
 
66
  def forward(self, input_ids, attention_mask=None, labels=None):
67
  hidden_states = self.embed_tokens(input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  logits = self.lm_head(hidden_states)
69
 
70
  loss = None
@@ -74,7 +175,10 @@ class SmolLM2ForCausalLM(PreTrainedModel):
74
  return logits if loss is None else (loss, logits)
75
 
76
  def prepare_inputs_for_generation(self, input_ids, **kwargs):
77
- return {"input_ids": input_ids}
 
 
 
78
 
79
  # Register the model architecture
80
  from transformers import AutoConfig, AutoModelForCausalLM
 
5
  import torch.nn.functional as F
6
  import math
7
 
8
+ class RMSNorm(nn.Module):
9
+ def __init__(self, hidden_size, eps=1e-5):
10
+ super().__init__()
11
+ self.weight = nn.Parameter(torch.ones(hidden_size))
12
+ self.eps = eps
13
+
14
+ def forward(self, x):
15
+ variance = x.pow(2).mean(-1, keepdim=True)
16
+ x = x * torch.rsqrt(variance + self.eps)
17
+ return self.weight * x
18
+
19
+ class LlamaAttention(nn.Module):
20
+ def __init__(self, config):
21
+ super().__init__()
22
+ self.hidden_size = config.hidden_size
23
+ self.num_heads = config.num_attention_heads
24
+ self.num_kv_heads = config.num_key_value_heads
25
+ self.head_dim = config.hidden_size // config.num_attention_heads
26
+
27
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
28
+ self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
29
+ self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
30
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
31
+
32
+ def forward(self, hidden_states, attention_mask=None):
33
+ batch_size, seq_length, _ = hidden_states.size()
34
+
35
+ q = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim)
36
+ k = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
37
+ v = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
38
+
39
+ if self.num_kv_heads < self.num_heads:
40
+ k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
41
+ v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
42
+
43
+ q = q.transpose(1, 2)
44
+ k = k.transpose(1, 2)
45
+ v = v.transpose(1, 2)
46
+
47
+ attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
48
+
49
+ if attention_mask is not None:
50
+ attention_scores = attention_scores + attention_mask
51
+
52
+ attention_probs = F.softmax(attention_scores, dim=-1)
53
+ context = torch.matmul(attention_probs, v)
54
+
55
+ context = context.transpose(1, 2).contiguous()
56
+ context = context.view(batch_size, seq_length, -1)
57
+
58
+ return self.o_proj(context)
59
+
60
+ class LlamaMLP(nn.Module):
61
+ def __init__(self, config):
62
+ super().__init__()
63
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
64
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
65
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
66
+ self.act_fn = nn.SiLU()
67
+
68
+ def forward(self, x):
69
+ gate = self.act_fn(self.gate_proj(x))
70
+ up = self.up_proj(x)
71
+ return self.down_proj(gate * up)
72
+
73
+ class LlamaDecoderLayer(nn.Module):
74
+ def __init__(self, config):
75
+ super().__init__()
76
+ self.self_attn = LlamaAttention(config)
77
+ self.mlp = LlamaMLP(config)
78
+ self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
79
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
80
+
81
+ def forward(self, hidden_states, attention_mask=None):
82
+ residual = hidden_states
83
+ hidden_states = self.input_layernorm(hidden_states)
84
+ hidden_states = self.self_attn(hidden_states, attention_mask)
85
+ hidden_states = residual + hidden_states
86
+
87
+ residual = hidden_states
88
+ hidden_states = self.post_attention_layernorm(hidden_states)
89
+ hidden_states = self.mlp(hidden_states)
90
+ hidden_states = residual + hidden_states
91
+
92
+ return hidden_states
93
+
94
  class SmolLM2Config(PretrainedConfig):
95
  model_type = "smollm2"
96
 
 
143
  self.config = config
144
 
145
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
146
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
147
+ self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
148
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
149
 
150
  if config.tie_word_embeddings:
 
152
 
153
  def forward(self, input_ids, attention_mask=None, labels=None):
154
  hidden_states = self.embed_tokens(input_ids)
155
+
156
+ # Create causal attention mask if none provided
157
+ if attention_mask is None:
158
+ attention_mask = torch.triu(
159
+ torch.ones((input_ids.size(1), input_ids.size(1)), dtype=torch.bool, device=input_ids.device),
160
+ diagonal=1
161
+ )
162
+ attention_mask = attention_mask.unsqueeze(0).unsqueeze(0)
163
+ attention_mask = attention_mask * -1e4
164
+
165
+ for layer in self.layers:
166
+ hidden_states = layer(hidden_states, attention_mask)
167
+
168
+ hidden_states = self.norm(hidden_states)
169
  logits = self.lm_head(hidden_states)
170
 
171
  loss = None
 
175
  return logits if loss is None else (loss, logits)
176
 
177
  def prepare_inputs_for_generation(self, input_ids, **kwargs):
178
+ return {
179
+ "input_ids": input_ids,
180
+ "attention_mask": kwargs.get("attention_mask", None)
181
+ }
182
 
183
  # Register the model architecture
184
  from transformers import AutoConfig, AutoModelForCausalLM