Sin2pi commited on
Commit
9e76dcc
·
verified ·
1 Parent(s): eaeb462

Update model_simple.py

Browse files
Files changed (1) hide show
  1. model_simple.py +196 -116
model_simple.py CHANGED
@@ -11,7 +11,7 @@ from dataclasses import dataclass
11
  from transformers.trainer_seq2seq import Seq2SeqTrainer
12
  from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
13
  from torch.nn.functional import scaled_dot_product_attention
14
-
15
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
  dtype = torch.float32
17
  warnings.filterwarnings("ignore")
@@ -55,42 +55,34 @@ class rotary(nn.Module):
55
  x1 = x1.view(orig_shape)
56
  return torch.cat([x1.type_as(x), x2], dim=-1)
57
 
58
- def shape(self, tensor: torch.Tensor, ctx: int, batch: int):
59
- return tensor.view(batch, ctx, self.head, self.head_dim).transpose(1, 2).contiguous()
60
-
61
- def reshape_to_output(self, attn_output, batch, ctx):
62
- return attn_output.permute(0, 2, 1, 3).reshape(batch, ctx, self.dims).contiguous()
63
-
64
- def qkv_init(dims: int, head: int):
65
  head_dim = dims // head
 
66
  q = nn.Linear(dims, dims)
67
  k = nn.Linear(dims, dims, bias=False)
68
  v = nn.Linear(dims, dims)
69
  o = nn.Linear(dims, dims)
70
- lna = nn.LayerNorm(dims, bias=False)
71
- lnb = nn.LayerNorm(head_dim, bias=False)
72
- return q, k, v, o, lna, lnb
73
 
74
- def create_qkv(dims, head, q, k, v, x, xa=None):
75
- z = default(xa, x)
76
  head_dim = dims // head
77
  scale = head_dim ** -0.25
78
  q = q(x) * scale
79
- k = k(z) * scale
80
- v = v(z)
81
- batch, ctx, dims = q.shape
82
  def _shape(tensor):
83
  return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous()
84
  return _shape(q), _shape(k), _shape(v)
85
 
86
- def calculate_attention(q, k, v, mask=None, temperature=1.0):
87
- batch, head, ctx, dims = q.shape
88
  scaled_q = q
89
  if temperature != 1.0 and temperature > 0:
90
  scaled_q = q * (1.0 / temperature)**.5
91
- a = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
92
- out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
93
- return out, None
 
94
 
95
  class LocalAttentionModule(nn.Module):
96
  def __init__(self, head_dim: int):
@@ -105,43 +97,56 @@ class LocalAttentionModule(nn.Module):
105
  return x
106
 
107
  class attentiona(nn.Module):
108
- def __init__(self, dims: int, head: int, max_iters: int = 3, threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1):
109
  super(attentiona, self).__init__()
110
-
111
- self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
112
  self.dims = dims
113
  self.head = head
114
  self.head_dim = dims // head
115
- self.dropout = dropout
116
- self.max_iters = max_iters
117
- self.rope = rotary(dims=dims, head=head)
118
-
119
  self.threshold = nn.Parameter(torch.tensor(threshold))
120
  self.factor = nn.Parameter(torch.tensor(factor))
 
 
 
 
 
 
 
 
 
 
 
121
  self.attn_local = LocalAttentionModule(self.head_dim)
122
 
123
  def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None):
124
- z = default(xa, x)
125
-
126
- q, k, v = create_qkv(self.dims, self.head, self.q, self.k, self.v, self.lna(x), self.lna(z))
 
 
 
 
127
  iteration = 0
128
- prev_attn = torch.zeros_like(q)
129
- attn_out = torch.zeros_like(q)
130
  threshold = self.threshold.item()
131
  factor = self.factor.item()
 
132
 
133
- q_cur = q
134
- while iteration < self.max_iters:
135
- eff_span = z.shape[1]
 
136
  if eff_span == 0:
137
  break
138
 
139
- q_iter = q_cur[:, :, :eff_span, :]
140
- k_iter = k[:, :, :eff_span, :]
141
- v_iter = v[:, :, :eff_span, :]
142
- q = self.attn_local.query_module(q_iter)
143
- k = self.attn_local.key_module(k_iter)
144
- v = self.attn_local.value_module(v_iter)
145
 
146
  iter_mask = None
