Sin2pi commited on
Commit
14b7fc4
·
verified ·
1 Parent(s): 2b8f805

Update model_simple.py

Browse files
Files changed (1) hide show
  1. model_simple.py +305 -47
model_simple.py CHANGED
@@ -55,73 +55,225 @@ class rotary(nn.Module):
55
  x1 = x1.view(orig_shape)
56
  return torch.cat([x1.type_as(x), x2], dim=-1)
57
 
58
- class attention(nn.Module):
59
- def __init__(self, dims: int, head: int):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  super(attention, self).__init__()
 
 
61
  self.dims = dims
62
  self.head = head
63
  self.head_dim = dims // head
64
- self.q = nn.Linear(dims, dims).to(device, dtype)
65
- self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
66
- self.v = nn.Linear(dims, dims).to(device, dtype)
67
- self.o = nn.Linear(dims, dims).to(device, dtype)
68
- self.rope = rotary(dims=dims, head=head)
69
- self.lny = nn.LayerNorm(self.head_dim, bias = False)
70
- self.lnx = nn.LayerNorm(dims, bias = False)
71
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def forward(self, x: Tensor, xa = None, mask = None):
73
- q = self.q(self.lnx(x))
74
- k = self.k(self.lnx(x if xa is None else xa))
75
- v = self.v(self.lnx(x if xa is None else xa))
76
- q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
77
- k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
78
- v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
79
  q = self.rope(q, q.shape[2])
80
  k = self.rope(k, k.shape[2])
81
- a = scaled_dot_product_attention(self.lny(q), self.lny(k), v, is_causal=mask is not None and q.shape[1] > 1)
 
82
  out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
83
  return self.o(out)
84
 
85
- class tgate(nn.Module):
86
- def __init__(self, dims, num_types=4):
87
- super().__init__()
88
- self.gates = nn.ModuleList([nn.Sequential(Linear(dims, 1), nn.Sigmoid()) for _ in range(num_types)])
89
- self.classifier = nn.Sequential(Linear(dims, num_types), nn.Softmax(dim=-1))
90
- def forward(self, x):
91
- types = self.classifier(x)
92
- gates = torch.stack([gate(x) for gate in self.gates], dim=-1)
93
- cgate = torch.sum(gates * types.unsqueeze(2), dim=-1)
94
- return cgate
95
-
96
- class Residual(nn.Module):
97
- _seen = set()
98
  def __init__(self, dims: int, head: int, act: str = "silu"):
99
  super().__init__()
100
- self.ln = nn.LayerNorm(dims, bias = False)
101
- self.blend = nn.Parameter(torch.tensor(0.5))
102
- self.attn = attention(dims, head)
 
103
  self.mlp = nn.Sequential(Linear(dims, dims*4), get_activation(act), Linear(dims*4, dims))
104
- self.tgate = tgate(dims=dims, num_types=4*2)
105
 
106
- def forward(self, x, xa=None, mask=None) -> Tensor:
107
- xb = x + self.attn(self.ln(x), xa=None, mask=mask)
 
108
  if xa is not None:
109
- x = x + self.attn(self.ln(x), xa=xa, mask=None)
110
- b = torch.sigmoid(self.blend)
111
- x = b * xb + (1 - b) * x
112
- out = self.mlp(self.ln(x))
113
- gate = self.tgate(self.ln(x))
114
- x = x + gate * out
115
  return x
116
-
117
  class processor(nn.Module):
118
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
119
  super(processor, self).__init__()
 
120
  self.ln = nn.LayerNorm(dims, device=device, dtype=dtype)
121
  self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
122
  self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
123
  self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
124
- self.positional_sin = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
125
 
126
  act_fn = get_activation(act)
