jatingocodeo commited on
Commit
e061e9b
·
verified ·
1 Parent(s): 25c11ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -1
app.py CHANGED
@@ -1,6 +1,189 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  # Cache for model and tokenizer
6
  MODEL = None
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PretrainedConfig
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+
8
+ class SmolLM2Config(PretrainedConfig):
9
+ model_type = "smollm2"
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size=49152,
14
+ hidden_size=576,
15
+ intermediate_size=1536,
16
+ num_hidden_layers=30,
17
+ num_attention_heads=9,
18
+ num_key_value_heads=3,
19
+ hidden_act="silu",
20
+ max_position_embeddings=2048,
21
+ initializer_range=0.041666666666666664,
22
+ rms_norm_eps=1e-5,
23
+ use_cache=True,
24
+ pad_token_id=None,
25
+ bos_token_id=0,
26
+ eos_token_id=0,
27
+ tie_word_embeddings=True,
28
+ rope_theta=10000.0,
29
+ **kwargs
30
+ ):
31
+ self.vocab_size = vocab_size
32
+ self.hidden_size = hidden_size
33
+ self.intermediate_size = intermediate_size
34
+ self.num_hidden_layers = num_hidden_layers
35
+ self.num_attention_heads = num_attention_heads
36
+ self.num_key_value_heads = num_key_value_heads
37
+ self.hidden_act = hidden_act
38
+ self.max_position_embeddings = max_position_embeddings
39
+ self.initializer_range = initializer_range
40
+ self.rms_norm_eps = rms_norm_eps
41
+ self.use_cache = use_cache
42
+ self.rope_theta = rope_theta
43
+ super().__init__(
44
+ pad_token_id=pad_token_id,
45
+ bos_token_id=bos_token_id,
46
+ eos_token_id=eos_token_id,
47
+ tie_word_embeddings=tie_word_embeddings,
48
+ **kwargs
49
+ )
50
+
51
+ class RMSNorm(nn.Module):
52
+ def __init__(self, hidden_size, eps=1e-5):
53
+ super().__init__()
54
+ self.weight = nn.Parameter(torch.ones(hidden_size))
55
+ self.eps = eps
56
+
57
+ def forward(self, x):
58
+ variance = x.pow(2).mean(-1, keepdim=True)
59
+ x = x * torch.rsqrt(variance + self.eps)
60
+ return self.weight * x
61
+
62
+ class LlamaAttention(nn.Module):
63
+ def __init__(self, config):
64
+ super().__init__()
65
+ self.hidden_size = config.hidden_size
66
+ self.num_heads = config.num_attention_heads
67
+ self.num_kv_heads = config.num_key_value_heads
68
+ self.head_dim = config.hidden_size // config.num_attention_heads
69
+
70
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
71
+ self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
72
+ self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
73
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
74
+
75
+ def forward(self, hidden_states, attention_mask=None):
76
+ batch_size, seq_length, _ = hidden_states.size()
77
+
78
+ q = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim)
79
+ k = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
80
+ v = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
81
+
82
+ if self.num_kv_heads < self.num_heads:
83
+ k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
84
+ v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
85
+
86
+ q = q.transpose(1, 2)
87
+ k = k.transpose(1, 2)
88
+ v = v.transpose(1, 2)
89
+
90
+ attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
91
+
92
+ if attention_mask is not None:
93
+ attention_scores = attention_scores + attention_mask
94
+
95
+ attention_probs = F.softmax(attention_scores, dim=-1)
96
+ context = torch.matmul(attention_probs, v)
97
+
98
+ context = context.transpose(1, 2).contiguous()
99
+ context = context.view(batch_size, seq_length, -1)
100
+
101
+ return self.o_proj(context)
102
+
103
+ class LlamaMLP(nn.Module):
104
+ def __init__(self, config):
105
+ super().__init__()
106
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
107
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
108
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
109
+ self.act_fn = nn.SiLU()
110
+
111
+ def forward(self, x):
112
+ gate = self.act_fn(self.gate_proj(x))
113
+ up = self.up_proj(x)
114
+ return self.down_proj(gate * up)
115
+
116
+ class LlamaDecoderLayer(nn.Module):
117
+ def __init__(self, config):
118
+ super().__init__()
119
+ self.self_attn = LlamaAttention(config)
120
+ self.mlp = LlamaMLP(config)
121
+ self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
122
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
123
+
124
+ def forward(self, hidden_states, attention_mask=None):
125
+ residual = hidden_states
126
+ hidden_states = self.input_layernorm(hidden_states)
127
+ hidden_states = self.self_attn(hidden_states, attention_mask)
128
+ hidden_states = residual + hidden_states
129
+
130
+ residual = hidden_states
131
+ hidden_states = self.post_attention_layernorm(hidden_states)
132
+ hidden_states = self.mlp(hidden_states)
133
+ hidden_states = residual + hidden_states
134
+
135
+ return hidden_states
136
+
137
+ class SmolLM2ForCausalLM(PreTrainedModel):
138
+ config_class = SmolLM2Config
139
+ _no_split_modules = ["LlamaDecoderLayer"]
140
+
141
+ def __init__(self, config):
142
+ super().__init__(config)
143
+ self.config = config
144
+
145
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
146
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
147
+ self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
148
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
149
+
150
+ if config.tie_word_embeddings:
151
+ self.lm_head.weight = self.embed_tokens.weight
152
+
153
+ def forward(self, input_ids, attention_mask=None, labels=None):
154
+ hidden_states = self.embed_tokens(input_ids)
155
+
156
+ # Create causal attention mask if none provided
157
+ if attention_mask is None:
158
+ attention_mask = torch.triu(
159
+ torch.ones((input_ids.size(1), input_ids.size(1)), dtype=torch.bool, device=input_ids.device),
160
+ diagonal=1
161
+ )
162
+ attention_mask = attention_mask.unsqueeze(0).unsqueeze(0)
163
+ attention_mask = attention_mask * -1e4
164
+
165
+ for layer in self.layers:
166
+ hidden_states = layer(hidden_states, attention_mask)
167
+
168
+ hidden_states = self.norm(hidden_states)
169
+ logits = self.lm_head(hidden_states)
170
+
171
+ loss = None
172
+ if labels is not None:
173
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
174
+
175
+ return logits if loss is None else (loss, logits)
176
+
177
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
178
+ return {
179
+ "input_ids": input_ids,
180
+ "attention_mask": kwargs.get("attention_mask", None)
181
+ }
182
+
183
+ # Register the model architecture
184
+ from transformers import AutoConfig
185
+ AutoConfig.register("smollm2", SmolLM2Config)
186
+ AutoModelForCausalLM.register(SmolLM2Config, SmolLM2ForCausalLM)
187
 
188
  # Cache for model and tokenizer
189
  MODEL = None