147
  if mask is not None:
@@ -150,78 +155,63 @@ class attentiona(nn.Module):
150
  elif mask.dim() == 2:
151
  iter_mask = mask[:eff_span, :eff_span]
152
 
153
- q = self.rope(q, q.shape[2])
154
- k = self.rope(k, k.shape[2])
155
-
156
- attn_iter, _ = calculate_attention(
157
- self.lnb(q), self.lnb(k), v, mask=iter_mask)
158
-
159
- out_span = self.attn_local._reshape_to_output(attn_iter)
160
- if out_span.dim() == 4:
161
- b, h, s, d = out_span.shape
162
- proj_span = self.attn_local.out_proj(out_span.view(-1, d)).view(b, h, s, -1)
163
- elif out_span.dim() == 3:
164
- b, s, d = out_span.shape
165
- if d == self.head_dim:
166
- proj_span = self.attn_local.out_proj(out_span.view(-1, d)).view(b, 1, s, -1)
167
- elif d == self.head * self.head_dim:
168
- proj_span = out_span.view(b, self.head, s, self.head_dim)
169
- else:
170
- raise RuntimeError(f"Cannot reshape out_span of shape {out_span.shape} to [b, h, s, head_dim]")
171
- else:
172
- raise RuntimeError(f"Unexpected out_span shape: {out_span.shape}")
173
-
174
- iter_out = torch.zeros_like(q_cur)
175
- iter_out[:, :, :eff_span, :] = proj_span
176
- diff = torch.abs(iter_out - prev_attn).mean()
177
  dthresh = threshold + factor * diff
 
178
  if diff < dthresh and iteration > 0:
179
  attn_out = iter_out
180
  break
181
 
182
- prev_attn = iter_out.clone()
183
- q_cur = q_cur + iter_out
184
  attn_out = iter_out
185
  iteration += 1
186
 
187
  output = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
188
  return self.o(output), None
189
 
190
- def _slide_win_local(self, x: Tensor, win_size: int, span_len: int,
191
- mask: Optional[Tensor] = None) -> Tensor:
192
- batch, ctx, dims = x.shape
193
  output = torch.zeros_like(x)
194
  num_win = (ctx + win_size - 1) // win_size
195
 
196
  for i in range(num_win):
197
- q_start = i * win_size
198
- q_end = min(q_start + win_size, ctx)
199
- q_len = q_end - q_start
200
- if q_len == 0:
201
  continue
202
 
203
- kv_start = max(0, q_end - span_len)
204
- kv_end = q_end
205
- query_win = x[:, q_start:q_end, :]
206
- key_win = x[:, kv_start:kv_end, :]
207
 
208
  win_mask = None
209
  if mask is not None:
210
  if mask.dim() == 4:
211
- win_mask = mask[:, :, q_start:q_end, kv_start:kv_end]
212
  elif mask.dim() == 2:
213
- win_mask = mask[q_start:q_end, kv_start:kv_end]
214
 
215
- attn_out_win, _ = self._focus(
216
- x=query_win,
217
- xa=key_win,
218
  mask=win_mask)
219
- output[:, q_start:q_end, :] = attn_out_win
220
  return output
221
 
222
  def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None,
223
- use_sliding_window: bool = False, win_size: int = 512, span_len: int = 1024) -> Tensor:
224
- if use_sliding_window:
225
  return self._slide_win_local(x, win_size, span_len, mask)
226
  else:
227
  output, _ = self._focus(x, xa, mask)
@@ -230,7 +220,6 @@ class attentiona(nn.Module):
230
  class attentionb(nn.Module):
231
  def __init__(self, dims: int, head: int):
232
  super(attentionb, self).__init__()
233
-
234
  self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
235
  self.dims = dims
236
  self.head = head
@@ -240,10 +229,8 @@ class attentionb(nn.Module):
240
  def forward(self, x: Tensor, xa = None, mask = None):
241
  z = default(xa, x)
242
  q, k, v = create_qkv(self.dims, self.head, self.q, self.k, self.v, self.lna(x), self.lna(z))
243
-
244
  q = self.rope(q, q.shape[2])
245
  k = self.rope(k, k.shape[2])
246
-
247
  a = scaled_dot_product_attention(self.lnb(q), self.lnb(k), v, is_causal=mask is not None and q.shape[1] > 1)
248
  out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
