Neu256 commited on
Commit
488f6f3
·
verified ·
1 Parent(s): 0b86f63

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +85 -118
model.py CHANGED
@@ -1,8 +1,26 @@
1
- import math
2
  import torch
3
  import torch.nn as nn
4
- from torch.nn import functional as F
5
- from utils import DEVICE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  class RMSNorm(torch.nn.Module):
8
  def __init__(self, dim: int, eps: float = 1e-6):
@@ -17,13 +35,12 @@ class RMSNorm(torch.nn.Module):
17
  output = self._norm(x.float()).type_as(x)
18
  return output * self.weight
19
 
20
-
21
  class Attention(nn.Module):
22
  """
23
  Multi-head Self-Attention with RoPE
24
  """
25
 
26
- def __init__(self, num_heads, head_size, num_embed, dropout):
27
  super().__init__()
28
  self.num_heads = num_heads
29
  self.head_size = head_size
@@ -32,27 +49,8 @@ class Attention(nn.Module):
32
  self.wk = nn.Linear(num_embed, num_heads * head_size, bias = False)
33
  self.wv = nn.Linear(num_embed, num_heads * head_size, bias = False)
34
  self.wo = nn.Linear(num_heads * head_size, num_embed, bias = False)
35
-
36
- inv_freq = 1 / (500000 ** (torch.arange(0, head_size, 2)[: (head_size // 2)].float() / head_size))
37
- self.register_buffer('inv_freq', inv_freq)
38
-
39
- self.dropout = nn.Dropout(dropout)
40
-
41
- def reshape_for_broadcast(self, freq_cis, x):
42
- ndim = x.ndim
43
- shape = [1] * (ndim - 2) + list(freq_cis.shape)
44
- return freq_cis.view(*shape)
45
-
46
- def apply_rope(self, x, position, freq):
47
- t = torch.arange(position, device=freq.device, dtype=torch.float32)
48
- freq = torch.outer(t, freq)
49
- freq_cis = torch.polar(torch.ones_like(freq), freq)
50
- x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
51
- freq_cis = self.reshape_for_broadcast(freq_cis, x)
52
- x_out = torch.view_as_real(x_ * freq_cis).flatten(3)
53
- return x_out.type_as(x)
54
-
55
- def forward(self, x):
56
  B, T, C = x.shape
57
 
58
  mask = torch.triu(torch.full((T, T), float("-inf"), device=x.device), diagonal=1)
@@ -67,91 +65,62 @@ class Attention(nn.Module):
67
  xk = xk.transpose(1, 2)
68
  xv = xv.transpose(1, 2)
69
 
70
- xq = self.apply_rope(xq, T, self.inv_freq)
71
- xk = self.apply_rope(xk, T, self.inv_freq)
72
 
73
  attn_weights = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_size)
74
  attn_weights += mask
75
  attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(xq)
76
  output = torch.matmul(attn_weights, xv)
77
  output = output.transpose(1, 2).contiguous().view(B, T, C)
78
- return self.dropout(self.wo(output))
79
-
80
 
81
  class MLP(nn.Module):
82
- """
83
- Implementation of a Multi-Layer Perceptron (MLP) sub-module.
84
-
85
- This module is a simple feed-forward network with two hidden layers
86
- used in various Transformer components like the Mixture of Experts layer.
87
- """
88
-
89
  def __init__(self, num_embed, dropout):
90
- """
91
- Constructor for the MLP.
92
-
93
- Args:
94
- num_embed (int): The number of embedding dimensions.
95
- """
96
-
97
  super().__init__()
98
- hidden = int(4 * num_embed * 2 / 3)
99
 
100
- # Define linear layers for the MLP
101
- self.w1 = nn.Linear(num_embed, hidden, bias=False)
102
- self.w2 = nn.Linear(hidden, num_embed, bias=False)
103
 
 
 
104
  self.dropout = nn.Dropout(dropout)
105
 
106
  def forward(self, x):
107
- """
108
- Forward pass of the MLP.
109
-
110
- Args:
111
- x (torch.Tensor): Input tensor of shape (batch_size, seq_len, num_embed).
112
-
113
- Returns:
114
- torch.Tensor: Output tensor after passing through the MLP (shape: batch_size, seq_len, num_embed).
115
- """
116
- return self.dropout(self.w2(F.silu(self.w1(x))))
117
-
118
  class TransformerBlock(nn.Module):
119
  """
120
  This calss will group together MultiHead Attention and
121
- MLP, so that we can copy it in Transformer
122
  """
123
 
124
- def __init__(self, num_heads, head_size, num_embed, dropout):
125
  super().__init__()
126
-
127
- self.mha = Attention(
 
 
128
  num_heads=num_heads,
129
  head_size=head_size,
130
- num_embed=num_embed,
131
- dropout=dropout
132
  )
133
-
134
- self.mlp = MLP(num_embed = num_embed, dropout = dropout)
135
-
136
  # add the layer normalization
137
- self.norm1 = RMSNorm(num_embed)
138
- self.norm2 = RMSNorm(num_embed)
139
-
140
- def forward(self, x):
141
- """
142
- Decodes the input sequence.
143
-
144
- Args:
145
- x (torch.Tensor): A tensor of shape (batch_size, sequence_length, embedding_dim).
146
- memory (torch.Tensor): A tensor of shape (batch_size, memory_length, embedding_dim).
147
-
148
- Returns:
149
- torch.Tensor: A tensor of shape (batch_size, sequence_length, embedding_dim).
150
- """
151
- #print(x.shape)
152
- x = x + self.mha(self.norm1(x))
153
- x = x + self.mlp(self.norm2(x))
154
-
155
  return x
156
 
157
 
@@ -161,82 +130,80 @@ class Transformer(nn.Module):
161
  # a simple lookup table that stores embeddings of a fixed dictionary and size
162
  # each token directly reads off the logits for the next token from a lookup table
163
  # see more: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
164
- self.model_type = 'Prome'
165
  self.vocab_size = kwargs.get("vocab_size", 100)
166
  self.num_embed = kwargs.get("num_embed", 32)
167
- self.block_size = kwargs.get("block_size", 8)
168
  self.num_heads = kwargs.get("num_heads", 4)
169
- self.head_size = kwargs.get("head_size", 128)
170
  self.num_layers = kwargs.get("num_layers", 4)
 
171
  self.dropout = kwargs.get("dropout", 0.2)
172
- self.max_seq_len = kwargs.get("max_sqe_len", 1024)
173
  # each token reads the logits for the next token from a lookup table
174
  self.token_embedding_table = nn.Embedding(self.vocab_size, self.num_embed)
175
  # each position from 0 to block_size-1 will get its embedding
176
- #self.position_embedding_table = nn.Embedding(self.max_seq_len, self.num_embed)
177
-
178
- self.decoder = nn.Sequential(
179
- *[
180
- TransformerBlock(
181
- num_heads=self.num_heads,
182
- head_size=self.head_size,
183
- num_embed=self.num_embed,
184
- dropout=self.dropout,
185
- )
186
- for _ in range(self.num_layers)
187
- ]
188
- )
189
-
190
  self.lm_head = nn.Linear(self.num_embed, self.vocab_size)
 
 
 
 
 
 
 
191
 
192
  def forward(self, idx, targets=None):
193
  B, T = idx.shape
194
  # idx and targets are (B,T) tensor of integers
195
  # the token_emb is (B, T, C), C = NUM_EMBED
196
  x = self.token_embedding_table(idx)
197
- # (T, C)
198
- #posit_emb = self.position_embedding_table(torch.arange(T, device=DEVICE))
199
-
200
- #x = token_emb + posit_emb
201
 
202
- x = self.decoder(x)
 
 
 
203
 
 
 
204
  # (B, T, vocab_size)
205
  logits = self.lm_head(x)
206
-
207
- # Compute the loss
208
  if targets != None:
209
  # cross_entropy accepts inputs in a (batch_size, num_classes)
210
  # so we need to reformat our logits dimensions to
211
  # (batch_size * time, dim_vocabulary), time = block_size
212
- #logits = logits.to(dtype=torch.float32)
213
-
214
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
215
  else:
216
  loss = None
217
-
218
  return logits, loss
219
 
220
- def generate(self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 0.6, top_p: float = 0.9):
221
  for _ in range(max_new_tokens):
222
  idx_crop = idx[:, -self.max_seq_len:]
223
 
 
224
  logits, loss = self.forward(idx_crop)
225
  logits = logits[:, -1, :]
226
 
227
  if temperature > 0:
228
- probs = F.softmax(logits / temperature, dim=-1)
229
  idx_next = self.sample_top_p(probs, top_p)
230
  else:
231
  probs = F.softmax(logits, dim=-1)
232
  idx_next = torch.multinomial(probs, num_samples=1)
233
  idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
234
- return idx
235
 
236
  def sample_top_p(self, probs: torch.Tensor, top_p: float) -> torch.Tensor:
237
  sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
238
  cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
239
-
240
  # Create a mask for top-p filtering
241
  top_p_mask = cumulative_probs <= top_p
242
  top_p_mask[..., 1:] = top_p_mask[..., :-1].clone()
 
1
+ import math
2
  import torch
3
  import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
7
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
8
+ t = torch.arange(end, device=freqs.device)
9
+ freqs = torch.outer(t, freqs)
10
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
11
+ return freqs_cis
12
+
13
+ def reshape_for_broadcast(freqs_cis, x):
14
+ batch_size, num_heads, seq_len, head_size = x.shape
15
+ freqs_cis = freqs_cis[:seq_len]
16
+ shape = [1, 1, seq_len, head_size // 2]
17
+ return freqs_cis.view(*shape)
18
+
19
+ def apply_rope(x, position, freqs_cis):
20
+ x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
21
+ freqs_cis = reshape_for_broadcast(freqs_cis, x)
22
+ x_out = torch.view_as_real(x_ * freqs_cis).flatten(3)
23
+ return x_out.type_as(x)
24
 
25
  class RMSNorm(torch.nn.Module):
26
  def __init__(self, dim: int, eps: float = 1e-6):
 
35
  output = self._norm(x.float()).type_as(x)
36
  return output * self.weight
37
 
 
38
  class Attention(nn.Module):
39
  """
40
  Multi-head Self-Attention with RoPE
41
  """
42
 
43
+ def __init__(self, num_heads, head_size, num_embed):
44
  super().__init__()
45
  self.num_heads = num_heads
46
  self.head_size = head_size
 
49
  self.wk = nn.Linear(num_embed, num_heads * head_size, bias = False)
50
  self.wv = nn.Linear(num_embed, num_heads * head_size, bias = False)
51
  self.wo = nn.Linear(num_heads * head_size, num_embed, bias = False)
52
+
53
+ def forward(self, x, freqs_cis):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  B, T, C = x.shape
55
 
56
  mask = torch.triu(torch.full((T, T), float("-inf"), device=x.device), diagonal=1)
 
65
  xk = xk.transpose(1, 2)
66
  xv = xv.transpose(1, 2)
67
 
68
+ xq = apply_rope(xq, T, freqs_cis)
69
+ xk = apply_rope(xk, T, freqs_cis)
70
 
71
  attn_weights = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_size)
72
  attn_weights += mask
73
  attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(xq)
74
  output = torch.matmul(attn_weights, xv)
75
  output = output.transpose(1, 2).contiguous().view(B, T, C)
76
+ return self.wo(output)
 
77
 
78
  class MLP(nn.Module):
 
 
 
 
 
 
 
79
  def __init__(self, num_embed, dropout):
 
 
 
 
 
 
 
80
  super().__init__()
81
+ self.num_embed = num_embed
82
 
83
+ hidden_dim = 3 * int(num_embed * 2 / 3)
 
 
84
 
85
+ self.linear1 = nn.Linear(num_embed, hidden_dim)
86
+ self.linear2 = nn.Linear(hidden_dim, num_embed)
87
  self.dropout = nn.Dropout(dropout)
88
 
89
  def forward(self, x):
90
+ x = self.linear1(x)
91
+ x = F.silu(x)
92
+ x = self.linear2(x)
93
+ x = self.dropout(x)
94
+ return x
95
+
 
 
 
 
 
96
  class TransformerBlock(nn.Module):
97
  """
98
  This calss will group together MultiHead Attention and
99
+ FeedForward NN, so that we can copy it in Transformer
100
  """
101
 
102
+ def __init__(self, num_heads, num_embed, dropout):
103
  super().__init__()
104
+ self.num_heads = num_heads
105
+ self.num_embed = num_embed
106
+ head_size = num_embed // num_heads
107
+ self.sa = Attention(
108
  num_heads=num_heads,
109
  head_size=head_size,
110
+ num_embed=num_embed
 
111
  )
112
+ self.ffwd = MLP(num_embed=num_embed, dropout=dropout)
 
 
113
  # add the layer normalization
114
+ self.ln1 = RMSNorm(num_embed)
115
+ self.ln2 = RMSNorm(num_embed)
116
+
117
+ def forward(self, x, freqs_cis):
118
+ # "x +" is the skip (or residual) connection
119
+ # it helps with optimization
120
+ # also we apply layer normalization before self-attention
121
+ # and feed-forward (a reshufle from original paper)
122
+ x = x + self.sa(self.ln1(x), freqs_cis)
123
+ x = x + self.ffwd(self.ln2(x))
 
 
 
 
 
 
 
 
124
  return x
125
 
126
 
 
130
  # a simple lookup table that stores embeddings of a fixed dictionary and size
131
  # each token directly reads off the logits for the next token from a lookup table
132
  # see more: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
 
133
  self.vocab_size = kwargs.get("vocab_size", 100)
134
  self.num_embed = kwargs.get("num_embed", 32)
 
135
  self.num_heads = kwargs.get("num_heads", 4)
 
136
  self.num_layers = kwargs.get("num_layers", 4)
137
+ self.max_seq_len = kwargs.get("max_seq_len", 1024)
138
  self.dropout = kwargs.get("dropout", 0.2)
 
139
  # each token reads the logits for the next token from a lookup table
140
  self.token_embedding_table = nn.Embedding(self.vocab_size, self.num_embed)
141
  # each position from 0 to block_size-1 will get its embedding
142
+ #self.position_embedding_table = nn.Embedding(self.block_size, self.num_embed)
143
+ self.blocks = nn.ModuleList([
144
+ TransformerBlock(
145
+ num_heads=self.num_heads,
146
+ num_embed=self.num_embed,
147
+ dropout=self.dropout
148
+ )
149
+ for _ in range(self.num_layers)
150
+ ])
151
+ # we add the layer norm before the Linear layer
 
 
 
 
152
  self.lm_head = nn.Linear(self.num_embed, self.vocab_size)
153
+ self.norm = RMSNorm(self.num_embed)
154
+
155
+ self.freqs_cis = precompute_freqs_cis(
156
+ self.num_embed//self.num_heads,
157
+ self.max_seq_len * 2,
158
+ 500000,
159
+ )
160
 
161
  def forward(self, idx, targets=None):
162
  B, T = idx.shape
163
  # idx and targets are (B,T) tensor of integers
164
  # the token_emb is (B, T, C), C = NUM_EMBED
165
  x = self.token_embedding_table(idx)
 
 
 
 
166
 
167
+ freq = self.freqs_cis[:self.max_seq_len]
168
+ # apply one head of self-attention
169
+ for block in self.blocks:
170
+ x = block(x, freq)
171
 
172
+ x = self.norm(x)
173
+
174
  # (B, T, vocab_size)
175
  logits = self.lm_head(x)
176
+ # compute the loss
 
177
  if targets != None:
178
  # cross_entropy accepts inputs in a (batch_size, num_classes)
179
  # so we need to reformat our logits dimensions to
180
  # (batch_size * time, dim_vocabulary), time = block_size
 
 
181
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
182
  else:
183
  loss = None
 
184
  return logits, loss
185
 
186
+ def generate(self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 0.7, top_p: float = 0.9):
187
  for _ in range(max_new_tokens):
188
  idx_crop = idx[:, -self.max_seq_len:]
189
 
190
+ freq = self.freqs_cis[:self.max_seq_len]
191
  logits, loss = self.forward(idx_crop)
192
  logits = logits[:, -1, :]
193
 
194
  if temperature > 0:
195
+ probs = F.softmax(logits / temperature, dim=-1)
196
  idx_next = self.sample_top_p(probs, top_p)
197
  else:
198
  probs = F.softmax(logits, dim=-1)
199
  idx_next = torch.multinomial(probs, num_samples=1)
200
  idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
201
+ return idx[0]
202
 
203
  def sample_top_p(self, probs: torch.Tensor, top_p: float) -> torch.Tensor:
204
  sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
205
  cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
206
+
207
  # Create a mask for top-p filtering
208
  top_p_mask = cumulative_probs <= top_p
209
  top_p_mask[..., 1:] = top_p_mask[..., :-1].clone()