Sin2pi commited on
Commit
a479bac
·
verified ·
1 Parent(s): 90b8189

Create model_simple.py

Browse files
Files changed (1) hide show
  1. model_simple.py +448 -0
model_simple.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import warnings
4
+ import logging
5
+ from itertools import chain
6
+ import torch
7
+ import torch.nn.functional as feature
8
+ from torch import nn, Tensor
9
+ from tensordict import TensorDict
10
+ from typing import Optional, Dict, Union, List, Tuple
11
+ import numpy as np
12
+ from functools import partial
13
+ from datetime import datetime
14
+ from tensordict import TensorDict
15
+ from transformers.trainer_seq2seq import Seq2SeqTrainer
16
+ from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
17
+ from echoutils import *
18
+
19
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
+ dtype = torch.float32
21
+ warnings.filterwarnings("ignore")
22
+ logging.basicConfig(level=logging.ERROR)
23
+
24
+ @dataclass
25
+ class Dimensions:
26
+ vocab: int
27
+ mels: int
28
+ ctx: int
29
+ dims: int
30
+ head: int
31
+ layer: int
32
+ act: str
33
+
34
+ class rotary(nn.Module):
35
+ def __init__(self, dims, head):
36
+ super(rotary, self).__init__()
37
+ self.dims = dims
38
+ self.head = head
39
+ self.head_dim = dims // head
40
+ self.theta = nn.Parameter((torch.tensor(36000, device=device, dtype=dtype)), requires_grad=True)
41
+
42
+ def forward(self, x=None) -> Tensor:
43
+ freqs = (self.theta / 220.0) * 700 * (
44
+ torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
45
+ self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
46
+ t = torch.arange(x, device=device, dtype=dtype) # type: ignore
47
+ freqs = t[:, None] * freqs
48
+ freqs=torch.polar(torch.ones_like(freqs), freqs)
49
+ return freqs.unsqueeze(0)
50
+
51
+ @staticmethod
52
+ def apply_rotary(x, freqs):
53
+ x1 = x[..., :freqs.shape[-1]*2]
54
+ x2 = x[..., freqs.shape[-1]*2:]
55
+ orig_shape = x1.shape
56
+ if x1.ndim == 2:
57
+ x1 = x1.unsqueeze(0)
58
+ x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
59
+ x1 = torch.view_as_complex(x1) * freqs
60
+ x1 = torch.view_as_real(x1).flatten(-2)
61
+ x1 = x1.view(orig_shape)
62
+ return torch.cat([x1.type_as(x), x2], dim=-1)
63
+
64
+ class MultiheadA(nn.Module):
65
+
66
+ def __init__(self, dims: int, head: int, debug: List[str] = []):
67
+ super(MultiheadA, self).__init__()
68
+
69
+ self.dims = dims
70
+ self.head = head
71
+ self.head_dim = dims // head
72
+ self.debug = debug
73
+
74
+ self.q = nn.Linear(dims, dims).to(device, dtype)
75
+ self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
76
+ self.v = nn.Linear(dims, dims).to(device, dtype)
77
+ self.o = nn.Linear(dims, dims).to(device, dtype)
78
+ self.rope = rotary(dims=dims, head=head)
79
+
80
+ def forward(self, x: Tensor, xa = None, mask = None):
81
+ scale = (self.dims // self.head) ** -0.25
82
+ q = self.q(x)
83
+ k = self.k(x if xa is None else xa)
84
+ v = self.v(x if xa is None else xa)
85
+ batch, ctx, dims = q.shape
86
+ q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
87
+ k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
88
+ v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
89
+ q = self.rope.apply_rotary(q, (self.rope(q.shape[2]))) # type: ignore
90
+ k = self.rope.apply_rotary(k, (self.rope(k.shape[2]))) # type: ignore
91
+ a = scaled_dot_product_attention(q, k, v, is_causal=mask is not None and ctx > 1)
92
+ out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
93
+ qk = None
94
+ return self.o(out), qk
95
+
96
+ class t_gate(nn.Module):
97
+ def __init__(self, dims, num_types=4):
98
+ super().__init__()
99
+ self.gate_projections = nn.ModuleList([
100
+ nn.Sequential(Linear(dims, 1), nn.Sigmoid())
101
+ for _ in range(num_types)])
102
+ self.type_classifier = nn.Sequential(
103
+ Linear(dims, num_types),
104
+ nn.Softmax(dim=-1))
105
+ def forward(self, x):
106
+ type_probs = self.type_classifier(x)
107
+ gates = torch.stack([gate(x) for gate in self.gate_projections], dim=-1)
108
+ comb_gate = torch.sum(gates * type_probs.unsqueeze(2), dim=-1)
109
+ return comb_gate
110
+
111
+ class Residual(nn.Module):
112
+ _seen = set()
113
+ def __init__(self, dims: int, head: int, ctx: int, act: str = "silu"):
114
+
115
+ super().__init__()
116
+
117
+ self.dims = dims
118
+ self.head = head
119
+ self.ctx = ctx
120
+ self.head_dim = dims // head
121
+
122
+
123
+ self.blend = nn.Parameter(torch.tensor(0.5))
124
+ act_fn = get_activation(act)
125
+ self.attn = MultiheadA(dims, head)
126
+ mlp = dims * 4
127
+ self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
128
+ self.t_gate = t_gate(dims=dims, num_types=4*2)
129
+
130
+ self.lna = RMSNorm(dims)
131
+ self.lnb = RMSNorm(dims)
132
+ self.lnc = RMSNorm(dims)
133
+
134
+ def forward(self, x, xa=None, mask=None) -> Tensor:
135
+ x = x + self.attn(self.lna(x), xa=None, mask=mask)[0]
136
+ xb = x
137
+ if xa is not None:
138
+ x = x + self.attn(self.lnb(x), xa=xa, mask=None)[0] # type: ignore
139
+ b = torch.sigmoid(self.blend)
140
+ x = b * xb + (1 - b) * x
141
+ normx = self.lnc(x)
142
+ mlp_out = self.mlp(normx)
143
+ gate = self.t_gate(normx)
144
+ x = x + gate * mlp_out
145
+ return x
146
+
147
+
148
+ class feature_encoder(nn.Module):
149
+ def __init__(self, mels, dims, head, layer, act="gelu"):
150
+ super().__init__()
151
+
152
+ self.dims = dims
153
+ self.head = head
154
+ self.head_dim = dims // head
155
+ self.dropout = 0.01
156
+ act_fn = get_activation(act)
157
+
158
+ # pitch
159
+ # self.encoder = nn.Sequential(
160
+ # Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
161
+ # Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
162
+ # Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
163
+
164
+ # spectrogram
165
+ self.encoder = nn.Sequential(
166
+ Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
167
+ Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
168
+ Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
169
+
170
+
171
+ self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
172
+ self.norm = RMSNorm(dims)
173
+
174
+ def forward(self, x, xa=None, mask=None, max_tscale=36000):
175
+ if x.dim() == 2:
176
+ x = x.unsqueeze(0)
177
+ # x = self.pitch(x).permute(0, 2, 1)
178
+ x = self.encoder(x).permute(0, 2, 1)
179
+ max_tscale = x.shape[1] * 1000 if max_tscale is None else max_tscale
180
+ x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype)
181
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
182
+ x = self.norm(x)
183
+ return x
184
+
185
+ class processor(nn.Module):
186
+ def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
187
+ super(processor, self).__init__()
188
+ self.dims = dims
189
+ self.head = head
190
+ self.layer = layer
191
+ self.ctx = ctx
192
+ self.act = act
193
+ self.dropout = 0.01
194
+ act_fn = get_activation(act)
195
+
196
+ self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
197
+ self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
198
+ self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
199
+
200
+ self.bA = nn.ModuleList(
201
+ [feature_encoder(mels=mels, dims=dims, head=head, layer=layer, act=act_fn)] +
202
+ [Residual(ctx=ctx, dims=dims, head=head, act=act_fn) for _ in range(layer)])
203
+ self.bB = nn.ModuleList([
204
+ Residual(ctx=ctx, dims=dims, head=head, act=act_fn)
205
+ for _ in range(layer)])
206
+
207
+ mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
208
+ self.register_buffer("mask", mask, persistent=False)
209
+ self.norm = nn.LayerNorm(dims, device=device, dtype=dtype)
210
+
211
+ def forward(self, x, xa, sequential=False) -> Tensor:
212
+ x = self.token(x.long()) + self.positional[:x.shape[1]]
213
+
214
+ for b in chain(self.bA or []):
215
+ xa = b(x=xa, xa=None, mask=None)
216
+
217
+ for b in chain(self.bB or []):
218
+ x = b(x=x, xa=None, mask=self.mask)
219
+ xc = b(x, xa=xa, mask=None)
220
+ if sequential:
221
+ x = xc
222
+ else:
223
+ a = torch.sigmoid(self.blend)
224
+ x = a * xc + (1 - a) * x
225
+
226
+ x = self.norm(x)
227
+ x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
228
+ return x
229
+
230
+ class Echo(nn.Module):
231
+ def __init__(self, param: Dimensions):
232
+ super().__init__()
233
+ self.param = param
234
+
235
+ self.processor = processor(
236
+ vocab=param.vocab,
237
+ mels=param.mels,
238
+ ctx=param.ctx,
239
+ dims=param.dims,
240
+ head=param.head,
241
+ layer=param.layer,
242
+ act=param.act,
243
+ )
244
+
245
+ def forward(self,
246
+ labels=None,
247
+ input_ids=None,
248
+ spectrogram: Optional[torch.Tensor]=None,
249
+ pitch: Optional[torch.Tensor]=None,
250
+ ) -> Dict[str, Optional[torch.Tensor]]:
251
+
252
+ enc= {}
253
+ if pitch is not None:
254
+ xa = pitch
255
+ enc["pitch"] = pitch
256
+ if spectrogram is not None:
257
+ xa = spectrogram
258
+ enc["spectrogram"] = spectrogram
259
+
260
+ x = input_ids
261
+ logits = self.processor(x, xa)
262
+
263
+ loss = None
264
+ if labels is not None:
265
+ loss = torch.nn.functional.cross_entropy(
266
+ logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
267
+ return {"logits": logits, "loss": loss}
268
+
269
+ @property
270
+ def device(self):
271
+ return next(self.parameters()).device
272
+ @property
273
+ def dtype(self):
274
+ return next(self.parameters()).dtype
275
+
276
+ def _init_weights(self, module):
277
+ std = 0.02
278
+ self.init_counts = {
279
+ "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
280
+ "Conv2d": 0, "processor": 0, "Echo": 0,
281
+ "Residual": 0, "MultiheadA": 0,
282
+ "MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
283
+ "WEncoder": 0, "PEncoder": 0, "feature_encoder": 0}
284
+
285
+ for name, module in self.named_modules():
286
+ if isinstance(module, RMSNorm):
287
+ nn.init.ones_(module.weight)
288
+ self.init_counts["RMSNorm"] += 1
289
+ elif isinstance(module, nn.Linear):
290
+ if module.weight is not None:
291
+ nn.init.xavier_uniform_(module.weight)
292
+ if module.bias is not None:
293
+ nn.init.zeros_(module.bias)
294
+ self.init_counts["Linear"] += 1
295
+ elif isinstance(module, Conv1d):
296
+ nn.init.normal_(module.weight, mean=0.0, std=std)
297
+ if module.bias is not None:
298
+ nn.init.zeros_(module.bias)
299
+ self.init_counts["Conv1d"] += 1
300
+ elif isinstance(module, Conv2d):
301
+ nn.init.normal_(module.weight, mean=0.0, std=std)
302
+ if module.bias is not None:
303
+ nn.init.zeros_(module.bias)
304
+ self.init_counts["Conv2d"] += 1
305
+ elif isinstance(module, MultiheadA):
306
+ self.init_counts["MultiheadA"] += 1
307
+ elif isinstance(module, Residual):
308
+ self.init_counts["Residual"] += 1
309
+ elif isinstance(module, feature_encoder):
310
+ self.init_counts["feature_encoder"] += 1
311
+ elif isinstance(module, processor):
312
+ self.init_counts["processor"] += 1
313
+ elif isinstance(module, Echo):
314
+ self.init_counts["Echo"] += 1
315
+
316
+ def init_weights(self):
317
+ print("Initializing model weights...")
318
+ self.apply(self._init_weights)
319
+ print("Initialization summary:")
320
+ for module_type, count in self.init_counts.items():
321
+ if count > 0:
322
+ print(f"{module_type}: {count}")
323
+
324
+ def main():
325
+ token = ""
326
+ log_dir = os.path.join('D:/newmodel/output/logs', datetime.now().strftime('%m-%d_%H_%M_%S'))
327
+ os.makedirs(log_dir, exist_ok=True)
328
+ tokenizer = setup_tokenizer("D:/newmodel/mod5/tokenizer.json")
329
+
330
+ sanity_check = False
331
+ streaming = False
332
+ load_saved = False
333
+ save_dataset = False
334
+ cache_dir = None
335
+ extract_args = None
336
+
337
+ extract_args = {
338
+ "waveform": False,
339
+ "spec": False,
340
+ "f0": False,
341
+ "f0t": False,
342
+ "pitch": True,
343
+ "harmonics": False,
344
+ "aperiodics": False,
345
+ "phase_mod": False,
346
+ "crepe": False,
347
+ "sample_rate": 16000,
348
+ "hop_length": 256,
349
+ "mode": "mean",
350
+ "debug": False,
351
+ }
352
+
353
+ param = Dimensions(
354
+ vocab=40000,
355
+ mels=128,
356
+ ctx=2048,
357
+ dims=512,
358
+ head=4,
359
+ layer=4,
360
+ act="swish",
361
+ )
362
+
363
+ train_dataset, test_dataset = prepare_datasets(tokenizer, token, sanity_check=sanity_check, sample_rate=16000, streaming=streaming,
364
+ load_saved=load_saved, save_dataset=save_dataset, cache_dir=cache_dir, extract_args=extract_args, max_ctx=param.ctx)
365
+
366
+ model = Echo(param).to('cuda')
367
+ print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
368
+ print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
369
+
370
+ from functools import partial
371
+ metrics_fn = partial(compute_metrics, print_pred=True, num_samples=1,
372
+ tokenizer=tokenizer, model=model)
373
+
374
+ if sanity_check:
375
+ training_args = Seq2SeqTrainingArguments(
376
+ output_dir=log_dir,
377
+ per_device_train_batch_size=1,
378
+ per_device_eval_batch_size=1,
379
+ max_steps=10,
380
+ eval_steps=5,
381
+ save_steps=0,
382
+ warmup_steps=0,
383
+ logging_steps=1,
384
+ logging_dir=log_dir,
385
+ eval_strategy="steps",
386
+ save_strategy="no",
387
+ logging_strategy="no",
388
+ report_to=["tensorboard"],
389
+ push_to_hub=False,
390
+ save_total_limit=1,
391
+ label_names=["labels"],
392
+ save_safetensors=False,
393
+ eval_on_start=False,
394
+ batch_eval_metrics=False,
395
+ disable_tqdm=False,
396
+ include_tokens_per_second=True,
397
+ include_num_input_tokens_seen=True,
398
+ learning_rate=1e-7,
399
+ weight_decay=0.01,
400
+ )
401
+ else:
402
+ training_args = Seq2SeqTrainingArguments(
403
+ output_dir=log_dir,
404
+ per_device_train_batch_size=1,
405
+ per_device_eval_batch_size=1,
406
+ max_steps=1000,
407
+ eval_steps=100,
408
+ save_steps=1000,
409
+ warmup_steps=100,
410
+ logging_steps=10,
411
+ logging_dir=log_dir,
412
+ logging_strategy="steps",
413
+ eval_strategy="steps",
414
+ save_strategy="no",
415
+ report_to=["tensorboard"],
416
+ push_to_hub=False,
417
+ save_total_limit=1,
418
+ label_names=["labels"],
419
+ save_safetensors=False,
420
+ eval_on_start=False,
421
+ batch_eval_metrics=False,
422
+ disable_tqdm=False,
423
+ include_tokens_per_second=True,
424
+ include_num_input_tokens_seen=True,
425
+ learning_rate=0.00025,
426
+ weight_decay=0.025,
427
+ )
428
+
429
+ optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-8, weight_decay=training_args.weight_decay, betas=(0.9, 0.999),
430
+ amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
431
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
432
+
433
+ trainer = Seq2SeqTrainer(
434
+ args=training_args,
435
+ model=model,
436
+ train_dataset=train_dataset,
437
+ eval_dataset=test_dataset,
438
+ data_collator=DataCollator(tokenizer=tokenizer),
439
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
440
+ compute_metrics=metrics_fn,
441
+ optimizers=(optimizer, scheduler)
442
+ )
443
+
444
+ model.init_weights()
445
+ trainer.train()
446
+ if __name__ == "__main__":
447
+
448
+ main()