249
  return self.o(out)
@@ -254,56 +241,61 @@ class Residual(nn.Module):
254
 
255
  self.lna = nn.LayerNorm(dims, bias=False)
256
  self.attnb = attentionb(dims, head)
257
- self.attna = attentiona(dims, head, max_iters=3)
258
  self.mlp = nn.Sequential(Linear(dims, dims*4), get_activation(act), Linear(dims*4, dims))
259
 
260
- def forward(self, x, xa = None, mask = None) -> Tensor:
261
-
262
- x = x + self.attnb(self.lna(x), xa=None, mask=mask)
263
  if xa is not None:
264
- x = x + self.attna(self.lna(x), xa, mask=None, use_sliding_window=True, win_size=500, span_len=1500)
265
  x = x + self.mlp(self.lna(x))
266
  return x
267
-
268
  class processor(nn.Module):
269
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
270
  super(processor, self).__init__()
271
 
272
- self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
273
- self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
274
- self.posin = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
 
 
275
 
276
  act_fn = get_activation(act)
277
- self.encoder = nn.Sequential(
278
  Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
279
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
280
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
281
 
282
- self.bA = nn.ModuleList([Residual(dims=dims, head=head, act=act_fn) for _ in range(layer)])
283
- self.bB = nn.ModuleList([Residual(dims=dims, head=head, act=act_fn) for _ in range(layer)])
284
 
285
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
286
  self.register_buffer("mask", mask, persistent=False)
287
- self.ln = nn.LayerNorm(dims, device=device, dtype=dtype)
288
 
289
  def forward(self, x, xa, sequential=False) -> Tensor:
290
 
291
- x = self.token(x.long()) + self.positional[:x.shape[1]]
292
- xa = self.encoder(xa).permute(0, 2, 1)
293
- xa = xa + self.posin(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
294
 
295
  for b in chain(self.bA or []):
296
  xa = b(x=xa, xa=None, mask=None)
 
 
 
 
297
 
298
- for b in chain(self.bB or []):
299
- x = b(x=x, xa=None, mask=self.mask)
300
- x = b(x, xa=xa, mask=None)
 
 
301
 
302
  x = nn.functional.dropout(x, p=0.001, training=self.training)
303
  x = self.ln(x)
304
- x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
305
  return x
306
-
307
  def init_weights(self):
308
  print("Initializing model weights...")
309
  self.apply(self._init_weights)
@@ -338,7 +330,7 @@ class Model(nn.Module):
338
  def _init_weights(self, module):
339
  self.init_counts = {
340
  "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
341
- "Conv2d": 0, "processor": 0, "attentiona": 0, "attentionb": 0, "Residual": 0}
342
  for name, module in self.named_modules():
343
  if isinstance(module, RMSNorm):
344
  nn.init.ones_(module.weight)
@@ -359,10 +351,9 @@ class Model(nn.Module):
359
  if module.bias is not None:
360
  nn.init.zeros_(module.bias)
361
  self.init_counts["Conv2d"] += 1
362
- elif isinstance(module, attentiona):
363
- self.init_counts["attentiona"] += 1
364
- elif isinstance(module, attentionb):
365
- self.init_counts["attentionb"] += 1
366
  self.init_counts["processor"] += 1
367
 
368
  def init_weights(self):
@@ -373,3 +364,92 @@ class Model(nn.Module):
373
  if count > 0:
374
  print(f"{module_type}: {count}")
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from transformers.trainer_seq2seq import Seq2SeqTrainer
12
  from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
13
  from torch.nn.functional import scaled_dot_product_attention
14
+ from echoutils import *
15
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
  dtype = torch.float32
17
  warnings.filterwarnings("ignore")
 
55
  x1 = x1.view(orig_shape)
56
  return torch.cat([x1.type_as(x), x2], dim=-1)
57
 
58
+ def qkvinit(dims: int, head: int):
 
 
 
 
 
 
59
  head_dim = dims // head
60
+ scale = head_dim ** -0.5
61
  q = nn.Linear(dims, dims)
62
  k = nn.Linear(dims, dims, bias=False)
63
  v = nn.Linear(dims, dims)
64
  o = nn.Linear(dims, dims)
65
+ return q, k, v, o, scale
 
 
66
 
67
+ def create_qkv(dims, head, q, k, v, x, xa):
 
68
  head_dim = dims // head
69
  scale = head_dim ** -0.25
70
  q = q(x) * scale
71
+ k = k(xa) * scale
72
+ v = v(xa)
73
+ batch, ctx, dims = x.shape
74
  def _shape(tensor):
75
  return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous()
76
  return _shape(q), _shape(k), _shape(v)
77
 
78
+ def calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True):
 
79
  scaled_q = q
80
  if temperature != 1.0 and temperature > 0:
81
  scaled_q = q * (1.0 / temperature)**.5
82
+
83
+ out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
84
+ # out = scaled_dot_product_attention(scaled_q, k, v, attn_mask=attn_mask, is_causal=is_causal if attn_mask is None else False)
85
+ return out
86
 
87
  class LocalAttentionModule(nn.Module):
88
  def __init__(self, head_dim: int):
 
97
  return x
98
 
99
  class attentiona(nn.Module):
100
+ def __init__(self, dims: int, head: int, max_iterations: int = 3, threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1):
101
  super(attentiona, self).__init__()
102
+ # self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
 
103
  self.dims = dims
104
  self.head = head
105
  self.head_dim = dims // head
106
+ self.max_iterations = max_iterations
 
 
 
107
  self.threshold = nn.Parameter(torch.tensor(threshold))
108
  self.factor = nn.Parameter(torch.tensor(factor))
109
+ self.dropout = dropout
110
+
111
+ self.q = nn.Linear(dims, dims)
112
+ self.k = nn.Linear(dims, dims, bias=False)
113
+ self.v = nn.Linear(dims, dims)
114
+ self.o = nn.Linear(dims, dims)
115
+
116
+ self.lna = nn.LayerNorm(dims, bias=False)
117
+ self.lnb = nn.LayerNorm(dims, bias=False)
118
+ self.lnc = nn.LayerNorm(self.head_dim, bias=False)
119
+ self.lnd = nn.LayerNorm(self.head_dim, bias=False)
120
  self.attn_local = LocalAttentionModule(self.head_dim)
121
 
122
  def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None):
123
+ q = self.q(self.lna(x))
124
+ k = self.k(self.lnb(x if xa is None else xa))
125
+ v = self.v(self.lnb(x if xa is None else xa))
126
+ query = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
127
+ key = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
128
+ value = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
129
+
130
  iteration = 0
131
+ prev_out = torch.zeros_like(query)
132
+ attn_out = torch.zeros_like(query)
133
  threshold = self.threshold.item()
134
  factor = self.factor.item()
135
+ qcur = query
136
 
137
+ while iteration < self.max_iterations:
138
+ eff_span = min(x.shape[1], qcur.size(1), key.size(1))
139
+ if xa is not None:
140
+ eff_span = min(eff_span, xa.shape[1])
141
  if eff_span == 0:
142
  break
143
 
144
+ qiter = qcur[:, :, :eff_span, :]
145
+ kiter = key[:, :, :eff_span, :]
146
+ viter = value[:, :, :eff_span, :]
147
+ q = self.attn_local.query_module(qiter)
148
+ k = self.attn_local.key_module(kiter)
149
+ v = self.attn_local.value_module(viter)
150
 
151
  iter_mask = None
152
  if mask is not None:
 
155
  elif mask.dim() == 2:
156
  iter_mask = mask[:eff_span, :eff_span]
157
 
158
+ attn_iter = calculate_attention(
159
+ self.lnc(q), self.lnd(k), v,
160
+ mask=iter_mask,
161
+ is_causal=True)
162
+
163
+ iter_out = torch.zeros_like(qcur)
164
+ iter_out[:, :, :eff_span, :] = attn_iter
165
+ diff = torch.abs(iter_out - prev_out).mean()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  dthresh = threshold + factor * diff
167
+
168
  if diff < dthresh and iteration > 0:
169
  attn_out = iter_out
170
  break
171
 
172
+ prev_out = iter_out.clone()
173
+ qcur = qcur + iter_out
174
  attn_out = iter_out
175
  iteration += 1
176
 
177
  output = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
178
  return self.o(output), None
179
 
180
+ def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, mask: Optional[Tensor] = None, is_causal: bool = False) -> Tensor:
181
+
182
+ batch, ctx, dims = x.size()
183
  output = torch.zeros_like(x)
184
  num_win = (ctx + win_size - 1) // win_size
185
 
186
  for i in range(num_win):
187
+ qstart = i * win_size
188
+ qend = min(qstart + win_size, ctx)
189
+ current_win_qlen = qend - qstart
190
+ if current_win_qlen == 0:
191
  continue
192
 
193
+ kvstart = max(0, qend - span_len)
194
+ kvend = qend
195
+ qwin = x[:, qstart:qend, :]
196
+ kwin = x[:, kvstart:kvend, :]
197
 
198
  win_mask = None
199
  if mask is not None:
200
  if mask.dim() == 4:
201
+ win_mask = mask[:, :, qstart:qend, kvstart:kvend]
202
  elif mask.dim() == 2:
203
+ win_mask = mask[qstart:qend, kvstart:kvend]
204
 
205
+ attn_out, _ = self._focus(
206
+ x=qwin,
207
+ xa=kwin,
208
  mask=win_mask)
209
+ output[:, qstart:qend, :] = attn_out
210
  return output
211
 
212
  def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None,
213
+ use_sliding_win: bool = False, win_size: int = 512, span_len: int = 1024) -> Tensor:
214
+ if use_sliding_win:
215
  return self._slide_win_local(x, win_size, span_len, mask)
216
  else:
217
  output, _ = self._focus(x, xa, mask)
 
220
  class attentionb(nn.Module):
221
  def __init__(self, dims: int, head: int):
222
  super(attentionb, self).__init__()
 
223
  self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
224
  self.dims = dims
225
  self.head = head
 
229
  def forward(self, x: Tensor, xa = None, mask = None):
230
  z = default(xa, x)
231
  q, k, v = create_qkv(self.dims, self.head, self.q, self.k, self.v, self.lna(x), self.lna(z))
 
232
  q = self.rope(q, q.shape[2])
233
  k = self.rope(k, k.shape[2])
 
234
  a = scaled_dot_product_attention(self.lnb(q), self.lnb(k), v, is_causal=mask is not None and q.shape[1] > 1)
235
  out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
236
  return self.o(out)
 
241
 
242
  self.lna = nn.LayerNorm(dims, bias=False)
243
  self.attnb = attentionb(dims, head)
244
+ self.attna = attentiona(dims, head, max_iterations=3)
245
  self.mlp = nn.Sequential(Linear(dims, dims*4), get_activation(act), Linear(dims*4, dims))
246
 
247
+ def forward(self, x, xa = None, mask = None) -> Tensor:
248
+ x = x + self.attnb(self.lna(x), xa=None, mask=mask)
 
249
  if xa is not None:
250
+ x = x + self.attna(self.lna(x), xa, mask=None, use_sliding_win=True, win_size=500, span_len=1500)
251
  x = x + self.mlp(self.lna(x))
252
  return x
253
+
254
  class processor(nn.Module):
255
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
256
  super(processor, self).__init__()
257
 
258
+ self.ln = nn.LayerNorm(dims)
259
+ self.blend = nn.Parameter(torch.tensor(0.5), requires_grad=True)
260
+ self.token_emb = nn.Embedding(vocab, dims)
261
+ self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
262
+ self.audio_emb = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
263
 
264
  act_fn = get_activation(act)
