Update model_simple.py
Browse files- model_simple.py +9 -99
model_simple.py
CHANGED
@@ -11,7 +11,7 @@ from dataclasses import dataclass
|
|
11 |
from transformers.trainer_seq2seq import Seq2SeqTrainer
|
12 |
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
|
13 |
from torch.nn.functional import scaled_dot_product_attention
|
14 |
-
|
15 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
16 |
dtype = torch.float32
|
17 |
warnings.filterwarnings("ignore")
|
@@ -106,7 +106,7 @@ class LocalAttentionModule(nn.Module):
|
|
106 |
|
107 |
class attentiona(nn.Module):
|
108 |
def __init__(self, dims: int, head: int, max_iters: int = 3, threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1):
|
109 |
-
super(
|
110 |
|
111 |
self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
|
112 |
self.dims = dims
|
@@ -122,9 +122,8 @@ class attentiona(nn.Module):
|
|
122 |
|
123 |
def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None):
|
124 |
z = default(xa, x)
|
|
|
125 |
q, k, v = create_qkv(self.dims, self.head, self.q, self.k, self.v, self.lna(x), self.lna(z))
|
126 |
-
# q=self.lnb(q)
|
127 |
-
# k=self.lnb(k)
|
128 |
iteration = 0
|
129 |
prev_attn = torch.zeros_like(q)
|
130 |
attn_out = torch.zeros_like(q)
|
@@ -231,6 +230,7 @@ class attentiona(nn.Module):
|
|
231 |
class attentionb(nn.Module):
|
232 |
def __init__(self, dims: int, head: int):
|
233 |
super(attentionb, self).__init__()
|
|
|
234 |
self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
|
235 |
self.dims = dims
|
236 |
self.head = head
|
@@ -344,7 +344,7 @@ class Model(nn.Module):
|
|
344 |
def _init_weights(self, module):
|
345 |
self.init_counts = {
|
346 |
"Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
|
347 |
-
"Conv2d": 0, "processor": 0, "
|
348 |
for name, module in self.named_modules():
|
349 |
if isinstance(module, RMSNorm):
|
350 |
nn.init.ones_(module.weight)
|
@@ -365,11 +365,10 @@ class Model(nn.Module):
|
|
365 |
if module.bias is not None:
|
366 |
nn.init.zeros_(module.bias)
|
367 |
self.init_counts["Conv2d"] += 1
|
368 |
-
elif isinstance(module,
|
369 |
-
self.init_counts["
|
370 |
-
elif isinstance(module,
|
371 |
-
self.init_counts["
|
372 |
-
elif isinstance(module, processor):
|
373 |
self.init_counts["processor"] += 1
|
374 |
|
375 |
def init_weights(self):
|
@@ -380,92 +379,3 @@ class Model(nn.Module):
|
|
380 |
if count > 0:
|
381 |
print(f"{module_type}: {count}")
|
382 |
|
383 |
-
def main():
|
384 |
-
token = ""
|
385 |
-
log_dir = os.path.join('D:/newmodel/output/logs/', datetime.now().strftime('%m-%d_%H_%M_%S'))
|
386 |
-
os.makedirs(log_dir, exist_ok=True)
|
387 |
-
tokenizer = setup_tokenizer("D:/newmodel/mod5/tokenizer.json")
|
388 |
-
|
389 |
-
extract_args = {
|
390 |
-
"waveform": False,
|
391 |
-
"spec": False,
|
392 |
-
"f0": False,
|
393 |
-
"f0t": False,
|
394 |
-
"pitch": True,
|
395 |
-
"harmonics": False,
|
396 |
-
"aperiodics": False,
|
397 |
-
"phase_mod": False,
|
398 |
-
"crepe": False,
|
399 |
-
"sample_rate": 16000,
|
400 |
-
"hop_length": 256,
|
401 |
-
"mode": "mean",
|
402 |
-
"debug": False,
|
403 |
-
}
|
404 |
-
|
405 |
-
param = Dimensions(
|
406 |
-
vocab=40000,
|
407 |
-
mels=128,
|
408 |
-
ctx=2048,
|
409 |
-
dims=512,
|
410 |
-
head=4,
|
411 |
-
layer=4,
|
412 |
-
act="swish",
|
413 |
-
)
|
414 |
-
|
415 |
-
train_dataset, test_dataset = prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
|
416 |
-
load_saved=False, save_dataset=False, cache_dir=None, extract_args=extract_args, max_ctx=param.ctx)
|
417 |
-
|
418 |
-
model = Model(param).to('cuda')
|
419 |
-
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
420 |
-
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
|
421 |
-
|
422 |
-
from functools import partial
|
423 |
-
metrics_fn = partial(compute_metrics, print_pred=True, num_samples=1, tokenizer=tokenizer, model=model)
|
424 |
-
|
425 |
-
training_args = Seq2SeqTrainingArguments(
|
426 |
-
output_dir=log_dir,
|
427 |
-
per_device_train_batch_size=1,
|
428 |
-
per_device_eval_batch_size=1,
|
429 |
-
max_steps=1000,
|
430 |
-
eval_steps=100,
|
431 |
-
save_steps=1000,
|
432 |
-
warmup_steps=100,
|
433 |
-
logging_steps=10,
|
434 |
-
logging_dir=log_dir,
|
435 |
-
logging_strategy="steps",
|
436 |
-
eval_strategy="steps",
|
437 |
-
save_strategy="no",
|
438 |
-
report_to=["tensorboard"],
|
439 |
-
push_to_hub=False,
|
440 |
-
save_total_limit=1,
|
441 |
-
label_names=["labels"],
|
442 |
-
save_safetensors=False,
|
443 |
-
eval_on_start=False,
|
444 |
-
batch_eval_metrics=False,
|
445 |
-
disable_tqdm=False,
|
446 |
-
include_tokens_per_second=True,
|
447 |
-
include_num_input_tokens_seen=True,
|
448 |
-
learning_rate=0.00025,
|
449 |
-
weight_decay=0.025,
|
450 |
-
)
|
451 |
-
|
452 |
-
optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-8, weight_decay=training_args.weight_decay, betas=(0.9, 0.999),
|
453 |
-
amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
|
454 |
-
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
|
455 |
-
|
456 |
-
trainer = Seq2SeqTrainer(
|
457 |
-
args=training_args,
|
458 |
-
model=model,
|
459 |
-
train_dataset=train_dataset,
|
460 |
-
eval_dataset=test_dataset,
|
461 |
-
data_collator=DataCollator(tokenizer=tokenizer),
|
462 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
463 |
-
compute_metrics=metrics_fn,
|
464 |
-
optimizers=(optimizer, scheduler)
|
465 |
-
)
|
466 |
-
|
467 |
-
model.init_weights()
|
468 |
-
trainer.train()
|
469 |
-
if __name__ == "__main__":
|
470 |
-
|
471 |
-
main()
|
|
|
11 |
from transformers.trainer_seq2seq import Seq2SeqTrainer
|
12 |
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
|
13 |
from torch.nn.functional import scaled_dot_product_attention
|
14 |
+
|
15 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
16 |
dtype = torch.float32
|
17 |
warnings.filterwarnings("ignore")
|
|
|
106 |
|
107 |
class attentiona(nn.Module):
|
108 |
def __init__(self, dims: int, head: int, max_iters: int = 3, threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1):
|
109 |
+
super(attentiona, self).__init__()
|
110 |
|
111 |
self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
|
112 |
self.dims = dims
|
|
|
122 |
|
123 |
def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None):
|
124 |
z = default(xa, x)
|
125 |
+
|
126 |
q, k, v = create_qkv(self.dims, self.head, self.q, self.k, self.v, self.lna(x), self.lna(z))
|
|
|
|
|
127 |
iteration = 0
|
128 |
prev_attn = torch.zeros_like(q)
|
129 |
attn_out = torch.zeros_like(q)
|
|
|
230 |
class attentionb(nn.Module):
|
231 |
def __init__(self, dims: int, head: int):
|
232 |
super(attentionb, self).__init__()
|
233 |
+
|
234 |
self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
|
235 |
self.dims = dims
|
236 |
self.head = head
|
|
|
344 |
def _init_weights(self, module):
|
345 |
self.init_counts = {
|
346 |
"Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
|
347 |
+
"Conv2d": 0, "processor": 0, "attentiona": 0, "attentionb": 0, "Residual": 0}
|
348 |
for name, module in self.named_modules():
|
349 |
if isinstance(module, RMSNorm):
|
350 |
nn.init.ones_(module.weight)
|
|
|
365 |
if module.bias is not None:
|
366 |
nn.init.zeros_(module.bias)
|
367 |
self.init_counts["Conv2d"] += 1
|
368 |
+
elif isinstance(module, attentiona):
|
369 |
+
self.init_counts["attentiona"] += 1
|
370 |
+
elif isinstance(module, attentionb):
|
371 |
+
self.init_counts["attentionb"] += 1
|
|
|
372 |
self.init_counts["processor"] += 1
|
373 |
|
374 |
def init_weights(self):
|
|
|
379 |
if count > 0:
|
380 |
print(f"{module_type}: {count}")
|
381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|