anjikum commited on
Commit
c3fe4b2
·
verified ·
1 Parent(s): 6f5c335

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +56 -0
  2. checkpoint_step_5500.pt +3 -0
  3. model.py +175 -0
  4. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer
3
+ from model import SmolLM2, SmolLM2Config
4
+ import gradio as gr
5
+
6
+ # Initialize model and tokenizer
7
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
9
+ model = SmolLM2(SmolLM2Config())
10
+
11
+ # Load trained weights
12
+ checkpoint = torch.load('checkpoints/checkpoint_step_5000.pt', map_location=device) # Adjust path as needed
13
+ model.load_state_dict(checkpoint['model_state_dict'])
14
+ model.to(device)
15
+ model.eval()
16
+
17
+ def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
18
+ """Generate text from a prompt"""
19
+ # Tokenize the prompt
20
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
21
+
22
+ # Generate
23
+ with torch.no_grad():
24
+ output_ids = model.generate(
25
+ input_ids,
26
+ max_new_tokens=max_length,
27
+ temperature=temperature,
28
+ top_k=top_k
29
+ )
30
+
31
+ # Decode and return the generated text
32
+ generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
33
+ return generated_text
34
+
35
+ # Gradio interface
36
+ def gradio_interface(prompt, max_length, temperature, top_k):
37
+ return generate_text(prompt, int(max_length), float(temperature), int(top_k))
38
+
39
+ iface = gr.Interface(
40
+ fn=gradio_interface,
41
+ inputs=[
42
+ gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
43
+ gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Max Length"),
44
+ gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
45
+ gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top K"),
46
+ ],
47
+ outputs=gr.Textbox(label="Generated Text"),
48
+ title="SmolLM2 Text Generation",
49
+ description="Generate text using the SmolLM2 model"
50
+ )
51
+
52
+ # For Hugging Face deployment
53
+ app = gr.mount_gradio_app(app, iface)
54
+
55
+ if __name__ == "__main__":
56
+ iface.launch()
checkpoint_step_5500.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b947d53bfe31093fd71b2e839f143cede77db8d4fa2e2edeaac676d18936dc0c
3
+ size 248122116
model.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+
7
+ class LlamaRMSNorm(nn.Module):
8
+ def __init__(self, hidden_size, eps=1e-6):
9
+ super().__init__()
10
+ self.weight = nn.Parameter(torch.ones(hidden_size))
11
+ self.eps = eps
12
+
13
+ def forward(self, x):
14
+ rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
15
+ x_norm = x / rms
16
+ return self.weight * x_norm
17
+
18
+ class LlamaRotaryEmbedding(nn.Module):
19
+ def __init__(self, dim, max_position_embeddings=2048, base=10000):
20
+ super().__init__()
21
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
22
+ self.register_buffer("inv_freq", inv_freq)
23
+ self.max_position_embeddings = max_position_embeddings
24
+ self.dim = dim
25
+
26
+ def forward(self, x, seq_len):
27
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
28
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
29
+ emb = torch.cat((freqs, freqs), dim=-1)
30
+ return emb
31
+
32
+ def rotate_half(x):
33
+ x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
34
+ return torch.cat((-x2, x1), dim=-1)
35
+
36
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
37
+ cos = cos.unsqueeze(0).unsqueeze(0)
38
+ sin = sin.unsqueeze(0).unsqueeze(0)
39
+ cos = cos.expand(q.shape[0], q.shape[1], -1, -1)
40
+ sin = sin.expand(k.shape[0], k.shape[1], -1, -1)
41
+ q_embed = (q * cos) + (rotate_half(q) * sin)
42
+ k_embed = (k * cos) + (rotate_half(k) * sin)
43
+ return q_embed, k_embed
44
+
45
+ class LlamaSdpaAttention(nn.Module):
46
+ def __init__(self, config):
47
+ super().__init__()
48
+ self.hidden_size = config.n_embd
49
+ self.num_heads = config.n_head
50
+ self.head_dim = config.n_embd // config.n_head
51
+ self.num_key_value_heads = config.n_head // 3
52
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
53
+
54
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
55
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
56
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
57
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
58
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)
59
+
60
+ def forward(self, x, attention_mask=None):
61
+ B, T, C = x.size()
62
+ q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim)
63
+ k = self.k_proj(x).view(B, T, self.num_key_value_heads, self.head_dim)
64
+ v = self.v_proj(x).view(B, T, self.num_key_value_heads, self.head_dim)
65
+
66
+ k = k.repeat_interleave(self.num_key_value_groups, dim=2)
67
+ v = v.repeat_interleave(self.num_key_value_groups, dim=2)
68
+
69
+ q = q.transpose(1, 2)
70
+ k = k.transpose(1, 2)
71
+ v = v.transpose(1, 2)
72
+
73
+ rotary_emb = self.rotary_emb(x, T)
74
+ cos, sin = rotary_emb.cos(), rotary_emb.sin()
75
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, None)
76
+
77
+ out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
78
+ out = out.transpose(1, 2).contiguous().view(B, T, C)
79
+ return self.o_proj(out)
80
+
81
+ class LlamaMLP(nn.Module):
82
+ def __init__(self, config):
83
+ super().__init__()
84
+ self.gate_proj = nn.Linear(config.n_embd, config.intermediate_size, bias=False)
85
+ self.up_proj = nn.Linear(config.n_embd, config.intermediate_size, bias=False)
86
+ self.down_proj = nn.Linear(config.intermediate_size, config.n_embd, bias=False)
87
+ self.act_fn = nn.SiLU()
88
+
89
+ def forward(self, x):
90
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
91
+
92
+ class LlamaDecoderLayer(nn.Module):
93
+ def __init__(self, config):
94
+ super().__init__()
95
+ self.input_layernorm = LlamaRMSNorm(config.n_embd)
96
+ self.self_attn = LlamaSdpaAttention(config)
97
+ self.post_attention_layernorm = LlamaRMSNorm(config.n_embd)
98
+ self.mlp = LlamaMLP(config)
99
+
100
+ def forward(self, x):
101
+ residual = x
102
+ x = self.input_layernorm(x)
103
+ x = self.self_attn(x)
104
+ x = residual + x
105
+
106
+ residual = x
107
+ x = self.post_attention_layernorm(x)
108
+ x = self.mlp(x)
109
+ x = residual + x
110
+ return x
111
+
112
+ @dataclass
113
+ class SmolLM2Config:
114
+ block_size: int = 2048
115
+ vocab_size: int = 49152
116
+ n_layer: int = 30
117
+ n_head: int = 9
118
+ n_embd: int = 576
119
+ intermediate_size: int = 1536
120
+ num_key_value_heads: int = 3
121
+ rms_norm_eps: float = 1e-5
122
+ rope_theta: float = 10000.0
123
+ initializer_range: float = 0.041666666666666664
124
+ use_cache: bool = True
125
+
126
+ class SmolLM2(nn.Module):
127
+ def __init__(self, config):
128
+ super().__init__()
129
+ self.config = config
130
+
131
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd)
132
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.n_layer)])
133
+ self.norm = LlamaRMSNorm(config.n_embd, eps=config.rms_norm_eps)
134
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
135
+ self.embed_tokens.weight = self.lm_head.weight
136
+ self.apply(self._init_weights)
137
+
138
+ def _init_weights(self, module):
139
+ if isinstance(module, nn.Linear):
140
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
141
+ if module.bias is not None:
142
+ torch.nn.init.zeros_(module.bias)
143
+ elif isinstance(module, nn.Embedding):
144
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
145
+
146
+ def forward(self, idx, targets=None):
147
+ B, T = idx.size()
148
+ x = self.embed_tokens(idx)
149
+
150
+ for layer in self.layers:
151
+ x = layer(x)
152
+
153
+ x = self.norm(x)
154
+ logits = self.lm_head(x)
155
+
156
+ loss = None
157
+ if targets is not None:
158
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
159
+
160
+ return logits, loss
161
+
162
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
163
+ for _ in range(max_new_tokens):
164
+ idx_cond = idx[:, -self.config.block_size:]
165
+ logits, _ = self(idx_cond)
166
+ logits = logits[:, -1, :] / temperature
167
+
168
+ if top_k is not None:
169
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
170
+ logits[logits < v[:, [-1]]] = float('-inf')
171
+
172
+ probs = F.softmax(logits, dim=-1)
173
+ idx_next = torch.multinomial(probs, num_samples=1)
174
+ idx = torch.cat((idx, idx_next), dim=1)
175
+ return idx
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio