Spaces:
Runtime error
Runtime error
File size: 4,593 Bytes
e79b770 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
import os,sys
now_dir = os.getcwd()
sys.path.append(now_dir)
from typing import Dict
import torch
from pytorch_lightning import LightningModule
from AR.models.t2s_model import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir,is_train=True):
super().__init__()
self.config = config
self.top_k = 3
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
pretrained_s1=config.get("pretrained_s1")
if(pretrained_s1 and is_train):
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["weight"]))
if is_train:
self.automatic_optimization = False
self.save_hyperparameters()
self.eval_dir = output_dir / 'eval'
self.eval_dir.mkdir(parents=True, exist_ok=True)
def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers()
scheduler = self.lr_schedulers()
loss, acc = self.model.forward(
batch['phoneme_ids'], batch['phoneme_ids_len'],
batch['semantic_ids'], batch['semantic_ids_len'],
batch['bert_feature'])
self.manual_backward(loss)
if batch_idx > 0 and batch_idx % 4 == 0:
opt.step()
opt.zero_grad()
scheduler.step()
self.log(
"total_loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True)
self.log(
"lr",
scheduler.get_last_lr()[0],
on_epoch=True,
prog_bar=True,
sync_dist=True)
self.log(
f"top_{self.top_k}_acc",
acc,
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True)
def validation_step(self, batch: Dict, batch_idx: int):return
# # get loss
# loss, acc = self.model.forward(
# batch['phoneme_ids'], batch['phoneme_ids_len'],
# batch['semantic_ids'], batch['semantic_ids_len'],
# batch['bert_feature']
# )
#
# self.log(
# "val_total_loss",
# loss,
# on_step=True,
# on_epoch=True,
# prog_bar=True,
# sync_dist=True)
# self.log(
# f"val_top_{self.top_k}_acc",
# acc,
# on_step=True,
# on_epoch=True,
# prog_bar=True,
# sync_dist=True)
#
# # get infer output
# semantic_len = batch['semantic_ids'].size(1)
# prompt_len = min(int(semantic_len * 0.5), 150)
# prompt = batch['semantic_ids'][:, :prompt_len]
# pred_semantic = self.model.infer(batch['phoneme_ids'],
# batch['phoneme_ids_len'], prompt,
# batch['bert_feature']
# )
# save_name = f'semantic_toks_{batch_idx}.pt'
# save_path = os.path.join(self.eval_dir, save_name)
# torch.save(pred_semantic.detach().cpu(), save_path)
def configure_optimizers(self):
model_parameters = self.model.parameters()
parameters_names = []
parameters_names.append([
name_param_pair[0]
for name_param_pair in self.model.named_parameters()
])
lm_opt = ScaledAdam(
model_parameters,
lr=0.01,
betas=(0.9, 0.95),
clipping_scale=2.0,
parameters_names=parameters_names,
show_dominant_parameters=False,
clipping_update_period=1000, )
return {
"optimizer": lm_opt,
"lr_scheduler": {
"scheduler":
WarmupCosineLRSchedule(
lm_opt,
init_lr=self.config['optimizer']['lr_init'],
peak_lr=self.config['optimizer']['lr'],
end_lr=self.config['optimizer']['lr_end'],
warmup_steps=self.config['optimizer']['warmup_steps'],
total_steps=self.config['optimizer']['decay_steps'])
}
}
|