Spaces:
Sleeping
Sleeping
Commit
·
bac9d3f
1
Parent(s):
5a42e50
Initial commit
Browse files- app.py +122 -0
- model.py +357 -0
- smollm2_HF.pth +3 -0
app.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from json import load
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import gradio as gr
|
5 |
+
from typing import Optional
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
from model import Transformer
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class ModelArgs:
|
13 |
+
# Arch params
|
14 |
+
dim: int = 576
|
15 |
+
intermediate_dim: int = 1536
|
16 |
+
n_layers: int = 30
|
17 |
+
n_heads: int = 9
|
18 |
+
n_kv_heads: Optional[int] = 3
|
19 |
+
vocab_size: int = 49152 # defined later by tokenizer
|
20 |
+
norm_eps: float = 1.0e-05
|
21 |
+
init_scale: float = 0.041666666666666664
|
22 |
+
rope_theta: int = 10000
|
23 |
+
dropout: float = 0.1
|
24 |
+
|
25 |
+
# Training params
|
26 |
+
seed: int = 42
|
27 |
+
max_batch_size: int = 2
|
28 |
+
max_seq_len: int = 2048
|
29 |
+
steps: int = 5050
|
30 |
+
breakpoint_step: int = 5000
|
31 |
+
warmup_steps_frac: float = 0.5
|
32 |
+
save_interval:int = 1000
|
33 |
+
eval_interval:int = 500
|
34 |
+
log_interval: int = 1
|
35 |
+
grad_accum_steps: int = 8
|
36 |
+
checkpoint_path = os.path.join(os.getcwd(), "checkpoints")
|
37 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
38 |
+
|
39 |
+
# Optimizer
|
40 |
+
initial_lr: float = 5e-4
|
41 |
+
adam_beta1: float = 0.9
|
42 |
+
adam_beta2: float = 0.95
|
43 |
+
adam_eps: float = 1.0e-08
|
44 |
+
weight_decay: float = 0.01
|
45 |
+
use_fused: bool = True
|
46 |
+
|
47 |
+
|
48 |
+
# Initialize model and tokenizer
|
49 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
50 |
+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
|
51 |
+
config = ModelArgs()
|
52 |
+
config.device = device
|
53 |
+
model = Transformer(config)
|
54 |
+
|
55 |
+
# Load trained weights from zip
|
56 |
+
def load_checkpoint(model, path, device):
|
57 |
+
try:
|
58 |
+
checkpoint = torch.load(path, map_location=device)
|
59 |
+
model.load_state_dict({k.replace("_orig_mod.", ""): v for k, v in checkpoint.items() if 'cached_keys' not in k and 'cached_values' not in k})
|
60 |
+
return model
|
61 |
+
except Exception as e:
|
62 |
+
print(f"Error loading checkpoint: {e}")
|
63 |
+
return None
|
64 |
+
|
65 |
+
model = load_checkpoint(model, "smollm2_HF.pth", device)
|
66 |
+
model.to(device)
|
67 |
+
model.eval()
|
68 |
+
|
69 |
+
def generate_text(prompt,
|
70 |
+
min_length: int = 28,
|
71 |
+
max_length: int = 40,
|
72 |
+
temperature: float =0.7,
|
73 |
+
top_k: int = 50,
|
74 |
+
top_p: float = 0.7
|
75 |
+
):
|
76 |
+
"""Generate text from a prompt"""
|
77 |
+
input_ids = tokenizer(prompt,
|
78 |
+
padding=True,
|
79 |
+
truncation=True,
|
80 |
+
max_length=config.max_seq_len,
|
81 |
+
return_tensors="pt")["input_ids"].to(device)
|
82 |
+
|
83 |
+
generated = model.generate(
|
84 |
+
input_ids,
|
85 |
+
max_length=max_length,
|
86 |
+
min_length=min_length,
|
87 |
+
pad_token_id=tokenizer.pad_token_id,
|
88 |
+
do_sample=True,
|
89 |
+
temperature=temperature,
|
90 |
+
top_k=top_k,
|
91 |
+
top_p=top_p
|
92 |
+
)
|
93 |
+
|
94 |
+
return tokenizer.decode(generated[0], skip_special_tokens=True)
|
95 |
+
|
96 |
+
# Gradio interface
|
97 |
+
def gradio_interface(prompt, max_length, temperature, top_k):
|
98 |
+
return generate_text(prompt, int(max_length), float(temperature), int(top_k))
|
99 |
+
|
100 |
+
iface = gr.Interface(
|
101 |
+
fn=gradio_interface,
|
102 |
+
inputs=[
|
103 |
+
gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
|
104 |
+
gr.Slider(minimum=10, maximum=500, label="Min Length"),
|
105 |
+
gr.Slider(minimum=10, maximum=500, label="Max Length"),
|
106 |
+
gr.Slider(minimum=0.1, maximum=2.0, label="Temperature"),
|
107 |
+
gr.Slider(minimum=1, maximum=100, label="Top K"),
|
108 |
+
gr.Slider(minimum=0.1, maximum=1.0, label="Top P")
|
109 |
+
],
|
110 |
+
outputs=gr.Textbox(label="Generated Text"),
|
111 |
+
title="SmolLM2-135M Text Generation",
|
112 |
+
description="SmolLM2-135M trained onn cosmopedia-v2 with just 5000 steps",
|
113 |
+
examples=[
|
114 |
+
["I found the love", 50, 0.7, 50],
|
115 |
+
["When the sun comes up", 40, 0.8, 40],
|
116 |
+
["The slow marching of ", 60, 0.9, 45]
|
117 |
+
],
|
118 |
+
)
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == "__main__":
|
122 |
+
iface.launch()
|
model.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import logging
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from typing import Optional
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
|
11 |
+
class RMSNorm(nn.Module):
|
12 |
+
def __init__(self, dim, eps):
|
13 |
+
super().__init__()
|
14 |
+
self.eps = eps
|
15 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
# Root Mean Square Layer Normalization
|
19 |
+
rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
20 |
+
return x * rms * self.weight
|
21 |
+
|
22 |
+
|
23 |
+
class RotaryEmbedding(nn.Module):
|
24 |
+
def __init__(self, dim, max_seq_len=2048, theta=10000):
|
25 |
+
super().__init__()
|
26 |
+
self.dim = dim
|
27 |
+
self.max_seq_len = max_seq_len
|
28 |
+
|
29 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
30 |
+
self.register_buffer("freqs", freqs)
|
31 |
+
|
32 |
+
t = torch.arange(max_seq_len, dtype=self.freqs.dtype)
|
33 |
+
freqs = torch.outer(t, self.freqs)
|
34 |
+
|
35 |
+
cos = freqs.cos()
|
36 |
+
sin = freqs.sin()
|
37 |
+
self.register_buffer('cos', cos)
|
38 |
+
self.register_buffer('sin', sin)
|
39 |
+
|
40 |
+
|
41 |
+
def rotate_half(self, x):
|
42 |
+
rot_dim = x.shape[-1]
|
43 |
+
x1 = x[..., :rot_dim // 2]
|
44 |
+
x2 = x[..., rot_dim // 2:]
|
45 |
+
return torch.cat((-x2, x1), dim=-1)
|
46 |
+
|
47 |
+
def apply_rotary_emb(self, t, x):
|
48 |
+
rot_dim = self.freqs.shape[-1]
|
49 |
+
cos = self.cos[t, :rot_dim]
|
50 |
+
sin = self.sin[t, :rot_dim]
|
51 |
+
|
52 |
+
rotated_x = (x[..., :rot_dim] * cos) + (self.rotate_half(x[..., :rot_dim]) * sin)
|
53 |
+
if x.shape[-1] > rot_dim:
|
54 |
+
rotated_x = torch.cat((rotated_x, x[..., rot_dim:]), dim=-1)
|
55 |
+
return rotated_x
|
56 |
+
|
57 |
+
def forward(self, x, seq_dim=-2):
|
58 |
+
seq_len = x.shape[seq_dim]
|
59 |
+
t = torch.arange(seq_len, device=x.device)
|
60 |
+
return self.apply_rotary_emb(t, x)
|
61 |
+
|
62 |
+
|
63 |
+
class Attention(nn.Module):
|
64 |
+
def __init__(self, args):
|
65 |
+
super().__init__()
|
66 |
+
self.dim = args.dim
|
67 |
+
self.num_heads = args.n_heads
|
68 |
+
self.kv_heads = args.n_kv_heads
|
69 |
+
self.head_dim = args.dim // args.n_heads
|
70 |
+
self.kv_head_dim = args.dim // args.n_kv_heads
|
71 |
+
|
72 |
+
assert self.head_dim * args.n_heads == args.dim, "args.dim must be divisible by args.n_heads"
|
73 |
+
assert self.kv_head_dim * args.n_kv_heads == args.dim, "args.dim must be divisible by args.n_kv_heads"
|
74 |
+
|
75 |
+
self.query_proj = nn.Linear(args.dim, args.dim, bias=False)
|
76 |
+
self.key_proj = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
|
77 |
+
self.value_proj = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
|
78 |
+
|
79 |
+
self.rope = RotaryEmbedding(self.head_dim)
|
80 |
+
|
81 |
+
self.out_proj = nn.Linear(args.dim, args.dim, bias=False)
|
82 |
+
self.dropout = nn.Dropout(args.dropout)
|
83 |
+
|
84 |
+
# # Caching storage (keys and values)
|
85 |
+
cached_keys = None
|
86 |
+
cached_values = None
|
87 |
+
self.register_buffer('cached_keys', cached_keys)
|
88 |
+
self.register_buffer('cached_values', cached_values)
|
89 |
+
|
90 |
+
def forward(self, x, mask=None, use_cache=False):
|
91 |
+
# # batch_size = x.size(0)
|
92 |
+
batch_size, seq_len, C = x.size()
|
93 |
+
|
94 |
+
query = self.query_proj(x)
|
95 |
+
key = self.key_proj(x)
|
96 |
+
value = self.value_proj(x)
|
97 |
+
|
98 |
+
# Reshape for attention computation
|
99 |
+
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
100 |
+
key = key.view(batch_size, seq_len, self.kv_heads, self.head_dim)
|
101 |
+
value = value.view(batch_size, seq_len, self.kv_heads, self.head_dim)
|
102 |
+
|
103 |
+
# Transpose for attention computation
|
104 |
+
query = query.transpose(1, 2) # [batch, num_heads, seq_len, head_dim]
|
105 |
+
key = key.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim]
|
106 |
+
value = value.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim]
|
107 |
+
|
108 |
+
query = self.rope(query)
|
109 |
+
key = self.rope(key)
|
110 |
+
|
111 |
+
# # If kv_heads are less than num_heads, repeat them
|
112 |
+
# if self.kv_heads < self.num_heads:
|
113 |
+
# key = key.repeat_interleave(self.num_heads // self.kv_heads, dim=1)
|
114 |
+
# value = value.repeat_interleave(self.num_heads // self.kv_heads, dim=1)
|
115 |
+
|
116 |
+
# # Compute attention
|
117 |
+
# attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
118 |
+
# if mask is not None:
|
119 |
+
# attn_weights = attn_weights + mask
|
120 |
+
# attn_weights = F.softmax(attn_weights, dim=-1)
|
121 |
+
|
122 |
+
# # Compute output
|
123 |
+
# output = torch.matmul(attn_weights, value)
|
124 |
+
|
125 |
+
# Flash-attn
|
126 |
+
output = F.scaled_dot_product_attention(query, key, value, is_causal=True, dropout_p=self.dropout.p, enable_gqa=True)
|
127 |
+
|
128 |
+
# Update cache only if using cache
|
129 |
+
if use_cache:
|
130 |
+
self.cached_keys = key
|
131 |
+
self.cached_values = value
|
132 |
+
else:
|
133 |
+
# Reset cached values during training (to prevent unwanted accumulation)
|
134 |
+
self.cached_keys = None
|
135 |
+
self.cached_values = None
|
136 |
+
|
137 |
+
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) # [batch, seq_len, num_heads * head_dim]
|
138 |
+
return self.out_proj(output)
|
139 |
+
|
140 |
+
class FeedForward(nn.Module):
|
141 |
+
def __init__(self, args):
|
142 |
+
"""
|
143 |
+
Initialize the FeedForward module.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
dim (int): Input dimension.
|
147 |
+
hidden_dim (int): Hidden dimension of the feedforward layer. # 2304
|
148 |
+
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
|
149 |
+
|
150 |
+
Attributes:
|
151 |
+
w1 (nn.Linear): Linear transformation for the first layer.
|
152 |
+
w2 (nn.Linear): Linear transformation for the second layer.
|
153 |
+
w3 (nn.Linear): Linear transformation for the third layer.
|
154 |
+
|
155 |
+
"""
|
156 |
+
super().__init__()
|
157 |
+
self.w1 = nn.Linear(args.dim, args.intermediate_dim, bias=False)
|
158 |
+
self.w2 = nn.Linear(args.intermediate_dim, args.dim, bias=False)
|
159 |
+
self.w3 = nn.Linear(args.dim, args.intermediate_dim, bias=False)
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
163 |
+
|
164 |
+
|
165 |
+
class TransformerBlock(nn.Module):
|
166 |
+
def __init__(self, layer_id: int, args):
|
167 |
+
"""
|
168 |
+
Initialize a TransformerBlock.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
layer_id (int): Identifier for the layer.
|
172 |
+
args (ModelArgs): Model configuration parameters.
|
173 |
+
|
174 |
+
Attributes:
|
175 |
+
n_heads (int): Number of attention heads.
|
176 |
+
dim (int): Dimension size of the model.
|
177 |
+
head_dim (int): Dimension size of each attention head.
|
178 |
+
attention (Attention): Attention module.
|
179 |
+
feed_forward (FeedForward): FeedForward module.
|
180 |
+
layer_id (int): Identifier for the layer.
|
181 |
+
attention_norm (RMSNorm): Layer normalization for attention output.
|
182 |
+
ffn_norm (RMSNorm): Layer normalization for feedforward output.
|
183 |
+
|
184 |
+
"""
|
185 |
+
super().__init__()
|
186 |
+
self.n_heads = args.n_heads
|
187 |
+
self.dim = args.dim
|
188 |
+
self.head_dim = args.dim // args.n_heads
|
189 |
+
self.attention = Attention(args)
|
190 |
+
self.feed_forward = FeedForward(args)
|
191 |
+
self.layer_id = layer_id
|
192 |
+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
193 |
+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
194 |
+
|
195 |
+
def forward(
|
196 |
+
self,
|
197 |
+
x: torch.Tensor,
|
198 |
+
mask: Optional[torch.Tensor],
|
199 |
+
use_cache: bool
|
200 |
+
):
|
201 |
+
"""
|
202 |
+
Perform a forward pass through the TransformerBlock.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
x (torch.Tensor): Input tensor.
|
206 |
+
mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
|
207 |
+
use_cache (bool): whether to use kv_cache
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
torch.Tensor: Output tensor after applying attention and feedforward layers.
|
211 |
+
|
212 |
+
"""
|
213 |
+
h = x + self.attention(self.attention_norm(x), mask=mask, use_cache=use_cache)
|
214 |
+
out = h + self.feed_forward(self.ffn_norm(h))
|
215 |
+
return out
|
216 |
+
|
217 |
+
|
218 |
+
class Transformer(nn.Module):
|
219 |
+
def __init__(self, args):
|
220 |
+
"""
|
221 |
+
Initialize a Transformer model.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
args (ModelArgs): Model configuration parameters.
|
225 |
+
|
226 |
+
Attributes:
|
227 |
+
args (ModelArgs): Model configuration parameters.
|
228 |
+
vocab_size (int): Vocabulary size.
|
229 |
+
n_layers (int): Number of layers in the model.
|
230 |
+
tok_embeddings (nn.Embedding): Token embeddings.
|
231 |
+
layers (torch.nn.ModuleList): List of Transformer blocks.
|
232 |
+
norm (RMSNorm): Layer normalization for the model output.
|
233 |
+
output (nn.Linear): Linear layer for final output.
|
234 |
+
|
235 |
+
"""
|
236 |
+
super().__init__()
|
237 |
+
self.args = args
|
238 |
+
self.vocab_size = args.vocab_size
|
239 |
+
self.n_layers = args.n_layers
|
240 |
+
|
241 |
+
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
242 |
+
|
243 |
+
self.layers = torch.nn.ModuleList()
|
244 |
+
for layer_id in range(args.n_layers):
|
245 |
+
self.layers.append(TransformerBlock(layer_id, args))
|
246 |
+
|
247 |
+
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
248 |
+
# self.output = nn.Linear(
|
249 |
+
# args.dim, args.vocab_size, bias=False
|
250 |
+
# )
|
251 |
+
|
252 |
+
# # weight sharing
|
253 |
+
# self.output.weight = self.tok_embeddings.weight
|
254 |
+
|
255 |
+
# weight initialization
|
256 |
+
self.apply(self._init_weights)
|
257 |
+
|
258 |
+
|
259 |
+
def _init_weights(self, module):
|
260 |
+
std = self.args.init_scale
|
261 |
+
if isinstance(module, nn.Linear):
|
262 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
263 |
+
# if module.bias is not None:
|
264 |
+
# module.bias.data.zero_()
|
265 |
+
elif isinstance(module, nn.Embedding):
|
266 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
267 |
+
|
268 |
+
|
269 |
+
def forward(self, tokens: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = False):
|
270 |
+
"""
|
271 |
+
Perform a forward pass through the Transformer model.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
tokens (torch.Tensor): Input token indices.
|
275 |
+
mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
|
276 |
+
use_cache (bool): whether to use kv_cache
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
torch.Tensor: Output logits after applying the Transformer model.
|
280 |
+
|
281 |
+
"""
|
282 |
+
_, seqlen = tokens.shape
|
283 |
+
h = self.tok_embeddings(tokens)
|
284 |
+
|
285 |
+
if mask is None:
|
286 |
+
mask = torch.triu(torch.ones((seqlen, seqlen),
|
287 |
+
dtype=torch.bool,
|
288 |
+
device=tokens.device),
|
289 |
+
diagonal=1)
|
290 |
+
mask = mask.unsqueeze(0).unsqueeze(0)
|
291 |
+
mask = mask * -1e4
|
292 |
+
|
293 |
+
for layer in self.layers:
|
294 |
+
h = layer(h, mask, use_cache)
|
295 |
+
h = self.norm(h)
|
296 |
+
# output = self.output(h).float()
|
297 |
+
output = F.linear(h, self.tok_embeddings.weight)
|
298 |
+
return output
|
299 |
+
|
300 |
+
def generate(self,
|
301 |
+
input_ids,
|
302 |
+
max_length,
|
303 |
+
min_length=None,
|
304 |
+
num_return_sequences=1,
|
305 |
+
pad_token_id=None,
|
306 |
+
do_sample=True,
|
307 |
+
temperature=0.8,
|
308 |
+
top_k=50,
|
309 |
+
top_p=0.95
|
310 |
+
):
|
311 |
+
self.eval()
|
312 |
+
# batch_size = input_ids.shape[0]
|
313 |
+
min_length = min_length if min_length is not None else input_ids.shape[1]
|
314 |
+
|
315 |
+
with torch.no_grad():
|
316 |
+
for ret_seq in range(num_return_sequences):
|
317 |
+
logger.info(f"Sequence #{ret_seq + 1}:")
|
318 |
+
for _ in range(max_length - input_ids.shape[1]):
|
319 |
+
outputs = self(input_ids, use_cache=True)
|
320 |
+
next_token_logits = outputs[:, -1, :]
|
321 |
+
|
322 |
+
# Apply temperature
|
323 |
+
next_token_logits = next_token_logits / temperature
|
324 |
+
|
325 |
+
# Apply top-k filtering
|
326 |
+
if top_k > 0:
|
327 |
+
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
|
328 |
+
next_token_logits[indices_to_remove] = float('-inf')
|
329 |
+
|
330 |
+
# Apply top-p (nucleus) filtering
|
331 |
+
if top_p < 1.0:
|
332 |
+
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
333 |
+
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
334 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
335 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
336 |
+
sorted_indices_to_remove[..., 0] = 0
|
337 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
338 |
+
next_token_logits[indices_to_remove] = float('-inf')
|
339 |
+
|
340 |
+
# Sample from the filtered distribution
|
341 |
+
if do_sample:
|
342 |
+
probs = torch.softmax(next_token_logits, dim=-1)
|
343 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
344 |
+
else:
|
345 |
+
next_tokens = torch.argmax(next_token_logits, dim=-1)
|
346 |
+
|
347 |
+
input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1)
|
348 |
+
|
349 |
+
# Stop if all sequences have hit the pad token
|
350 |
+
if pad_token_id is not None and (next_tokens == pad_token_id).all():
|
351 |
+
break
|
352 |
+
|
353 |
+
# Stop if we've reached min_length
|
354 |
+
if input_ids.shape[1] < min_length:
|
355 |
+
continue
|
356 |
+
|
357 |
+
return input_ids
|
smollm2_HF.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e5f29f2b3a2075407a1d43c74cc25ff2af4efbce0e73f0fe2a10eb4f16a69044
|
3 |
+
size 553939176
|