265
+ self.audio_enc = nn.Sequential(
266
  Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
267
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
268
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
269
 
270
+ self.bA = nn.ModuleList([Residual(dims, head, act_fn) for _ in range(layer)])
 
271
 
272
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
273
  self.register_buffer("mask", mask, persistent=False)
 
274
 
275
  def forward(self, x, xa, sequential=False) -> Tensor:
276
 
277
+ x = self.token_emb(x.long()) + self.positions[:x.shape[1]]
278
+ xa = self.audio_enc(xa).permute(0, 2, 1)
279
+ xa = xa + self.audio_emb(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
280
 
281
  for b in chain(self.bA or []):
282
  xa = b(x=xa, xa=None, mask=None)
283
+ x = b(x=x, xa=None, mask=self.mask)
284
+ x = b(x=x, xa=xa, mask=None)
285
+ # xc = b(torch.cat([x, xa], dim=1), xa=None, mask=self.mask)
286
+ # x = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None)
287
 
288
+ # if sequential:
289
+ # x = y
290
+ # else:
291
+ # a = torch.sigmoid(self.blend)
292
+ # x = a * y + (1 - a) * x
293
 
294
  x = nn.functional.dropout(x, p=0.001, training=self.training)
295
  x = self.ln(x)
296
+ x = x @ torch.transpose(self.token_emb.weight.to(dtype), 0, 1).float()
297
  return x
298
+
299
  def init_weights(self):
300
  print("Initializing model weights...")
301
  self.apply(self._init_weights)
 
330
  def _init_weights(self, module):
331
  self.init_counts = {
332
  "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
333
+ "Conv2d": 0, "processor": 0, "attention": 0, "Residual": 0}
334
  for name, module in self.named_modules():
335
  if isinstance(module, RMSNorm):
336
  nn.init.ones_(module.weight)
 
351
  if module.bias is not None:
352
  nn.init.zeros_(module.bias)
353
  self.init_counts["Conv2d"] += 1
354
+ elif isinstance(module, Residual):
355
+ self.init_counts["Residual"] += 1
356
+ elif isinstance(module, processor):
 
357
  self.init_counts["processor"] += 1
358
 
359
  def init_weights(self):
 
364
  if count > 0:
365
  print(f"{module_type}: {count}")
366
 
367
+ def main():
368
+ token = ""
369
+ log_dir = os.path.join('D:/newmodel/output/logs/', datetime.now().strftime('%m-%d_%H_%M_%S'))
370
+ os.makedirs(log_dir, exist_ok=True)
371
+ tokenizer = setup_tokenizer("D:/newmodel/mod5/tokenizer.json")
372
+
373
+ extract_args = {
374
+ "waveform": False,
375
+ "spec": False,
376
+ "f0": False,
377
+ "f0t": False,
378
+ "pitch": True,
379
+ "harmonics": False,
380
+ "aperiodics": False,
381
+ "phase_mod": False,
382
+ "crepe": False,
383
+ "sample_rate": 16000,
384
+ "hop_length": 256,
385
+ "mode": "mean",
386
+ "debug": False,
387
+ }
388
+
389
+ param = Dimensions(
390
+ vocab=40000,
391
+ mels=128,
392
+ ctx=2048,
393
+ dims=512,
394
+ head=4,
395
+ layer=4,
396
+ act="swish",
397
+ )
398
+
399
+ train_dataset, test_dataset = prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
400
+ load_saved=False, save_dataset=False, cache_dir=None, extract_args=extract_args, max_ctx=param.ctx)
401
+
402
+ model = Model(param).to('cuda')
403
+ print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
404
+ print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
405
+
406
+ from functools import partial
407
+ metrics_fn = partial(compute_metrics, print_pred=True, num_samples=1, tokenizer=tokenizer, model=model)
408
+
409
+ training_args = Seq2SeqTrainingArguments(
410
+ output_dir=log_dir,
411
+ per_device_train_batch_size=1,
412
+ per_device_eval_batch_size=1,
413
+ max_steps=1000,
414
+ eval_steps=100,
415
+ save_steps=1000,
416
+ warmup_steps=100,
417
+ logging_steps=10,
418
+ logging_dir=log_dir,
419
+ logging_strategy="steps",
420
+ eval_strategy="steps",
421
+ save_strategy="no",
422
+ report_to=["tensorboard"],
423
+ push_to_hub=False,
424
+ save_total_limit=1,
425
+ label_names=["labels"],
426
+ save_safetensors=False,
427
+ eval_on_start=False,
428
+ batch_eval_metrics=False,
429
+ disable_tqdm=False,
430
+ include_tokens_per_second=True,
431
+ include_num_input_tokens_seen=True,
432
+ learning_rate=0.00025,
433
+ weight_decay=0.025,
434
+ )
435
+
436
+ optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-8, weight_decay=training_args.weight_decay, betas=(0.9, 0.999),
437
+ amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
438
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
439
+
440
+ trainer = Seq2SeqTrainer(
441
+ args=training_args,
442
+ model=model,
443
+ train_dataset=train_dataset,
444
+ eval_dataset=test_dataset,
445
+ data_collator=DataCollator(tokenizer=tokenizer),
446
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
447
+ compute_metrics=metrics_fn,
448
+ optimizers=(optimizer, scheduler)
449
+ )
450
+
451
+ model.init_weights()
452
+ trainer.train()
453
+ if __name__ == "__main__":
454
+
455
+ main()