hoang1007
commited on
Commit
·
5381499
1
Parent(s):
74c8f6d
init
Browse files- .gitignore +1 -0
- app.py +44 -0
- checkpoints/.gitkeep +0 -0
- finetuning/preprocess.py +27 -0
- finetuning/run.sh +13 -0
- finetuning/train.py +135 -0
- finetuning/wav2vec2.py +200 -0
- packages.txt +1 -0
- requirements.txt +9 -0
- src/__init__.py +0 -0
- src/config/__init__.py +0 -0
- src/config/model.py +57 -0
- src/datamodule/__init__.py +4 -0
- src/datamodule/vlsp2020.py +131 -0
- src/model/__init__.py +1 -0
- src/model/modules/__init__.py +4 -0
- src/model/modules/context_encoder.py +149 -0
- src/model/modules/feature_extractor.py +103 -0
- src/model/modules/processor.py +42 -0
- src/model/modules/quantization.py +103 -0
- src/model/modules/transformers.py +200 -0
- src/model/wav2vec2.py +293 -0
- src/train.py +27 -0
- src/utils/__init__.py +1 -0
- src/utils/functional.py +28 -0
- src/utils/metrics.py +72 -0
- src/utils/scheduler.py +83 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
app.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
sys.path.append("..")
|
4 |
+
|
5 |
+
import gradio
|
6 |
+
import torch, torchaudio
|
7 |
+
import numpy as np
|
8 |
+
from transformers import (
|
9 |
+
Wav2Vec2ForPreTraining,
|
10 |
+
Wav2Vec2CTCTokenizer,
|
11 |
+
Wav2Vec2FeatureExtractor,
|
12 |
+
)
|
13 |
+
from finetuning.wav2vec2 import SpeechRecognizer
|
14 |
+
|
15 |
+
|
16 |
+
def load_model(ckpt_path: str):
|
17 |
+
model_name = "nguyenvulebinh/wav2vec2-base-vietnamese-250h"
|
18 |
+
|
19 |
+
wav2vec2 = Wav2Vec2ForPreTraining.from_pretrained(model_name)
|
20 |
+
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name)
|
21 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
|
22 |
+
|
23 |
+
model = SpeechRecognizer.load_from_checkpoint(
|
24 |
+
ckpt_path,
|
25 |
+
wav2vec2=wav2vec2,
|
26 |
+
tokenizer=tokenizer,
|
27 |
+
feature_extractor=feature_extractor,
|
28 |
+
)
|
29 |
+
|
30 |
+
return model
|
31 |
+
|
32 |
+
model = load_model("checkpoints/last.ckpt")
|
33 |
+
model.eval()
|
34 |
+
|
35 |
+
def transcribe(audio):
|
36 |
+
sample_rate, waveform = audio
|
37 |
+
waveform = torch.from_numpy(waveform[:, 0]).float().unsqueeze_(0)
|
38 |
+
waveform = torchaudio.functional.resample(waveform, sample_rate, 16_000)
|
39 |
+
|
40 |
+
transcript = model.predict(waveform)[0]
|
41 |
+
|
42 |
+
return transcript
|
43 |
+
|
44 |
+
gradio.Interface(fn=transcribe, inputs=gradio.Audio(source="microphone", type="numpy"), outputs="textbox").launch()
|
checkpoints/.gitkeep
ADDED
File without changes
|
finetuning/preprocess.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
sys.path.append("..")
|
4 |
+
|
5 |
+
import os
|
6 |
+
import argparse
|
7 |
+
from torch.utils.data import random_split
|
8 |
+
from src.datamodule import VLSP2020TarDataset, VLSP2020Dataset
|
9 |
+
|
10 |
+
|
11 |
+
def prepare_tar_dataset(data_dir: str, dest_dir: str):
|
12 |
+
dts = VLSP2020Dataset(data_dir)
|
13 |
+
train_set, val_set = random_split(dts, [42_000, 14_427])
|
14 |
+
|
15 |
+
VLSP2020TarDataset(os.path.join(dest_dir, "vlsp2020_train_set.tar")).convert(
|
16 |
+
train_set
|
17 |
+
)
|
18 |
+
VLSP2020TarDataset(os.path.join(dest_dir, "vlsp2020_val_set.tar")).convert(val_set)
|
19 |
+
|
20 |
+
|
21 |
+
def main():
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
parser.add_argument("--data_dir", type=str, required=True)
|
24 |
+
parser.add_argument("--dest_dir", type=str, required=True)
|
25 |
+
args = parser.parse_args()
|
26 |
+
|
27 |
+
prepare_tar_dataset(args.data_dir, args.dest_dir)
|
finetuning/run.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python3 main.py \
|
2 |
+
--batch_size 2 \
|
3 |
+
--num_workers 2 \
|
4 |
+
--classifier_lr 1e-4 \
|
5 |
+
--wav2vec2_lr 1e-5 \
|
6 |
+
--max_epochs 10 \
|
7 |
+
--accelerator cpu \
|
8 |
+
--weight_decay 0.001 \
|
9 |
+
--warmup_steps 0.1 \
|
10 |
+
--constant_steps 0.4 \
|
11 |
+
--scheduler_factor 0.001 \
|
12 |
+
--data_dir data \
|
13 |
+
--ckpt_dir ckpt
|
finetuning/train.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
sys.path.append("..")
|
4 |
+
|
5 |
+
from argparse import ArgumentParser
|
6 |
+
import os, string
|
7 |
+
from transformers import (
|
8 |
+
Wav2Vec2ForPreTraining,
|
9 |
+
Wav2Vec2CTCTokenizer,
|
10 |
+
Wav2Vec2FeatureExtractor,
|
11 |
+
)
|
12 |
+
from pytorch_lightning import seed_everything
|
13 |
+
from pytorch_lightning import Trainer
|
14 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
15 |
+
from pytorch_lightning.loggers import WandbLogger
|
16 |
+
|
17 |
+
from src.datamodule import VLSP2020TarDataset
|
18 |
+
from src.datamodule.vlsp2020 import get_dataloader
|
19 |
+
from finetuning.wav2vec2 import SpeechRecognizer
|
20 |
+
|
21 |
+
|
22 |
+
def remove_punctuation(text: str):
|
23 |
+
return text.translate(str.maketrans("", "", string.punctuation)).lower()
|
24 |
+
|
25 |
+
|
26 |
+
def prepare_dataloader(data_dir, batch_size, num_workers):
|
27 |
+
train_dataset = VLSP2020TarDataset(
|
28 |
+
os.path.join(data_dir, "vlsp2020_train_set.tar")
|
29 |
+
).load()
|
30 |
+
val_dataset = VLSP2020TarDataset(
|
31 |
+
os.path.join(data_dir, "vlsp2020_val_set.tar")
|
32 |
+
).load()
|
33 |
+
|
34 |
+
train_dataloader = get_dataloader(
|
35 |
+
train_dataset,
|
36 |
+
return_transcript=True,
|
37 |
+
target_transform=remove_punctuation,
|
38 |
+
batch_size=batch_size,
|
39 |
+
num_workers=num_workers,
|
40 |
+
)
|
41 |
+
|
42 |
+
val_dataloader = get_dataloader(
|
43 |
+
val_dataset,
|
44 |
+
return_transcript=True,
|
45 |
+
target_transform=remove_punctuation,
|
46 |
+
batch_size=batch_size,
|
47 |
+
num_workers=num_workers,
|
48 |
+
)
|
49 |
+
|
50 |
+
return train_dataloader, val_dataloader
|
51 |
+
|
52 |
+
|
53 |
+
def prepare_model(adam_config: dict, tristate_scheduler_config: dict):
|
54 |
+
model_name = "nguyenvulebinh/wav2vec2-base-vietnamese-250h"
|
55 |
+
|
56 |
+
wav2vec2 = Wav2Vec2ForPreTraining.from_pretrained(model_name)
|
57 |
+
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name)
|
58 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
|
59 |
+
|
60 |
+
model = SpeechRecognizer(
|
61 |
+
wav2vec2, tokenizer, feature_extractor, adam_config, tristate_scheduler_config
|
62 |
+
)
|
63 |
+
|
64 |
+
return model
|
65 |
+
|
66 |
+
|
67 |
+
def main():
|
68 |
+
parser = ArgumentParser()
|
69 |
+
|
70 |
+
parser.add_argument("--batch_size", type=int, default=2)
|
71 |
+
parser.add_argument("--num_workers", type=int, default=0)
|
72 |
+
parser.add_argument("--classifier_lr", type=float, default=1e-4)
|
73 |
+
parser.add_argument("--wav2vec2_lr", type=float, default=1e-5)
|
74 |
+
parser.add_argument("--max_epochs", type=int, default=10)
|
75 |
+
parser.add_argument("--accelerator", type=str, default="gpu")
|
76 |
+
parser.add_argument("--weight_decay", type=float, default=0.0)
|
77 |
+
parser.add_argument("--warmup_steps", type=float, default=0.1)
|
78 |
+
parser.add_argument("--constant_steps", type=float, default=0.4)
|
79 |
+
parser.add_argument("--scheduler_factor", type=float, default=1e-3)
|
80 |
+
parser.add_argument("--data_dir", type=str, default="data")
|
81 |
+
parser.add_argument("--ckpt_dir", type=str, default="ckpt")
|
82 |
+
parser.add_argument("--ckpt_path", type=str, default=None)
|
83 |
+
parser.add_argument("--detect_anomaly", type=bool, default=False)
|
84 |
+
parser.add_argument("--grad_clip", type=float, default=None)
|
85 |
+
parser.add_argument("--wandb_id", type=str, default=None)
|
86 |
+
|
87 |
+
args = parser.parse_args()
|
88 |
+
print(args)
|
89 |
+
|
90 |
+
train_loader, val_loader = prepare_dataloader(
|
91 |
+
args.data_dir, args.batch_size, args.num_workers
|
92 |
+
)
|
93 |
+
|
94 |
+
total_steps = args.max_epochs * 42_000 // args.batch_size
|
95 |
+
warmup_steps = int(total_steps * args.warmup_steps)
|
96 |
+
constant_steps = int(total_steps * args.constant_steps)
|
97 |
+
|
98 |
+
model = prepare_model(
|
99 |
+
{
|
100 |
+
"wav2vec2_lr": args.wav2vec2_lr,
|
101 |
+
"classifier_lr": args.classifier_lr,
|
102 |
+
"weight_decay": args.weight_decay,
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"warmup_steps": warmup_steps,
|
106 |
+
"constant_steps": constant_steps,
|
107 |
+
"total_steps": total_steps,
|
108 |
+
"factor": args.scheduler_factor,
|
109 |
+
},
|
110 |
+
)
|
111 |
+
|
112 |
+
trainer = Trainer(
|
113 |
+
accelerator=args.accelerator,
|
114 |
+
callbacks=[
|
115 |
+
ModelCheckpoint(
|
116 |
+
args.ckpt_dir,
|
117 |
+
monitor="val/wer",
|
118 |
+
mode="min",
|
119 |
+
save_top_k=1,
|
120 |
+
save_last=True,
|
121 |
+
),
|
122 |
+
LearningRateMonitor(logging_interval="step"),
|
123 |
+
],
|
124 |
+
logger=WandbLogger(project="Wav2Vec2", id=args.wandb_id),
|
125 |
+
max_epochs=args.max_epochs,
|
126 |
+
detect_anomaly=args.detect_anomaly,
|
127 |
+
gradient_clip_val=args.grad_clip,
|
128 |
+
)
|
129 |
+
|
130 |
+
trainer.fit(model, train_loader, val_loader)
|
131 |
+
|
132 |
+
|
133 |
+
if __name__ == "__main__":
|
134 |
+
seed_everything(188)
|
135 |
+
main()
|
finetuning/wav2vec2.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import torch
|
3 |
+
from pytorch_lightning import LightningModule
|
4 |
+
from torchmetrics import MeanMetric
|
5 |
+
from transformers import (
|
6 |
+
Wav2Vec2ForPreTraining,
|
7 |
+
Wav2Vec2CTCTokenizer,
|
8 |
+
Wav2Vec2FeatureExtractor,
|
9 |
+
)
|
10 |
+
|
11 |
+
from src.utils.metrics import character_error_rate, word_error_rate
|
12 |
+
from src.utils.scheduler import TriStateScheduler
|
13 |
+
|
14 |
+
|
15 |
+
class SpeechRecognizer(LightningModule):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
wav2vec2: Wav2Vec2ForPreTraining,
|
19 |
+
tokenizer: Wav2Vec2CTCTokenizer,
|
20 |
+
feature_extractor: Wav2Vec2FeatureExtractor,
|
21 |
+
adam_config: dict,
|
22 |
+
tristate_scheduler_config: dict,
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.hidden_size = wav2vec2.config.proj_codevector_dim
|
27 |
+
self.vocab_size = tokenizer.vocab_size
|
28 |
+
|
29 |
+
self.wav2vec2 = wav2vec2
|
30 |
+
self.wav2vec2.freeze_feature_encoder()
|
31 |
+
self.tokenizer = tokenizer
|
32 |
+
self.feature_extractor = feature_extractor
|
33 |
+
|
34 |
+
self.adam_config = adam_config
|
35 |
+
self.tristate_scheduler_config = tristate_scheduler_config
|
36 |
+
|
37 |
+
self.dropout = torch.nn.Dropout(0.1)
|
38 |
+
self.fc = torch.nn.Sequential(
|
39 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size // 2),
|
40 |
+
torch.nn.ReLU(inplace=True),
|
41 |
+
torch.nn.Linear(self.hidden_size // 2, self.vocab_size),
|
42 |
+
)
|
43 |
+
|
44 |
+
self.criterion = torch.nn.CTCLoss(blank=tokenizer.pad_token_id, zero_infinity=True)
|
45 |
+
|
46 |
+
self.train_loss = MeanMetric()
|
47 |
+
|
48 |
+
self.save_hyperparameters(ignore=["wav2vec2", "tokenizer", "feature_extractor"])
|
49 |
+
|
50 |
+
def forward(self, waveforms: Tuple[torch.Tensor], transcripts: Tuple[str] = None):
|
51 |
+
# convert torch.Tensor to numpy.ndarray
|
52 |
+
waveforms = tuple(waveform.cpu().numpy() for waveform in waveforms)
|
53 |
+
|
54 |
+
input_values, attention_mask = self.feature_extractor(
|
55 |
+
waveforms,
|
56 |
+
sampling_rate=16000,
|
57 |
+
padding=True,
|
58 |
+
return_tensors="pt",
|
59 |
+
return_attention_mask=True,
|
60 |
+
).values()
|
61 |
+
|
62 |
+
input_values = input_values.to(self.device)
|
63 |
+
attention_mask = attention_mask.to(self.device)
|
64 |
+
|
65 |
+
# hidden_states.shape == (batch_size, sequence_length, hidden_size)
|
66 |
+
hidden_states = self.wav2vec2(
|
67 |
+
input_values,
|
68 |
+
attention_mask=attention_mask,
|
69 |
+
)[0]
|
70 |
+
|
71 |
+
hidden_states = self.dropout(hidden_states)
|
72 |
+
|
73 |
+
# logits.shape == (batch_size, sequence_length, vocab_size)
|
74 |
+
logits = self.fc(hidden_states)
|
75 |
+
|
76 |
+
# get the length of valids sequence
|
77 |
+
input_lengths = self.wav2vec2._get_feat_extract_output_lengths(
|
78 |
+
attention_mask.sum(-1)
|
79 |
+
).long()
|
80 |
+
|
81 |
+
if transcripts is not None:
|
82 |
+
# tokenize transcripts
|
83 |
+
target_ids, target_lengths = self.tokenizer(
|
84 |
+
transcripts,
|
85 |
+
padding=True,
|
86 |
+
return_length=True,
|
87 |
+
return_attention_mask=False,
|
88 |
+
return_tensors="pt",
|
89 |
+
).values()
|
90 |
+
|
91 |
+
target_ids = target_ids.to(self.device)
|
92 |
+
assert (
|
93 |
+
target_ids < self.tokenizer.vocab_size
|
94 |
+
).all(), "target_ids is out of range"
|
95 |
+
|
96 |
+
target_lengths = target_lengths.to(self.device)
|
97 |
+
assert (
|
98 |
+
target_lengths <= logits.size(1)
|
99 |
+
).all(), "target_lengths is out of range"
|
100 |
+
|
101 |
+
# (batch_size, sequence_length, vocab_size) -> (sequence_length, batch_size, vocab_size)
|
102 |
+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1).transpose(0, 1)
|
103 |
+
|
104 |
+
# compute loss
|
105 |
+
loss = self.criterion(log_probs, target_ids, input_lengths, target_lengths)
|
106 |
+
|
107 |
+
return loss, logits, input_lengths
|
108 |
+
else:
|
109 |
+
return logits, input_lengths
|
110 |
+
|
111 |
+
@staticmethod
|
112 |
+
def _get_predicted_ids(logits: torch.Tensor, lengths: torch.Tensor):
|
113 |
+
# logits.shape == (batch_size, sequence_length, vocab_size)
|
114 |
+
# lengths.shape == (batch_size, )
|
115 |
+
|
116 |
+
# get the max value of logits
|
117 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
118 |
+
|
119 |
+
# remove the padding
|
120 |
+
predicted_ids = [
|
121 |
+
predicted_id[:length]
|
122 |
+
for predicted_id, length in zip(predicted_ids, lengths)
|
123 |
+
]
|
124 |
+
|
125 |
+
return predicted_ids
|
126 |
+
|
127 |
+
def training_step(self, batch, batch_idx):
|
128 |
+
transcripts, waveforms = batch
|
129 |
+
|
130 |
+
loss = self(waveforms, transcripts)[0]
|
131 |
+
|
132 |
+
self.train_loss(loss)
|
133 |
+
|
134 |
+
if self.global_step % 500 == 0:
|
135 |
+
self.log("train/loss", self.train_loss, on_step=True, on_epoch=True)
|
136 |
+
|
137 |
+
return loss
|
138 |
+
|
139 |
+
def on_train_epoch_end(self) -> None:
|
140 |
+
self.train_loss.reset()
|
141 |
+
|
142 |
+
def validation_step(self, batch, batch_idx):
|
143 |
+
transcripts, waveforms = batch
|
144 |
+
|
145 |
+
logits, seq_lengths = self(waveforms)
|
146 |
+
|
147 |
+
predicted_ids = self._get_predicted_ids(logits, seq_lengths)
|
148 |
+
predicted_texts = self.tokenizer.batch_decode(
|
149 |
+
predicted_ids, skip_special_tokens=True
|
150 |
+
)
|
151 |
+
|
152 |
+
wer = word_error_rate(predicted_texts, transcripts)
|
153 |
+
cer = character_error_rate(predicted_texts, transcripts)
|
154 |
+
|
155 |
+
return wer, cer
|
156 |
+
|
157 |
+
def validation_epoch_end(self, outputs):
|
158 |
+
wer, cer = zip(*outputs)
|
159 |
+
|
160 |
+
wer = sum(wer) / len(wer)
|
161 |
+
cer = sum(cer) / len(cer)
|
162 |
+
|
163 |
+
self.log("val/wer", wer, on_epoch=True)
|
164 |
+
self.log("val/cer", cer, on_epoch=True)
|
165 |
+
|
166 |
+
@torch.no_grad()
|
167 |
+
def predict(self, waveforms: Tuple[torch.Tensor]):
|
168 |
+
logits, seq_lengths = self(waveforms)
|
169 |
+
|
170 |
+
predicted_ids = self._get_predicted_ids(logits, seq_lengths)
|
171 |
+
predicted_texts = self.tokenizer.batch_decode(
|
172 |
+
predicted_ids, skip_special_tokens=True
|
173 |
+
)
|
174 |
+
|
175 |
+
return predicted_texts
|
176 |
+
|
177 |
+
def configure_optimizers(self):
|
178 |
+
optimizer = torch.optim.AdamW(
|
179 |
+
params=[
|
180 |
+
{
|
181 |
+
"params": self.wav2vec2.parameters(),
|
182 |
+
"lr": self.adam_config["wav2vec2_lr"],
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"params": self.fc.parameters(),
|
186 |
+
"lr": self.adam_config["classifier_lr"],
|
187 |
+
},
|
188 |
+
],
|
189 |
+
weight_decay=self.adam_config["weight_decay"],
|
190 |
+
)
|
191 |
+
|
192 |
+
scheduler = TriStateScheduler(optimizer, **self.tristate_scheduler_config)
|
193 |
+
return {
|
194 |
+
"optimizer": optimizer,
|
195 |
+
"lr_scheduler": {
|
196 |
+
"scheduler": scheduler,
|
197 |
+
"interval": "step",
|
198 |
+
"frequency": 1,
|
199 |
+
},
|
200 |
+
}
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
ffmpeg
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchaudio
|
3 |
+
pytorch-lightning
|
4 |
+
einops
|
5 |
+
easydict
|
6 |
+
webdataset
|
7 |
+
transformers
|
8 |
+
gradio
|
9 |
+
altair
|
src/__init__.py
ADDED
File without changes
|
src/config/__init__.py
ADDED
File without changes
|
src/config/model.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as dict
|
2 |
+
|
3 |
+
D_MODEL = 768
|
4 |
+
HIDDEN_SIZE = 512
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
context_encoder = dict(
|
9 |
+
feature_projection=dict(
|
10 |
+
in_features=HIDDEN_SIZE,
|
11 |
+
out_features=D_MODEL,
|
12 |
+
dropout=0.1,
|
13 |
+
),
|
14 |
+
encoder=dict(
|
15 |
+
d_model=D_MODEL,
|
16 |
+
num_layers=12,
|
17 |
+
layer_drop=0.05,
|
18 |
+
pos_embedding=dict(
|
19 |
+
d_model=D_MODEL,
|
20 |
+
kernel_size=3,
|
21 |
+
groups=2,
|
22 |
+
dropout=0.1,
|
23 |
+
),
|
24 |
+
layer=dict(
|
25 |
+
d_model=D_MODEL,
|
26 |
+
num_heads=8,
|
27 |
+
layer_norm_first=False,
|
28 |
+
feed_forward_dim=2048,
|
29 |
+
dropout=0.1,
|
30 |
+
),
|
31 |
+
)
|
32 |
+
)
|
33 |
+
|
34 |
+
feature_extractor = dict(
|
35 |
+
num_channels=7 * (HIDDEN_SIZE,),
|
36 |
+
kernel_sizes=(10,) + 4 * (3,) + 2 * (2,),
|
37 |
+
strides=(5,) + 6 * (2,),
|
38 |
+
)
|
39 |
+
|
40 |
+
quantizer = dict(
|
41 |
+
in_features=HIDDEN_SIZE,
|
42 |
+
num_codebooks=2,
|
43 |
+
num_codewords=320,
|
44 |
+
d_model=D_MODEL,
|
45 |
+
)
|
46 |
+
|
47 |
+
wav2vec2_pretraining = dict(
|
48 |
+
context_encoder=context_encoder,
|
49 |
+
feature_extractor=feature_extractor,
|
50 |
+
quantizer=quantizer,
|
51 |
+
mask_prob=0.65,
|
52 |
+
mask_length=10,
|
53 |
+
min_masks=2,
|
54 |
+
num_negatives=100,
|
55 |
+
contrastive_logits_temperature=0.1,
|
56 |
+
diversity_loss_weight=0.2,
|
57 |
+
)
|
src/datamodule/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .vlsp2020 import (
|
2 |
+
VLSP2020TarDataset,
|
3 |
+
VLSP2020Dataset,
|
4 |
+
)
|
src/datamodule/vlsp2020.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Optional, Union
|
2 |
+
from tqdm import tqdm
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
import torchaudio.functional as F
|
7 |
+
from torch.utils.data import Dataset, DataLoader, IterableDataset, random_split
|
8 |
+
from pytorch_lightning import LightningDataModule
|
9 |
+
import webdataset
|
10 |
+
|
11 |
+
|
12 |
+
class VLSP2020Dataset(Dataset):
|
13 |
+
def __init__(self, root: str, sample_rate: int = 16000):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
self.sample_rate = sample_rate
|
17 |
+
self.memory = self._prepare_data(root)
|
18 |
+
self._memory = tuple(
|
19 |
+
(v["transcript"], v["audio"]) for v in self.memory.values()
|
20 |
+
)
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
def _prepare_data(root: str):
|
24 |
+
memory = {}
|
25 |
+
|
26 |
+
for f in os.scandir(root):
|
27 |
+
file_name, file_ext = os.path.splitext(f.name)
|
28 |
+
|
29 |
+
if file_ext == ".txt":
|
30 |
+
if file_name not in memory:
|
31 |
+
memory[file_name] = {"transcript": f.path}
|
32 |
+
elif "transcript" not in memory[file_name]:
|
33 |
+
memory[file_name]["transcript"] = f.path
|
34 |
+
else:
|
35 |
+
raise ValueError(f"Duplicate transcript for {f.path}")
|
36 |
+
else:
|
37 |
+
if file_name not in memory:
|
38 |
+
memory[file_name] = {"audio": f.path}
|
39 |
+
elif "audio" not in memory[file_name]:
|
40 |
+
memory[file_name]["audio"] = f.path
|
41 |
+
else:
|
42 |
+
raise ValueError(f"Duplicate audio for {f.path}")
|
43 |
+
|
44 |
+
for key, value in memory.items():
|
45 |
+
if "audio" not in value:
|
46 |
+
raise ValueError(f"Missing audio for {key}")
|
47 |
+
elif "transcript" not in value:
|
48 |
+
raise ValueError(f"Missing transcript for {key}")
|
49 |
+
|
50 |
+
return memory
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return len(self.memory)
|
54 |
+
|
55 |
+
def __getitem__(self, index: int):
|
56 |
+
transcript, audio = self._memory[index]
|
57 |
+
|
58 |
+
with open(transcript, "r") as f:
|
59 |
+
transcript = f.read()
|
60 |
+
|
61 |
+
audio, sample_rate = torchaudio.load(audio)
|
62 |
+
audio = F.resample(audio, sample_rate, self.sample_rate)
|
63 |
+
|
64 |
+
return transcript, audio
|
65 |
+
|
66 |
+
|
67 |
+
class VLSP2020TarDataset:
|
68 |
+
def __init__(self, outpath: str):
|
69 |
+
self.outpath = outpath
|
70 |
+
|
71 |
+
def convert(self, dataset: VLSP2020Dataset):
|
72 |
+
writer = webdataset.TarWriter(self.outpath)
|
73 |
+
|
74 |
+
for idx, (transcript, audio) in enumerate(tqdm(dataset, colour="green")):
|
75 |
+
writer.write(
|
76 |
+
{
|
77 |
+
"__key__": f"{idx:08d}",
|
78 |
+
"txt": transcript,
|
79 |
+
"pth": audio,
|
80 |
+
}
|
81 |
+
)
|
82 |
+
|
83 |
+
writer.close()
|
84 |
+
|
85 |
+
def load(self) -> webdataset.WebDataset:
|
86 |
+
self.data = (
|
87 |
+
webdataset.WebDataset(self.outpath)
|
88 |
+
.decode(
|
89 |
+
webdataset.handle_extension("txt", lambda x: x.decode("utf-8")),
|
90 |
+
webdataset.torch_audio,
|
91 |
+
)
|
92 |
+
.to_tuple("txt", "pth")
|
93 |
+
)
|
94 |
+
|
95 |
+
return self.data
|
96 |
+
|
97 |
+
|
98 |
+
def get_dataloader(
|
99 |
+
dataset: Union[VLSP2020Dataset, webdataset.WebDataset],
|
100 |
+
return_transcript: bool = False,
|
101 |
+
target_transform: Optional[Callable] = None,
|
102 |
+
batch_size: int = 32,
|
103 |
+
num_workers: int = 2,
|
104 |
+
):
|
105 |
+
def collate_fn(batch):
|
106 |
+
def get_audio(item):
|
107 |
+
audio = item[1]
|
108 |
+
|
109 |
+
assert (
|
110 |
+
isinstance(audio, torch.Tensor)
|
111 |
+
and audio.ndim == 2
|
112 |
+
and audio.size(0) == 1
|
113 |
+
)
|
114 |
+
|
115 |
+
return audio.squeeze(0)
|
116 |
+
|
117 |
+
audio = tuple(get_audio(item) for item in batch)
|
118 |
+
|
119 |
+
if return_transcript:
|
120 |
+
if target_transform is not None:
|
121 |
+
transcript = tuple(target_transform(item[0]) for item in batch)
|
122 |
+
else:
|
123 |
+
transcript = tuple(item[0] for item in batch)
|
124 |
+
|
125 |
+
return transcript, audio
|
126 |
+
else:
|
127 |
+
return audio
|
128 |
+
|
129 |
+
return DataLoader(
|
130 |
+
dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn
|
131 |
+
)
|
src/model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .wav2vec2 import Wav2Vec2PretrainingModule
|
src/model/modules/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .context_encoder import ContextEncoder
|
2 |
+
from .feature_extractor import FeatureExtractor
|
3 |
+
from .quantization import QuantizationModule
|
4 |
+
from .processor import Wav2Vec2Processor
|
src/model/modules/context_encoder.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from .transformers import EncoderLayer
|
6 |
+
|
7 |
+
|
8 |
+
class FeatureProjection(nn.Module):
|
9 |
+
def __init__(self, in_features: int, out_features: int, dropout: float = 0.1):
|
10 |
+
"""
|
11 |
+
Projects the extracted features to the encoder dimension.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
x (Tensor): The input features. Shape: (batch, num_frames, in_features)
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
hiddens (Tensor): The latent features. Shape: (batch, num_frames, out_features)
|
18 |
+
"""
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
self.projection = nn.Linear(in_features, out_features)
|
22 |
+
self.layernorm = nn.LayerNorm(in_features)
|
23 |
+
self.dropout = nn.Dropout(dropout)
|
24 |
+
|
25 |
+
def forward(self, x: torch.Tensor):
|
26 |
+
|
27 |
+
hiddens = self.layernorm(x)
|
28 |
+
hiddens = self.projection(x)
|
29 |
+
hiddens = self.dropout(hiddens)
|
30 |
+
return hiddens
|
31 |
+
|
32 |
+
|
33 |
+
class RelativePositionalEmbedding(nn.Module):
|
34 |
+
def __init__(
|
35 |
+
self, d_model: int, kernel_size: int, groups: int, dropout: float = 0.1
|
36 |
+
):
|
37 |
+
"""
|
38 |
+
Args:
|
39 |
+
x (Tensor): The extracted features. Shape: (batch, num_frames, d_model)
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
out (Tensor): The output which encoded the relative positional information. Shape: (batch, num_frames, d_model)
|
43 |
+
"""
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
self.conv = nn.Conv1d(
|
47 |
+
in_channels=d_model,
|
48 |
+
out_channels=d_model,
|
49 |
+
kernel_size=kernel_size,
|
50 |
+
padding=kernel_size // 2,
|
51 |
+
groups=groups,
|
52 |
+
)
|
53 |
+
self.dropout = nn.Dropout(dropout)
|
54 |
+
self.num_remove = 1 if kernel_size % 2 == 0 else 0
|
55 |
+
|
56 |
+
def forward(self, x: torch.Tensor):
|
57 |
+
# (batch, channels=d_model, num_frames)
|
58 |
+
out = x.transpose(1, 2)
|
59 |
+
|
60 |
+
out = self.conv(out)
|
61 |
+
|
62 |
+
if self.num_remove > 0:
|
63 |
+
out = out[..., : -self.num_remove]
|
64 |
+
|
65 |
+
out = F.gelu(out)
|
66 |
+
|
67 |
+
# (batch, num_frames, channels=d_model)
|
68 |
+
out = out.transpose_(1, 2)
|
69 |
+
out = out + x
|
70 |
+
out = self.dropout(out)
|
71 |
+
|
72 |
+
return out
|
73 |
+
|
74 |
+
|
75 |
+
class TranformerEncoder(nn.Module):
|
76 |
+
def __init__(self, config):
|
77 |
+
"""
|
78 |
+
Args:
|
79 |
+
x (Tensor): The extracted features. Shape: (batch, num_frames, d_model)
|
80 |
+
mask (Tensor): The mask for the valid frames. Shape: (batch, num_frames)
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
out (Tensor): The output of the transformer encoder. Shape: (batch, num_frames, d_model)
|
84 |
+
"""
|
85 |
+
super().__init__()
|
86 |
+
|
87 |
+
self.pos_embedding = RelativePositionalEmbedding(**config.pos_embedding)
|
88 |
+
self.layernorm = nn.LayerNorm(config.d_model)
|
89 |
+
self.layer_drop = config.layer_drop
|
90 |
+
|
91 |
+
self.layers = nn.ModuleList(
|
92 |
+
EncoderLayer(**config.layer) for _ in range(config.num_layers)
|
93 |
+
)
|
94 |
+
|
95 |
+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
|
96 |
+
out = self.pos_embedding(x)
|
97 |
+
|
98 |
+
for layer in self.layers:
|
99 |
+
skip_layer = self.training and torch.rand(1).item() < self.layer_drop
|
100 |
+
|
101 |
+
if skip_layer:
|
102 |
+
continue
|
103 |
+
else:
|
104 |
+
out, _ = layer(out, attention_mask=mask)
|
105 |
+
|
106 |
+
out = self.layernorm(out)
|
107 |
+
|
108 |
+
return out
|
109 |
+
|
110 |
+
|
111 |
+
class ContextEncoder(nn.Module):
|
112 |
+
def __init__(self, config):
|
113 |
+
"""
|
114 |
+
Args:
|
115 |
+
x (Tensor): The extracted features. Shape: (batch, num_frames, in_features)
|
116 |
+
attention_mask (BoolTensor): The mask for the valid frames. `True` is invalid. Shape: (batch, num_frames)
|
117 |
+
"""
|
118 |
+
super().__init__()
|
119 |
+
|
120 |
+
self.feature_projection = FeatureProjection(**config.feature_projection)
|
121 |
+
self.encoder = TranformerEncoder(config.encoder)
|
122 |
+
self.masked_spec_embed = nn.Parameter(
|
123 |
+
torch.FloatTensor(config.feature_projection.out_features).uniform_()
|
124 |
+
)
|
125 |
+
|
126 |
+
def forward(
|
127 |
+
self,
|
128 |
+
x: torch.Tensor,
|
129 |
+
attention_mask: torch.Tensor = None,
|
130 |
+
mask_time_indices: torch.Tensor = None,
|
131 |
+
):
|
132 |
+
x = self.feature_projection(x)
|
133 |
+
|
134 |
+
if mask_time_indices is not None:
|
135 |
+
x[mask_time_indices] = self.masked_spec_embed.to(x.dtype)
|
136 |
+
|
137 |
+
if attention_mask is not None:
|
138 |
+
x[attention_mask] = 0.0 # turn invalid frames to zero
|
139 |
+
|
140 |
+
attention_mask = attention_mask[:, None, None, :]
|
141 |
+
# (batch, 1, num_frames, num_frames)
|
142 |
+
# mask = mask[:, None, None, :].repeat(1, 1, mask.size(1), 1) # TODO: check this
|
143 |
+
attention_mask = (
|
144 |
+
torch.maximum(attention_mask, attention_mask.transpose(2, 3)) * -1e6
|
145 |
+
)
|
146 |
+
|
147 |
+
x = self.encoder(x, mask=attention_mask)
|
148 |
+
|
149 |
+
return x
|
src/model/modules/feature_extractor.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class _Conv1DLayer(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
in_channels: int,
|
11 |
+
out_channels: int,
|
12 |
+
kernel_size: int,
|
13 |
+
stride: int,
|
14 |
+
):
|
15 |
+
"""
|
16 |
+
Args:
|
17 |
+
x (Tensor): The ouput. Shape: (batch, in_channels, in_frames)
|
18 |
+
length (Tensor): The valid length of each sample. Shape: (batch)
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
x (Tensor): The output. Shape: (batch, out_channels, out_frames)
|
22 |
+
length (Tensor): The valid length of each sample. Shape: (batch)
|
23 |
+
"""
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.kernel_size = kernel_size
|
27 |
+
self.stride = stride
|
28 |
+
|
29 |
+
self.conv = nn.Conv1d(
|
30 |
+
in_channels=in_channels,
|
31 |
+
out_channels=out_channels,
|
32 |
+
stride=stride,
|
33 |
+
kernel_size=kernel_size,
|
34 |
+
bias=False,
|
35 |
+
)
|
36 |
+
|
37 |
+
self.layernorm = nn.LayerNorm(out_channels)
|
38 |
+
|
39 |
+
def forward(self, x: torch.Tensor, length: torch.Tensor):
|
40 |
+
x = self.conv(x)
|
41 |
+
x = x.transpose_(1, 2)
|
42 |
+
x = self.layernorm(x)
|
43 |
+
x = x.transpose_(1, 2)
|
44 |
+
x = F.gelu(x)
|
45 |
+
|
46 |
+
length = (length - self.kernel_size) // self.stride + 1
|
47 |
+
length = length.clamp_min_(min=0) # prevent negative lengths
|
48 |
+
return x, length
|
49 |
+
|
50 |
+
|
51 |
+
class FeatureExtractor(nn.Module):
|
52 |
+
def __init__(self, config):
|
53 |
+
"""
|
54 |
+
Extracts features from the waveform.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
waveforms (Tensor): The waveform to extract features from. Shape: (batch, wavelength)
|
58 |
+
wavelength (Tensor): The valid length of each waveform. Shape: (batch)
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
features (Tensor): The extracted features. Shape: (batch, num_frames, num_channels)
|
62 |
+
num_frames (Tensor): The valid length of each feature. Shape: (batch)
|
63 |
+
"""
|
64 |
+
super().__init__()
|
65 |
+
|
66 |
+
num_channels = config.num_channels
|
67 |
+
kernel_sizes = config.kernel_sizes
|
68 |
+
strides = config.strides
|
69 |
+
|
70 |
+
assert (
|
71 |
+
len(num_channels) == len(kernel_sizes) == len(strides)
|
72 |
+
), "The number of layers must be the same for all parameters"
|
73 |
+
|
74 |
+
self.conv_layers = nn.ModuleList(
|
75 |
+
(
|
76 |
+
_Conv1DLayer(
|
77 |
+
in_channels=1,
|
78 |
+
out_channels=num_channels[0],
|
79 |
+
kernel_size=kernel_sizes[0],
|
80 |
+
stride=strides[0],
|
81 |
+
),
|
82 |
+
)
|
83 |
+
)
|
84 |
+
|
85 |
+
for i in range(1, len(num_channels)):
|
86 |
+
self.conv_layers.append(
|
87 |
+
_Conv1DLayer(
|
88 |
+
in_channels=num_channels[i - 1],
|
89 |
+
out_channels=num_channels[i],
|
90 |
+
kernel_size=kernel_sizes[i],
|
91 |
+
stride=strides[i],
|
92 |
+
)
|
93 |
+
)
|
94 |
+
|
95 |
+
def forward(self, waveforms: torch.Tensor, wavelength: torch.Tensor):
|
96 |
+
features = waveforms.unsqueeze(1)
|
97 |
+
|
98 |
+
for conv_layer in self.conv_layers:
|
99 |
+
features, wavelength = conv_layer(features, wavelength)
|
100 |
+
|
101 |
+
# (batch, num_channels, num_frames) -> (batch, num_frames, num_channels)
|
102 |
+
features = features.transpose(1, 2)
|
103 |
+
return features, wavelength
|
src/model/modules/processor.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class Wav2Vec2Processor(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
"""
|
9 |
+
Convert tuple of waveforms whose length is different to a batch.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
waveforms (Tuple[torch.Tensor]): The waveforms. Shape: (batch_size, wave_length).
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
waveforms (torch.Tensor): The batched waveforms. Shape: (batch_size, max_wave_length).
|
16 |
+
wave_lengths (torch.Tensor): The wave length of each waveform. Shape: (batch_size,).
|
17 |
+
"""
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
def forward(self, waveforms: Tuple[torch.Tensor, ...]):
|
21 |
+
device = waveforms[0].device
|
22 |
+
wave_lengths = torch.tensor(
|
23 |
+
tuple(waveform.size(0) for waveform in waveforms), device=device
|
24 |
+
)
|
25 |
+
|
26 |
+
max_length = wave_lengths.max().item()
|
27 |
+
|
28 |
+
padded = []
|
29 |
+
|
30 |
+
for waveform in waveforms:
|
31 |
+
padded.append(
|
32 |
+
nn.functional.pad(
|
33 |
+
waveform,
|
34 |
+
(0, max_length - waveform.size(0)),
|
35 |
+
mode="constant",
|
36 |
+
value=0.0,
|
37 |
+
)
|
38 |
+
)
|
39 |
+
|
40 |
+
batched_waveforms = torch.stack(padded, dim=0)
|
41 |
+
|
42 |
+
return batched_waveforms, wave_lengths
|
src/model/modules/quantization.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import einops
|
6 |
+
|
7 |
+
|
8 |
+
class QuantizationModule(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self, config
|
11 |
+
):
|
12 |
+
"""
|
13 |
+
Args:
|
14 |
+
x (Tensor): The extracted features from waveforms. Shape: (batch, num_frames, in_features)
|
15 |
+
mask (BoolTensor): The mask for the valid frames. `True` is invalid. Shape: (batch, num_frames)
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
out (Tensor): The quantized features. Shape: (batch, num_frames, d_model)
|
19 |
+
perplexity (Tensor): The perplexity of the quantized features. Shape: (1)
|
20 |
+
"""
|
21 |
+
super().__init__()
|
22 |
+
|
23 |
+
assert (
|
24 |
+
config.d_model % config.num_codebooks == 0
|
25 |
+
), "d_model must be divisible by num_codebooks"
|
26 |
+
|
27 |
+
self.num_codebooks = config.num_codebooks
|
28 |
+
self.num_codewords = config.num_codewords
|
29 |
+
self.d_model = config.d_model
|
30 |
+
self.codeword_dim = config.d_model // config.num_codebooks
|
31 |
+
|
32 |
+
self.codebooks = self._init_codebooks()
|
33 |
+
|
34 |
+
self.projection = nn.Linear(
|
35 |
+
config.in_features, self.num_codebooks * self.num_codewords
|
36 |
+
)
|
37 |
+
|
38 |
+
self.tau = 1 # temperature factor
|
39 |
+
|
40 |
+
def _init_codebooks(self):
|
41 |
+
codebooks = torch.randn(
|
42 |
+
1, 1, self.num_codebooks, self.num_codewords, self.codeword_dim
|
43 |
+
)
|
44 |
+
nn.init.xavier_uniform_(codebooks)
|
45 |
+
|
46 |
+
return nn.Parameter(codebooks)
|
47 |
+
|
48 |
+
@property
|
49 |
+
def total_codewords(self):
|
50 |
+
return self.num_codebooks * self.num_codewords
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def _compute_perplexity(probs: torch.Tensor, mask: Optional[torch.Tensor] = None):
|
54 |
+
"""
|
55 |
+
Computes the perplexity of the quantized features. (Diversity loss)
|
56 |
+
|
57 |
+
Args:
|
58 |
+
probs (Tensor): The probability distribution of words in codebooks. Shape: (batch, num_frames, num_codebooks, num_codewords)
|
59 |
+
mask (BoolTensor): The mask for the valid frames. `True` is invalid. Shape: (batch, num_frames)
|
60 |
+
"""
|
61 |
+
if mask is not None:
|
62 |
+
probs = (
|
63 |
+
probs * ~mask[..., None, None]
|
64 |
+
) # Turn invalid frames' probability to 0
|
65 |
+
marginal_probs = (
|
66 |
+
einops.reduce(probs, "b nf nb nw -> nb nw", "sum") / mask.sum()
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
marginal_probs = einops.reduce(probs, "b nf nb nw -> nb nw", "mean")
|
70 |
+
|
71 |
+
perplexity = torch.exp(
|
72 |
+
-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)
|
73 |
+
).sum()
|
74 |
+
return perplexity
|
75 |
+
|
76 |
+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
|
77 |
+
batch_size, num_frames, _ = x.shape
|
78 |
+
|
79 |
+
logits = self.projection(x)
|
80 |
+
logits = logits.view(
|
81 |
+
batch_size, num_frames, self.num_codebooks, self.num_codewords
|
82 |
+
)
|
83 |
+
|
84 |
+
if self.training:
|
85 |
+
word_probs = F.gumbel_softmax(logits, tau=self.tau, hard=True, dim=-1)
|
86 |
+
word_soft_probs = F.softmax(logits, dim=-1)
|
87 |
+
|
88 |
+
perplexity = self._compute_perplexity(word_soft_probs, mask=mask)
|
89 |
+
else:
|
90 |
+
word_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
91 |
+
word_probs = torch.zeros_like(logits).scatter_(-1, word_ids, 1.0) # One-hot
|
92 |
+
|
93 |
+
perplexity = self._compute_perplexity(word_probs, mask=mask)
|
94 |
+
|
95 |
+
# (batch, num_frames, num_codebooks, num_codewords, 1) x (1, 1, num_codebooks, num_codewords, codeword_dim)
|
96 |
+
# -> (batch, num_frames, num_codebooks x codeword_dim)
|
97 |
+
quantized = einops.reduce(
|
98 |
+
word_probs.unsqueeze_(-1) * self.codebooks,
|
99 |
+
"b nf nb nw d -> b nf (nb d)",
|
100 |
+
reduction="sum",
|
101 |
+
)
|
102 |
+
|
103 |
+
return quantized, perplexity
|
src/model/modules/transformers.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file contains the implementation of the Transformer Encoder layer.
|
3 |
+
Source: https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/components.py
|
4 |
+
"""
|
5 |
+
from typing import Optional, Tuple
|
6 |
+
import torch
|
7 |
+
from torch import nn, Tensor
|
8 |
+
from torch.nn import Module
|
9 |
+
|
10 |
+
|
11 |
+
class SelfAttention(Module):
|
12 |
+
"""Multihead Self Attention module
|
13 |
+
Args:
|
14 |
+
embed_dim (int): Total dimension of the model.
|
15 |
+
num_heads (int): The number of heads.
|
16 |
+
dropout (float, optional):
|
17 |
+
Dropout probability on attn_output_weights. Default: ``0.0``
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
embed_dim: int,
|
23 |
+
num_heads: int,
|
24 |
+
dropout: float = 0.0,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
head_dim = embed_dim // num_heads
|
28 |
+
if head_dim * num_heads != embed_dim:
|
29 |
+
raise ValueError(
|
30 |
+
f"`embed_dim ({embed_dim})` is not divisible by `num_heads ({num_heads})`"
|
31 |
+
)
|
32 |
+
|
33 |
+
self.embed_dim = embed_dim
|
34 |
+
self.num_heads = num_heads
|
35 |
+
self.dropout = torch.nn.Dropout(dropout)
|
36 |
+
self.head_dim = head_dim
|
37 |
+
|
38 |
+
self.scaling = self.head_dim**-0.5
|
39 |
+
|
40 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
41 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
42 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
43 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
44 |
+
|
45 |
+
def forward(
|
46 |
+
self,
|
47 |
+
x: Tensor,
|
48 |
+
attention_mask: Optional[Tensor] = None,
|
49 |
+
position_bias: Optional[Tensor] = None,
|
50 |
+
key_padding_mask: Optional[Tensor] = None,
|
51 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
52 |
+
"""
|
53 |
+
Args:
|
54 |
+
x (Tensor): shape: ``[batch_size, sequence_length, embed_dim]``.
|
55 |
+
attention_mask (Tensor or ``None``, optional):
|
56 |
+
shape: ``[batch_size, 1, sequence_length, sequence_length]``
|
57 |
+
position_bias: Not used. Only for the compatibility with :py:class:`WavLMSelfAttention`.
|
58 |
+
key_padding_mask (Tensor or ``None``): Not used. Only for the compatibility with
|
59 |
+
:py:class:`WavLMSelfAttention`.
|
60 |
+
Returns:
|
61 |
+
(Tensor, ``None``): The resulting attention output and ``None`` (necessary for compatibility
|
62 |
+
with :py:class:`WavLMSelAttention`).
|
63 |
+
Attention output shape: ``[batch, sequence_length, embed_dim]``.
|
64 |
+
"""
|
65 |
+
if x.ndim != 3 or x.shape[2] != self.embed_dim:
|
66 |
+
raise ValueError(
|
67 |
+
f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). "
|
68 |
+
f"Found {x.shape}."
|
69 |
+
)
|
70 |
+
batch_size, length, embed_dim = x.size()
|
71 |
+
if attention_mask is not None:
|
72 |
+
shape_ = (batch_size, 1, length, length)
|
73 |
+
if attention_mask.size() != shape_:
|
74 |
+
raise ValueError(
|
75 |
+
f"The expected attention mask shape is {shape_}. "
|
76 |
+
f"Found {attention_mask.size()}."
|
77 |
+
)
|
78 |
+
|
79 |
+
shape = (batch_size, length, self.num_heads, self.head_dim)
|
80 |
+
q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
|
81 |
+
k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L
|
82 |
+
v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
|
83 |
+
|
84 |
+
# scale down q to avoid value overflow.
|
85 |
+
weights = (self.scaling * q) @ k # B, nH, L, L
|
86 |
+
if attention_mask is not None:
|
87 |
+
weights += attention_mask
|
88 |
+
# subtracting a constant value from the tensor won't change the output of softmax.
|
89 |
+
# apply the subtraction to avoid value overflow in torch.nn.functional.softmax.
|
90 |
+
# for more details, please see Equation 7 in https://arxiv.org/abs/2112.08778
|
91 |
+
weights = weights - weights.max(dim=-1, keepdim=True)[0]
|
92 |
+
|
93 |
+
weights = torch.nn.functional.softmax(weights, dim=-1)
|
94 |
+
weights = self.dropout(weights)
|
95 |
+
|
96 |
+
output = weights @ v # B, nH, L, Hd
|
97 |
+
output = output.transpose(2, 1).reshape(batch_size, length, embed_dim)
|
98 |
+
|
99 |
+
output = self.out_proj(output)
|
100 |
+
return output, None # Necessary for compatibility with WavLMSelAttention
|
101 |
+
|
102 |
+
|
103 |
+
class FeedForward(Module):
|
104 |
+
"""Layer that follows attention layer in encoder layer."""
|
105 |
+
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
io_features: int,
|
109 |
+
intermediate_features: int,
|
110 |
+
intermediate_dropout: float,
|
111 |
+
output_dropout: float,
|
112 |
+
):
|
113 |
+
super().__init__()
|
114 |
+
self.intermediate_dense = nn.Linear(io_features, intermediate_features)
|
115 |
+
self.intermediate_dropout = nn.Dropout(intermediate_dropout)
|
116 |
+
self.output_dense = nn.Linear(intermediate_features, io_features)
|
117 |
+
self.output_dropout = nn.Dropout(output_dropout)
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
"""
|
121 |
+
Args:
|
122 |
+
x (Tensor): shape: `(batch, sequence_length, io_features)`
|
123 |
+
Returns:
|
124 |
+
x (Tensor): shape: `(batch, sequence_length, io_features)`
|
125 |
+
"""
|
126 |
+
x = self.intermediate_dense(x)
|
127 |
+
x = torch.nn.functional.gelu(x)
|
128 |
+
x = self.intermediate_dropout(x)
|
129 |
+
|
130 |
+
x = self.output_dense(x)
|
131 |
+
x = self.output_dropout(x)
|
132 |
+
return x
|
133 |
+
|
134 |
+
|
135 |
+
class EncoderLayer(Module):
|
136 |
+
"""A layer unit in encoder. Combines multihead self attention and feed forward."""
|
137 |
+
|
138 |
+
def __init__(
|
139 |
+
self,
|
140 |
+
d_model: int,
|
141 |
+
num_heads: int,
|
142 |
+
layer_norm_first: bool,
|
143 |
+
feed_forward_dim: int,
|
144 |
+
dropout: float = 0.1,
|
145 |
+
):
|
146 |
+
super().__init__()
|
147 |
+
self.attention = SelfAttention(
|
148 |
+
embed_dim=d_model,
|
149 |
+
num_heads=num_heads,
|
150 |
+
dropout=dropout,
|
151 |
+
)
|
152 |
+
|
153 |
+
self.dropout = nn.Dropout(dropout)
|
154 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
155 |
+
self.layer_norm_first = layer_norm_first
|
156 |
+
self.feed_forward = FeedForward(d_model, feed_forward_dim, dropout, dropout)
|
157 |
+
self.final_layer_norm = nn.LayerNorm(d_model)
|
158 |
+
|
159 |
+
def forward(
|
160 |
+
self,
|
161 |
+
x: Tensor,
|
162 |
+
attention_mask: Optional[Tensor] = None,
|
163 |
+
position_bias: Optional[Tensor] = None,
|
164 |
+
key_padding_mask: Optional[Tensor] = None,
|
165 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
166 |
+
"""
|
167 |
+
Args:
|
168 |
+
x (Tensor): Input of shape ``(batch, sequence_length, embed_dim)``.
|
169 |
+
attention_mask (Tensor or ``None``, optional): attention mask
|
170 |
+
of shape ``(batch, 1, sequence_length, sequence_length)``. (Default: ``None``)
|
171 |
+
position_bias (Tensor or ``None``, optional): position bias of shape
|
172 |
+
``(batch_size * num_heads, src_len, src_len)``.
|
173 |
+
Only necessary for WavLM model, ``None`` otherwise. (Default: ``None``)
|
174 |
+
key_padding_mask (Tensor or ``None``, optional): key padding mask of shape ``(batch_size, src_len)``.
|
175 |
+
Only used for WavLM model, ignored otherwise. (Default: ``None``)
|
176 |
+
Returns:
|
177 |
+
(x, position_bias): Shapes are the same as in the input. Position bias is only relevant for WaLM model,
|
178 |
+
``None`` otherwise.
|
179 |
+
"""
|
180 |
+
residual = x
|
181 |
+
|
182 |
+
if self.layer_norm_first:
|
183 |
+
x = self.layer_norm(x)
|
184 |
+
|
185 |
+
x, position_bias = self.attention(
|
186 |
+
x,
|
187 |
+
attention_mask=attention_mask,
|
188 |
+
position_bias=position_bias,
|
189 |
+
key_padding_mask=key_padding_mask,
|
190 |
+
)
|
191 |
+
|
192 |
+
x = self.dropout(x)
|
193 |
+
x = residual + x
|
194 |
+
|
195 |
+
if self.layer_norm_first:
|
196 |
+
x = x + self.feed_forward(self.final_layer_norm(x))
|
197 |
+
else:
|
198 |
+
x = self.layer_norm(x)
|
199 |
+
x = self.final_layer_norm(x + self.feed_forward(x))
|
200 |
+
return x, position_bias
|
src/model/wav2vec2.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A wrapper of Wav2Vec2 for training phase.
|
3 |
+
"""
|
4 |
+
from typing import Tuple, Optional
|
5 |
+
import torch
|
6 |
+
from pytorch_lightning import LightningModule
|
7 |
+
import einops
|
8 |
+
from torchmetrics import MeanMetric
|
9 |
+
|
10 |
+
from .modules import (
|
11 |
+
ContextEncoder,
|
12 |
+
FeatureExtractor,
|
13 |
+
QuantizationModule,
|
14 |
+
Wav2Vec2Processor,
|
15 |
+
)
|
16 |
+
from src.utils import init_module_weights
|
17 |
+
|
18 |
+
|
19 |
+
class Wav2Vec2PretrainingModule(LightningModule):
|
20 |
+
def __init__(self, config):
|
21 |
+
super().__init__()
|
22 |
+
|
23 |
+
self.save_hyperparameters(config)
|
24 |
+
|
25 |
+
self.processor = Wav2Vec2Processor()
|
26 |
+
self.context_encoder = ContextEncoder(config.context_encoder)
|
27 |
+
self.feature_extractor = FeatureExtractor(config.feature_extractor)
|
28 |
+
self.quantizer = QuantizationModule(config.quantizer)
|
29 |
+
|
30 |
+
self.train_loss = MeanMetric()
|
31 |
+
|
32 |
+
def forward(self, waveforms: Tuple[torch.Tensor, ...]):
|
33 |
+
"""
|
34 |
+
Args:
|
35 |
+
waveforms (Tuple[torch.Tensor]): The waveforms. Shape: (batch_size, wave_length).
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
loss: The loss of the model. Contrastive loss + Diversity loss.
|
39 |
+
"""
|
40 |
+
waveforms, wave_lengths = self.processor(waveforms)
|
41 |
+
|
42 |
+
# features.shape == (batch_size, num_frames, hidden_size)
|
43 |
+
features, num_frames = self.feature_extractor(waveforms, wave_lengths)
|
44 |
+
|
45 |
+
attention_mask = self._compute_attention_mask(num_frames)
|
46 |
+
mask_time_indices = self._compute_mask_span(
|
47 |
+
shape=features.shape[:-1],
|
48 |
+
mask_prob=self.hparams.mask_prob,
|
49 |
+
mask_length=self.hparams.mask_length,
|
50 |
+
attention_mask=attention_mask,
|
51 |
+
device=features.device,
|
52 |
+
min_masks=self.hparams.min_masks,
|
53 |
+
)
|
54 |
+
|
55 |
+
context_features = self.context_encoder(
|
56 |
+
features, attention_mask=attention_mask, mask_time_indices=mask_time_indices
|
57 |
+
)
|
58 |
+
|
59 |
+
quantized_features, perplexity = self.quantizer(features, attention_mask)
|
60 |
+
|
61 |
+
negative_quantized_features = self._sample_negatives(
|
62 |
+
quantized_features,
|
63 |
+
num_negatives=self.hparams.num_negatives,
|
64 |
+
attention_mask=attention_mask,
|
65 |
+
)
|
66 |
+
|
67 |
+
# (batch_size, num_frames, num_negatives + 1)
|
68 |
+
contrastive_logits = self._compute_contrastive_logits(
|
69 |
+
context_features,
|
70 |
+
quantized_features,
|
71 |
+
negative_quantized_features,
|
72 |
+
self.hparams.contrastive_logits_temperature,
|
73 |
+
).flatten(0, -2)
|
74 |
+
|
75 |
+
# compute contrastive loss
|
76 |
+
# positive indices are always the first one
|
77 |
+
targets = (1 - mask_time_indices.long().flatten()) * -100
|
78 |
+
|
79 |
+
contrastive_loss = torch.nn.functional.cross_entropy(
|
80 |
+
contrastive_logits, targets, reduction="sum"
|
81 |
+
)
|
82 |
+
|
83 |
+
# compute diversity loss
|
84 |
+
diversity_loss = 1 - perplexity / self.quantizer.total_codewords
|
85 |
+
|
86 |
+
loss = contrastive_loss + diversity_loss * self.hparams.diversity_loss_weight
|
87 |
+
|
88 |
+
return loss
|
89 |
+
|
90 |
+
@staticmethod
|
91 |
+
def _sample_negatives(
|
92 |
+
features: torch.Tensor,
|
93 |
+
num_negatives: int,
|
94 |
+
attention_mask: Optional[torch.Tensor] = None,
|
95 |
+
):
|
96 |
+
"""
|
97 |
+
Sampling negative features from quantized features to compute the contrastive loss.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
features (torch.Tensor): The quantized features. Shape: (batch_size, num_frames, d_model).
|
101 |
+
num_negatives (int): The number of negative samples.
|
102 |
+
attention_mask (Optional[torch.Tensor]): The mask for valid frames. `True` is invalid. Shape: (batch_size, num_frames).
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
sampled_negatives (torch.Tensor): The sampled negative features. Shape: (batch_size, num_frames, num_negatives, d_model).
|
106 |
+
"""
|
107 |
+
|
108 |
+
batch_size, num_frames, d_model = features.shape
|
109 |
+
|
110 |
+
features = features.view(-1, d_model) # (batch_size * num_frames, d_model)
|
111 |
+
|
112 |
+
with torch.no_grad():
|
113 |
+
sampled_ids = []
|
114 |
+
|
115 |
+
for batch_idx in range(batch_size):
|
116 |
+
num_valid_frames = (
|
117 |
+
features.size(1)
|
118 |
+
if attention_mask is None
|
119 |
+
else (1 - attention_mask[batch_idx].long()).sum()
|
120 |
+
).item()
|
121 |
+
|
122 |
+
sampled_ids.append(
|
123 |
+
torch.randint(
|
124 |
+
0,
|
125 |
+
num_valid_frames - 1,
|
126 |
+
(num_frames * num_negatives,),
|
127 |
+
device=features.device,
|
128 |
+
)
|
129 |
+
)
|
130 |
+
|
131 |
+
sampled_ids = torch.stack(
|
132 |
+
sampled_ids, dim=0
|
133 |
+
) # (batch_size, num_frames * num_negatives)
|
134 |
+
|
135 |
+
feature_ids = einops.repeat(
|
136 |
+
torch.arange(num_frames, device=features.device),
|
137 |
+
"f -> (f n)",
|
138 |
+
n=num_negatives,
|
139 |
+
)
|
140 |
+
|
141 |
+
# avoid sampling the same positive vector, but keep the distribution uniform
|
142 |
+
sampled_ids[sampled_ids >= feature_ids] += 1
|
143 |
+
|
144 |
+
# correct for batch size
|
145 |
+
# E.g [[0, 1, 2], [0, 1, 2]] -> [0, 1, 2, 3, 4, 5]
|
146 |
+
sampled_ids += torch.arange(
|
147 |
+
0, batch_size * num_frames, num_frames, device=features.device
|
148 |
+
).unsqueeze_(-1)
|
149 |
+
|
150 |
+
sampled_negatives = features[sampled_ids.view(-1)]
|
151 |
+
sampled_negatives = einops.rearrange(
|
152 |
+
sampled_negatives,
|
153 |
+
"(b f n) d -> b f n d",
|
154 |
+
b=batch_size,
|
155 |
+
f=num_frames,
|
156 |
+
n=num_negatives,
|
157 |
+
)
|
158 |
+
|
159 |
+
return sampled_negatives
|
160 |
+
|
161 |
+
@staticmethod
|
162 |
+
def _compute_contrastive_logits(
|
163 |
+
predicted_features: torch.Tensor,
|
164 |
+
target_features: torch.Tensor,
|
165 |
+
negative_features: torch.Tensor,
|
166 |
+
temperature: int = 1,
|
167 |
+
):
|
168 |
+
"""
|
169 |
+
Compute the logits for contrastive loss.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
predicted_features (torch.Tensor): The predicted features. Shape: (batch_size, num_frames, d_model).
|
173 |
+
target_features (torch.Tensor): The target features. Shape: (batch_size, num_frames, d_model).
|
174 |
+
negative_features (torch.Tensor): The negative features. Shape: (batch_size, num_frames, num_negatives, d_model).
|
175 |
+
temperature (int): The temperature for contrastive loss.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
logits (torch.Tensor): The logits for contrastive loss. Shape: (batch_size, num_frames, num_negatives + 1).
|
179 |
+
"""
|
180 |
+
|
181 |
+
# (batch_size, num_frames, num_negatives + 1, d_model)
|
182 |
+
target_features = torch.cat(
|
183 |
+
(target_features.unsqueeze_(2), negative_features), dim=2
|
184 |
+
)
|
185 |
+
|
186 |
+
# (batch_size, num_frames, 1, d_model)
|
187 |
+
predicted_features = predicted_features.unsqueeze_(2)
|
188 |
+
|
189 |
+
# (batch_size, num_frames, num_negatives + 1)
|
190 |
+
logits = torch.cosine_similarity(predicted_features, target_features, dim=-1)
|
191 |
+
logits /= temperature
|
192 |
+
|
193 |
+
return logits
|
194 |
+
|
195 |
+
@staticmethod
|
196 |
+
def _compute_mask_span(
|
197 |
+
shape: Tuple[int, int],
|
198 |
+
mask_prob: float = 0.065,
|
199 |
+
mask_length: int = 10,
|
200 |
+
attention_mask: Optional[torch.Tensor] = None,
|
201 |
+
device: torch.device = torch.device("cpu"),
|
202 |
+
min_masks: int = 0,
|
203 |
+
):
|
204 |
+
"""
|
205 |
+
Compute the mask span for contrastive task.
|
206 |
+
|
207 |
+
Args:
|
208 |
+
shape (Tuple[int, int]): The shape of the mask span. Shape: (batch_size, num_frames).
|
209 |
+
mask_prob (float): The probability of choosing a frame to be the start of masking position.
|
210 |
+
mask_length (int): The length of the mask span.
|
211 |
+
attention_mask (Optional[torch.Tensor]): The mask for valid frames. `True` is invalid. Shape: (batch_size, num_frames).
|
212 |
+
device (torch.device): The device of the mask span.
|
213 |
+
min_masks (int): The minimum number of masks.
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
mask_span (torch.Tensor): The mask span. Shape: (batch_size, num_frames).
|
217 |
+
"""
|
218 |
+
|
219 |
+
batch_size, num_frames = shape
|
220 |
+
|
221 |
+
# NOTE: num_frames / mask_length: the number of spans in one waveform
|
222 |
+
num_masked_spans = int(
|
223 |
+
mask_prob * num_frames / mask_length + torch.rand(1).item()
|
224 |
+
)
|
225 |
+
num_masked_spans = max(num_masked_spans, min_masks)
|
226 |
+
|
227 |
+
# make sure num masked indices <= num frames
|
228 |
+
if num_masked_spans * mask_length > num_frames:
|
229 |
+
num_masked_spans = num_frames // mask_length
|
230 |
+
|
231 |
+
# uniform distribution to sample from
|
232 |
+
# NOTE: num_frames - (mask_length - 1): the number of start positions of the span
|
233 |
+
uniform_dist = torch.ones(
|
234 |
+
(batch_size, num_frames - (mask_length - 1)), device=device
|
235 |
+
)
|
236 |
+
|
237 |
+
# (batch_size, num_masked_spans)
|
238 |
+
mask_span_ids = torch.multinomial(uniform_dist, num_masked_spans)
|
239 |
+
|
240 |
+
# (batch_size, num_masked_spans * mask_length)
|
241 |
+
mask_span_ids = einops.repeat(mask_span_ids, "b n -> b (n l)", l=mask_length)
|
242 |
+
|
243 |
+
offsets = einops.repeat(
|
244 |
+
torch.arange(mask_length, device=device),
|
245 |
+
"l -> b (n l)",
|
246 |
+
b=batch_size,
|
247 |
+
n=num_masked_spans,
|
248 |
+
)
|
249 |
+
|
250 |
+
mask_span_ids = mask_span_ids + offsets
|
251 |
+
|
252 |
+
mask_span = torch.zeros(shape, device=device, dtype=torch.bool)
|
253 |
+
mask_span = mask_span.scatter_(1, mask_span_ids, True)
|
254 |
+
|
255 |
+
if attention_mask is not None:
|
256 |
+
# Make sure the invalid frames are not masked
|
257 |
+
mask_span = torch.where(attention_mask.bool(), mask_span, False)
|
258 |
+
|
259 |
+
return mask_span
|
260 |
+
|
261 |
+
@staticmethod
|
262 |
+
def _compute_attention_mask(length: torch.Tensor):
|
263 |
+
"""
|
264 |
+
Args:
|
265 |
+
length (Tensor): The length of valid frames. Shape: (batch)
|
266 |
+
max_length (int): The maximum length of the frames.
|
267 |
+
|
268 |
+
Returns:
|
269 |
+
attention_mask (BoolTensor): The mask for the valid frames. `True` is invalid. Shape: (batch, num_frames)
|
270 |
+
"""
|
271 |
+
max_length = length.max().item()
|
272 |
+
|
273 |
+
mask = (
|
274 |
+
torch.arange(max_length, device=length.device).expand(
|
275 |
+
length.size(0), max_length
|
276 |
+
)
|
277 |
+
>= length[:, None]
|
278 |
+
)
|
279 |
+
|
280 |
+
return mask
|
281 |
+
|
282 |
+
def training_step(self, batch, batch_idx):
|
283 |
+
loss = self(batch)
|
284 |
+
|
285 |
+
self.train_loss(loss)
|
286 |
+
|
287 |
+
if batch_idx % 100 == 0:
|
288 |
+
self.log("train/loss", self.train_loss, on_step=True, on_epoch=True)
|
289 |
+
|
290 |
+
return loss
|
291 |
+
|
292 |
+
def configure_optimizers(self):
|
293 |
+
return torch.optim.AdamW(self.parameters(), lr=1e-4)
|
src/train.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append(".")
|
3 |
+
|
4 |
+
from src.config import model as conf
|
5 |
+
from src.model import Wav2Vec2PretrainingModule
|
6 |
+
from src.datamodule import WebDatasetConverter, VLSP2020ForPretrainingDataModule
|
7 |
+
from pytorch_lightning import Trainer
|
8 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
9 |
+
|
10 |
+
|
11 |
+
if __name__ == "__main__":
|
12 |
+
|
13 |
+
model = Wav2Vec2PretrainingModule(conf.wav2vec2_pretraining)
|
14 |
+
dts = WebDatasetConverter(conf.dataset.path).get_dataset()
|
15 |
+
dtm = VLSP2020ForPretrainingDataModule(dts, **conf.dataset)
|
16 |
+
trainer = Trainer(
|
17 |
+
callbacks=[
|
18 |
+
ModelCheckpoint(
|
19 |
+
monitor="val/loss",
|
20 |
+
dirpath=conf["checkpoint_dir"],
|
21 |
+
)
|
22 |
+
],
|
23 |
+
gradient_clip_val=1.0,
|
24 |
+
accelerator="gpu"
|
25 |
+
)
|
26 |
+
|
27 |
+
trainer.fit(model, dtm)
|
src/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .functional import init_module_weights
|
src/utils/functional.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def init_module_weights(module):
|
5 |
+
"""Initialize the weights"""
|
6 |
+
|
7 |
+
from src.model.modules import QuantizationModule
|
8 |
+
|
9 |
+
# gumbel softmax requires special init
|
10 |
+
if isinstance(module, QuantizationModule):
|
11 |
+
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
|
12 |
+
module.weight_proj.bias.data.zero_()
|
13 |
+
torch.nn.init.uniform_(module.codebooks)
|
14 |
+
elif isinstance(module, torch.nn.Linear):
|
15 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
16 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
17 |
+
module.weight.data.normal_(mean=0.0, std=0.5)
|
18 |
+
elif isinstance(module, (torch.nn.LayerNorm, torch.nn.GroupNorm)):
|
19 |
+
module.bias.data.zero_()
|
20 |
+
module.weight.data.fill_(1.0)
|
21 |
+
elif isinstance(module, torch.nn.Conv1d):
|
22 |
+
torch.nn.init.kaiming_normal_(module.weight.data)
|
23 |
+
|
24 |
+
if (
|
25 |
+
isinstance(module, (torch.nn.Linear, torch.nn.Conv1d))
|
26 |
+
and module.bias is not None
|
27 |
+
):
|
28 |
+
module.bias.data.zero_()
|
src/utils/metrics.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Union
|
2 |
+
import re
|
3 |
+
|
4 |
+
|
5 |
+
def levenshtein_distance(source: Tuple[str], target: Tuple[str]):
|
6 |
+
"""
|
7 |
+
Compute the Levenshtein distance between two sequences.
|
8 |
+
"""
|
9 |
+
|
10 |
+
n, m = len(source), len(target)
|
11 |
+
if n > m:
|
12 |
+
# Make sure n <= m, to use O(min(n,m)) space
|
13 |
+
source, target = target, source
|
14 |
+
n, m = m, n
|
15 |
+
|
16 |
+
current_row = range(n + 1) # Keep current and previous row, not entire matrix
|
17 |
+
for i in range(1, m + 1):
|
18 |
+
previous_row, current_row = current_row, [i] + [0] * n
|
19 |
+
for j in range(1, n + 1):
|
20 |
+
add, delete, change = (
|
21 |
+
previous_row[j] + 1,
|
22 |
+
current_row[j - 1] + 1,
|
23 |
+
previous_row[j - 1],
|
24 |
+
)
|
25 |
+
if source[j - 1] != target[i - 1]:
|
26 |
+
change += 1
|
27 |
+
current_row[j] = min(add, delete, change)
|
28 |
+
|
29 |
+
distance = current_row[n]
|
30 |
+
|
31 |
+
del current_row
|
32 |
+
del previous_row
|
33 |
+
|
34 |
+
return distance
|
35 |
+
|
36 |
+
|
37 |
+
def word_error_rate(
|
38 |
+
predicted: Union[str, Tuple[str]], transcript: Union[str, Tuple[str]]
|
39 |
+
):
|
40 |
+
if isinstance(predicted, str):
|
41 |
+
predicted = (predicted,)
|
42 |
+
if isinstance(transcript, str):
|
43 |
+
transcript = (transcript,)
|
44 |
+
|
45 |
+
pattern = r"\W+"
|
46 |
+
|
47 |
+
err, total = 0, 0
|
48 |
+
|
49 |
+
for pred, tgt in zip(predicted, transcript):
|
50 |
+
pred_tokens = re.split(pattern, pred)
|
51 |
+
tgt_tokens = re.split(pattern, tgt)
|
52 |
+
err += levenshtein_distance(pred_tokens, tgt_tokens)
|
53 |
+
total += len(tgt_tokens)
|
54 |
+
|
55 |
+
return err / total
|
56 |
+
|
57 |
+
|
58 |
+
def character_error_rate(
|
59 |
+
predicted: Union[str, Tuple[str]], transcript: Union[str, Tuple[str]]
|
60 |
+
):
|
61 |
+
if isinstance(predicted, str):
|
62 |
+
predicted = (predicted,)
|
63 |
+
if isinstance(transcript, str):
|
64 |
+
transcript = (transcript,)
|
65 |
+
|
66 |
+
err, total = 0, 0
|
67 |
+
|
68 |
+
for pred, tgt in zip(predicted, transcript):
|
69 |
+
err += levenshtein_distance(pred, tgt)
|
70 |
+
total += len(tgt)
|
71 |
+
|
72 |
+
return err / total
|
src/utils/scheduler.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
3 |
+
|
4 |
+
|
5 |
+
class WarmUpScheduler(_LRScheduler):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
optimizer,
|
9 |
+
warmup_steps: int,
|
10 |
+
feature_size: int,
|
11 |
+
factor: float = 1.0,
|
12 |
+
last_epoch=-1,
|
13 |
+
):
|
14 |
+
self.warmup_steps = warmup_steps
|
15 |
+
self.feature_size = feature_size
|
16 |
+
self.factor = factor
|
17 |
+
super().__init__(optimizer, last_epoch)
|
18 |
+
|
19 |
+
def get_lr(self):
|
20 |
+
lr = self._compute_lr()
|
21 |
+
return [lr] * len(self.base_lrs)
|
22 |
+
|
23 |
+
def _compute_lr(self):
|
24 |
+
if self.last_epoch == 0:
|
25 |
+
return 0.0
|
26 |
+
|
27 |
+
lr = (self.feature_size ** (-0.5)) * min(
|
28 |
+
self.last_epoch ** (-0.5), self.last_epoch * self.warmup_steps ** (-1.5)
|
29 |
+
)
|
30 |
+
|
31 |
+
return lr * self.factor
|
32 |
+
|
33 |
+
|
34 |
+
class TriStateScheduler(_LRScheduler):
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
optimizer,
|
38 |
+
total_steps: int,
|
39 |
+
warmup_steps: int,
|
40 |
+
constant_steps: int,
|
41 |
+
factor: float = 0.3,
|
42 |
+
last_epoch: int = -1,
|
43 |
+
):
|
44 |
+
self.warmup_steps = warmup_steps
|
45 |
+
self.constant_steps = constant_steps
|
46 |
+
self.total_steps = total_steps
|
47 |
+
self.factor = factor
|
48 |
+
|
49 |
+
super().__init__(optimizer, last_epoch)
|
50 |
+
|
51 |
+
def get_lr(self):
|
52 |
+
if not hasattr(self, "eta_min"):
|
53 |
+
self.eta_max = self.base_lrs.copy()
|
54 |
+
self.eta_min = [eta_max * self.factor for eta_max in self.eta_max]
|
55 |
+
|
56 |
+
return [
|
57 |
+
self._compute_lr(group["lr"], eta_min, eta_max)
|
58 |
+
for group, eta_min, eta_max in zip(
|
59 |
+
self.optimizer.param_groups, self.eta_min, self.eta_max
|
60 |
+
)
|
61 |
+
]
|
62 |
+
|
63 |
+
def _compute_lr(self, prev_lr: float, eta_min: float, eta_max: float):
|
64 |
+
# first stage
|
65 |
+
if self.last_epoch <= self.warmup_steps:
|
66 |
+
lr = eta_max - 0.5 * (eta_max - eta_min) * (
|
67 |
+
1 + math.cos(math.pi * self.last_epoch / self.warmup_steps)
|
68 |
+
)
|
69 |
+
# second stage
|
70 |
+
elif self.last_epoch <= self.warmup_steps + self.constant_steps:
|
71 |
+
lr = prev_lr
|
72 |
+
else:
|
73 |
+
# third stage
|
74 |
+
decay_steps = self.total_steps - self.warmup_steps - self.constant_steps
|
75 |
+
k = self.last_epoch - self.warmup_steps - self.constant_steps
|
76 |
+
lr = eta_min + 0.5 * (eta_max - eta_min) * (
|
77 |
+
1 + math.cos(math.pi * k / decay_steps)
|
78 |
+
)
|
79 |
+
|
80 |
+
return lr
|
81 |
+
|
82 |
+
def state_dict(self) -> dict:
|
83 |
+
return super().state_dict()
|