MilindChawre commited on
Commit
e5e63f7
·
1 Parent(s): 6441deb

Adding code for SmolLM2 text generator app

Browse files
Files changed (5) hide show
  1. README.md +65 -1
  2. app.py +80 -0
  3. model.py +245 -0
  4. requirements.txt +3 -0
  5. smollm2_final.pt +3 -0
README.md CHANGED
@@ -9,4 +9,68 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pinned: false
10
  ---
11
 
12
+ # SmolLM2 Text Generator
13
+
14
+ This is a Gradio application for generating text using the trained SmolLM2 model. The app allows users to input a text prompt and generate multiple sequences of text based on that prompt. The number of sequences and the length of the generated text can be adjusted using sliders.
15
+
16
+ ## Features
17
+
18
+ - **Text Generation**: Generate text based on a user-provided prompt using the SmolLM2 model.
19
+ - **Adjustable Length**: Control the length of the generated text.
20
+ - **Multiple Sequences**: Generate multiple sequences of text in one go.
21
+
22
+ ## Requirements
23
+
24
+ To run this application, you need the following Python packages:
25
+
26
+ - `torch`
27
+ - `transformers`
28
+ - `gradio`
29
+
30
+ You can install the required packages using pip:
31
+
32
+ ```bash
33
+ pip install -r requirements.txt
34
+ ```
35
+
36
+ ## Usage
37
+
38
+ 1. **Run the App**: Launch the Gradio app by running the following command in your terminal:
39
+
40
+ ```bash
41
+ python app.py
42
+ ```
43
+
44
+ 2. **Input Prompt**: Enter your desired text prompt in the provided textbox.
45
+
46
+ 3. **Adjust Sliders**:
47
+ - Use the "Predict Additional Text of Length" slider to set the desired length of the generated text.
48
+ - Use the "Number of Sequences to Generate" slider to specify how many sequences you want to generate.
49
+
50
+ 4. **Generate Text**: Click the "Generate Text" button to produce the text sequences.
51
+
52
+ 5. **View Output**: The generated sequences will be displayed in the output textbox, each prefixed with "Sequence X:" for clarity.
53
+
54
+ ## Example
55
+
56
+ - **Prompt**: "Once upon a time"
57
+ - **Number of Sequences**: 2
58
+
59
+ **Output**:
60
+ ```
61
+ Sequence 1:
62
+ Once upon a time, there is a cat ....
63
+
64
+ Sequence 2:
65
+ Once upon a time in a small village ....
66
+ ```
67
+
68
+ ## License
69
+
70
+ This project is licensed under the MIT License. See the LICENSE file for more details.
71
+
72
+ ## Acknowledgments
73
+
74
+ - Hugging Face for the Transformers library and model support.
75
+ - Gradio for providing an easy-to-use interface for machine learning applications.
76
+ - The SmolLM2 model for enabling advanced text generation capabilities.
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer
4
+ from model import SmolLM2 # Ensure this imports your model correctly
5
+
6
+ # Load the model and tokenizer
7
+ model_path = "smollm2_final.pt"
8
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer") # Adjust if necessary
9
+
10
+ # Load model configuration
11
+ model_config = {
12
+ "bos_token_id": 0,
13
+ "eos_token_id": 0,
14
+ "hidden_act": "silu",
15
+ "hidden_size": 576,
16
+ "initializer_range": 0.041666666666666664,
17
+ "intermediate_size": 1536,
18
+ "is_llama_config": True,
19
+ "max_position_embeddings": 2048,
20
+ "num_attention_heads": 9,
21
+ "num_hidden_layers": 30,
22
+ "num_key_value_heads": 3,
23
+ "pad_token_id": None,
24
+ "pretraining_tp": 1,
25
+ "rms_norm_eps": 1.0e-05,
26
+ "rope_interleaved": False,
27
+ "rope_scaling": None,
28
+ "rope_theta": 10000.0,
29
+ "tie_word_embeddings": True,
30
+ "use_cache": True,
31
+ "vocab_size": 49152
32
+ }
33
+
34
+ # Initialize the model with the configuration
35
+ model = SmolLM2(model_config) # Pass the configuration to the model
36
+
37
+ # Load the model weights with map_location to handle CPU-only environments
38
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # Load the model weights
39
+ model.eval() # Set the model to evaluation mode
40
+
41
+ def generate_text(prompt, length, num_sequences):
42
+ input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
43
+
44
+ generated_texts = []
45
+ for _ in range(num_sequences):
46
+ generated_sequence = model.generate(
47
+ input_ids,
48
+ max_length=length + len(input_ids[0]), # Adjust for input length
49
+ pad_token_id=tokenizer.pad_token_id,
50
+ do_sample=True,
51
+ temperature=0.8,
52
+ top_k=50,
53
+ top_p=0.95
54
+ )
55
+
56
+ # Decode the generated sequence
57
+ generated_text = tokenizer.decode(generated_sequence[0], skip_special_tokens=True)
58
+ generated_texts.append(generated_text)
59
+
60
+ # Format the output
61
+ formatted_output = "\n\n".join([f"Sequence {i + 1}:\n{text}" for i, text in enumerate(generated_texts)])
62
+ return formatted_output
63
+
64
+ # Create Gradio interface
65
+ with gr.Blocks() as app:
66
+ gr.Markdown("# SmolLM2 Text Generator")
67
+ prompt_input = gr.Textbox(label="Enter your text prompt", placeholder="Type your prompt here...")
68
+ length_slider = gr.Slider(minimum=10, maximum=200, label="Predict Additional Text of Length", value=50)
69
+ num_sequences_slider = gr.Slider(minimum=1, maximum=5, label="Number of Sequences to Generate", value=1, step=1) # Step set to 1 for integer values
70
+ generate_button = gr.Button("Generate Text")
71
+ output_text = gr.Textbox(label="Generated Text", interactive=False)
72
+
73
+ generate_button.click(
74
+ fn=generate_text,
75
+ inputs=[prompt_input, length_slider, num_sequences_slider],
76
+ outputs=output_text
77
+ )
78
+
79
+ # Launch the app
80
+ app.launch()
model.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ def _init_weights(module, std=0.041666666666666664):
7
+ if isinstance(module, nn.Linear):
8
+ module.weight.data.normal_(mean=0.0, std=std)
9
+ elif isinstance(module, nn.Embedding):
10
+ module.weight.data.normal_(mean=0.0, std=std)
11
+
12
+ class RMSNorm(nn.Module):
13
+ def __init__(self, dim, eps=1e-5):
14
+ super().__init__()
15
+ self.eps = eps
16
+ self.weight = nn.Parameter(torch.ones(dim))
17
+
18
+ def forward(self, x):
19
+ norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
20
+ return x * norm * self.weight
21
+
22
+ class RotaryEmbedding(nn.Module):
23
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, theta=10000.0):
24
+ super().__init__()
25
+ self.dim = dim
26
+ self.max_position_embeddings = max_position_embeddings
27
+ self.base = base
28
+ self.theta = theta
29
+
30
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
31
+ self.register_buffer("inv_freq", inv_freq)
32
+
33
+ t = torch.arange(self.max_position_embeddings).type_as(self.inv_freq)
34
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
35
+ emb = torch.cat((freqs, freqs), dim=-1)
36
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :])
37
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :])
38
+
39
+ def forward(self, x, seq_len=None):
40
+ if seq_len > self.max_position_embeddings:
41
+ seq_len = self.max_position_embeddings
42
+
43
+ return (
44
+ self.cos_cached[:,:,:seq_len,:],
45
+ self.sin_cached[:,:,:seq_len,:]
46
+ )
47
+
48
+ def rotate_half(x):
49
+ """Rotates half the hidden dims of the input."""
50
+ x1, x2 = x.chunk(2, dim=-1)
51
+ return torch.cat((-x2, x1), dim=-1)
52
+
53
+ def apply_rotary_pos_emb(q, k, cos, sin):
54
+ # Ensure proper broadcasting
55
+ cos = cos[:, :, :q.size(2), :] # [batch, 1, seq_len, dim]
56
+ sin = sin[:, :, :q.size(2), :] # [batch, 1, seq_len, dim]
57
+
58
+ q_embed = (q * cos) + (rotate_half(q) * sin)
59
+ k_embed = (k * cos) + (rotate_half(k) * sin)
60
+ return q_embed, k_embed
61
+
62
+ class Attention(nn.Module):
63
+ def __init__(self, config):
64
+ super().__init__()
65
+ self.hidden_size = config["hidden_size"]
66
+ self.num_attention_heads = config["num_attention_heads"]
67
+ self.num_key_value_heads = config["num_key_value_heads"]
68
+ self.head_dim = self.hidden_size // self.num_attention_heads
69
+
70
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
71
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
72
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
73
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
74
+
75
+ self.kv_cache = None
76
+
77
+ def forward(self, hidden_states, cos, sin, attention_mask=None, use_cache=False):
78
+ batch_size, seq_length, _ = hidden_states.shape
79
+
80
+ q = self.q_proj(hidden_states)
81
+ k = self.k_proj(hidden_states)
82
+ v = self.v_proj(hidden_states)
83
+
84
+ # Reshape for attention computation
85
+ q = q.view(batch_size, seq_length, self.num_attention_heads, self.head_dim)
86
+ k = k.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim)
87
+ v = v.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim)
88
+
89
+ # Transpose for attention computation
90
+ q = q.transpose(1, 2) # [batch, num_heads, seq_len, head_dim]
91
+ k = k.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim]
92
+ v = v.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim]
93
+
94
+ # Apply rotary embeddings
95
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
96
+
97
+ # Repeat k/v heads if num_key_value_heads < num_attention_heads
98
+ if self.num_key_value_heads != self.num_attention_heads:
99
+ k = k.repeat_interleave(self.num_attention_heads // self.num_key_value_heads, dim=1)
100
+ v = v.repeat_interleave(self.num_attention_heads // self.num_key_value_heads, dim=1)
101
+
102
+ # Compute attention
103
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
104
+
105
+ if attention_mask is not None:
106
+ attn_weights = attn_weights + attention_mask
107
+
108
+ attn_weights = F.softmax(attn_weights, dim=-1)
109
+
110
+ # Compute output
111
+ output = torch.matmul(attn_weights, v)
112
+ output = output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim]
113
+ output = output.view(batch_size, seq_length, -1)
114
+
115
+ return self.o_proj(output)
116
+
117
+ class MLP(nn.Module):
118
+ def __init__(self, config):
119
+ super().__init__()
120
+ self.gate_proj = nn.Linear(config["hidden_size"], config["intermediate_size"], bias=False)
121
+ self.up_proj = nn.Linear(config["hidden_size"], config["intermediate_size"], bias=False)
122
+ self.down_proj = nn.Linear(config["intermediate_size"], config["hidden_size"], bias=False)
123
+ self.act_fn = nn.SiLU()
124
+
125
+ def forward(self, x):
126
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
127
+
128
+ class DecoderLayer(nn.Module):
129
+ def __init__(self, config):
130
+ super().__init__()
131
+ self.self_attn = Attention(config)
132
+ self.mlp = MLP(config)
133
+ self.input_layernorm = RMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
134
+ self.post_attention_layernorm = RMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
135
+
136
+ def forward(self, hidden_states, cos, sin, attention_mask=None, use_cache=False):
137
+ residual = hidden_states
138
+ hidden_states = self.input_layernorm(hidden_states)
139
+ hidden_states = self.self_attn(hidden_states, cos, sin, attention_mask, use_cache)
140
+ hidden_states = residual + hidden_states
141
+
142
+ residual = hidden_states
143
+ hidden_states = self.post_attention_layernorm(hidden_states)
144
+ hidden_states = self.mlp(hidden_states)
145
+ hidden_states = residual + hidden_states
146
+
147
+ return hidden_states
148
+
149
+ class SmolLM2(nn.Module):
150
+ def __init__(self, config):
151
+ super().__init__()
152
+ self.config = config
153
+
154
+ self.embed_tokens = nn.Embedding(config["vocab_size"], config["hidden_size"])
155
+ self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config["num_hidden_layers"])])
156
+ self.norm = RMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
157
+ self.rotary_emb = RotaryEmbedding(
158
+ config["hidden_size"] // config["num_attention_heads"],
159
+ max_position_embeddings=config["max_position_embeddings"],
160
+ theta=config.get("rope_theta", 10000.0)
161
+ )
162
+
163
+ # Initialize weights
164
+ self.apply(lambda p: _init_weights(p, std=config.get("initializer_range", 0.041666666666666664)))
165
+
166
+ def forward(self, input_ids, attention_mask=None, use_cache=False):
167
+ hidden_states = self.embed_tokens(input_ids)
168
+
169
+ seq_length = input_ids.shape[1]
170
+ cos, sin = self.rotary_emb(hidden_states, seq_length)
171
+
172
+ for layer in self.layers:
173
+ hidden_states = layer(hidden_states, cos, sin, attention_mask, use_cache)
174
+
175
+ hidden_states = self.norm(hidden_states)
176
+
177
+ # Use tied weights for the output projection
178
+ if self.config.get("tie_word_embeddings", True):
179
+ logits = F.linear(hidden_states, self.embed_tokens.weight)
180
+ else:
181
+ logits = self.lm_head(hidden_states)
182
+
183
+ return logits
184
+
185
+ def generate(
186
+ self,
187
+ input_ids,
188
+ max_length,
189
+ min_length=None,
190
+ num_return_sequences=1,
191
+ pad_token_id=None,
192
+ do_sample=True,
193
+ temperature=0.8,
194
+ top_k=50,
195
+ top_p=0.95
196
+ ):
197
+ self.eval()
198
+ batch_size = input_ids.shape[0]
199
+ min_length = min_length if min_length is not None else input_ids.shape[1]
200
+
201
+ # Clear KV cache
202
+ for layer in self.layers:
203
+ layer.self_attn.kv_cache = None
204
+
205
+ with torch.no_grad():
206
+ for _ in range(max_length - input_ids.shape[1]):
207
+ outputs = self(input_ids, use_cache=True)
208
+ next_token_logits = outputs[:, -1, :]
209
+
210
+ # Apply temperature
211
+ next_token_logits = next_token_logits / temperature
212
+
213
+ # Apply top-k filtering
214
+ if top_k > 0:
215
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
216
+ next_token_logits[indices_to_remove] = float('-inf')
217
+
218
+ # Apply top-p (nucleus) filtering
219
+ if top_p < 1.0:
220
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
221
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
222
+ sorted_indices_to_remove = cumulative_probs > top_p
223
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
224
+ sorted_indices_to_remove[..., 0] = 0
225
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
226
+ next_token_logits[indices_to_remove] = float('-inf')
227
+
228
+ # Sample from the filtered distribution
229
+ if do_sample:
230
+ probs = torch.softmax(next_token_logits, dim=-1)
231
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
232
+ else:
233
+ next_tokens = torch.argmax(next_token_logits, dim=-1)
234
+
235
+ input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1)
236
+
237
+ # Stop if all sequences have hit the pad token
238
+ if pad_token_id is not None and (next_tokens == pad_token_id).all():
239
+ break
240
+
241
+ # Stop if we've reached min_length
242
+ if input_ids.shape[1] < min_length:
243
+ continue
244
+
245
+ return input_ids
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
smollm2_final.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c96efcee9cd1f94cf2d072647409d1bfce940859d08e89cade4fd48b9502ad2b
3
+ size 269663830