127
  self.encoder = nn.Sequential(
@@ -131,24 +283,41 @@ class processor(nn.Module):
131
 
132
  self.bA = nn.ModuleList([Residual(dims=dims, head=head, act=act_fn) for _ in range(layer)])
133
  self.bB = nn.ModuleList([Residual(dims=dims, head=head, act=act_fn) for _ in range(layer)])
 
134
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
135
  self.register_buffer("mask", mask, persistent=False)
136
 
137
- def forward(self, x, xa) -> Tensor:
138
 
139
  x = self.token(x.long()) + self.positional[:x.shape[1]]
140
  xa = self.encoder(xa).permute(0, 2, 1)
141
- xa = xa + self.positional_sin(xa.shape[1], xa.shape[-1], 10000.0).to(device, dtype)
 
142
  for b in chain(self.bA or []):
143
  xa = b(x=xa, xa=None, mask=None)
 
144
  for b in chain(self.bB or []):
145
  x = b(x=x, xa=None, mask=self.mask)
146
- x = b(x, xa=xa, mask=None)
 
 
 
 
 
 
147
  x = nn.functional.dropout(x, p=0.001, training=self.training)
148
  x = self.ln(x)
149
  x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
150
  return x
151
 
 
 
 
 
 
 
 
 
152
  class Model(nn.Module):
153
  def __init__(self, param: Dimensions):
154
  super().__init__()
@@ -211,3 +380,92 @@ class Model(nn.Module):
211
  if count > 0:
212
  print(f"{module_type}: {count}")
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  x1 = x1.view(orig_shape)
56
  return torch.cat([x1.type_as(x), x2], dim=-1)
57
 
58
+ def shape(self, tensor: torch.Tensor, ctx: int, batch: int):
59
+ return tensor.view(batch, ctx, self.head, self.head_dim).transpose(1, 2).contiguous()
60
+
61
+ def reshape_to_output(self, attn_output, batch, ctx):
62
+ return attn_output.permute(0, 2, 1, 3).reshape(batch, ctx, self.dims).contiguous()
63
+
64
+ def qkv_init(dims: int, head: int):
65
+ head_dim = dims // head
66
+ q = nn.Linear(dims, dims)
67
+ k = nn.Linear(dims, dims, bias=False)
68
+ v = nn.Linear(dims, dims)
69
+ o = nn.Linear(dims, dims)
70
+ lna = nn.LayerNorm(dims, bias=False)
71
+ lnb = nn.LayerNorm(head_dim, bias=False)
72
+ return q, k, v, o, lna, lnb
73
+
74
+ def create_qkv(dims, head, q, k, v, x, xa=None):
75
+ z = default(xa, x)
76
+ head_dim = dims // head
77
+ scale = head_dim ** -0.25
78
+ q = q(x) * scale
79
+ k = k(z) * scale
80
+ v = v(z)
81
+ batch, ctx, dims = q.shape
82
+ def _shape(tensor):
83
+ return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous()
84
+ return _shape(q), _shape(k), _shape(v)
85
+
86
+ def calculate_attention(q, k, v, mask=None, temperature=1.0):
87
+ batch, head, ctx, dims = q.shape
88
+ scaled_q = q
89
+ if temperature != 1.0 and temperature > 0:
90
+ scaled_q = q * (1.0 / temperature)**.5
91
+ a = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
92
+ out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
93
+ return out, None
94
+
95
+ class LocalAttentionModule(nn.Module):
96
+ def __init__(self, head_dim: int):
97
+ super().__init__()
98
+ self.head_dim = head_dim
99
+ self.query_module = nn.Linear(head_dim, head_dim)
100
+ self.key_module = nn.Linear(head_dim, head_dim)
101
+ self.value_module = nn.Linear(head_dim, head_dim)
102
+ self.out_proj = nn.Linear(head_dim, head_dim)
103
+
104
+ def _reshape_to_output(self, x):
105
+ return x
106
+
107
+ class attentiona(nn.Module):
108
+ def __init__(self, dims: int, head: int, max_iters: int = 3, threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1):
109
  super(attention, self).__init__()
110
+
111
+ self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
112
  self.dims = dims
113
  self.head = head
114
  self.head_dim = dims // head
115
+ self.dropout = dropout
116
+ self.max_iters = max_iters
117
+ self.rope = rotary(dims=dims, head=head)
118
+
119
+ self.threshold = nn.Parameter(torch.tensor(threshold))
120
+ self.factor = nn.Parameter(torch.tensor(factor))
121
+ self.attn_local = LocalAttentionModule(self.head_dim)
122
+
123
+ def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None):
124
+ z = default(xa, x)
125
+ q, k, v = create_qkv(self.dims, self.head, self.q, self.k, self.v, self.lna(x), self.lna(z))
126
+ # q=self.lnb(q)
127
+ # k=self.lnb(k)
128
+ iteration = 0
129
+ prev_attn = torch.zeros_like(q)
130
+ attn_out = torch.zeros_like(q)
131
+ threshold = self.threshold.item()
132
+ factor = self.factor.item()
133
+
134
+ q_cur = q
135
+ while iteration < self.max_iters:
136
+ eff_span = z.shape[1]
137
+ if eff_span == 0:
138
+ break
139
+
140
+ q_iter = q_cur[:, :, :eff_span, :]
141
+ k_iter = k[:, :, :eff_span, :]
142
+ v_iter = v[:, :, :eff_span, :]
143
+ q = self.attn_local.query_module(q_iter)
144
+ k = self.attn_local.key_module(k_iter)
145
+ v = self.attn_local.value_module(v_iter)
146
+
147
+ iter_mask = None
148
+ if mask is not None:
149
+ if mask.dim() == 4:
150
+ iter_mask = mask[:, :, :eff_span, :eff_span]
151
+ elif mask.dim() == 2:
152
+ iter_mask = mask[:eff_span, :eff_span]
153
+
154
+ q = self.rope(q, q.shape[2])
155
+ k = self.rope(k, k.shape[2])
156
+
157
+ attn_iter, _ = calculate_attention(
158
+ self.lnb(q), self.lnb(k), v, mask=iter_mask)
159
+
160
+ out_span = self.attn_local._reshape_to_output(attn_iter)
161
+ if out_span.dim() == 4:
162
+ b, h, s, d = out_span.shape
163
+ proj_span = self.attn_local.out_proj(out_span.view(-1, d)).view(b, h, s, -1)
164
+ elif out_span.dim() == 3:
165
+ b, s, d = out_span.shape
166
+ if d == self.head_dim:
167
+ proj_span = self.attn_local.out_proj(out_span.view(-1, d)).view(b, 1, s, -1)
168
+ elif d == self.head * self.head_dim:
169
+ proj_span = out_span.view(b, self.head, s, self.head_dim)
170
+ else:
171
+ raise RuntimeError(f"Cannot reshape out_span of shape {out_span.shape} to [b, h, s, head_dim]")
172
+ else:
173
+ raise RuntimeError(f"Unexpected out_span shape: {out_span.shape}")
174
+
175
+ iter_out = torch.zeros_like(q_cur)
176
+ iter_out[:, :, :eff_span, :] = proj_span
177
+ diff = torch.abs(iter_out - prev_attn).mean()
178
+ dthresh = threshold + factor * diff
179
+ if diff < dthresh and iteration > 0:
180
+ attn_out = iter_out
181
+ break
182
+
183
+ prev_attn = iter_out.clone()
184
+ q_cur = q_cur + iter_out
185
+ attn_out = iter_out
186
+ iteration += 1
187
+
188
+ output = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
189
+ return self.o(output), None
190
+
191
+ def _slide_win_local(self, x: Tensor, win_size: int, span_len: int,
192
+ mask: Optional[Tensor] = None) -> Tensor:
193
+ batch, ctx, dims = x.shape
194
+ output = torch.zeros_like(x)
195
+ num_win = (ctx + win_size - 1) // win_size
196
+
197
+ for i in range(num_win):
198
+ q_start = i * win_size
199
+ q_end = min(q_start + win_size, ctx)
200
+ q_len = q_end - q_start
201
+ if q_len == 0:
202
+ continue
203
+
204
+ kv_start = max(0, q_end - span_len)
205
+ kv_end = q_end
206
+ query_win = x[:, q_start:q_end, :]
207
+ key_win = x[:, kv_start:kv_end, :]
208
+
209
+ win_mask = None
210
+ if mask is not None:
211
+ if mask.dim() == 4:
212
+ win_mask = mask[:, :, q_start:q_end, kv_start:kv_end]
213
+ elif mask.dim() == 2:
214
+ win_mask = mask[q_start:q_end, kv_start:kv_end]
215
+
216
+ attn_out_win, _ = self._focus(
217
+ x=query_win,
218
+ xa=key_win,
219
+ mask=win_mask)
220
+ output[:, q_start:q_end, :] = attn_out_win
221
+ return output
222
+
223
+ def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None,
224
+ use_sliding_window: bool = False, win_size: int = 512, span_len: int = 1024) -> Tensor:
225
+ if use_sliding_window:
226
+ return self._slide_win_local(x, win_size, span_len, mask)
227
+ else:
228
+ output, _ = self._focus(x, xa, mask)
229
+ return output
230
+
231
+ class attentionb(nn.Module):
232
+ def __init__(self, dims: int, head: int):
233
+ super(attentionb, self).__init__()
234
+ self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
235
+ self.dims = dims
236
+ self.head = head
237
+ self.head_dim = dims // head
238
+ self.rope = rotary(dims=dims, head=head)
239
+
240
  def forward(self, x: Tensor, xa = None, mask = None):
