Sin2pi commited on
Commit
edcaa5d
·
verified ·
1 Parent(s): 606e3d4

Update model_simple.py

Browse files
Files changed (1) hide show
  1. model_simple.py +38 -126
model_simple.py CHANGED
@@ -1,19 +1,17 @@
1
  import os
2
- import math
3
  import warnings
4
  import logging
5
  from itertools import chain
6
  import torch
7
- import torch.nn.functional as feature
8
  from torch import nn, Tensor
9
- from typing import Optional, Dict, Union, List, Tuple
10
  import numpy as np
11
- from functools import partial
12
  from datetime import datetime
 
13
  from transformers.trainer_seq2seq import Seq2SeqTrainer
14
  from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
 
15
  from echoutils import *
16
-
17
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
  dtype = torch.float32
19
  warnings.filterwarnings("ignore")
@@ -36,11 +34,18 @@ class rotary(nn.Module):
36
  self.head = head
37
  self.head_dim = dims // head
38
  self.theta = nn.Parameter((torch.tensor(10000, device=device, dtype=dtype)), requires_grad=True)
 
 
 
 
 
 
39
  def forward(self, x, ctx) -> Tensor:
40
- freqs = (self.theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
41
- t = torch.arange(ctx, device=device, dtype=dtype)
42
- freqs = t[:, None] * freqs
43
  freqs=torch.polar(torch.ones_like(freqs), freqs)
 
44
  x1 = x[..., :freqs.shape[-1]*2]
45
  x2 = x[..., freqs.shape[-1]*2:]
46
  orig_shape = x1.shape
@@ -63,8 +68,8 @@ class attention(nn.Module):
63
  self.rope = rotary(dims=dims, head=head)
64
  self.lny = nn.LayerNorm(self.head_dim, bias = False)
65
  self.lnx = nn.LayerNorm(dims, bias = False)
 
66
  def forward(self, x: Tensor, xa = None, mask = None):
67
- scale = (self.dims // self.head) ** -0.25
68
  q = self.q(self.lnx(x))
69
  k = self.k(self.lnx(x if xa is None else xa))
70
  v = self.v(self.lnx(x if xa is None else xa))
@@ -80,59 +85,58 @@ class attention(nn.Module):
80
  class tgate(nn.Module):
81
  def __init__(self, dims, num_types=4):
82
  super().__init__()
83
- self.gate_projections = nn.ModuleList([
84
- nn.Sequential(Linear(dims, 1), nn.Sigmoid())
85
- for _ in range(num_types)])
86
- self.type_classifier = nn.Sequential(
87
- Linear(dims, num_types),
88
- nn.Softmax(dim=-1))
89
  def forward(self, x):
90
- type_probs = self.type_classifier(x)
91
- gates = torch.stack([gate(x) for gate in self.gate_projections], dim=-1)
92
- comb_gate = torch.sum(gates * type_probs.unsqueeze(2), dim=-1)
93
- return comb_gate
94
 
95
  class Residual(nn.Module):
96
  _seen = set()
97
  def __init__(self, dims: int, head: int, act: str = "silu"):
98
  super().__init__()
99
- act_fn = get_activation(act)
100
  self.blend = nn.Parameter(torch.tensor(0.5))
101
  self.attn = attention(dims, head)
102
- self.mlp = nn.Sequential(Linear(dims, dims*4), act_fn, Linear(dims*4, dims))
103
  self.tgate = tgate(dims=dims, num_types=4*2)
104
- self.lna = nn.LayerNorm(dims, bias = False)
105
  def forward(self, x, xa=None, mask=None) -> Tensor:
106
- xb = x + self.attn(self.lna(x), xa=None, mask=mask)[0]
107
  if xa is not None:
108
- x = x + self.attn(self.lna(x), xa=xa, mask=None)[0]
109
  b = torch.sigmoid(self.blend)
110
- x = b * xb + (1 - b) * x
111
- out = self.mlp(self.lna(x))
112
- gate = self.tgate(self.lna(x))
113
  x = x + gate * out
114
  return x
115
 
116
  class processor(nn.Module):
117
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
118
  super(processor, self).__init__()
119
- act_fn = get_activation(act)
 
120
  self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
121
  self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
122
- self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
123
  self.positional_sin = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
 
 
124
  self.encoder = nn.Sequential(
125
  Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
126
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
127
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
 
128
  self.bA = nn.ModuleList([Residual(dims=dims, head=head, act=act_fn) for _ in range(layer)])
129
  self.bB = nn.ModuleList([Residual(dims=dims, head=head, act=act_fn) for _ in range(layer)])
130
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
131
  self.register_buffer("mask", mask, persistent=False)
132
- self.norm = nn.LayerNorm(dims, device=device, dtype=dtype)
133
 
134
  def forward(self, x, xa) -> Tensor:
135
- x = self.token(x.long()) + self.positional[:x.shape[1]]
 
136
  xa = self.encoder(xa).permute(0, 2, 1)
137
  xa = xa + self.positional_sin(xa.shape[1], xa.shape[-1], 10000.0).to(device, dtype)
138
  for b in chain(self.bA or []):
@@ -141,7 +145,7 @@ class processor(nn.Module):
141
  x = b(x=x, xa=None, mask=self.mask)
142
  x = b(x, xa=xa, mask=None)
143
  x = nn.functional.dropout(x, p=0.001, training=self.training)
144
- x = self.norm(x)
145
  x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
146
  return x
147
 
@@ -149,7 +153,6 @@ class Model(nn.Module):
149
  def __init__(self, param: Dimensions):
150
  super().__init__()
151
  self.param = param
152
-
153
  self.processor = processor(
154
  vocab=param.vocab,
155
  mels=param.mels,
@@ -157,14 +160,12 @@ class Model(nn.Module):
157
  dims=param.dims,
158
  head=param.head,
159
  layer=param.layer,
160
- act=param.act,
161
- )
162
 
163
  def forward(self,
164
  labels=None, input_ids=None, pitch: Optional[torch.Tensor]=None) -> Dict[str, Optional[torch.Tensor]]:
165
- if pitch is not None:
166
- xa = pitch
167
  x = input_ids
 
168
  logits = self.processor(x, xa)
169
  loss = None
170
  if labels is not None:
@@ -210,92 +211,3 @@ class Model(nn.Module):
210
  if count > 0:
211
  print(f"{module_type}: {count}")
212
 
213
- def main():
214
- token = ""
215
- log_dir = os.path.join('D:/newmodel/output/logs', datetime.now().strftime('%m-%d_%H_%M_%S'))
216
- os.makedirs(log_dir, exist_ok=True)
217
- tokenizer = setup_tokenizer("D:/newmodel/mod5/tokenizer.json")
218
-
219
- extract_args = {
220
- "waveform": False,
221
- "spec": False,
222
- "f0": False,
223
- "f0t": False,
224
- "pitch": True,
225
- "harmonics": False,
226
- "aperiodics": False,
227
- "phase_mod": False,
228
- "crepe": False,
229
- "sample_rate": 16000,
230
- "hop_length": 256,
231
- "mode": "mean",
232
- "debug": False,
233
- }
234
-
235
- param = Dimensions(
236
- vocab=40000,
237
- mels=128,
238
- ctx=2048,
239
- dims=512,
240
- head=4,
241
- layer=4,
242
- act="swish",
243
- )
244
-
245
- train_dataset, test_dataset = prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
246
- load_saved=False, save_dataset=False, cache_dir=None, extract_args=extract_args, max_ctx=param.ctx)
247
-
248
- model = Model(param).to('cuda')
249
- print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
250
- print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
251
-
252
- from functools import partial
253
- metrics_fn = partial(compute_metrics, print_pred=True, num_samples=1, tokenizer=tokenizer, model=model)
254
-
255
- training_args = Seq2SeqTrainingArguments(
256
- output_dir=log_dir,
257
- per_device_train_batch_size=1,
258
- per_device_eval_batch_size=1,
259
- max_steps=1000,
260
- eval_steps=100,
261
- save_steps=1000,
262
- warmup_steps=100,
263
- logging_steps=10,
264
- logging_dir=log_dir,
265
- logging_strategy="steps",
266
- eval_strategy="steps",
267
- save_strategy="no",
268
- report_to=["tensorboard"],
269
- push_to_hub=False,
270
- save_total_limit=1,
271
- label_names=["labels"],
272
- save_safetensors=False,
273
- eval_on_start=False,
274
- batch_eval_metrics=False,
275
- disable_tqdm=False,
276
- include_tokens_per_second=True,
277
- include_num_input_tokens_seen=True,
278
- learning_rate=0.00025,
279
- weight_decay=0.025,
280
- )
281
-
282
- optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-8, weight_decay=training_args.weight_decay, betas=(0.9, 0.999),
283
- amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
284
- scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
285
-
286
- trainer = Seq2SeqTrainer(
287
- args=training_args,
288
- model=model,
289
- train_dataset=train_dataset,
290
- eval_dataset=test_dataset,
291
- data_collator=DataCollator(tokenizer=tokenizer),
292
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
293
- compute_metrics=metrics_fn,
294
- optimizers=(optimizer, scheduler)
295
- )
296
-
297
- model.init_weights()
298
- trainer.train()
299
- if __name__ == "__main__":
300
-
301
- main()
 
1
  import os
 
2
  import warnings
3
  import logging
4
  from itertools import chain
5
  import torch
 
6
  from torch import nn, Tensor
7
+ from typing import Optional, Dict
8
  import numpy as np
 
9
  from datetime import datetime
10
+ from dataclasses import dataclass
11
  from transformers.trainer_seq2seq import Seq2SeqTrainer
12
  from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
13
+ from torch.nn.functional import scaled_dot_product_attention
14
  from echoutils import *
 
15
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
  dtype = torch.float32
17
  warnings.filterwarnings("ignore")
 
34
  self.head = head
35
  self.head_dim = dims // head
36
  self.theta = nn.Parameter((torch.tensor(10000, device=device, dtype=dtype)), requires_grad=True)
37
+ self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
38
+
39
+ def _compute_freqs_base(self):
40
+ mel_scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1
41
+ return 200 * mel_scale / 1000
42
+
43
  def forward(self, x, ctx) -> Tensor:
44
+ freqs = (self.theta / 220.0) * self.freqs_base
45
+ pos = torch.arange(ctx, device=device, dtype=dtype)
46
+ freqs = pos[:, None] * freqs
47
  freqs=torch.polar(torch.ones_like(freqs), freqs)
48
+
49
  x1 = x[..., :freqs.shape[-1]*2]
50
  x2 = x[..., freqs.shape[-1]*2:]
51
  orig_shape = x1.shape
 
68
  self.rope = rotary(dims=dims, head=head)
69
  self.lny = nn.LayerNorm(self.head_dim, bias = False)
70
  self.lnx = nn.LayerNorm(dims, bias = False)
71
+
72
  def forward(self, x: Tensor, xa = None, mask = None):
 
73
  q = self.q(self.lnx(x))
74
  k = self.k(self.lnx(x if xa is None else xa))
75
  v = self.v(self.lnx(x if xa is None else xa))
 
85
  class tgate(nn.Module):
86
  def __init__(self, dims, num_types=4):
87
  super().__init__()
88
+ self.gates = nn.ModuleList([nn.Sequential(Linear(dims, 1), nn.Sigmoid()) for _ in range(num_types)])
89
+ self.classifier = nn.Sequential(Linear(dims, num_types), nn.Softmax(dim=-1))
 
 
 
 
90
  def forward(self, x):
91
+ types = self.classifier(x)
92
+ gates = torch.stack([gate(x) for gate in self.gates], dim=-1)
93
+ cgate = torch.sum(gates * types.unsqueeze(2), dim=-1)
94
+ return cgate
95
 
96
  class Residual(nn.Module):
97
  _seen = set()
98
  def __init__(self, dims: int, head: int, act: str = "silu"):
99
  super().__init__()
100
+ self.ln = nn.LayerNorm(dims, bias = False)
101
  self.blend = nn.Parameter(torch.tensor(0.5))
102
  self.attn = attention(dims, head)
103
+ self.mlp = nn.Sequential(Linear(dims, dims*4), get_activation(act), Linear(dims*4, dims))
104
  self.tgate = tgate(dims=dims, num_types=4*2)
105
+
106
  def forward(self, x, xa=None, mask=None) -> Tensor:
107
+ xb = x + self.attn(self.ln(x), xa=None, mask=mask)
108
  if xa is not None:
109
+ x = x + self.attn(self.ln(x), xa=xa, mask=None)
110
  b = torch.sigmoid(self.blend)
111
+ x = b * xb + (1 - b) * x
112
+ out = self.mlp(self.ln(x))
113
+ gate = self.tgate(self.ln(x))
114
  x = x + gate * out
115
  return x
116
 
117
  class processor(nn.Module):
118
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
119
  super(processor, self).__init__()
120
+ self.ln = nn.LayerNorm(dims, device=device, dtype=dtype)
121
+ self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
122
  self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
123
  self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
 
124
  self.positional_sin = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
125
+
126
+ act_fn = get_activation(act)
127
  self.encoder = nn.Sequential(
128
  Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
129
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
130
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
131
+
132
  self.bA = nn.ModuleList([Residual(dims=dims, head=head, act=act_fn) for _ in range(layer)])
133
  self.bB = nn.ModuleList([Residual(dims=dims, head=head, act=act_fn) for _ in range(layer)])
134
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
135
  self.register_buffer("mask", mask, persistent=False)
 
136
 
137
  def forward(self, x, xa) -> Tensor:
138
+
139
+ x = self.token(x.long()) + self.positional[:x.shape[1]]
140
  xa = self.encoder(xa).permute(0, 2, 1)
141
  xa = xa + self.positional_sin(xa.shape[1], xa.shape[-1], 10000.0).to(device, dtype)
142
  for b in chain(self.bA or []):
 
145
  x = b(x=x, xa=None, mask=self.mask)
146
  x = b(x, xa=xa, mask=None)
147
  x = nn.functional.dropout(x, p=0.001, training=self.training)
148
+ x = self.ln(x)
149
  x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
150
  return x
151
 
 
153
  def __init__(self, param: Dimensions):
154
  super().__init__()
155
  self.param = param
 
156
  self.processor = processor(
157
  vocab=param.vocab,
158
  mels=param.mels,
 
160
  dims=param.dims,
161
  head=param.head,
162
  layer=param.layer,
163
+ act=param.act)
 
164
 
165
  def forward(self,
166
  labels=None, input_ids=None, pitch: Optional[torch.Tensor]=None) -> Dict[str, Optional[torch.Tensor]]:
 
 
167
  x = input_ids
168
+ xa = pitch if pitch is not None else torch.zeros(1, 1, self.param.mels, device=device, dtype=dtype)
169
  logits = self.processor(x, xa)
170
  loss = None
171
  if labels is not None:
 
211
  if count > 0:
212
  print(f"{module_type}: {count}")
213