241
+ z = default(xa, x)
242
+ q, k, v = create_qkv(self.dims, self.head, self.q, self.k, self.v, self.lna(x), self.lna(z))
243
+
 
 
 
244
  q = self.rope(q, q.shape[2])
245
  k = self.rope(k, k.shape[2])
246
+
247
+ a = scaled_dot_product_attention(self.lnb(q), self.lnb(k), v, is_causal=mask is not None and q.shape[1] > 1)
248
  out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
249
  return self.o(out)
250
 
251
+ class Residual(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
252
  def __init__(self, dims: int, head: int, act: str = "silu"):
253
  super().__init__()
254
+
255
+ self.lna = nn.LayerNorm(dims, bias=False)
256
+ self.attnb = attentionb(dims, head)
257
+ self.attna = attentiona(dims, head, max_iters=3)
258
  self.mlp = nn.Sequential(Linear(dims, dims*4), get_activation(act), Linear(dims*4, dims))
 
259
 
260
+ def forward(self, x, xa = None, mask = None) -> Tensor:
261
+
262
+ x = x + self.attnb(self.lna(x), xa=None, mask=mask)
263
  if xa is not None:
264
+ x = x + self.attna(self.lna(x), xa, mask=None, use_sliding_window=True, win_size=500, span_len=1500)
265
+ x = x + self.mlp(self.lna(x))
 
 
 
 
266
  return x
267
+
268
  class processor(nn.Module):
269
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
270
  super(processor, self).__init__()
271
+
272
  self.ln = nn.LayerNorm(dims, device=device, dtype=dtype)
273
  self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
274
  self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
275
  self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
276
+ self.posin = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
277
 
278
  act_fn = get_activation(act)
279
  self.encoder = nn.Sequential(
 
283
 
284
  self.bA = nn.ModuleList([Residual(dims=dims, head=head, act=act_fn) for _ in range(layer)])
285
  self.bB = nn.ModuleList([Residual(dims=dims, head=head, act=act_fn) for _ in range(layer)])
286
+
287
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
288
  self.register_buffer("mask", mask, persistent=False)
289
 
290
+ def forward(self, x, xa, sequential=False) -> Tensor:
291
 
292
  x = self.token(x.long()) + self.positional[:x.shape[1]]
293
  xa = self.encoder(xa).permute(0, 2, 1)
294
+ xa = xa + self.posin(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
295
+
296
  for b in chain(self.bA or []):
297
  xa = b(x=xa, xa=None, mask=None)
298
+
299
  for b in chain(self.bB or []):
300
  x = b(x=x, xa=None, mask=self.mask)
301
+ y = b(x, xa=xa, mask=None)
302
+ if sequential:
303
+ x = y
304
+ else:
305
+ a = torch.sigmoid(self.blend)
306
+ x = a * y + (1 - a) * x
307
+
308
  x = nn.functional.dropout(x, p=0.001, training=self.training)
309
  x = self.ln(x)
310
  x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
311
  return x
312
 
313
+ def init_weights(self):
314
+ print("Initializing model weights...")
315
+ self.apply(self._init_weights)
316
+ print("Initialization summary:")
317
+ for module_type, count in self.init_counts.items():
318
+ if count > 0:
319
+ print(f"{module_type}: {count}")
320
+
321
  class Model(nn.Module):
322
  def __init__(self, param: Dimensions):
323
  super().__init__()
 
380
  if count > 0:
381
  print(f"{module_type}: {count}")
382
 
383
+ def main():
384
+ token = ""
385
+ log_dir = os.path.join('D:/newmodel/output/logs/', datetime.now().strftime('%m-%d_%H_%M_%S'))
386
+ os.makedirs(log_dir, exist_ok=True)
387
+ tokenizer = setup_tokenizer("D:/newmodel/mod5/tokenizer.json")
388
+
389
+ extract_args = {
390
+ "waveform": False,
391
+ "spec": False,
392
+ "f0": False,
393
+ "f0t": False,
394
+ "pitch": True,
395
+ "harmonics": False,
396
+ "aperiodics": False,
397
+ "phase_mod": False,
398
+ "crepe": False,
399
+ "sample_rate": 16000,
400
+ "hop_length": 256,
401
+ "mode": "mean",
402
+ "debug": False,
403
+ }
404
+
405
+ param = Dimensions(
406
+ vocab=40000,
407
+ mels=128,
408
+ ctx=2048,
409
+ dims=512,
410
+ head=4,
411
+ layer=4,
412
+ act="swish",
413
+ )
414
+
415
+ train_dataset, test_dataset = prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
416
+ load_saved=False, save_dataset=False, cache_dir=None, extract_args=extract_args, max_ctx=param.ctx)
417
+
418
+ model = Model(param).to('cuda')
419
+ print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
420
+ print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
421
+
422
+ from functools import partial
423
+ metrics_fn = partial(compute_metrics, print_pred=True, num_samples=1, tokenizer=tokenizer, model=model)
424
+
425
+ training_args = Seq2SeqTrainingArguments(
426
+ output_dir=log_dir,
427
+ per_device_train_batch_size=1,
428
+ per_device_eval_batch_size=1,
429
+ max_steps=1000,
430
+ eval_steps=100,
431
+ save_steps=1000,
432
+ warmup_steps=100,
433
+ logging_steps=10,
434
+ logging_dir=log_dir,
435
+ logging_strategy="steps",
436
+ eval_strategy="steps",
437
+ save_strategy="no",
438
+ report_to=["tensorboard"],
439
+ push_to_hub=False,
440
+ save_total_limit=1,
441
+ label_names=["labels"],
442
+ save_safetensors=False,
443
+ eval_on_start=False,
444
+ batch_eval_metrics=False,
445
+ disable_tqdm=False,
446
+ include_tokens_per_second=True,
447
+ include_num_input_tokens_seen=True,
448
+ learning_rate=0.00025,
449
+ weight_decay=0.025,
450
+ )
451
+
452
+ optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-8, weight_decay=training_args.weight_decay, betas=(0.9, 0.999),
453
+ amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
454
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
455
+
456
+ trainer = Seq2SeqTrainer(
457
+ args=training_args,
458
+ model=model,
459
+ train_dataset=train_dataset,
460
+ eval_dataset=test_dataset,
461
+ data_collator=DataCollator(tokenizer=tokenizer),
462
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
463
+ compute_metrics=metrics_fn,
464
+ optimizers=(optimizer, scheduler)
465
+ )
466
+
467
+ model.init_weights()
468
+ trainer.train()
469
+ if __name__ == "__main__":
470
+
471
+ main()