Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- ASR/.ipynb_checkpoints/audio_tokenizer-checkpoint.py +611 -0
- ASR/.ipynb_checkpoints/demo-checkpoint.ipynb +849 -0
- ASR/.ipynb_checkpoints/demo-checkpoint.py +24 -0
- ASR/.ipynb_checkpoints/tokenizer_training-checkpoint.ipynb +203 -0
- ASR/__pycache__/audio_tokenizer.cpython-38.pyc +0 -0
- ASR/__pycache__/tokenizer.cpython-38.pyc +0 -0
- ASR/audio_tokenizer.py +611 -0
- ASR/demo.ipynb +878 -0
- ASR/demo.py +24 -0
- ASR/repcodec/.ipynb_checkpoints/RepCodec-checkpoint.py +84 -0
- ASR/repcodec/RepCodec.py +84 -0
- ASR/repcodec/__pycache__/RepCodec.cpython-38.pyc +0 -0
- ASR/repcodec/configs/repcodec_dim1024.yaml +18 -0
- ASR/repcodec/configs/repcodec_dim1280.yaml +18 -0
- ASR/repcodec/configs/repcodec_dim768.yaml +18 -0
- ASR/repcodec/layers/__pycache__/conv_layer.cpython-38.pyc +0 -0
- ASR/repcodec/layers/__pycache__/vq_module.cpython-38.pyc +0 -0
- ASR/repcodec/layers/conv_layer.py +95 -0
- ASR/repcodec/layers/vq_module.py +155 -0
- ASR/repcodec/modules/__pycache__/decoder.cpython-38.pyc +0 -0
- ASR/repcodec/modules/__pycache__/encoder.cpython-38.pyc +0 -0
- ASR/repcodec/modules/__pycache__/projector.cpython-38.pyc +0 -0
- ASR/repcodec/modules/__pycache__/quantizer.cpython-38.pyc +0 -0
- ASR/repcodec/modules/__pycache__/residual_unit.cpython-38.pyc +0 -0
- ASR/repcodec/modules/decoder.py +109 -0
- ASR/repcodec/modules/encoder.py +89 -0
- ASR/repcodec/modules/projector.py +32 -0
- ASR/repcodec/modules/quantizer.py +46 -0
- ASR/repcodec/modules/residual_unit.py +39 -0
- ASR/repcodec/tokenize.py +212 -0
- ASR/test-gpt2-opt.onnx +3 -0
- ASR/test-gpt2.onnx +3 -0
- ASR/test-gpt2.plan +3 -0
- ASR/tokenized_librispeech/dataset_dict.json +1 -0
- ASR/tokenized_librispeech/test.clean/data-00000-of-00001.arrow +3 -0
- ASR/tokenized_librispeech/test.clean/dataset_info.json +29 -0
- ASR/tokenized_librispeech/test.clean/state.json +13 -0
- ASR/tokenized_librispeech/test.other/data-00000-of-00001.arrow +3 -0
- ASR/tokenized_librispeech/test.other/dataset_info.json +29 -0
- ASR/tokenized_librispeech/test.other/state.json +13 -0
- ASR/tokenized_librispeech/train.clean.100/data-00000-of-00001.arrow +3 -0
- ASR/tokenized_librispeech/train.clean.100/dataset_info.json +29 -0
- ASR/tokenized_librispeech/train.clean.100/state.json +13 -0
- ASR/tokenized_librispeech/train.clean.360/data-00000-of-00003.arrow +3 -0
- ASR/tokenized_librispeech/train.clean.360/data-00001-of-00003.arrow +3 -0
- ASR/tokenized_librispeech/train.clean.360/data-00002-of-00003.arrow +3 -0
- ASR/tokenized_librispeech/train.clean.360/dataset_info.json +29 -0
- ASR/tokenized_librispeech/train.clean.360/state.json +19 -0
- ASR/tokenized_librispeech/train.other.500/data-00000-of-00004.arrow +3 -0
.gitattributes
CHANGED
@@ -40,3 +40,7 @@ prompting/train_data/train.other.500.json filter=lfs diff=lfs merge=lfs -text
|
|
40 |
prompting/transcripts/train.clean.360.txt filter=lfs diff=lfs merge=lfs -text
|
41 |
prompting/transcripts/train.other.500.txt filter=lfs diff=lfs merge=lfs -text
|
42 |
prompting/wandb/run-20240615_114519-wfpe2teb/run-wfpe2teb.wandb filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
40 |
prompting/transcripts/train.clean.360.txt filter=lfs diff=lfs merge=lfs -text
|
41 |
prompting/transcripts/train.other.500.txt filter=lfs diff=lfs merge=lfs -text
|
42 |
prompting/wandb/run-20240615_114519-wfpe2teb/run-wfpe2teb.wandb filter=lfs diff=lfs merge=lfs -text
|
43 |
+
ASR/test-gpt2.plan filter=lfs diff=lfs merge=lfs -text
|
44 |
+
ASR/transformer-deploy/docs/infinity/infinity.xcf filter=lfs diff=lfs merge=lfs -text
|
45 |
+
ASR/transformer-deploy/resources/img/export_process.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
ASR/transformer-deploy/resources/img/gpt2.png filter=lfs diff=lfs merge=lfs -text
|
ASR/.ipynb_checkpoints/audio_tokenizer-checkpoint.py
ADDED
@@ -0,0 +1,611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
from omegaconf import II
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.distributed as dist
|
12 |
+
|
13 |
+
from fairseq.modules import EMAModule, EMAModuleConfig
|
14 |
+
from fairseq.data.data_utils import compute_mask_indices
|
15 |
+
from fairseq.models import BaseFairseqModel, register_model
|
16 |
+
from fairseq.models.wav2vec import (
|
17 |
+
ConvFeatureExtractionModel,
|
18 |
+
Wav2Vec2Config,
|
19 |
+
TransformerEncoder,
|
20 |
+
)
|
21 |
+
from fairseq.modules import (
|
22 |
+
GradMultiply,
|
23 |
+
LayerNorm,
|
24 |
+
)
|
25 |
+
from fairseq.utils import index_put
|
26 |
+
|
27 |
+
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
@dataclass
|
32 |
+
class Data2VecAudioConfig(Wav2Vec2Config):
|
33 |
+
|
34 |
+
loss_beta: float = field(
|
35 |
+
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
|
36 |
+
)
|
37 |
+
loss_scale: Optional[float] = field(
|
38 |
+
default=None,
|
39 |
+
metadata={
|
40 |
+
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
|
41 |
+
},
|
42 |
+
)
|
43 |
+
average_top_k_layers: int = field(
|
44 |
+
default=8, metadata={"help": "how many layers to average"}
|
45 |
+
)
|
46 |
+
|
47 |
+
layer_norm_target_layer: bool = False
|
48 |
+
instance_norm_target_layer: bool = False
|
49 |
+
instance_norm_targets: bool = False
|
50 |
+
layer_norm_targets: bool = False
|
51 |
+
batch_norm_target_layer: bool = False
|
52 |
+
group_norm_target_layer: bool = False
|
53 |
+
|
54 |
+
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
|
55 |
+
ema_end_decay: float = field(
|
56 |
+
default=0.9999, metadata={"help": "final ema decay rate"}
|
57 |
+
)
|
58 |
+
|
59 |
+
# when to finish annealing ema decay rate
|
60 |
+
ema_anneal_end_step: int = II("optimization.max_update")
|
61 |
+
|
62 |
+
ema_transformer_only: bool = field(
|
63 |
+
default=True,
|
64 |
+
metadata={"help": "whether to momentum update only the transformer"},
|
65 |
+
)
|
66 |
+
ema_layers_only: bool = field(
|
67 |
+
default=True,
|
68 |
+
metadata={"help": "whether to momentum update only the transformer layers"},
|
69 |
+
)
|
70 |
+
|
71 |
+
max_update: int = II("optimization.max_update")
|
72 |
+
|
73 |
+
min_target_var: float = field(
|
74 |
+
default=0.1, metadata={"help": "stop training if target var falls below this"}
|
75 |
+
)
|
76 |
+
min_pred_var: float = field(
|
77 |
+
default=0.01,
|
78 |
+
metadata={"help": "stop training if prediction var falls below this"},
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
def get_annealed_rate(start, end, curr_step, total_steps):
|
83 |
+
r = end - start
|
84 |
+
pct_remaining = 1 - curr_step / total_steps
|
85 |
+
return end - r * pct_remaining
|
86 |
+
|
87 |
+
|
88 |
+
@register_model("data2vec_audio", dataclass=Data2VecAudioConfig)
|
89 |
+
class Data2VecAudioModel(BaseFairseqModel):
|
90 |
+
def __init__(self, cfg: Data2VecAudioConfig):
|
91 |
+
super().__init__()
|
92 |
+
self.cfg = cfg
|
93 |
+
|
94 |
+
feature_enc_layers = eval(cfg.conv_feature_layers)
|
95 |
+
self.extractor_embed = feature_enc_layers[-1][0]
|
96 |
+
|
97 |
+
self.ema = None
|
98 |
+
self.embed = cfg.encoder_embed_dim
|
99 |
+
|
100 |
+
self.average_top_k_layers = cfg.average_top_k_layers
|
101 |
+
self.loss_beta = cfg.loss_beta
|
102 |
+
self.loss_scale = cfg.loss_scale
|
103 |
+
|
104 |
+
self.feature_extractor = ConvFeatureExtractionModel(
|
105 |
+
conv_layers=feature_enc_layers,
|
106 |
+
dropout=0.0,
|
107 |
+
mode=cfg.extractor_mode,
|
108 |
+
conv_bias=cfg.conv_bias,
|
109 |
+
)
|
110 |
+
|
111 |
+
self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim)
|
112 |
+
|
113 |
+
self.mask_prob = cfg.mask_prob
|
114 |
+
self.mask_selection = cfg.mask_selection
|
115 |
+
self.mask_other = cfg.mask_other
|
116 |
+
self.mask_length = cfg.mask_length
|
117 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
118 |
+
self.mask_min_space = cfg.mask_min_space
|
119 |
+
|
120 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
121 |
+
self.mask_channel_before = cfg.mask_channel_before
|
122 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
123 |
+
self.mask_channel_other = cfg.mask_channel_other
|
124 |
+
self.mask_channel_length = cfg.mask_channel_length
|
125 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
126 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
127 |
+
|
128 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
129 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
130 |
+
|
131 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
132 |
+
|
133 |
+
self.mask_emb = nn.Parameter(
|
134 |
+
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
135 |
+
)
|
136 |
+
|
137 |
+
self.encoder = TransformerEncoder(cfg)
|
138 |
+
self.layer_norm = LayerNorm(self.extractor_embed)
|
139 |
+
|
140 |
+
self.final_proj = nn.Linear(self.embed, self.embed)
|
141 |
+
|
142 |
+
self.num_updates = 0
|
143 |
+
|
144 |
+
def make_ema_teacher(self):
|
145 |
+
ema_config = EMAModuleConfig(
|
146 |
+
ema_decay=self.cfg.ema_decay,
|
147 |
+
ema_fp32=True,
|
148 |
+
)
|
149 |
+
skip_keys = set()
|
150 |
+
if self.cfg.ema_layers_only:
|
151 |
+
self.cfg.ema_transformer_only = True
|
152 |
+
for k, _ in self.encoder.pos_conv.named_parameters():
|
153 |
+
skip_keys.add(f"pos_conv.{k}")
|
154 |
+
|
155 |
+
self.ema = EMAModule(
|
156 |
+
self.encoder if self.cfg.ema_transformer_only else self,
|
157 |
+
ema_config,
|
158 |
+
skip_keys=skip_keys,
|
159 |
+
)
|
160 |
+
|
161 |
+
def set_num_updates(self, num_updates):
|
162 |
+
super().set_num_updates(num_updates)
|
163 |
+
|
164 |
+
if self.ema is None and self.final_proj is not None:
|
165 |
+
logger.info(f"making ema teacher")
|
166 |
+
self.make_ema_teacher()
|
167 |
+
elif self.training and self.ema is not None:
|
168 |
+
if self.cfg.ema_decay != self.cfg.ema_end_decay:
|
169 |
+
if num_updates >= self.cfg.ema_anneal_end_step:
|
170 |
+
decay = self.cfg.ema_end_decay
|
171 |
+
else:
|
172 |
+
decay = get_annealed_rate(
|
173 |
+
self.cfg.ema_decay,
|
174 |
+
self.cfg.ema_end_decay,
|
175 |
+
num_updates,
|
176 |
+
self.cfg.ema_anneal_end_step,
|
177 |
+
)
|
178 |
+
self.ema.set_decay(decay)
|
179 |
+
if self.ema.get_decay() < 1:
|
180 |
+
self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
|
181 |
+
|
182 |
+
self.num_updates = num_updates
|
183 |
+
|
184 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
185 |
+
state = super().state_dict(destination, prefix, keep_vars)
|
186 |
+
|
187 |
+
if self.ema is not None:
|
188 |
+
state[prefix + "_ema"] = self.ema.fp32_params
|
189 |
+
|
190 |
+
return state
|
191 |
+
|
192 |
+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
193 |
+
if self.ema is not None:
|
194 |
+
k = prefix + "_ema"
|
195 |
+
assert k in state_dict
|
196 |
+
self.ema.restore(state_dict[k], True)
|
197 |
+
del state_dict[k]
|
198 |
+
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
199 |
+
|
200 |
+
@classmethod
|
201 |
+
def build_model(cls, cfg: Data2VecAudioConfig, task=None):
|
202 |
+
"""Build a new model instance."""
|
203 |
+
|
204 |
+
return cls(cfg)
|
205 |
+
|
206 |
+
def apply_mask(
|
207 |
+
self,
|
208 |
+
x,
|
209 |
+
padding_mask,
|
210 |
+
mask_indices=None,
|
211 |
+
mask_channel_indices=None,
|
212 |
+
):
|
213 |
+
B, T, C = x.shape
|
214 |
+
|
215 |
+
if self.mask_channel_prob > 0 and self.mask_channel_before:
|
216 |
+
mask_channel_indices = compute_mask_indices(
|
217 |
+
(B, C),
|
218 |
+
None,
|
219 |
+
self.mask_channel_prob,
|
220 |
+
self.mask_channel_length,
|
221 |
+
self.mask_channel_selection,
|
222 |
+
self.mask_channel_other,
|
223 |
+
no_overlap=self.no_mask_channel_overlap,
|
224 |
+
min_space=self.mask_channel_min_space,
|
225 |
+
)
|
226 |
+
mask_channel_indices = (
|
227 |
+
torch.from_numpy(mask_channel_indices)
|
228 |
+
.to(x.device)
|
229 |
+
.unsqueeze(1)
|
230 |
+
.expand(-1, T, -1)
|
231 |
+
)
|
232 |
+
x[mask_channel_indices] = 0
|
233 |
+
|
234 |
+
if self.mask_prob > 0:
|
235 |
+
if mask_indices is None:
|
236 |
+
mask_indices = compute_mask_indices(
|
237 |
+
(B, T),
|
238 |
+
padding_mask,
|
239 |
+
self.mask_prob,
|
240 |
+
self.mask_length,
|
241 |
+
self.mask_selection,
|
242 |
+
self.mask_other,
|
243 |
+
min_masks=1,
|
244 |
+
no_overlap=self.no_mask_overlap,
|
245 |
+
min_space=self.mask_min_space,
|
246 |
+
require_same_masks=self.cfg.require_same_masks,
|
247 |
+
mask_dropout=self.cfg.mask_dropout,
|
248 |
+
)
|
249 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
250 |
+
x = index_put(x, mask_indices, self.mask_emb)
|
251 |
+
else:
|
252 |
+
mask_indices = None
|
253 |
+
|
254 |
+
if self.mask_channel_prob > 0 and not self.mask_channel_before:
|
255 |
+
if mask_channel_indices is None:
|
256 |
+
mask_channel_indices = compute_mask_indices(
|
257 |
+
(B, C),
|
258 |
+
None,
|
259 |
+
self.mask_channel_prob,
|
260 |
+
self.mask_channel_length,
|
261 |
+
self.mask_channel_selection,
|
262 |
+
self.mask_channel_other,
|
263 |
+
no_overlap=self.no_mask_channel_overlap,
|
264 |
+
min_space=self.mask_channel_min_space,
|
265 |
+
)
|
266 |
+
mask_channel_indices = (
|
267 |
+
torch.from_numpy(mask_channel_indices)
|
268 |
+
.to(x.device)
|
269 |
+
.unsqueeze(1)
|
270 |
+
.expand(-1, T, -1)
|
271 |
+
)
|
272 |
+
x = index_put(x, mask_channel_indices, 0)
|
273 |
+
|
274 |
+
return x, mask_indices
|
275 |
+
|
276 |
+
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
277 |
+
"""
|
278 |
+
Computes the output length of the convolutional layers
|
279 |
+
"""
|
280 |
+
|
281 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
282 |
+
return torch.floor((input_length - kernel_size) / stride + 1)
|
283 |
+
|
284 |
+
conv_cfg_list = eval(self.cfg.conv_feature_layers)
|
285 |
+
|
286 |
+
for i in range(len(conv_cfg_list)):
|
287 |
+
input_lengths = _conv_out_length(
|
288 |
+
input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
|
289 |
+
)
|
290 |
+
|
291 |
+
return input_lengths.to(torch.long)
|
292 |
+
|
293 |
+
def forward(
|
294 |
+
self,
|
295 |
+
source,
|
296 |
+
padding_mask=None,
|
297 |
+
mask=True,
|
298 |
+
features_only=False,
|
299 |
+
layer=None,
|
300 |
+
mask_indices=None,
|
301 |
+
mask_channel_indices=None,
|
302 |
+
padding_count=None,
|
303 |
+
):
|
304 |
+
features = source
|
305 |
+
|
306 |
+
if self.feature_grad_mult > 0:
|
307 |
+
features = self.feature_extractor(features)
|
308 |
+
if self.feature_grad_mult != 1.0:
|
309 |
+
features = GradMultiply.apply(features, self.feature_grad_mult)
|
310 |
+
else:
|
311 |
+
with torch.no_grad():
|
312 |
+
features = self.feature_extractor(features)
|
313 |
+
|
314 |
+
features = features.transpose(1, 2)
|
315 |
+
|
316 |
+
features = self.layer_norm(features)
|
317 |
+
|
318 |
+
orig_padding_mask = padding_mask
|
319 |
+
|
320 |
+
if padding_mask is not None and padding_mask.any():
|
321 |
+
input_lengths = (1 - padding_mask.long()).sum(-1)
|
322 |
+
# apply conv formula to get real output_lengths
|
323 |
+
output_lengths = self._get_feat_extract_output_lengths(input_lengths)
|
324 |
+
|
325 |
+
padding_mask = torch.zeros(
|
326 |
+
features.shape[:2], dtype=features.dtype, device=features.device
|
327 |
+
)
|
328 |
+
|
329 |
+
# these two operations makes sure that all values
|
330 |
+
# before the output lengths indices are attended to
|
331 |
+
padding_mask[
|
332 |
+
(
|
333 |
+
torch.arange(padding_mask.shape[0], device=padding_mask.device),
|
334 |
+
output_lengths - 1,
|
335 |
+
)
|
336 |
+
] = 1
|
337 |
+
padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
|
338 |
+
else:
|
339 |
+
padding_mask = None
|
340 |
+
|
341 |
+
if self.post_extract_proj is not None:
|
342 |
+
features = self.post_extract_proj(features)
|
343 |
+
|
344 |
+
pre_encoder_features = None
|
345 |
+
if self.cfg.ema_transformer_only:
|
346 |
+
pre_encoder_features = features.clone()
|
347 |
+
|
348 |
+
features = self.dropout_input(features)
|
349 |
+
|
350 |
+
if mask:
|
351 |
+
x, mask_indices = self.apply_mask(
|
352 |
+
features,
|
353 |
+
padding_mask,
|
354 |
+
mask_indices=mask_indices,
|
355 |
+
mask_channel_indices=mask_channel_indices,
|
356 |
+
)
|
357 |
+
else:
|
358 |
+
x = features
|
359 |
+
mask_indices = None
|
360 |
+
|
361 |
+
x, layer_results = self.encoder(
|
362 |
+
x,
|
363 |
+
padding_mask=padding_mask,
|
364 |
+
layer=layer,
|
365 |
+
)
|
366 |
+
|
367 |
+
if features_only:
|
368 |
+
return {
|
369 |
+
"x": x,
|
370 |
+
"padding_mask": padding_mask,
|
371 |
+
"layer_results": layer_results,
|
372 |
+
}
|
373 |
+
|
374 |
+
result = {
|
375 |
+
"losses": {},
|
376 |
+
}
|
377 |
+
|
378 |
+
with torch.no_grad():
|
379 |
+
self.ema.model.eval()
|
380 |
+
|
381 |
+
if self.cfg.ema_transformer_only:
|
382 |
+
y, layer_results = self.ema.model.extract_features(
|
383 |
+
pre_encoder_features,
|
384 |
+
padding_mask=padding_mask,
|
385 |
+
min_layer=self.cfg.encoder_layers - self.average_top_k_layers,
|
386 |
+
)
|
387 |
+
y = {
|
388 |
+
"x": y,
|
389 |
+
"padding_mask": padding_mask,
|
390 |
+
"layer_results": layer_results,
|
391 |
+
}
|
392 |
+
else:
|
393 |
+
y = self.ema.model.extract_features(
|
394 |
+
source=source,
|
395 |
+
padding_mask=orig_padding_mask,
|
396 |
+
mask=False,
|
397 |
+
)
|
398 |
+
|
399 |
+
target_layer_results = [l[2] for l in y["layer_results"]]
|
400 |
+
|
401 |
+
permuted = False
|
402 |
+
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
|
403 |
+
target_layer_results = [
|
404 |
+
tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
|
405 |
+
]
|
406 |
+
permuted = True
|
407 |
+
|
408 |
+
if self.cfg.batch_norm_target_layer:
|
409 |
+
target_layer_results = [
|
410 |
+
F.batch_norm(
|
411 |
+
tl.float(), running_mean=None, running_var=None, training=True
|
412 |
+
)
|
413 |
+
for tl in target_layer_results
|
414 |
+
]
|
415 |
+
|
416 |
+
if self.cfg.instance_norm_target_layer:
|
417 |
+
target_layer_results = [
|
418 |
+
F.instance_norm(tl.float()) for tl in target_layer_results
|
419 |
+
]
|
420 |
+
|
421 |
+
if permuted:
|
422 |
+
target_layer_results = [
|
423 |
+
tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
|
424 |
+
]
|
425 |
+
|
426 |
+
if self.cfg.group_norm_target_layer:
|
427 |
+
target_layer_results = [
|
428 |
+
F.layer_norm(tl.float(), tl.shape[-2:])
|
429 |
+
for tl in target_layer_results
|
430 |
+
]
|
431 |
+
|
432 |
+
if self.cfg.layer_norm_target_layer:
|
433 |
+
target_layer_results = [
|
434 |
+
F.layer_norm(tl.float(), tl.shape[-1:])
|
435 |
+
for tl in target_layer_results
|
436 |
+
]
|
437 |
+
|
438 |
+
y = sum(target_layer_results) / len(target_layer_results)
|
439 |
+
|
440 |
+
if self.cfg.layer_norm_targets:
|
441 |
+
y = F.layer_norm(y.float(), y.shape[-1:])
|
442 |
+
|
443 |
+
if self.cfg.instance_norm_targets:
|
444 |
+
y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
|
445 |
+
|
446 |
+
if not permuted:
|
447 |
+
y = y.transpose(0, 1)
|
448 |
+
|
449 |
+
y = y[mask_indices]
|
450 |
+
|
451 |
+
x = x[mask_indices]
|
452 |
+
x = self.final_proj(x)
|
453 |
+
|
454 |
+
sz = x.size(-1)
|
455 |
+
|
456 |
+
if self.loss_beta == 0:
|
457 |
+
loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
|
458 |
+
else:
|
459 |
+
loss = F.smooth_l1_loss(
|
460 |
+
x.float(), y.float(), reduction="none", beta=self.loss_beta
|
461 |
+
).sum(dim=-1)
|
462 |
+
|
463 |
+
if self.loss_scale is not None:
|
464 |
+
scale = self.loss_scale
|
465 |
+
else:
|
466 |
+
scale = 1 / math.sqrt(sz)
|
467 |
+
|
468 |
+
result["losses"]["regression"] = loss.sum() * scale
|
469 |
+
|
470 |
+
if "sample_size" not in result:
|
471 |
+
result["sample_size"] = loss.numel()
|
472 |
+
|
473 |
+
with torch.no_grad():
|
474 |
+
result["target_var"] = self.compute_var(y)
|
475 |
+
result["pred_var"] = self.compute_var(x.float())
|
476 |
+
|
477 |
+
if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var:
|
478 |
+
logger.error(
|
479 |
+
f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
|
480 |
+
)
|
481 |
+
raise Exception(
|
482 |
+
f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
|
483 |
+
)
|
484 |
+
if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var:
|
485 |
+
logger.error(
|
486 |
+
f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
|
487 |
+
)
|
488 |
+
raise Exception(
|
489 |
+
f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
|
490 |
+
)
|
491 |
+
|
492 |
+
if self.ema is not None:
|
493 |
+
result["ema_decay"] = self.ema.get_decay() * 1000
|
494 |
+
|
495 |
+
return result
|
496 |
+
|
497 |
+
@staticmethod
|
498 |
+
def compute_var(y):
|
499 |
+
y = y.view(-1, y.size(-1))
|
500 |
+
if dist.is_initialized():
|
501 |
+
zc = torch.tensor(y.size(0)).cuda()
|
502 |
+
zs = y.sum(dim=0)
|
503 |
+
zss = (y ** 2).sum(dim=0)
|
504 |
+
|
505 |
+
dist.all_reduce(zc)
|
506 |
+
dist.all_reduce(zs)
|
507 |
+
dist.all_reduce(zss)
|
508 |
+
|
509 |
+
var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
|
510 |
+
return torch.sqrt(var + 1e-6).mean()
|
511 |
+
else:
|
512 |
+
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
|
513 |
+
|
514 |
+
def extract_features(
|
515 |
+
self, source, padding_mask, mask=False, layer=None
|
516 |
+
):
|
517 |
+
res = self.forward(
|
518 |
+
source,
|
519 |
+
padding_mask,
|
520 |
+
mask=mask,
|
521 |
+
features_only=True,
|
522 |
+
layer=layer,
|
523 |
+
)
|
524 |
+
return res
|
525 |
+
|
526 |
+
def remove_pretraining_modules(self, last_layer=None):
|
527 |
+
self.final_proj = None
|
528 |
+
self.ema = None
|
529 |
+
if last_layer is not None:
|
530 |
+
self.encoder.layers = nn.ModuleList(
|
531 |
+
l for i, l in enumerate(self.encoder.layers) if i <= last_layer
|
532 |
+
)
|
533 |
+
|
534 |
+
import logging
|
535 |
+
|
536 |
+
import torch
|
537 |
+
import torch.nn.functional as F
|
538 |
+
from fairseq import tasks
|
539 |
+
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
|
540 |
+
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
541 |
+
from omegaconf import OmegaConf
|
542 |
+
|
543 |
+
logger = logging.getLogger("dump_feature")
|
544 |
+
|
545 |
+
|
546 |
+
class Data2vecFeatureReader(object):
|
547 |
+
def __init__(self, ckpt_path: str, layer: int, device: str, max_chunk=1600000):
|
548 |
+
state = load_checkpoint_to_cpu(ckpt_path)
|
549 |
+
cfg = state["cfg"]
|
550 |
+
# load task
|
551 |
+
task = tasks.setup_task(cfg.task, from_checkpoint=True)
|
552 |
+
task.load_state_dict(state["task_state"])
|
553 |
+
# load model config
|
554 |
+
if "layer_type" not in cfg.model:
|
555 |
+
# fix a missing key
|
556 |
+
model_config = {k: v for k, v in cfg.model.items()}
|
557 |
+
model_config["layer_type"] = "transformer"
|
558 |
+
model_config = OmegaConf.create(model_config)
|
559 |
+
else:
|
560 |
+
model_config = cfg.model
|
561 |
+
|
562 |
+
# fix param name in the state
|
563 |
+
state["model"]["final_proj.weight"] = state["model"].pop("final_proj.0.weight")
|
564 |
+
state["model"]["final_proj.bias"] = state["model"].pop("final_proj.0.bias")
|
565 |
+
del state["model"]["_ema"]
|
566 |
+
|
567 |
+
# load model
|
568 |
+
model = Data2VecAudioModel.build_model(model_config)
|
569 |
+
model.load_state_dict(
|
570 |
+
state["model"], strict=True, model_cfg=model_config
|
571 |
+
)
|
572 |
+
|
573 |
+
self.device = device
|
574 |
+
logger.info(f"device = {self.device}")
|
575 |
+
|
576 |
+
self.model = model.eval().to(self.device)
|
577 |
+
self.task = task
|
578 |
+
self.layer = layer - 1 # make it 1-based
|
579 |
+
self.max_chunk = max_chunk
|
580 |
+
logger.info(f"TASK CONFIG:\n{self.task.cfg}")
|
581 |
+
logger.info(f" max_chunk = {self.max_chunk}")
|
582 |
+
|
583 |
+
def read_audio(self, path, ref_len=None):
|
584 |
+
wav = get_features_or_waveform(path, need_waveform=True, use_sample_rate=self.task.cfg.sample_rate)
|
585 |
+
if wav.ndim == 2:
|
586 |
+
wav = wav.mean(-1)
|
587 |
+
assert wav.ndim == 1, wav.ndim
|
588 |
+
if ref_len is not None and abs(ref_len - len(wav)) > 160:
|
589 |
+
logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
|
590 |
+
return wav
|
591 |
+
|
592 |
+
def get_feats(self, path, ref_len=None):
|
593 |
+
x = self.read_audio(path, ref_len=ref_len)
|
594 |
+
with torch.no_grad():
|
595 |
+
x = torch.from_numpy(x).float().to(self.device)
|
596 |
+
if self.task.cfg.normalize:
|
597 |
+
x = F.layer_norm(x, x.shape)
|
598 |
+
x = x.view(1, -1)
|
599 |
+
|
600 |
+
feat = []
|
601 |
+
for start in range(0, x.size(1), self.max_chunk):
|
602 |
+
x_chunk = x[:, start: start + self.max_chunk]
|
603 |
+
res = self.model.extract_features(
|
604 |
+
source=x_chunk,
|
605 |
+
padding_mask=None,
|
606 |
+
mask=False,
|
607 |
+
layer=self.layer,
|
608 |
+
)
|
609 |
+
feat_chunk = res["x"]
|
610 |
+
feat.append(feat_chunk)
|
611 |
+
return torch.cat(feat, 1).squeeze(0)
|
ASR/.ipynb_checkpoints/demo-checkpoint.ipynb
ADDED
@@ -0,0 +1,849 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "715a402a-44b9-4fa2-abf0-b0cfd2f3d80b",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"## Recording voice in Real Time"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "code",
|
13 |
+
"execution_count": null,
|
14 |
+
"id": "dbdf6bab-7418-4a6f-8b75-c31f98a6ada5",
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"\"\"\"\n",
|
19 |
+
"Sprints:\n",
|
20 |
+
"- [ ] Do Inference optimization of ASR LM\n",
|
21 |
+
"- [ ] Train on train.other.500\n",
|
22 |
+
"- [ ] Generate dataset for prompting\n",
|
23 |
+
"\n",
|
24 |
+
"Evaluation Dates: 20th - 21st June, 2023, 3:30 - 5:30pm\n",
|
25 |
+
"Sharpen PPT Skills: 20th June, 3:30pm - 4:45pm\n",
|
26 |
+
"Flow of the PPT:\n",
|
27 |
+
"Demo -> Datasets -> Techniques -> Evaluation -> Q&A\n",
|
28 |
+
"- [ Done ] Update the one pager deck slide\n",
|
29 |
+
"https://sprinklr-my.sharepoint.com/:p:/r/personal/sricharan_narayanam_sprinklr_com/_layouts/15/Doc.aspx?sourcedoc=%7B84811f56-5fc7-4eaa-87d2-db4a3588d18c%7D&action=edit&wdPreviousSession=948ccc35-dc05-f1f9-612d-9a22300e25ba\n",
|
30 |
+
"My PPT:\n",
|
31 |
+
"https://sprinklr-my.sharepoint.com/:p:/p/darshan_makwana/Ec4jCiyMWhxMproH625msc8BClFVceNQ8o4kS3EhZBO9MA?e=YCSDxm&wdOrigin=TEAMS-MAGLEV.p2p_ns.rwc&wdExp=TEAMS-TREATMENT&wdhostclicktime=1718703689001&web=1\n",
|
32 |
+
"Intern Tracker:\n",
|
33 |
+
"https://sprinklr.sharepoint.com/:x:/s/AIIntuition/EbRhHPIAIw9MlZ5PpXbztmABde1LFbaSoSHJAo9qU8ggDg?e=xiLkRt&wdOrigin=TEAMS-MAGLEV.p2p_ns.rwc&wdExp=TEAMS-TREATMENT&wdhostclicktime=1718692666812&web=1\n",
|
34 |
+
"\"\"\""
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "markdown",
|
39 |
+
"id": "150aca01-4098-4ab2-809a-25775ec52069",
|
40 |
+
"metadata": {},
|
41 |
+
"source": [
|
42 |
+
"## ASR LM Inference"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": null,
|
48 |
+
"id": "804a58af-beb2-48c1-9530-98024e27c0d6",
|
49 |
+
"metadata": {},
|
50 |
+
"outputs": [],
|
51 |
+
"source": [
|
52 |
+
"from audio_tokenizer import Data2vecFeatureReader\n",
|
53 |
+
"from repcodec.RepCodec import RepCodec\n",
|
54 |
+
"import torch.nn.functional as F\n",
|
55 |
+
"import torch\n",
|
56 |
+
"import yaml\n",
|
57 |
+
"\n",
|
58 |
+
"reader = Data2vecFeatureReader(\"./../prompting/models/vox_pretrained.pt\", 18, device=\"cuda:0\", max_chunk=1600000)\n",
|
59 |
+
"\n",
|
60 |
+
"config = \"./repcodec/configs/repcodec_dim1024.yaml\"\n",
|
61 |
+
"with open(config) as fp:\n",
|
62 |
+
" conf = yaml.load(fp, Loader=yaml.FullLoader)\n",
|
63 |
+
"\n",
|
64 |
+
"audio_model = RepCodec(**conf)\n",
|
65 |
+
"audio_model.load_state_dict(torch.load(\"./../prompting/models/data2vec_large_l18.pkl\", map_location=\"cuda:0\")[\"model\"][\"repcodec\"])\n",
|
66 |
+
"audio_model.quantizer.initial()\n",
|
67 |
+
"audio_model.to(\"cuda:0\")\n",
|
68 |
+
"audio_model.eval()\n",
|
69 |
+
"\n",
|
70 |
+
"print(\"Successfully Loaded Audio Tokenizer\")"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"cell_type": "code",
|
75 |
+
"execution_count": null,
|
76 |
+
"id": "7d8da397-2030-4b36-9a42-97862488797b",
|
77 |
+
"metadata": {},
|
78 |
+
"outputs": [],
|
79 |
+
"source": [
|
80 |
+
"from datasets import load_dataset\n",
|
81 |
+
"\n",
|
82 |
+
"cache_dir = \"./../cache\"\n",
|
83 |
+
"dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir, trust_remote_code=True)"
|
84 |
+
]
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"cell_type": "code",
|
88 |
+
"execution_count": 2,
|
89 |
+
"id": "bb8016b2-fc9d-4c23-9e85-b6e1c5ca164c",
|
90 |
+
"metadata": {},
|
91 |
+
"outputs": [
|
92 |
+
{
|
93 |
+
"ename": "ImportError",
|
94 |
+
"evalue": "FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.",
|
95 |
+
"output_type": "error",
|
96 |
+
"traceback": [
|
97 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
98 |
+
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
|
99 |
+
"Cell \u001b[0;32mIn[2], line 33\u001b[0m\n\u001b[1;32m 30\u001b[0m eot_token \u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mencode(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m<|endoftranscript|>\u001b[39m\u001b[38;5;124m\"\u001b[39m)[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 31\u001b[0m pad_token \u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mencode(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m<|padding|>\u001b[39m\u001b[38;5;124m\"\u001b[39m)[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m---> 33\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mGPT2LMHeadModel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m./../out/checkpoint-10000\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattn_implementation\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mflash_attention_2\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 34\u001b[0m model\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mpad_token_id \u001b[38;5;241m=\u001b[39m pad_token\n\u001b[1;32m 35\u001b[0m model\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39meos_token_id \u001b[38;5;241m=\u001b[39m eot_token\n",
|
100 |
+
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py:3620\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 3617\u001b[0m init_contexts\u001b[38;5;241m.\u001b[39mappend(init_empty_weights())\n\u001b[1;32m 3619\u001b[0m config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(config) \u001b[38;5;66;03m# We do not want to modify the config inplace in from_pretrained.\u001b[39;00m\n\u001b[0;32m-> 3620\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_autoset_attn_implementation\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3621\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muse_flash_attention_2\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_flash_attention_2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch_dtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_map\u001b[49m\n\u001b[1;32m 3622\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3624\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ContextManagers(init_contexts):\n\u001b[1;32m 3625\u001b[0m \u001b[38;5;66;03m# Let's make sure we don't run the init function of buffer modules\u001b[39;00m\n\u001b[1;32m 3626\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m(config, \u001b[38;5;241m*\u001b[39mmodel_args, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs)\n",
|
101 |
+
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py:1469\u001b[0m, in \u001b[0;36mPreTrainedModel._autoset_attn_implementation\u001b[0;34m(cls, config, use_flash_attention_2, torch_dtype, device_map, check_device_map)\u001b[0m\n\u001b[1;32m 1466\u001b[0m config\u001b[38;5;241m.\u001b[39m_attn_implementation \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attention_2\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1468\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config\u001b[38;5;241m.\u001b[39m_attn_implementation \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attention_2\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m-> 1469\u001b[0m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_check_and_enable_flash_attn_2\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1470\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1471\u001b[0m \u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch_dtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1472\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_map\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1473\u001b[0m \u001b[43m \u001b[49m\u001b[43mhard_check_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1474\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_device_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcheck_device_map\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1475\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1476\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m requested_attn_implementation \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msdpa\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available():\n\u001b[1;32m 1477\u001b[0m \u001b[38;5;66;03m# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.\u001b[39;00m\n\u001b[1;32m 1478\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_check_and_enable_sdpa(\n\u001b[1;32m 1479\u001b[0m config,\n\u001b[1;32m 1480\u001b[0m hard_check_only\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m \u001b[38;5;28;01mif\u001b[39;00m requested_attn_implementation \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 1481\u001b[0m )\n",
|
102 |
+
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py:1571\u001b[0m, in \u001b[0;36mPreTrainedModel._check_and_enable_flash_attn_2\u001b[0;34m(cls, config, torch_dtype, device_map, check_device_map, hard_check_only)\u001b[0m\n\u001b[1;32m 1568\u001b[0m install_message \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1570\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m importlib\u001b[38;5;241m.\u001b[39mutil\u001b[38;5;241m.\u001b[39mfind_spec(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attn\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1571\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpreface\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m the package flash_attn seems to be not installed. \u001b[39m\u001b[38;5;132;01m{\u001b[39;00minstall_message\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1573\u001b[0m flash_attention_version \u001b[38;5;241m=\u001b[39m version\u001b[38;5;241m.\u001b[39mparse(importlib\u001b[38;5;241m.\u001b[39mmetadata\u001b[38;5;241m.\u001b[39mversion(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attn\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 1574\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mversion\u001b[38;5;241m.\u001b[39mcuda:\n",
|
103 |
+
"\u001b[0;31mImportError\u001b[0m: FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
|
104 |
+
]
|
105 |
+
}
|
106 |
+
],
|
107 |
+
"source": [
|
108 |
+
"from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer\n",
|
109 |
+
"import torch\n",
|
110 |
+
"import string\n",
|
111 |
+
"\n",
|
112 |
+
"def process(text):\n",
|
113 |
+
"\n",
|
114 |
+
" # Lower case every letter\n",
|
115 |
+
" text = text.lower()\n",
|
116 |
+
"\n",
|
117 |
+
" # Remove punctuation\n",
|
118 |
+
" punctuation_to_remove = string.punctuation.replace(\"'\", \"\")\n",
|
119 |
+
" translation_table = str.maketrans('', '', punctuation_to_remove)\n",
|
120 |
+
" text = text.translate(translation_table)\n",
|
121 |
+
"\n",
|
122 |
+
" # Remove whitespaces from front and behind\n",
|
123 |
+
" while text[0] == ' ' or text[-1] == ' ':\n",
|
124 |
+
" if text[0] == ' ':\n",
|
125 |
+
" text = text[1:]\n",
|
126 |
+
" if text[-1] == ' ':\n",
|
127 |
+
" text = text[:-1]\n",
|
128 |
+
" \n",
|
129 |
+
" return text\n",
|
130 |
+
"\n",
|
131 |
+
"device = \"cuda:0\"\n",
|
132 |
+
"dtype = torch.float16\n",
|
133 |
+
"context_length = 1877\n",
|
134 |
+
"\n",
|
135 |
+
"# Load tokenizer and add audio tokens\n",
|
136 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"./tokenizer\")\n",
|
137 |
+
"eot_token = tokenizer.encode(\"<|endoftranscript|>\")[0]\n",
|
138 |
+
"pad_token = tokenizer.encode(\"<|padding|>\")[0]\n",
|
139 |
+
"\n",
|
140 |
+
"model = GPT2LMHeadModel.from_pretrained(\"./../out/checkpoint-10000\", attn_implementation=\"flash_attention_2\", device_map=device, torch_dtype=dtype).eval()\n",
|
141 |
+
"model.config.pad_token_id = pad_token\n",
|
142 |
+
"model.config.eos_token_id = eot_token\n",
|
143 |
+
"# model = torch.compile(model)"
|
144 |
+
]
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"cell_type": "code",
|
148 |
+
"execution_count": null,
|
149 |
+
"id": "7cabe9dc-bbbf-41b4-918f-3f60ee5582f2",
|
150 |
+
"metadata": {},
|
151 |
+
"outputs": [],
|
152 |
+
"source": [
|
153 |
+
"from tqdm import tqdm\n",
|
154 |
+
"from math import ceil\n",
|
155 |
+
"import torch\n",
|
156 |
+
"import time\n",
|
157 |
+
"\n",
|
158 |
+
"sample = dataset[\"train.clean.100\"][5]\n",
|
159 |
+
"\n",
|
160 |
+
"x = sample[\"audio\"][\"array\"]\n",
|
161 |
+
"\n",
|
162 |
+
"start_time = time.time()\n",
|
163 |
+
"\n",
|
164 |
+
"with torch.no_grad():\n",
|
165 |
+
" x = torch.from_numpy(x).float().to(reader.device)\n",
|
166 |
+
" if reader.task.cfg.normalize:\n",
|
167 |
+
" x = F.layer_norm(x, x.shape)\n",
|
168 |
+
" x = x.view(1, -1)\n",
|
169 |
+
"\n",
|
170 |
+
" feat = []\n",
|
171 |
+
" for start in range(0, x.size(1), reader.max_chunk):\n",
|
172 |
+
" x_chunk = x[:, start: start + reader.max_chunk]\n",
|
173 |
+
" res = reader.model.extract_features(\n",
|
174 |
+
" source=x_chunk,\n",
|
175 |
+
" padding_mask=None,\n",
|
176 |
+
" mask=False,\n",
|
177 |
+
" layer=reader.layer,\n",
|
178 |
+
" )\n",
|
179 |
+
" feat_chunk = res[\"x\"]\n",
|
180 |
+
" feat.append(feat_chunk)\n",
|
181 |
+
" \n",
|
182 |
+
" features = torch.cat(feat, 1).permute(0, 2, 1)\n",
|
183 |
+
"\n",
|
184 |
+
" x = audio_model.encoder(features)\n",
|
185 |
+
" z = audio_model.projector(x)\n",
|
186 |
+
" _, idx = audio_model.quantizer.codebook.forward_index(z.transpose(2, 1))\n",
|
187 |
+
" tokens = idx.cpu().data.numpy().tolist()[0]\n",
|
188 |
+
" \n",
|
189 |
+
"text = \"\".join([f\"<|audio:{token}|>\" for token in tokens]) + \"<|startoftranscript|>\"\n",
|
190 |
+
"input_ids = tokenizer(text, return_tensors=\"pt\").to(device)[\"input_ids\"]\n",
|
191 |
+
"\n",
|
192 |
+
"input_time = time.time()\n",
|
193 |
+
"\n",
|
194 |
+
"generations = model.generate(\n",
|
195 |
+
" input_ids,\n",
|
196 |
+
" pad_token_id = pad_token,\n",
|
197 |
+
" eos_token_id = eot_token,\n",
|
198 |
+
" max_new_tokens = context_length,\n",
|
199 |
+
" use_cache=True\n",
|
200 |
+
")\n",
|
201 |
+
"\n",
|
202 |
+
"finish_time = time.time()\n",
|
203 |
+
"\n",
|
204 |
+
"tokenizer.batch_decode(generations, skip_special_tokens=True)\n",
|
205 |
+
"print(\"First Token Latency: \", (input_time - start_time) * 1000, \"ms\")\n",
|
206 |
+
"# print(\"Throughput: \", (1 + num_tokens)/total_time, \"tokens/s\")\n",
|
207 |
+
"print(\"End to End Inference Time: \", (finish_time - start_time) * 1000, \"ms\")\n",
|
208 |
+
"print(\"Refer Text: \", process(sample[\"text\"]))\n",
|
209 |
+
"print(\"Transcript: \", tokenizer.batch_decode(generations, skip_special_tokens=True)[0])"
|
210 |
+
]
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"cell_type": "code",
|
214 |
+
"execution_count": null,
|
215 |
+
"id": "baa8d79b-7cf5-4435-838c-1f3d4e043d60",
|
216 |
+
"metadata": {},
|
217 |
+
"outputs": [],
|
218 |
+
"source": [
|
219 |
+
"import time\n",
|
220 |
+
"\n",
|
221 |
+
"sample = dataset[\"train.clean.100\"][0]\n",
|
222 |
+
"\n",
|
223 |
+
"x = sample[\"audio\"][\"array\"]\n",
|
224 |
+
"\n",
|
225 |
+
"start_time = time.time()\n",
|
226 |
+
"\n",
|
227 |
+
"with torch.no_grad():\n",
|
228 |
+
" x = torch.from_numpy(x).float().to(reader.device)\n",
|
229 |
+
" if reader.task.cfg.normalize:\n",
|
230 |
+
" x = F.layer_norm(x, x.shape)\n",
|
231 |
+
" x = x.view(1, -1)\n",
|
232 |
+
"\n",
|
233 |
+
" feat = []\n",
|
234 |
+
" for start in range(0, x.size(1), reader.max_chunk):\n",
|
235 |
+
" x_chunk = x[:, start: start + reader.max_chunk]\n",
|
236 |
+
" res = reader.model.extract_features(\n",
|
237 |
+
" source=x_chunk,\n",
|
238 |
+
" padding_mask=None,\n",
|
239 |
+
" mask=False,\n",
|
240 |
+
" layer=reader.layer,\n",
|
241 |
+
" )\n",
|
242 |
+
" feat_chunk = res[\"x\"]\n",
|
243 |
+
" feat.append(feat_chunk)\n",
|
244 |
+
" \n",
|
245 |
+
" features = torch.cat(feat, 1).permute(0, 2, 1)\n",
|
246 |
+
"\n",
|
247 |
+
" x = audio_model.encoder(features)\n",
|
248 |
+
" z = audio_model.projector(x)\n",
|
249 |
+
" _, idx = audio_model.quantizer.codebook.forward_index(z.transpose(2, 1))\n",
|
250 |
+
" tokens = idx.cpu().data.numpy().tolist()[0]\n",
|
251 |
+
"\n",
|
252 |
+
"from tqdm import tqdm\n",
|
253 |
+
"from math import ceil\n",
|
254 |
+
"import torch\n",
|
255 |
+
"\n",
|
256 |
+
"context_length = 1877\n",
|
257 |
+
"eot_token = tokenizer.encode(\"<|endoftranscript|>\")[0]\n",
|
258 |
+
"pad_token = tokenizer.encode(\"<|padding|>\")[0]\n",
|
259 |
+
" \n",
|
260 |
+
"text = \"\".join([f\"<|audio:{token}|>\" for token in tokens]) + \"<|startoftranscript|>\"\n",
|
261 |
+
"input_ids = tokenizer(text, return_tensors=\"pt\").to(device)[\"input_ids\"]\n",
|
262 |
+
"\n",
|
263 |
+
"max_new_tokens = context_length\n",
|
264 |
+
"num_tokens = 0\n",
|
265 |
+
"first_token = True\n",
|
266 |
+
"\n",
|
267 |
+
"while max_new_tokens > 0 and input_ids.shape[-1] < context_length:\n",
|
268 |
+
"\n",
|
269 |
+
" with torch.no_grad():\n",
|
270 |
+
" outputs = model(input_ids = input_ids)\n",
|
271 |
+
"\n",
|
272 |
+
" logits = outputs[\"logits\"][:, -1]\n",
|
273 |
+
"\n",
|
274 |
+
" # Greedy Sampling\n",
|
275 |
+
" probas = torch.softmax(logits, dim=-1)\n",
|
276 |
+
" pred_idx = torch.argmax(probas, dim=-1, keepdim=True)\n",
|
277 |
+
" next_idx = pred_idx.item()\n",
|
278 |
+
"\n",
|
279 |
+
" if first_token:\n",
|
280 |
+
" first_token_latency = time.time() - start_time\n",
|
281 |
+
" first_token = False\n",
|
282 |
+
" start_time = time.time()\n",
|
283 |
+
"\n",
|
284 |
+
" if next_idx == eot_token:\n",
|
285 |
+
" break\n",
|
286 |
+
"\n",
|
287 |
+
" input_ids = torch.cat((input_ids, pred_idx), dim=-1)\n",
|
288 |
+
"\n",
|
289 |
+
" max_new_tokens -= 1\n",
|
290 |
+
" num_tokens += 1\n",
|
291 |
+
"\n",
|
292 |
+
"total_time = time.time() - start_time\n",
|
293 |
+
"\n",
|
294 |
+
"print(\"First Token Latency: \", first_token_latency * 1000, \"ms\")\n",
|
295 |
+
"print(\"Throughput: \", (1 + num_tokens)/total_time, \"tokens/s\")\n",
|
296 |
+
"print(\"End to End Inference Time: \", (total_time + first_token_latency) * 1000, \"ms\")\n",
|
297 |
+
"print(tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0])\n",
|
298 |
+
"print(process(sample[\"text\"]))"
|
299 |
+
]
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"cell_type": "code",
|
303 |
+
"execution_count": null,
|
304 |
+
"id": "014ed999-3293-4d68-8f9c-017584adc642",
|
305 |
+
"metadata": {},
|
306 |
+
"outputs": [],
|
307 |
+
"source": [
|
308 |
+
"tokenizer.batch_decode([[1, 2, 3]])"
|
309 |
+
]
|
310 |
+
},
|
311 |
+
{
|
312 |
+
"cell_type": "markdown",
|
313 |
+
"id": "ec11e43f-1eb8-4399-9a93-6f1427782661",
|
314 |
+
"metadata": {
|
315 |
+
"jp-MarkdownHeadingCollapsed": true
|
316 |
+
},
|
317 |
+
"source": [
|
318 |
+
"## Accelerating GPT 2 Inference"
|
319 |
+
]
|
320 |
+
},
|
321 |
+
{
|
322 |
+
"cell_type": "code",
|
323 |
+
"execution_count": null,
|
324 |
+
"id": "5489cb4e-3213-4931-abe1-4c96d1a7ba56",
|
325 |
+
"metadata": {},
|
326 |
+
"outputs": [],
|
327 |
+
"source": [
|
328 |
+
"\"\"\"\n",
|
329 |
+
"- change tensorrt.tensorrt to tensorrt\n",
|
330 |
+
"- remove cpu quantization lines\n",
|
331 |
+
"- output_names [\"logits\"]\n",
|
332 |
+
"\"\"\""
|
333 |
+
]
|
334 |
+
},
|
335 |
+
{
|
336 |
+
"cell_type": "code",
|
337 |
+
"execution_count": null,
|
338 |
+
"id": "7e7e6ea6-7319-4e57-af33-5d917d26abc6",
|
339 |
+
"metadata": {},
|
340 |
+
"outputs": [],
|
341 |
+
"source": [
|
342 |
+
"import logging\n",
|
343 |
+
"import time\n",
|
344 |
+
"from typing import Callable, Dict\n",
|
345 |
+
"\n",
|
346 |
+
"import numpy as np\n",
|
347 |
+
"import tensorrt as trt\n",
|
348 |
+
"import torch\n",
|
349 |
+
"from tensorrt import ICudaEngine\n",
|
350 |
+
"from tensorrt import Logger, Runtime\n",
|
351 |
+
"from transformers import AutoTokenizer, BatchEncoding, GPT2LMHeadModel, AutoModelForCausalLM\n",
|
352 |
+
"from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions\n",
|
353 |
+
"from transformer_deploy.utils.generative_model import GPTModelWrapper\n",
|
354 |
+
"import inspect\n",
|
355 |
+
"from transformers import TensorType\n",
|
356 |
+
"\n",
|
357 |
+
"from transformer_deploy.backends.ort_utils import create_model_for_provider, inference_onnx_binding, optimize_onnx\n",
|
358 |
+
"from transformer_deploy.backends.pytorch_utils import convert_to_onnx, get_model_size\n",
|
359 |
+
"from transformer_deploy.backends.trt_utils import build_engine, load_engine, save_engine"
|
360 |
+
]
|
361 |
+
},
|
362 |
+
{
|
363 |
+
"cell_type": "code",
|
364 |
+
"execution_count": null,
|
365 |
+
"id": "21681412-7747-4824-894a-6006eb12a821",
|
366 |
+
"metadata": {},
|
367 |
+
"outputs": [],
|
368 |
+
"source": [
|
369 |
+
"model_name = \"gpt2\"\n",
|
370 |
+
"\n",
|
371 |
+
"model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(model_name)\n",
|
372 |
+
"model.eval()\n",
|
373 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
374 |
+
"model.config.pad_token_id = tokenizer.eos_token_id"
|
375 |
+
]
|
376 |
+
},
|
377 |
+
{
|
378 |
+
"cell_type": "code",
|
379 |
+
"execution_count": null,
|
380 |
+
"id": "46783acd-c404-44b4-904b-d8fb687afc34",
|
381 |
+
"metadata": {},
|
382 |
+
"outputs": [],
|
383 |
+
"source": [
|
384 |
+
"inputs = tokenizer(\"Here is some text to encode Hello World\", return_tensors=\"pt\")\n",
|
385 |
+
"print(\"input tensors\")\n",
|
386 |
+
"print(inputs)\n",
|
387 |
+
"print(\"input tensor shape\")\n",
|
388 |
+
"print(inputs[\"input_ids\"].size())\n",
|
389 |
+
"\n",
|
390 |
+
"with torch.no_grad():\n",
|
391 |
+
" outputs = model(**inputs)\n",
|
392 |
+
"\n",
|
393 |
+
"logits = outputs.logits\n",
|
394 |
+
"print(\"output tensor\")\n",
|
395 |
+
"print(logits)\n",
|
396 |
+
"print(\"output shape\")\n",
|
397 |
+
"print(logits.shape)"
|
398 |
+
]
|
399 |
+
},
|
400 |
+
{
|
401 |
+
"cell_type": "code",
|
402 |
+
"execution_count": null,
|
403 |
+
"id": "2f6cc7bd-5e2d-4d4e-a7e6-73a6b2ecd7af",
|
404 |
+
"metadata": {},
|
405 |
+
"outputs": [],
|
406 |
+
"source": [
|
407 |
+
"size = 0\n",
|
408 |
+
"for i in range(8, 256, 1):\n",
|
409 |
+
" # input sequence (input_ids) made of int-32 (4 bytes)\n",
|
410 |
+
" size += np.prod([1, i]) * 4\n",
|
411 |
+
" # output tensor made of float-32 (4 bytes)\n",
|
412 |
+
" size += np.prod([1, i, 50257]) * 4\n",
|
413 |
+
"print(f\"total size (input+output): {size / 1024**3:.2f} Gb\")\n",
|
414 |
+
"\n",
|
415 |
+
"# to manually check actual tensor size:\n",
|
416 |
+
"# np.prod(logits.shape)*32/8/1024**2:.2f}\n",
|
417 |
+
"# or\n",
|
418 |
+
"# sys.getsizeof(logits.storage())/1024**2"
|
419 |
+
]
|
420 |
+
},
|
421 |
+
{
|
422 |
+
"cell_type": "code",
|
423 |
+
"execution_count": null,
|
424 |
+
"id": "7debb40e-9941-45e4-9db8-4bb021ce44ab",
|
425 |
+
"metadata": {},
|
426 |
+
"outputs": [],
|
427 |
+
"source": [
|
428 |
+
"input_ids: BatchEncoding = tokenizer(\n",
|
429 |
+
" \"Here is some text to encode Hello World\", add_special_tokens=True, return_attention_mask=False, return_tensors=\"pt\"\n",
|
430 |
+
")\n",
|
431 |
+
"# some inference engines don't support int64 tensor as inputs, we convert all input tensors to int32 type\n",
|
432 |
+
"for k, v in input_ids.items(): # type: str, torch.Tensor\n",
|
433 |
+
" input_ids[k] = v.type(dtype=torch.int32)\n",
|
434 |
+
"\n",
|
435 |
+
"convert_to_onnx(\n",
|
436 |
+
" model_pytorch=model,\n",
|
437 |
+
" output_path=\"test-gpt2.onnx\",\n",
|
438 |
+
" inputs_pytorch=dict(input_ids),\n",
|
439 |
+
" quantization=False,\n",
|
440 |
+
" var_output_seq=True, # we inform ONNX export tool that the output shape will vary with the input shape\n",
|
441 |
+
" output_names = [\"logits\"]\n",
|
442 |
+
")\n",
|
443 |
+
"# model may switch to train mode for some unknown reasons, we force the eval mode.\n",
|
444 |
+
"_ = model.eval()"
|
445 |
+
]
|
446 |
+
},
|
447 |
+
{
|
448 |
+
"cell_type": "code",
|
449 |
+
"execution_count": null,
|
450 |
+
"id": "956c3007-2c18-4d92-af4f-6cef474d86b5",
|
451 |
+
"metadata": {},
|
452 |
+
"outputs": [],
|
453 |
+
"source": [
|
454 |
+
"logging.basicConfig()\n",
|
455 |
+
"logging.getLogger().setLevel(logging.INFO)\n",
|
456 |
+
"num_attention_heads, hidden_size = get_model_size(path=model_name)\n",
|
457 |
+
"optimize_onnx(\n",
|
458 |
+
" onnx_path=\"test-gpt2.onnx\",\n",
|
459 |
+
" onnx_optim_model_path=\"test-gpt2-opt.onnx\",\n",
|
460 |
+
" fp16=False,\n",
|
461 |
+
" use_cuda=True,\n",
|
462 |
+
" num_attention_heads=num_attention_heads,\n",
|
463 |
+
" hidden_size=hidden_size,\n",
|
464 |
+
" architecture=\"gpt2\",\n",
|
465 |
+
")"
|
466 |
+
]
|
467 |
+
},
|
468 |
+
{
|
469 |
+
"cell_type": "code",
|
470 |
+
"execution_count": null,
|
471 |
+
"id": "85f30ed9-2802-46c9-9201-a70e200b6860",
|
472 |
+
"metadata": {},
|
473 |
+
"outputs": [],
|
474 |
+
"source": [
|
475 |
+
"from pathlib import Path\n",
|
476 |
+
"\n",
|
477 |
+
"trt_logger: Logger = trt.Logger(trt.Logger.ERROR)\n",
|
478 |
+
"runtime: Runtime = trt.Runtime(trt_logger)\n",
|
479 |
+
"trt_model_name = \"test-gpt2.plan\"\n",
|
480 |
+
"\n",
|
481 |
+
"# create only of does not exist because it's slow to run...\n",
|
482 |
+
"\n",
|
483 |
+
"engine: ICudaEngine = build_engine(\n",
|
484 |
+
" runtime=runtime,\n",
|
485 |
+
" onnx_file_path=\"test-gpt2.onnx\",\n",
|
486 |
+
" logger=trt_logger,\n",
|
487 |
+
" min_shape=(1, 1),\n",
|
488 |
+
" optimal_shape=(1, 128), # num beam, batch size\n",
|
489 |
+
" max_shape=(1, 384), # num beam, batch size\n",
|
490 |
+
" workspace_size=10000 * 1024**2,\n",
|
491 |
+
" fp16=True,\n",
|
492 |
+
" int8=False,\n",
|
493 |
+
")\n",
|
494 |
+
"save_engine(engine, trt_model_name)"
|
495 |
+
]
|
496 |
+
},
|
497 |
+
{
|
498 |
+
"cell_type": "code",
|
499 |
+
"execution_count": null,
|
500 |
+
"id": "908fe664-800e-4c5f-a1d5-adfd31fd1c64",
|
501 |
+
"metadata": {},
|
502 |
+
"outputs": [],
|
503 |
+
"source": [
|
504 |
+
"engine.num_bindings"
|
505 |
+
]
|
506 |
+
},
|
507 |
+
{
|
508 |
+
"cell_type": "code",
|
509 |
+
"execution_count": null,
|
510 |
+
"id": "4626926b-fa94-4633-95d5-0d515f8db5f6",
|
511 |
+
"metadata": {},
|
512 |
+
"outputs": [],
|
513 |
+
"source": [
|
514 |
+
"print(inspect.getsource(GPTModelWrapper))"
|
515 |
+
]
|
516 |
+
},
|
517 |
+
{
|
518 |
+
"cell_type": "code",
|
519 |
+
"execution_count": null,
|
520 |
+
"id": "d5bd1de1-a949-46a3-8d15-457d51db4e40",
|
521 |
+
"metadata": {},
|
522 |
+
"outputs": [],
|
523 |
+
"source": [
|
524 |
+
"inputs = tokenizer(\n",
|
525 |
+
" \"Here is some text to encode Hello World\", # Nvidia example prompt\n",
|
526 |
+
" add_special_tokens=True,\n",
|
527 |
+
" return_attention_mask=False, # Not used\n",
|
528 |
+
" return_tensors=TensorType.PYTORCH,\n",
|
529 |
+
")\n",
|
530 |
+
"inputs"
|
531 |
+
]
|
532 |
+
},
|
533 |
+
{
|
534 |
+
"cell_type": "code",
|
535 |
+
"execution_count": null,
|
536 |
+
"id": "815b548f-fa00-4183-b72c-10ecdd4b11c7",
|
537 |
+
"metadata": {},
|
538 |
+
"outputs": [],
|
539 |
+
"source": [
|
540 |
+
"from transformers.generation import GenerationConfig\n",
|
541 |
+
"\n",
|
542 |
+
"class GPTWrapper(GPTModelWrapper):\n",
|
543 |
+
" def __init__(self, *args, **kwargs):\n",
|
544 |
+
" super().__init__(*args, **kwargs)\n",
|
545 |
+
"\n",
|
546 |
+
" self.generation_config = GenerationConfig.from_model_config(self.config) if self.can_generate() else None\n",
|
547 |
+
"\n",
|
548 |
+
" @classmethod\n",
|
549 |
+
" def can_generate(cls) -> bool:\n",
|
550 |
+
" \"\"\"\n",
|
551 |
+
" Returns whether this model can generate sequences with `.generate()`.\n",
|
552 |
+
"\n",
|
553 |
+
" Returns:\n",
|
554 |
+
" `bool`: Whether this model can generate sequences with `.generate()`.\n",
|
555 |
+
" \"\"\"\n",
|
556 |
+
" # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.\n",
|
557 |
+
" # Alternativelly, the model can also have a custom `generate` function.\n",
|
558 |
+
" if \"GenerationMixin\" in str(cls.prepare_inputs_for_generation) and \"GenerationMixin\" in str(cls.generate):\n",
|
559 |
+
" return False\n",
|
560 |
+
" return True"
|
561 |
+
]
|
562 |
+
},
|
563 |
+
{
|
564 |
+
"cell_type": "code",
|
565 |
+
"execution_count": null,
|
566 |
+
"id": "ca57ed1e-0bbe-48dd-ae0f-f3d8ecd7fd04",
|
567 |
+
"metadata": {},
|
568 |
+
"outputs": [],
|
569 |
+
"source": [
|
570 |
+
"def inference_torch(input_ids: torch.Tensor) -> torch.Tensor:\n",
|
571 |
+
" transformer_outputs: BaseModelOutputWithPastAndCrossAttentions = model.transformer(input_ids=input_ids)\n",
|
572 |
+
" return model.lm_head(transformer_outputs.last_hidden_state)\n",
|
573 |
+
"\n",
|
574 |
+
"\n",
|
575 |
+
"model.cuda()\n",
|
576 |
+
"model.eval()\n",
|
577 |
+
"inputs.to(\"cuda\")\n",
|
578 |
+
"with torch.inference_mode():\n",
|
579 |
+
" gpt2_model = GPTWrapper(config=model.config, device=model.device, inference=inference_torch)\n",
|
580 |
+
" sample_output = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
|
581 |
+
" print(tokenizer.decode(sample_output[0], skip_special_tokens=False))\n",
|
582 |
+
" for _ in range(2):\n",
|
583 |
+
" _ = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
|
584 |
+
" torch.cuda.synchronize()\n",
|
585 |
+
" start = time.time()\n",
|
586 |
+
" for _ in range(10):\n",
|
587 |
+
" _ = gpt2_model.generate(inputs.input_ids, max_length=256)\n",
|
588 |
+
" torch.cuda.synchronize()\n",
|
589 |
+
" print(f\"----\\nPytorch: {(time.time() - start)/10:.2f}s/sequence\")\n",
|
590 |
+
"_ = model.cpu()"
|
591 |
+
]
|
592 |
+
},
|
593 |
+
{
|
594 |
+
"cell_type": "code",
|
595 |
+
"execution_count": null,
|
596 |
+
"id": "f0849aae-876e-47bc-b045-14a594170947",
|
597 |
+
"metadata": {},
|
598 |
+
"outputs": [],
|
599 |
+
"source": [
|
600 |
+
"model_onnx = create_model_for_provider(path=\"test-gpt2-opt.onnx\", provider_to_use=\"CUDAExecutionProvider\")\n",
|
601 |
+
"\n",
|
602 |
+
"\n",
|
603 |
+
"def inference_onnx_naive(input_ids: torch.Tensor) -> torch.Tensor:\n",
|
604 |
+
" data = {\"input_ids\": input_ids.detach().cpu().numpy().astype(np.int32)}\n",
|
605 |
+
" logit = model_onnx.run(None, data)\n",
|
606 |
+
" np_logit = np.array(logit) # convert list of numpy arrays to a numpy array\n",
|
607 |
+
" # we convert numpy tensor to Pytorch tensor as it's the type expected by HF decoding algorithm\n",
|
608 |
+
" return torch.squeeze(torch.from_numpy(np_logit), dim=0)\n",
|
609 |
+
"\n",
|
610 |
+
"\n",
|
611 |
+
"gpt2_model = GPTWrapper(config=model.config, device=torch.device(\"cpu\"), inference=inference_onnx_naive)\n",
|
612 |
+
"inputs.to(\"cpu\")\n",
|
613 |
+
"sample_output = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
|
614 |
+
"print(tokenizer.decode(sample_output[0], skip_special_tokens=True))\n",
|
615 |
+
"for _ in range(2):\n",
|
616 |
+
" _ = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
|
617 |
+
"start = time.time()\n",
|
618 |
+
"for _ in range(10):\n",
|
619 |
+
" _ = gpt2_model.generate(inputs.input_ids, max_length=256)\n",
|
620 |
+
"print(f\"----\\nONNX Runtime (standard API): {(time.time() - start)/10:.2f}s/sequence\")\n",
|
621 |
+
"\n",
|
622 |
+
"del model_onnx"
|
623 |
+
]
|
624 |
+
},
|
625 |
+
{
|
626 |
+
"cell_type": "code",
|
627 |
+
"execution_count": null,
|
628 |
+
"id": "96114897-894b-4997-bc61-8ac0682e0e55",
|
629 |
+
"metadata": {},
|
630 |
+
"outputs": [],
|
631 |
+
"source": [
|
632 |
+
"model_onnx = create_model_for_provider(path=\"test-gpt2-opt.onnx\", provider_to_use=\"CUDAExecutionProvider\")\n",
|
633 |
+
"\n",
|
634 |
+
"\n",
|
635 |
+
"def inference_onnx_optimized(input_ids: torch.Tensor) -> torch.Tensor:\n",
|
636 |
+
" data = {\"input_ids\": input_ids}\n",
|
637 |
+
" return inference_onnx_binding(model_onnx=model_onnx, inputs=data, device=\"cuda\")[\"output\"]\n",
|
638 |
+
"\n",
|
639 |
+
"\n",
|
640 |
+
"gpt2_model = GPTWrapper(config=model.config, device=torch.device(\"cuda\"), inference=inference_onnx_optimized)\n",
|
641 |
+
"inputs.to(\"cuda\")\n",
|
642 |
+
"sample_output = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
|
643 |
+
"print(tokenizer.decode(sample_output[0], skip_special_tokens=True))\n",
|
644 |
+
"for _ in range(2):\n",
|
645 |
+
" _ = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
|
646 |
+
"start = time.time()\n",
|
647 |
+
"for _ in range(10):\n",
|
648 |
+
" _ = gpt2_model.generate(inputs.input_ids, max_length=256)\n",
|
649 |
+
"print(f\"----\\nONNX Runtime (binding io API): {(time.time() - start)/10:.2f}/sequence\")\n",
|
650 |
+
"del model_onnx"
|
651 |
+
]
|
652 |
+
},
|
653 |
+
{
|
654 |
+
"cell_type": "code",
|
655 |
+
"execution_count": null,
|
656 |
+
"id": "0b5b5427-fd6b-4f70-b307-9c579f0f842a",
|
657 |
+
"metadata": {},
|
658 |
+
"outputs": [],
|
659 |
+
"source": [
|
660 |
+
"tensorrt_model: Callable[[Dict[str, torch.Tensor]], torch.Tensor] = load_engine(\n",
|
661 |
+
" engine_file_path=\"test-gpt2.plan\", runtime=runtime\n",
|
662 |
+
")\n",
|
663 |
+
"\n",
|
664 |
+
"\n",
|
665 |
+
"def inference_tensorrt(input_ids: torch.Tensor) -> torch.Tensor:\n",
|
666 |
+
" data = {\"input_ids\": input_ids}\n",
|
667 |
+
" return tensorrt_model(data)\n",
|
668 |
+
"\n",
|
669 |
+
"\n",
|
670 |
+
"gpt2_model = GPTWrapper(config=model.config, device=torch.device(\"cuda\"), inference=inference_tensorrt)\n",
|
671 |
+
"inputs.to(\"cuda\")\n",
|
672 |
+
"sample_output = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
|
673 |
+
"print(tokenizer.decode(sample_output[0], skip_special_tokens=True))\n",
|
674 |
+
"for _ in range(2):\n",
|
675 |
+
" _ = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
|
676 |
+
"start = time.time()\n",
|
677 |
+
"for _ in range(10):\n",
|
678 |
+
" _ = gpt2_model.generate(inputs.input_ids, max_length=256)\n",
|
679 |
+
"print(f\"----\\nTensorRT + CUDA tensors: {(time.time() - start)/10:.2f}/sequence\")\n",
|
680 |
+
"\n",
|
681 |
+
"del tensorrt_model"
|
682 |
+
]
|
683 |
+
},
|
684 |
+
{
|
685 |
+
"cell_type": "markdown",
|
686 |
+
"id": "f547239d-4f7a-433b-8ef6-9e5110a61f4b",
|
687 |
+
"metadata": {
|
688 |
+
"jp-MarkdownHeadingCollapsed": true
|
689 |
+
},
|
690 |
+
"source": [
|
691 |
+
"## Using CUDAExecution Provider"
|
692 |
+
]
|
693 |
+
},
|
694 |
+
{
|
695 |
+
"cell_type": "code",
|
696 |
+
"execution_count": null,
|
697 |
+
"id": "6e34c682-85fc-4e8d-b13c-7c1c9ea39ead",
|
698 |
+
"metadata": {},
|
699 |
+
"outputs": [],
|
700 |
+
"source": [
|
701 |
+
"from optimum.onnxruntime import ORTModelForCausalLM\n",
|
702 |
+
"from optimum.pipelines import pipeline\n",
|
703 |
+
"from transformers import AutoTokenizer\n",
|
704 |
+
"\n",
|
705 |
+
"model_id = \"openai-community/gpt2\"\n",
|
706 |
+
"\n",
|
707 |
+
"ort_model = ORTModelForCausalLM.from_pretrained(\n",
|
708 |
+
" model_id,\n",
|
709 |
+
" export=True,\n",
|
710 |
+
" provider=\"CUDAExecutionProvider\",\n",
|
711 |
+
" use_io_binding=True\n",
|
712 |
+
")\n",
|
713 |
+
"\n",
|
714 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
|
715 |
+
"tokenizer.pad_token = tokenizer.eos_token\n",
|
716 |
+
"\n",
|
717 |
+
"pipe = pipeline(task=\"text-generation\", model=ort_model, tokenizer=tokenizer, device=\"cuda:0\")"
|
718 |
+
]
|
719 |
+
},
|
720 |
+
{
|
721 |
+
"cell_type": "code",
|
722 |
+
"execution_count": null,
|
723 |
+
"id": "17d28184-26db-4dd3-b24b-0c5a12b10d6d",
|
724 |
+
"metadata": {},
|
725 |
+
"outputs": [],
|
726 |
+
"source": [
|
727 |
+
"import time\n",
|
728 |
+
"\n",
|
729 |
+
"start_time = time.time()\n",
|
730 |
+
"\n",
|
731 |
+
"generations = pipe(\"Both the music and visual were astounding, not to mention the actors performance.\")\n",
|
732 |
+
"generations[0][\"generated_text\"]\n",
|
733 |
+
"\n",
|
734 |
+
"finish_time = time.time()\n",
|
735 |
+
"\n",
|
736 |
+
"print(\"End to End Latency: \", (finish_time - start_time) * 1000, \"ms\")"
|
737 |
+
]
|
738 |
+
},
|
739 |
+
{
|
740 |
+
"cell_type": "markdown",
|
741 |
+
"id": "19c4230a-3244-4dce-b5ef-d9927dec5c45",
|
742 |
+
"metadata": {},
|
743 |
+
"source": [
|
744 |
+
"## ASR LM with CUDAExcecution Provider"
|
745 |
+
]
|
746 |
+
},
|
747 |
+
{
|
748 |
+
"cell_type": "code",
|
749 |
+
"execution_count": null,
|
750 |
+
"id": "0f0f1cdc-bfcd-46c5-80a4-60bc76366cf5",
|
751 |
+
"metadata": {},
|
752 |
+
"outputs": [],
|
753 |
+
"source": [
|
754 |
+
"from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer\n",
|
755 |
+
"from datasets import DatasetDict\n",
|
756 |
+
"import torch\n",
|
757 |
+
"\n",
|
758 |
+
"device = \"cuda:0\"\n",
|
759 |
+
"dtype = torch.float16\n",
|
760 |
+
"\n",
|
761 |
+
"dataset = DatasetDict.load_from_disk(\"./../librispeech_tokenized.hf\")\n",
|
762 |
+
"\n",
|
763 |
+
"from optimum.onnxruntime import ORTModelForCausalLM\n",
|
764 |
+
"from optimum.pipelines import pipeline\n",
|
765 |
+
"from transformers import AutoTokenizer\n",
|
766 |
+
"\n",
|
767 |
+
"model_id = \"./../out/checkpoint-10000\"\n",
|
768 |
+
"\n",
|
769 |
+
"ort_model = ORTModelForCausalLM.from_pretrained(\n",
|
770 |
+
" model_id,\n",
|
771 |
+
" export=True,\n",
|
772 |
+
" provider=\"CUDAExecutionProvider\",\n",
|
773 |
+
" use_io_binding=True\n",
|
774 |
+
")\n",
|
775 |
+
"\n",
|
776 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"./tokenizer\")\n",
|
777 |
+
"\n",
|
778 |
+
"pipe = pipeline(task=\"text-generation\", model=ort_model, tokenizer=tokenizer, device=\"cuda:0\")"
|
779 |
+
]
|
780 |
+
},
|
781 |
+
{
|
782 |
+
"cell_type": "code",
|
783 |
+
"execution_count": null,
|
784 |
+
"id": "9d32098c-b0ec-4c36-95ac-775a3a865512",
|
785 |
+
"metadata": {},
|
786 |
+
"outputs": [],
|
787 |
+
"source": [
|
788 |
+
"ort_model.config.eos_token_id = tokenizer.encode(\"<|endoftranscript|>\")[0]\n",
|
789 |
+
"ort_model.config.bos_token_id = tokenizer.encode(\"<|startoftranscript|>\")[0]"
|
790 |
+
]
|
791 |
+
},
|
792 |
+
{
|
793 |
+
"cell_type": "code",
|
794 |
+
"execution_count": null,
|
795 |
+
"id": "1fd0a1fb-9349-4c7a-af03-21e29334f420",
|
796 |
+
"metadata": {},
|
797 |
+
"outputs": [],
|
798 |
+
"source": [
|
799 |
+
"dataset[split][idx].keys()"
|
800 |
+
]
|
801 |
+
},
|
802 |
+
{
|
803 |
+
"cell_type": "code",
|
804 |
+
"execution_count": null,
|
805 |
+
"id": "15d8b989-6460-4555-b6e2-2f9e219d7034",
|
806 |
+
"metadata": {},
|
807 |
+
"outputs": [],
|
808 |
+
"source": [
|
809 |
+
"split = \"train.clean.100\"\n",
|
810 |
+
"idx = 0\n",
|
811 |
+
"\n",
|
812 |
+
"text = \"\".join([ f\"<|audio:{tkn}|>\"for tkn in dataset[split][idx][\"audio_tokens\"]]) + \"<|startoftranscript|>\"\n",
|
813 |
+
"\n",
|
814 |
+
"import time\n",
|
815 |
+
"\n",
|
816 |
+
"start_time = time.time()\n",
|
817 |
+
"\n",
|
818 |
+
"generations = pipe(text, max_new_tokens=10, skip_special_tokens=True)\n",
|
819 |
+
"\n",
|
820 |
+
"finish_time = time.time()\n",
|
821 |
+
"\n",
|
822 |
+
"print(generations[0][\"generated_text\"])\n",
|
823 |
+
"\n",
|
824 |
+
"print(\"End to End Latency: \", (finish_time - start_time) * 1000, \"ms\")"
|
825 |
+
]
|
826 |
+
}
|
827 |
+
],
|
828 |
+
"metadata": {
|
829 |
+
"kernelspec": {
|
830 |
+
"display_name": "Python 3 (ipykernel)",
|
831 |
+
"language": "python",
|
832 |
+
"name": "python3"
|
833 |
+
},
|
834 |
+
"language_info": {
|
835 |
+
"codemirror_mode": {
|
836 |
+
"name": "ipython",
|
837 |
+
"version": 3
|
838 |
+
},
|
839 |
+
"file_extension": ".py",
|
840 |
+
"mimetype": "text/x-python",
|
841 |
+
"name": "python",
|
842 |
+
"nbconvert_exporter": "python",
|
843 |
+
"pygments_lexer": "ipython3",
|
844 |
+
"version": "3.8.10"
|
845 |
+
}
|
846 |
+
},
|
847 |
+
"nbformat": 4,
|
848 |
+
"nbformat_minor": 5
|
849 |
+
}
|
ASR/.ipynb_checkpoints/demo-checkpoint.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request
|
2 |
+
# import speech_recognition as sr
|
3 |
+
|
4 |
+
app = Flask(__name__)
|
5 |
+
# recognizer = sr.Recognizer()
|
6 |
+
|
7 |
+
@app.route("/darshan/microphone", methods=['POST'])
|
8 |
+
def handle_audio():
|
9 |
+
audio_data = request.data
|
10 |
+
print(audio_data)
|
11 |
+
# audio = sr.AudioData(audio_data, sample_rate=44100, sample_width=2) # Adjust sample rate and sample width as needed
|
12 |
+
# try:
|
13 |
+
# text = recognizer.recognize_google(audio)
|
14 |
+
# print(f"Transcription: {text}")
|
15 |
+
# return {'transcription': text}, 200
|
16 |
+
# except sr.UnknownValueError:
|
17 |
+
# print("Could not understand audio")
|
18 |
+
# return '', 400
|
19 |
+
# except sr.RequestError as e:
|
20 |
+
# print(f"Error from Google Speech Recognition service; {e}")
|
21 |
+
# return '', 500
|
22 |
+
|
23 |
+
if __name__ == '__main__':
|
24 |
+
app.run(host='0.0.0.0', port=8723) # Replace with your desired host and port
|
ASR/.ipynb_checkpoints/tokenizer_training-checkpoint.ipynb
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "8f95f1d6-be90-4900-9116-c27b82bd7836",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from tokenizers import SentencePieceBPETokenizer\n",
|
11 |
+
"import transformers\n",
|
12 |
+
"from transformers import GPT2Tokenizer, AutoModelForCausalLM\n",
|
13 |
+
"from datasets import Dataset, DatasetDict\n",
|
14 |
+
"\n",
|
15 |
+
"cache_dir = \"./cache\"\n",
|
16 |
+
"\n",
|
17 |
+
"dataset = DatasetDict.load_from_disk(\"./../librispeech_tokenized.hf\")\n",
|
18 |
+
"\n",
|
19 |
+
"text = []\n",
|
20 |
+
"for split in dataset.keys():\n",
|
21 |
+
" text += list(dataset[split][\"text\"])\n",
|
22 |
+
"\n",
|
23 |
+
"model_max_length = 1877\n",
|
24 |
+
"special_tokens = [ f\"<|audio:{idx}|>\" for idx in range(1024)] + [\"<|startoftranscript|>\", \"<|endoftranscript|>\", \"<|padding|>\"]\n",
|
25 |
+
"\n",
|
26 |
+
"bpe_tokenizer = SentencePieceBPETokenizer()\n",
|
27 |
+
"bpe_tokenizer.train_from_iterator(\n",
|
28 |
+
" text,\n",
|
29 |
+
" vocab_size = 5000 + len(special_tokens),\n",
|
30 |
+
" min_frequency = 2,\n",
|
31 |
+
" show_progress = True,\n",
|
32 |
+
" special_tokens = special_tokens\n",
|
33 |
+
")\n",
|
34 |
+
"\n",
|
35 |
+
"tokenizer = transformers.PreTrainedTokenizerFast(\n",
|
36 |
+
" tokenizer_object = bpe_tokenizer,\n",
|
37 |
+
" model_max_length = model_max_length,\n",
|
38 |
+
" special_tokens = special_tokens\n",
|
39 |
+
")\n",
|
40 |
+
"\n",
|
41 |
+
"tokenizer.pad_token = \"<|padding|>\"\n",
|
42 |
+
"tokenizer.pad_token_id = bpe_tokenizer.token_to_id(\"<|padding|>\")\n",
|
43 |
+
"\n",
|
44 |
+
"tokenizer.save_pretrained(\"./tokenizer\")"
|
45 |
+
]
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"cell_type": "code",
|
49 |
+
"execution_count": null,
|
50 |
+
"id": "d259b76d-1c8d-4c74-9d04-d711a4b3f395",
|
51 |
+
"metadata": {},
|
52 |
+
"outputs": [],
|
53 |
+
"source": [
|
54 |
+
"from transformers import GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer\n",
|
55 |
+
"from datasets import Dataset, DatasetDict\n",
|
56 |
+
"\n",
|
57 |
+
"max_length = 1877\n",
|
58 |
+
"\n",
|
59 |
+
"dataset = DatasetDict.load_from_disk(\"./../librispeech_tokenized.hf\")\n",
|
60 |
+
"\n",
|
61 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"./tokenizer\")\n",
|
62 |
+
"\n",
|
63 |
+
"def tokenize(row):\n",
|
64 |
+
" text = \"\".join([f\"<|audio:{token}|>\" for token in row[\"audio_tokens\"]]) + \"<|startoftranscript|>\" + row[\"text\"] + \"<|endoftranscript|>\"\n",
|
65 |
+
" input_ids = tokenizer(\n",
|
66 |
+
" text,\n",
|
67 |
+
" padding=\"max_length\",\n",
|
68 |
+
" max_length=max_length,\n",
|
69 |
+
" )\n",
|
70 |
+
" return input_ids\n",
|
71 |
+
"\n",
|
72 |
+
"dataset = dataset.map(tokenize, remove_columns=[\"text\", \"audio_tokens\"])\n",
|
73 |
+
"\n",
|
74 |
+
"dataset.save_to_disk(\"tokenized_librispeech\")"
|
75 |
+
]
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"cell_type": "code",
|
79 |
+
"execution_count": 1,
|
80 |
+
"id": "9bc4d1db-1390-4ba1-b296-d52a8993c87f",
|
81 |
+
"metadata": {},
|
82 |
+
"outputs": [
|
83 |
+
{
|
84 |
+
"name": "stderr",
|
85 |
+
"output_type": "stream",
|
86 |
+
"text": [
|
87 |
+
"/usr/local/lib/python3.8/dist-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'.\n",
|
88 |
+
" table = cls._concat_blocks(blocks, axis=0)\n"
|
89 |
+
]
|
90 |
+
}
|
91 |
+
],
|
92 |
+
"source": [
|
93 |
+
"from transformers import GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer\n",
|
94 |
+
"from datasets import Dataset, DatasetDict\n",
|
95 |
+
"\n",
|
96 |
+
"max_length = 1877\n",
|
97 |
+
"\n",
|
98 |
+
"dataset = DatasetDict.load_from_disk(\"./../librispeech_tokenized.hf\")\n",
|
99 |
+
"\n",
|
100 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"./tokenizer\")"
|
101 |
+
]
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"cell_type": "code",
|
105 |
+
"execution_count": 2,
|
106 |
+
"id": "a83d83c7-2e56-4b9e-9f63-11f7eea22d6d",
|
107 |
+
"metadata": {},
|
108 |
+
"outputs": [
|
109 |
+
{
|
110 |
+
"data": {
|
111 |
+
"text/plain": [
|
112 |
+
"[1024]"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
"execution_count": 2,
|
116 |
+
"metadata": {},
|
117 |
+
"output_type": "execute_result"
|
118 |
+
}
|
119 |
+
],
|
120 |
+
"source": [
|
121 |
+
"tokenizer.encode(\"<|startoftranscript|>\")"
|
122 |
+
]
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"cell_type": "code",
|
126 |
+
"execution_count": null,
|
127 |
+
"id": "5b9f0fa9-384b-4453-bf4c-0d49f0c1e4a5",
|
128 |
+
"metadata": {},
|
129 |
+
"outputs": [],
|
130 |
+
"source": [
|
131 |
+
"tokenizer.pad_token_id"
|
132 |
+
]
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"cell_type": "code",
|
136 |
+
"execution_count": null,
|
137 |
+
"id": "1964689d-687d-4ab9-967d-80d9eb95b159",
|
138 |
+
"metadata": {},
|
139 |
+
"outputs": [],
|
140 |
+
"source": [
|
141 |
+
"from tqdm import tqdm\n",
|
142 |
+
"lens = []\n",
|
143 |
+
"\n",
|
144 |
+
"for split in dataset.keys():\n",
|
145 |
+
" for idx in tqdm(range(len(dataset[split]))):\n",
|
146 |
+
" sample = dataset[split][idx]\n",
|
147 |
+
" max_len = len(tokenizer.encode(sample[\"text\"])) + len(sample[\"audio_tokens\"])\n",
|
148 |
+
" lens.append(max_len)"
|
149 |
+
]
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"cell_type": "code",
|
153 |
+
"execution_count": null,
|
154 |
+
"id": "e49c1a27-c3e4-43eb-a4e0-68a0d739f27b",
|
155 |
+
"metadata": {},
|
156 |
+
"outputs": [],
|
157 |
+
"source": [
|
158 |
+
"max(lens)"
|
159 |
+
]
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"cell_type": "code",
|
163 |
+
"execution_count": null,
|
164 |
+
"id": "96dbd94d-5455-49bd-8893-4aae5dfd9b7f",
|
165 |
+
"metadata": {},
|
166 |
+
"outputs": [],
|
167 |
+
"source": [
|
168 |
+
"min(lens)"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "code",
|
173 |
+
"execution_count": null,
|
174 |
+
"id": "b9aaca12-286d-416c-a2f7-0cdf689eeb2e",
|
175 |
+
"metadata": {},
|
176 |
+
"outputs": [],
|
177 |
+
"source": [
|
178 |
+
"tokenizer.encode(\"<|audio:0|>\")"
|
179 |
+
]
|
180 |
+
}
|
181 |
+
],
|
182 |
+
"metadata": {
|
183 |
+
"kernelspec": {
|
184 |
+
"display_name": "Python 3 (ipykernel)",
|
185 |
+
"language": "python",
|
186 |
+
"name": "python3"
|
187 |
+
},
|
188 |
+
"language_info": {
|
189 |
+
"codemirror_mode": {
|
190 |
+
"name": "ipython",
|
191 |
+
"version": 3
|
192 |
+
},
|
193 |
+
"file_extension": ".py",
|
194 |
+
"mimetype": "text/x-python",
|
195 |
+
"name": "python",
|
196 |
+
"nbconvert_exporter": "python",
|
197 |
+
"pygments_lexer": "ipython3",
|
198 |
+
"version": "3.8.10"
|
199 |
+
}
|
200 |
+
},
|
201 |
+
"nbformat": 4,
|
202 |
+
"nbformat_minor": 5
|
203 |
+
}
|
ASR/__pycache__/audio_tokenizer.cpython-38.pyc
ADDED
Binary file (14.8 kB). View file
|
|
ASR/__pycache__/tokenizer.cpython-38.pyc
ADDED
Binary file (14.8 kB). View file
|
|
ASR/audio_tokenizer.py
ADDED
@@ -0,0 +1,611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
from omegaconf import II
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.distributed as dist
|
12 |
+
|
13 |
+
from fairseq.modules import EMAModule, EMAModuleConfig
|
14 |
+
from fairseq.data.data_utils import compute_mask_indices
|
15 |
+
from fairseq.models import BaseFairseqModel, register_model
|
16 |
+
from fairseq.models.wav2vec import (
|
17 |
+
ConvFeatureExtractionModel,
|
18 |
+
Wav2Vec2Config,
|
19 |
+
TransformerEncoder,
|
20 |
+
)
|
21 |
+
from fairseq.modules import (
|
22 |
+
GradMultiply,
|
23 |
+
LayerNorm,
|
24 |
+
)
|
25 |
+
from fairseq.utils import index_put
|
26 |
+
|
27 |
+
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
@dataclass
|
32 |
+
class Data2VecAudioConfig(Wav2Vec2Config):
|
33 |
+
|
34 |
+
loss_beta: float = field(
|
35 |
+
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
|
36 |
+
)
|
37 |
+
loss_scale: Optional[float] = field(
|
38 |
+
default=None,
|
39 |
+
metadata={
|
40 |
+
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
|
41 |
+
},
|
42 |
+
)
|
43 |
+
average_top_k_layers: int = field(
|
44 |
+
default=8, metadata={"help": "how many layers to average"}
|
45 |
+
)
|
46 |
+
|
47 |
+
layer_norm_target_layer: bool = False
|
48 |
+
instance_norm_target_layer: bool = False
|
49 |
+
instance_norm_targets: bool = False
|
50 |
+
layer_norm_targets: bool = False
|
51 |
+
batch_norm_target_layer: bool = False
|
52 |
+
group_norm_target_layer: bool = False
|
53 |
+
|
54 |
+
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
|
55 |
+
ema_end_decay: float = field(
|
56 |
+
default=0.9999, metadata={"help": "final ema decay rate"}
|
57 |
+
)
|
58 |
+
|
59 |
+
# when to finish annealing ema decay rate
|
60 |
+
ema_anneal_end_step: int = II("optimization.max_update")
|
61 |
+
|
62 |
+
ema_transformer_only: bool = field(
|
63 |
+
default=True,
|
64 |
+
metadata={"help": "whether to momentum update only the transformer"},
|
65 |
+
)
|
66 |
+
ema_layers_only: bool = field(
|
67 |
+
default=True,
|
68 |
+
metadata={"help": "whether to momentum update only the transformer layers"},
|
69 |
+
)
|
70 |
+
|
71 |
+
max_update: int = II("optimization.max_update")
|
72 |
+
|
73 |
+
min_target_var: float = field(
|
74 |
+
default=0.1, metadata={"help": "stop training if target var falls below this"}
|
75 |
+
)
|
76 |
+
min_pred_var: float = field(
|
77 |
+
default=0.01,
|
78 |
+
metadata={"help": "stop training if prediction var falls below this"},
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
def get_annealed_rate(start, end, curr_step, total_steps):
|
83 |
+
r = end - start
|
84 |
+
pct_remaining = 1 - curr_step / total_steps
|
85 |
+
return end - r * pct_remaining
|
86 |
+
|
87 |
+
|
88 |
+
@register_model("data2vec_audio", dataclass=Data2VecAudioConfig)
|
89 |
+
class Data2VecAudioModel(BaseFairseqModel):
|
90 |
+
def __init__(self, cfg: Data2VecAudioConfig):
|
91 |
+
super().__init__()
|
92 |
+
self.cfg = cfg
|
93 |
+
|
94 |
+
feature_enc_layers = eval(cfg.conv_feature_layers)
|
95 |
+
self.extractor_embed = feature_enc_layers[-1][0]
|
96 |
+
|
97 |
+
self.ema = None
|
98 |
+
self.embed = cfg.encoder_embed_dim
|
99 |
+
|
100 |
+
self.average_top_k_layers = cfg.average_top_k_layers
|
101 |
+
self.loss_beta = cfg.loss_beta
|
102 |
+
self.loss_scale = cfg.loss_scale
|
103 |
+
|
104 |
+
self.feature_extractor = ConvFeatureExtractionModel(
|
105 |
+
conv_layers=feature_enc_layers,
|
106 |
+
dropout=0.0,
|
107 |
+
mode=cfg.extractor_mode,
|
108 |
+
conv_bias=cfg.conv_bias,
|
109 |
+
)
|
110 |
+
|
111 |
+
self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim)
|
112 |
+
|
113 |
+
self.mask_prob = cfg.mask_prob
|
114 |
+
self.mask_selection = cfg.mask_selection
|
115 |
+
self.mask_other = cfg.mask_other
|
116 |
+
self.mask_length = cfg.mask_length
|
117 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
118 |
+
self.mask_min_space = cfg.mask_min_space
|
119 |
+
|
120 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
121 |
+
self.mask_channel_before = cfg.mask_channel_before
|
122 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
123 |
+
self.mask_channel_other = cfg.mask_channel_other
|
124 |
+
self.mask_channel_length = cfg.mask_channel_length
|
125 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
126 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
127 |
+
|
128 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
129 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
130 |
+
|
131 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
132 |
+
|
133 |
+
self.mask_emb = nn.Parameter(
|
134 |
+
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
135 |
+
)
|
136 |
+
|
137 |
+
self.encoder = TransformerEncoder(cfg)
|
138 |
+
self.layer_norm = LayerNorm(self.extractor_embed)
|
139 |
+
|
140 |
+
self.final_proj = nn.Linear(self.embed, self.embed)
|
141 |
+
|
142 |
+
self.num_updates = 0
|
143 |
+
|
144 |
+
def make_ema_teacher(self):
|
145 |
+
ema_config = EMAModuleConfig(
|
146 |
+
ema_decay=self.cfg.ema_decay,
|
147 |
+
ema_fp32=True,
|
148 |
+
)
|
149 |
+
skip_keys = set()
|
150 |
+
if self.cfg.ema_layers_only:
|
151 |
+
self.cfg.ema_transformer_only = True
|
152 |
+
for k, _ in self.encoder.pos_conv.named_parameters():
|
153 |
+
skip_keys.add(f"pos_conv.{k}")
|
154 |
+
|
155 |
+
self.ema = EMAModule(
|
156 |
+
self.encoder if self.cfg.ema_transformer_only else self,
|
157 |
+
ema_config,
|
158 |
+
skip_keys=skip_keys,
|
159 |
+
)
|
160 |
+
|
161 |
+
def set_num_updates(self, num_updates):
|
162 |
+
super().set_num_updates(num_updates)
|
163 |
+
|
164 |
+
if self.ema is None and self.final_proj is not None:
|
165 |
+
logger.info(f"making ema teacher")
|
166 |
+
self.make_ema_teacher()
|
167 |
+
elif self.training and self.ema is not None:
|
168 |
+
if self.cfg.ema_decay != self.cfg.ema_end_decay:
|
169 |
+
if num_updates >= self.cfg.ema_anneal_end_step:
|
170 |
+
decay = self.cfg.ema_end_decay
|
171 |
+
else:
|
172 |
+
decay = get_annealed_rate(
|
173 |
+
self.cfg.ema_decay,
|
174 |
+
self.cfg.ema_end_decay,
|
175 |
+
num_updates,
|
176 |
+
self.cfg.ema_anneal_end_step,
|
177 |
+
)
|
178 |
+
self.ema.set_decay(decay)
|
179 |
+
if self.ema.get_decay() < 1:
|
180 |
+
self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
|
181 |
+
|
182 |
+
self.num_updates = num_updates
|
183 |
+
|
184 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
185 |
+
state = super().state_dict(destination, prefix, keep_vars)
|
186 |
+
|
187 |
+
if self.ema is not None:
|
188 |
+
state[prefix + "_ema"] = self.ema.fp32_params
|
189 |
+
|
190 |
+
return state
|
191 |
+
|
192 |
+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
193 |
+
if self.ema is not None:
|
194 |
+
k = prefix + "_ema"
|
195 |
+
assert k in state_dict
|
196 |
+
self.ema.restore(state_dict[k], True)
|
197 |
+
del state_dict[k]
|
198 |
+
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
199 |
+
|
200 |
+
@classmethod
|
201 |
+
def build_model(cls, cfg: Data2VecAudioConfig, task=None):
|
202 |
+
"""Build a new model instance."""
|
203 |
+
|
204 |
+
return cls(cfg)
|
205 |
+
|
206 |
+
def apply_mask(
|
207 |
+
self,
|
208 |
+
x,
|
209 |
+
padding_mask,
|
210 |
+
mask_indices=None,
|
211 |
+
mask_channel_indices=None,
|
212 |
+
):
|
213 |
+
B, T, C = x.shape
|
214 |
+
|
215 |
+
if self.mask_channel_prob > 0 and self.mask_channel_before:
|
216 |
+
mask_channel_indices = compute_mask_indices(
|
217 |
+
(B, C),
|
218 |
+
None,
|
219 |
+
self.mask_channel_prob,
|
220 |
+
self.mask_channel_length,
|
221 |
+
self.mask_channel_selection,
|
222 |
+
self.mask_channel_other,
|
223 |
+
no_overlap=self.no_mask_channel_overlap,
|
224 |
+
min_space=self.mask_channel_min_space,
|
225 |
+
)
|
226 |
+
mask_channel_indices = (
|
227 |
+
torch.from_numpy(mask_channel_indices)
|
228 |
+
.to(x.device)
|
229 |
+
.unsqueeze(1)
|
230 |
+
.expand(-1, T, -1)
|
231 |
+
)
|
232 |
+
x[mask_channel_indices] = 0
|
233 |
+
|
234 |
+
if self.mask_prob > 0:
|
235 |
+
if mask_indices is None:
|
236 |
+
mask_indices = compute_mask_indices(
|
237 |
+
(B, T),
|
238 |
+
padding_mask,
|
239 |
+
self.mask_prob,
|
240 |
+
self.mask_length,
|
241 |
+
self.mask_selection,
|
242 |
+
self.mask_other,
|
243 |
+
min_masks=1,
|
244 |
+
no_overlap=self.no_mask_overlap,
|
245 |
+
min_space=self.mask_min_space,
|
246 |
+
require_same_masks=self.cfg.require_same_masks,
|
247 |
+
mask_dropout=self.cfg.mask_dropout,
|
248 |
+
)
|
249 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
250 |
+
x = index_put(x, mask_indices, self.mask_emb)
|
251 |
+
else:
|
252 |
+
mask_indices = None
|
253 |
+
|
254 |
+
if self.mask_channel_prob > 0 and not self.mask_channel_before:
|
255 |
+
if mask_channel_indices is None:
|
256 |
+
mask_channel_indices = compute_mask_indices(
|
257 |
+
(B, C),
|
258 |
+
None,
|
259 |
+
self.mask_channel_prob,
|
260 |
+
self.mask_channel_length,
|
261 |
+
self.mask_channel_selection,
|
262 |
+
self.mask_channel_other,
|
263 |
+
no_overlap=self.no_mask_channel_overlap,
|
264 |
+
min_space=self.mask_channel_min_space,
|
265 |
+
)
|
266 |
+
mask_channel_indices = (
|
267 |
+
torch.from_numpy(mask_channel_indices)
|
268 |
+
.to(x.device)
|
269 |
+
.unsqueeze(1)
|
270 |
+
.expand(-1, T, -1)
|
271 |
+
)
|
272 |
+
x = index_put(x, mask_channel_indices, 0)
|
273 |
+
|
274 |
+
return x, mask_indices
|
275 |
+
|
276 |
+
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
277 |
+
"""
|
278 |
+
Computes the output length of the convolutional layers
|
279 |
+
"""
|
280 |
+
|
281 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
282 |
+
return torch.floor((input_length - kernel_size) / stride + 1)
|
283 |
+
|
284 |
+
conv_cfg_list = eval(self.cfg.conv_feature_layers)
|
285 |
+
|
286 |
+
for i in range(len(conv_cfg_list)):
|
287 |
+
input_lengths = _conv_out_length(
|
288 |
+
input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
|
289 |
+
)
|
290 |
+
|
291 |
+
return input_lengths.to(torch.long)
|
292 |
+
|
293 |
+
def forward(
|
294 |
+
self,
|
295 |
+
source,
|
296 |
+
padding_mask=None,
|
297 |
+
mask=True,
|
298 |
+
features_only=False,
|
299 |
+
layer=None,
|
300 |
+
mask_indices=None,
|
301 |
+
mask_channel_indices=None,
|
302 |
+
padding_count=None,
|
303 |
+
):
|
304 |
+
features = source
|
305 |
+
|
306 |
+
if self.feature_grad_mult > 0:
|
307 |
+
features = self.feature_extractor(features)
|
308 |
+
if self.feature_grad_mult != 1.0:
|
309 |
+
features = GradMultiply.apply(features, self.feature_grad_mult)
|
310 |
+
else:
|
311 |
+
with torch.no_grad():
|
312 |
+
features = self.feature_extractor(features)
|
313 |
+
|
314 |
+
features = features.transpose(1, 2)
|
315 |
+
|
316 |
+
features = self.layer_norm(features)
|
317 |
+
|
318 |
+
orig_padding_mask = padding_mask
|
319 |
+
|
320 |
+
if padding_mask is not None and padding_mask.any():
|
321 |
+
input_lengths = (1 - padding_mask.long()).sum(-1)
|
322 |
+
# apply conv formula to get real output_lengths
|
323 |
+
output_lengths = self._get_feat_extract_output_lengths(input_lengths)
|
324 |
+
|
325 |
+
padding_mask = torch.zeros(
|
326 |
+
features.shape[:2], dtype=features.dtype, device=features.device
|
327 |
+
)
|
328 |
+
|
329 |
+
# these two operations makes sure that all values
|
330 |
+
# before the output lengths indices are attended to
|
331 |
+
padding_mask[
|
332 |
+
(
|
333 |
+
torch.arange(padding_mask.shape[0], device=padding_mask.device),
|
334 |
+
output_lengths - 1,
|
335 |
+
)
|
336 |
+
] = 1
|
337 |
+
padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
|
338 |
+
else:
|
339 |
+
padding_mask = None
|
340 |
+
|
341 |
+
if self.post_extract_proj is not None:
|
342 |
+
features = self.post_extract_proj(features)
|
343 |
+
|
344 |
+
pre_encoder_features = None
|
345 |
+
if self.cfg.ema_transformer_only:
|
346 |
+
pre_encoder_features = features.clone()
|
347 |
+
|
348 |
+
features = self.dropout_input(features)
|
349 |
+
|
350 |
+
if mask:
|
351 |
+
x, mask_indices = self.apply_mask(
|
352 |
+
features,
|
353 |
+
padding_mask,
|
354 |
+
mask_indices=mask_indices,
|
355 |
+
mask_channel_indices=mask_channel_indices,
|
356 |
+
)
|
357 |
+
else:
|
358 |
+
x = features
|
359 |
+
mask_indices = None
|
360 |
+
|
361 |
+
x, layer_results = self.encoder(
|
362 |
+
x,
|
363 |
+
padding_mask=padding_mask,
|
364 |
+
layer=layer,
|
365 |
+
)
|
366 |
+
|
367 |
+
if features_only:
|
368 |
+
return {
|
369 |
+
"x": x,
|
370 |
+
"padding_mask": padding_mask,
|
371 |
+
"layer_results": layer_results,
|
372 |
+
}
|
373 |
+
|
374 |
+
result = {
|
375 |
+
"losses": {},
|
376 |
+
}
|
377 |
+
|
378 |
+
with torch.no_grad():
|
379 |
+
self.ema.model.eval()
|
380 |
+
|
381 |
+
if self.cfg.ema_transformer_only:
|
382 |
+
y, layer_results = self.ema.model.extract_features(
|
383 |
+
pre_encoder_features,
|
384 |
+
padding_mask=padding_mask,
|
385 |
+
min_layer=self.cfg.encoder_layers - self.average_top_k_layers,
|
386 |
+
)
|
387 |
+
y = {
|
388 |
+
"x": y,
|
389 |
+
"padding_mask": padding_mask,
|
390 |
+
"layer_results": layer_results,
|
391 |
+
}
|
392 |
+
else:
|
393 |
+
y = self.ema.model.extract_features(
|
394 |
+
source=source,
|
395 |
+
padding_mask=orig_padding_mask,
|
396 |
+
mask=False,
|
397 |
+
)
|
398 |
+
|
399 |
+
target_layer_results = [l[2] for l in y["layer_results"]]
|
400 |
+
|
401 |
+
permuted = False
|
402 |
+
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
|
403 |
+
target_layer_results = [
|
404 |
+
tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
|
405 |
+
]
|
406 |
+
permuted = True
|
407 |
+
|
408 |
+
if self.cfg.batch_norm_target_layer:
|
409 |
+
target_layer_results = [
|
410 |
+
F.batch_norm(
|
411 |
+
tl.float(), running_mean=None, running_var=None, training=True
|
412 |
+
)
|
413 |
+
for tl in target_layer_results
|
414 |
+
]
|
415 |
+
|
416 |
+
if self.cfg.instance_norm_target_layer:
|
417 |
+
target_layer_results = [
|
418 |
+
F.instance_norm(tl.float()) for tl in target_layer_results
|
419 |
+
]
|
420 |
+
|
421 |
+
if permuted:
|
422 |
+
target_layer_results = [
|
423 |
+
tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
|
424 |
+
]
|
425 |
+
|
426 |
+
if self.cfg.group_norm_target_layer:
|
427 |
+
target_layer_results = [
|
428 |
+
F.layer_norm(tl.float(), tl.shape[-2:])
|
429 |
+
for tl in target_layer_results
|
430 |
+
]
|
431 |
+
|
432 |
+
if self.cfg.layer_norm_target_layer:
|
433 |
+
target_layer_results = [
|
434 |
+
F.layer_norm(tl.float(), tl.shape[-1:])
|
435 |
+
for tl in target_layer_results
|
436 |
+
]
|
437 |
+
|
438 |
+
y = sum(target_layer_results) / len(target_layer_results)
|
439 |
+
|
440 |
+
if self.cfg.layer_norm_targets:
|
441 |
+
y = F.layer_norm(y.float(), y.shape[-1:])
|
442 |
+
|
443 |
+
if self.cfg.instance_norm_targets:
|
444 |
+
y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
|
445 |
+
|
446 |
+
if not permuted:
|
447 |
+
y = y.transpose(0, 1)
|
448 |
+
|
449 |
+
y = y[mask_indices]
|
450 |
+
|
451 |
+
x = x[mask_indices]
|
452 |
+
x = self.final_proj(x)
|
453 |
+
|
454 |
+
sz = x.size(-1)
|
455 |
+
|
456 |
+
if self.loss_beta == 0:
|
457 |
+
loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
|
458 |
+
else:
|
459 |
+
loss = F.smooth_l1_loss(
|
460 |
+
x.float(), y.float(), reduction="none", beta=self.loss_beta
|
461 |
+
).sum(dim=-1)
|
462 |
+
|
463 |
+
if self.loss_scale is not None:
|
464 |
+
scale = self.loss_scale
|
465 |
+
else:
|
466 |
+
scale = 1 / math.sqrt(sz)
|
467 |
+
|
468 |
+
result["losses"]["regression"] = loss.sum() * scale
|
469 |
+
|
470 |
+
if "sample_size" not in result:
|
471 |
+
result["sample_size"] = loss.numel()
|
472 |
+
|
473 |
+
with torch.no_grad():
|
474 |
+
result["target_var"] = self.compute_var(y)
|
475 |
+
result["pred_var"] = self.compute_var(x.float())
|
476 |
+
|
477 |
+
if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var:
|
478 |
+
logger.error(
|
479 |
+
f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
|
480 |
+
)
|
481 |
+
raise Exception(
|
482 |
+
f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
|
483 |
+
)
|
484 |
+
if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var:
|
485 |
+
logger.error(
|
486 |
+
f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
|
487 |
+
)
|
488 |
+
raise Exception(
|
489 |
+
f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
|
490 |
+
)
|
491 |
+
|
492 |
+
if self.ema is not None:
|
493 |
+
result["ema_decay"] = self.ema.get_decay() * 1000
|
494 |
+
|
495 |
+
return result
|
496 |
+
|
497 |
+
@staticmethod
|
498 |
+
def compute_var(y):
|
499 |
+
y = y.view(-1, y.size(-1))
|
500 |
+
if dist.is_initialized():
|
501 |
+
zc = torch.tensor(y.size(0)).cuda()
|
502 |
+
zs = y.sum(dim=0)
|
503 |
+
zss = (y ** 2).sum(dim=0)
|
504 |
+
|
505 |
+
dist.all_reduce(zc)
|
506 |
+
dist.all_reduce(zs)
|
507 |
+
dist.all_reduce(zss)
|
508 |
+
|
509 |
+
var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
|
510 |
+
return torch.sqrt(var + 1e-6).mean()
|
511 |
+
else:
|
512 |
+
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
|
513 |
+
|
514 |
+
def extract_features(
|
515 |
+
self, source, padding_mask, mask=False, layer=None
|
516 |
+
):
|
517 |
+
res = self.forward(
|
518 |
+
source,
|
519 |
+
padding_mask,
|
520 |
+
mask=mask,
|
521 |
+
features_only=True,
|
522 |
+
layer=layer,
|
523 |
+
)
|
524 |
+
return res
|
525 |
+
|
526 |
+
def remove_pretraining_modules(self, last_layer=None):
|
527 |
+
self.final_proj = None
|
528 |
+
self.ema = None
|
529 |
+
if last_layer is not None:
|
530 |
+
self.encoder.layers = nn.ModuleList(
|
531 |
+
l for i, l in enumerate(self.encoder.layers) if i <= last_layer
|
532 |
+
)
|
533 |
+
|
534 |
+
import logging
|
535 |
+
|
536 |
+
import torch
|
537 |
+
import torch.nn.functional as F
|
538 |
+
from fairseq import tasks
|
539 |
+
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
|
540 |
+
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
541 |
+
from omegaconf import OmegaConf
|
542 |
+
|
543 |
+
logger = logging.getLogger("dump_feature")
|
544 |
+
|
545 |
+
|
546 |
+
class Data2vecFeatureReader(object):
|
547 |
+
def __init__(self, ckpt_path: str, layer: int, device: str, max_chunk=1600000):
|
548 |
+
state = load_checkpoint_to_cpu(ckpt_path)
|
549 |
+
cfg = state["cfg"]
|
550 |
+
# load task
|
551 |
+
task = tasks.setup_task(cfg.task, from_checkpoint=True)
|
552 |
+
task.load_state_dict(state["task_state"])
|
553 |
+
# load model config
|
554 |
+
if "layer_type" not in cfg.model:
|
555 |
+
# fix a missing key
|
556 |
+
model_config = {k: v for k, v in cfg.model.items()}
|
557 |
+
model_config["layer_type"] = "transformer"
|
558 |
+
model_config = OmegaConf.create(model_config)
|
559 |
+
else:
|
560 |
+
model_config = cfg.model
|
561 |
+
|
562 |
+
# fix param name in the state
|
563 |
+
state["model"]["final_proj.weight"] = state["model"].pop("final_proj.0.weight")
|
564 |
+
state["model"]["final_proj.bias"] = state["model"].pop("final_proj.0.bias")
|
565 |
+
del state["model"]["_ema"]
|
566 |
+
|
567 |
+
# load model
|
568 |
+
model = Data2VecAudioModel.build_model(model_config)
|
569 |
+
model.load_state_dict(
|
570 |
+
state["model"], strict=True, model_cfg=model_config
|
571 |
+
)
|
572 |
+
|
573 |
+
self.device = device
|
574 |
+
logger.info(f"device = {self.device}")
|
575 |
+
|
576 |
+
self.model = model.eval().to(self.device)
|
577 |
+
self.task = task
|
578 |
+
self.layer = layer - 1 # make it 1-based
|
579 |
+
self.max_chunk = max_chunk
|
580 |
+
logger.info(f"TASK CONFIG:\n{self.task.cfg}")
|
581 |
+
logger.info(f" max_chunk = {self.max_chunk}")
|
582 |
+
|
583 |
+
def read_audio(self, path, ref_len=None):
|
584 |
+
wav = get_features_or_waveform(path, need_waveform=True, use_sample_rate=self.task.cfg.sample_rate)
|
585 |
+
if wav.ndim == 2:
|
586 |
+
wav = wav.mean(-1)
|
587 |
+
assert wav.ndim == 1, wav.ndim
|
588 |
+
if ref_len is not None and abs(ref_len - len(wav)) > 160:
|
589 |
+
logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
|
590 |
+
return wav
|
591 |
+
|
592 |
+
def get_feats(self, path, ref_len=None):
|
593 |
+
x = self.read_audio(path, ref_len=ref_len)
|
594 |
+
with torch.no_grad():
|
595 |
+
x = torch.from_numpy(x).float().to(self.device)
|
596 |
+
if self.task.cfg.normalize:
|
597 |
+
x = F.layer_norm(x, x.shape)
|
598 |
+
x = x.view(1, -1)
|
599 |
+
|
600 |
+
feat = []
|
601 |
+
for start in range(0, x.size(1), self.max_chunk):
|
602 |
+
x_chunk = x[:, start: start + self.max_chunk]
|
603 |
+
res = self.model.extract_features(
|
604 |
+
source=x_chunk,
|
605 |
+
padding_mask=None,
|
606 |
+
mask=False,
|
607 |
+
layer=self.layer,
|
608 |
+
)
|
609 |
+
feat_chunk = res["x"]
|
610 |
+
feat.append(feat_chunk)
|
611 |
+
return torch.cat(feat, 1).squeeze(0)
|
ASR/demo.ipynb
ADDED
@@ -0,0 +1,878 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "715a402a-44b9-4fa2-abf0-b0cfd2f3d80b",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"## Recording voice in Real Time"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "code",
|
13 |
+
"execution_count": null,
|
14 |
+
"id": "dbdf6bab-7418-4a6f-8b75-c31f98a6ada5",
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"\"\"\"\n",
|
19 |
+
"Sprints:\n",
|
20 |
+
"- [ ] Do Inference optimization of ASR LM\n",
|
21 |
+
"- [ ] Train on train.other.500\n",
|
22 |
+
"- [ ] Generate dataset for prompting\n",
|
23 |
+
"\n",
|
24 |
+
"Evaluation Dates: 20th - 21st June, 2023, 3:30 - 5:30pm\n",
|
25 |
+
"Sharpen PPT Skills: 20th June, 3:30pm - 4:45pm\n",
|
26 |
+
"Flow of the PPT:\n",
|
27 |
+
"Demo -> Datasets -> Techniques -> Evaluation -> Q&A\n",
|
28 |
+
"- [ Done ] Update the one pager deck slide\n",
|
29 |
+
"https://sprinklr-my.sharepoint.com/:p:/r/personal/sricharan_narayanam_sprinklr_com/_layouts/15/Doc.aspx?sourcedoc=%7B84811f56-5fc7-4eaa-87d2-db4a3588d18c%7D&action=edit&wdPreviousSession=948ccc35-dc05-f1f9-612d-9a22300e25ba\n",
|
30 |
+
"My PPT:\n",
|
31 |
+
"https://sprinklr-my.sharepoint.com/:p:/p/darshan_makwana/Ec4jCiyMWhxMproH625msc8BClFVceNQ8o4kS3EhZBO9MA?e=YCSDxm&wdOrigin=TEAMS-MAGLEV.p2p_ns.rwc&wdExp=TEAMS-TREATMENT&wdhostclicktime=1718703689001&web=1\n",
|
32 |
+
"Intern Tracker:\n",
|
33 |
+
"https://sprinklr.sharepoint.com/:x:/s/AIIntuition/EbRhHPIAIw9MlZ5PpXbztmABde1LFbaSoSHJAo9qU8ggDg?e=xiLkRt&wdOrigin=TEAMS-MAGLEV.p2p_ns.rwc&wdExp=TEAMS-TREATMENT&wdhostclicktime=1718692666812&web=1\n",
|
34 |
+
"\"\"\""
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "markdown",
|
39 |
+
"id": "150aca01-4098-4ab2-809a-25775ec52069",
|
40 |
+
"metadata": {},
|
41 |
+
"source": [
|
42 |
+
"## ASR LM Inference"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": null,
|
48 |
+
"id": "804a58af-beb2-48c1-9530-98024e27c0d6",
|
49 |
+
"metadata": {},
|
50 |
+
"outputs": [],
|
51 |
+
"source": [
|
52 |
+
"from audio_tokenizer import Data2vecFeatureReader\n",
|
53 |
+
"from repcodec.RepCodec import RepCodec\n",
|
54 |
+
"import torch.nn.functional as F\n",
|
55 |
+
"import torch\n",
|
56 |
+
"import yaml\n",
|
57 |
+
"\n",
|
58 |
+
"reader = Data2vecFeatureReader(\"./../prompting/models/vox_pretrained.pt\", 18, device=\"cuda:0\", max_chunk=1600000)\n",
|
59 |
+
"\n",
|
60 |
+
"config = \"./repcodec/configs/repcodec_dim1024.yaml\"\n",
|
61 |
+
"with open(config) as fp:\n",
|
62 |
+
" conf = yaml.load(fp, Loader=yaml.FullLoader)\n",
|
63 |
+
"\n",
|
64 |
+
"audio_model = RepCodec(**conf)\n",
|
65 |
+
"audio_model.load_state_dict(torch.load(\"./../prompting/models/data2vec_large_l18.pkl\", map_location=\"cuda:0\")[\"model\"][\"repcodec\"])\n",
|
66 |
+
"audio_model.quantizer.initial()\n",
|
67 |
+
"audio_model.to(\"cuda:0\")\n",
|
68 |
+
"audio_model.eval()\n",
|
69 |
+
"\n",
|
70 |
+
"print(\"Successfully Loaded Audio Tokenizer\")"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"cell_type": "code",
|
75 |
+
"execution_count": null,
|
76 |
+
"id": "7d8da397-2030-4b36-9a42-97862488797b",
|
77 |
+
"metadata": {},
|
78 |
+
"outputs": [],
|
79 |
+
"source": [
|
80 |
+
"from datasets import load_dataset\n",
|
81 |
+
"\n",
|
82 |
+
"cache_dir = \"./../cache\"\n",
|
83 |
+
"dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir, trust_remote_code=True)"
|
84 |
+
]
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"cell_type": "code",
|
88 |
+
"execution_count": 1,
|
89 |
+
"id": "bb8016b2-fc9d-4c23-9e85-b6e1c5ca164c",
|
90 |
+
"metadata": {},
|
91 |
+
"outputs": [],
|
92 |
+
"source": [
|
93 |
+
"from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer\n",
|
94 |
+
"import torch\n",
|
95 |
+
"import string\n",
|
96 |
+
"\n",
|
97 |
+
"def process(text):\n",
|
98 |
+
"\n",
|
99 |
+
" # Lower case every letter\n",
|
100 |
+
" text = text.lower()\n",
|
101 |
+
"\n",
|
102 |
+
" # Remove punctuation\n",
|
103 |
+
" punctuation_to_remove = string.punctuation.replace(\"'\", \"\")\n",
|
104 |
+
" translation_table = str.maketrans('', '', punctuation_to_remove)\n",
|
105 |
+
" text = text.translate(translation_table)\n",
|
106 |
+
"\n",
|
107 |
+
" # Remove whitespaces from front and behind\n",
|
108 |
+
" while text[0] == ' ' or text[-1] == ' ':\n",
|
109 |
+
" if text[0] == ' ':\n",
|
110 |
+
" text = text[1:]\n",
|
111 |
+
" if text[-1] == ' ':\n",
|
112 |
+
" text = text[:-1]\n",
|
113 |
+
" \n",
|
114 |
+
" return text\n",
|
115 |
+
"\n",
|
116 |
+
"device = \"cuda:0\"\n",
|
117 |
+
"dtype = torch.float16\n",
|
118 |
+
"context_length = 1877\n",
|
119 |
+
"\n",
|
120 |
+
"# Load tokenizer and add audio tokens\n",
|
121 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"./tokenizer\")\n",
|
122 |
+
"eot_token = tokenizer.encode(\"<|endoftranscript|>\")[0]\n",
|
123 |
+
"pad_token = tokenizer.encode(\"<|padding|>\")[0]\n",
|
124 |
+
"\n",
|
125 |
+
"model = GPT2LMHeadModel.from_pretrained(\"./../out/checkpoint-19000\", attn_implementation=\"flash_attention_2\", device_map=device, torch_dtype=dtype).eval()\n",
|
126 |
+
"model.config.pad_token_id = pad_token\n",
|
127 |
+
"model.config.eos_token_id = eot_token\n",
|
128 |
+
"# model = torch.compile(model)"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "code",
|
133 |
+
"execution_count": 3,
|
134 |
+
"id": "693db182-92ac-4e36-b848-989fafd10e73",
|
135 |
+
"metadata": {},
|
136 |
+
"outputs": [
|
137 |
+
{
|
138 |
+
"data": {
|
139 |
+
"text/plain": [
|
140 |
+
"GPT2Model(\n",
|
141 |
+
" (wte): Embedding(6027, 768)\n",
|
142 |
+
" (wpe): Embedding(1877, 768)\n",
|
143 |
+
" (drop): Dropout(p=0.1, inplace=False)\n",
|
144 |
+
" (h): ModuleList(\n",
|
145 |
+
" (0-11): 12 x GPT2Block(\n",
|
146 |
+
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
147 |
+
" (attn): GPT2FlashAttention2(\n",
|
148 |
+
" (c_attn): Conv1D()\n",
|
149 |
+
" (c_proj): Conv1D()\n",
|
150 |
+
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
|
151 |
+
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
|
152 |
+
" )\n",
|
153 |
+
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
154 |
+
" (mlp): GPT2MLP(\n",
|
155 |
+
" (c_fc): Conv1D()\n",
|
156 |
+
" (c_proj): Conv1D()\n",
|
157 |
+
" (act): NewGELUActivation()\n",
|
158 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
159 |
+
" )\n",
|
160 |
+
" )\n",
|
161 |
+
" )\n",
|
162 |
+
" (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
163 |
+
")"
|
164 |
+
]
|
165 |
+
},
|
166 |
+
"execution_count": 3,
|
167 |
+
"metadata": {},
|
168 |
+
"output_type": "execute_result"
|
169 |
+
}
|
170 |
+
],
|
171 |
+
"source": [
|
172 |
+
"model.transformer"
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "code",
|
177 |
+
"execution_count": null,
|
178 |
+
"id": "7cabe9dc-bbbf-41b4-918f-3f60ee5582f2",
|
179 |
+
"metadata": {},
|
180 |
+
"outputs": [],
|
181 |
+
"source": [
|
182 |
+
"from tqdm import tqdm\n",
|
183 |
+
"from math import ceil\n",
|
184 |
+
"import torch\n",
|
185 |
+
"import time\n",
|
186 |
+
"\n",
|
187 |
+
"sample = dataset[\"train.clean.100\"][5]\n",
|
188 |
+
"\n",
|
189 |
+
"x = sample[\"audio\"][\"array\"]\n",
|
190 |
+
"\n",
|
191 |
+
"start_time = time.time()\n",
|
192 |
+
"\n",
|
193 |
+
"with torch.no_grad():\n",
|
194 |
+
" x = torch.from_numpy(x).float().to(reader.device)\n",
|
195 |
+
" if reader.task.cfg.normalize:\n",
|
196 |
+
" x = F.layer_norm(x, x.shape)\n",
|
197 |
+
" x = x.view(1, -1)\n",
|
198 |
+
"\n",
|
199 |
+
" feat = []\n",
|
200 |
+
" for start in range(0, x.size(1), reader.max_chunk):\n",
|
201 |
+
" x_chunk = x[:, start: start + reader.max_chunk]\n",
|
202 |
+
" res = reader.model.extract_features(\n",
|
203 |
+
" source=x_chunk,\n",
|
204 |
+
" padding_mask=None,\n",
|
205 |
+
" mask=False,\n",
|
206 |
+
" layer=reader.layer,\n",
|
207 |
+
" )\n",
|
208 |
+
" feat_chunk = res[\"x\"]\n",
|
209 |
+
" feat.append(feat_chunk)\n",
|
210 |
+
" \n",
|
211 |
+
" features = torch.cat(feat, 1).permute(0, 2, 1)\n",
|
212 |
+
"\n",
|
213 |
+
" x = audio_model.encoder(features)\n",
|
214 |
+
" z = audio_model.projector(x)\n",
|
215 |
+
" _, idx = audio_model.quantizer.codebook.forward_index(z.transpose(2, 1))\n",
|
216 |
+
" tokens = idx.cpu().data.numpy().tolist()[0]\n",
|
217 |
+
" \n",
|
218 |
+
"text = \"\".join([f\"<|audio:{token}|>\" for token in tokens]) + \"<|startoftranscript|>\"\n",
|
219 |
+
"input_ids = tokenizer(text, return_tensors=\"pt\").to(device)[\"input_ids\"]\n",
|
220 |
+
"\n",
|
221 |
+
"input_time = time.time()\n",
|
222 |
+
"\n",
|
223 |
+
"generations = model.generate(\n",
|
224 |
+
" input_ids,\n",
|
225 |
+
" pad_token_id = pad_token,\n",
|
226 |
+
" eos_token_id = eot_token,\n",
|
227 |
+
" max_new_tokens = context_length,\n",
|
228 |
+
" use_cache=True\n",
|
229 |
+
")\n",
|
230 |
+
"\n",
|
231 |
+
"finish_time = time.time()\n",
|
232 |
+
"\n",
|
233 |
+
"tokenizer.batch_decode(generations, skip_special_tokens=True)\n",
|
234 |
+
"print(\"First Token Latency: \", (input_time - start_time) * 1000, \"ms\")\n",
|
235 |
+
"# print(\"Throughput: \", (1 + num_tokens)/total_time, \"tokens/s\")\n",
|
236 |
+
"print(\"End to End Inference Time: \", (finish_time - start_time) * 1000, \"ms\")\n",
|
237 |
+
"print(\"Refer Text: \", process(sample[\"text\"]))\n",
|
238 |
+
"print(\"Transcript: \", tokenizer.batch_decode(generations, skip_special_tokens=True)[0])"
|
239 |
+
]
|
240 |
+
},
|
241 |
+
{
|
242 |
+
"cell_type": "code",
|
243 |
+
"execution_count": null,
|
244 |
+
"id": "baa8d79b-7cf5-4435-838c-1f3d4e043d60",
|
245 |
+
"metadata": {},
|
246 |
+
"outputs": [],
|
247 |
+
"source": [
|
248 |
+
"import time\n",
|
249 |
+
"\n",
|
250 |
+
"sample = dataset[\"train.clean.100\"][0]\n",
|
251 |
+
"\n",
|
252 |
+
"x = sample[\"audio\"][\"array\"]\n",
|
253 |
+
"\n",
|
254 |
+
"start_time = time.time()\n",
|
255 |
+
"\n",
|
256 |
+
"with torch.no_grad():\n",
|
257 |
+
" x = torch.from_numpy(x).float().to(reader.device)\n",
|
258 |
+
" if reader.task.cfg.normalize:\n",
|
259 |
+
" x = F.layer_norm(x, x.shape)\n",
|
260 |
+
" x = x.view(1, -1)\n",
|
261 |
+
"\n",
|
262 |
+
" feat = []\n",
|
263 |
+
" for start in range(0, x.size(1), reader.max_chunk):\n",
|
264 |
+
" x_chunk = x[:, start: start + reader.max_chunk]\n",
|
265 |
+
" res = reader.model.extract_features(\n",
|
266 |
+
" source=x_chunk,\n",
|
267 |
+
" padding_mask=None,\n",
|
268 |
+
" mask=False,\n",
|
269 |
+
" layer=reader.layer,\n",
|
270 |
+
" )\n",
|
271 |
+
" feat_chunk = res[\"x\"]\n",
|
272 |
+
" feat.append(feat_chunk)\n",
|
273 |
+
" \n",
|
274 |
+
" features = torch.cat(feat, 1).permute(0, 2, 1)\n",
|
275 |
+
"\n",
|
276 |
+
" x = audio_model.encoder(features)\n",
|
277 |
+
" z = audio_model.projector(x)\n",
|
278 |
+
" _, idx = audio_model.quantizer.codebook.forward_index(z.transpose(2, 1))\n",
|
279 |
+
" tokens = idx.cpu().data.numpy().tolist()[0]\n",
|
280 |
+
"\n",
|
281 |
+
"from tqdm import tqdm\n",
|
282 |
+
"from math import ceil\n",
|
283 |
+
"import torch\n",
|
284 |
+
"\n",
|
285 |
+
"context_length = 1877\n",
|
286 |
+
"eot_token = tokenizer.encode(\"<|endoftranscript|>\")[0]\n",
|
287 |
+
"pad_token = tokenizer.encode(\"<|padding|>\")[0]\n",
|
288 |
+
" \n",
|
289 |
+
"text = \"\".join([f\"<|audio:{token}|>\" for token in tokens]) + \"<|startoftranscript|>\"\n",
|
290 |
+
"input_ids = tokenizer(text, return_tensors=\"pt\").to(device)[\"input_ids\"]\n",
|
291 |
+
"\n",
|
292 |
+
"max_new_tokens = context_length\n",
|
293 |
+
"num_tokens = 0\n",
|
294 |
+
"first_token = True\n",
|
295 |
+
"\n",
|
296 |
+
"while max_new_tokens > 0 and input_ids.shape[-1] < context_length:\n",
|
297 |
+
"\n",
|
298 |
+
" with torch.no_grad():\n",
|
299 |
+
" outputs = model(input_ids = input_ids)\n",
|
300 |
+
"\n",
|
301 |
+
" logits = outputs[\"logits\"][:, -1]\n",
|
302 |
+
"\n",
|
303 |
+
" # Greedy Sampling\n",
|
304 |
+
" probas = torch.softmax(logits, dim=-1)\n",
|
305 |
+
" pred_idx = torch.argmax(probas, dim=-1, keepdim=True)\n",
|
306 |
+
" next_idx = pred_idx.item()\n",
|
307 |
+
"\n",
|
308 |
+
" if first_token:\n",
|
309 |
+
" first_token_latency = time.time() - start_time\n",
|
310 |
+
" first_token = False\n",
|
311 |
+
" start_time = time.time()\n",
|
312 |
+
"\n",
|
313 |
+
" if next_idx == eot_token:\n",
|
314 |
+
" break\n",
|
315 |
+
"\n",
|
316 |
+
" input_ids = torch.cat((input_ids, pred_idx), dim=-1)\n",
|
317 |
+
"\n",
|
318 |
+
" max_new_tokens -= 1\n",
|
319 |
+
" num_tokens += 1\n",
|
320 |
+
"\n",
|
321 |
+
"total_time = time.time() - start_time\n",
|
322 |
+
"\n",
|
323 |
+
"print(\"First Token Latency: \", first_token_latency * 1000, \"ms\")\n",
|
324 |
+
"print(\"Throughput: \", (1 + num_tokens)/total_time, \"tokens/s\")\n",
|
325 |
+
"print(\"End to End Inference Time: \", (total_time + first_token_latency) * 1000, \"ms\")\n",
|
326 |
+
"print(tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0])\n",
|
327 |
+
"print(process(sample[\"text\"]))"
|
328 |
+
]
|
329 |
+
},
|
330 |
+
{
|
331 |
+
"cell_type": "code",
|
332 |
+
"execution_count": null,
|
333 |
+
"id": "014ed999-3293-4d68-8f9c-017584adc642",
|
334 |
+
"metadata": {},
|
335 |
+
"outputs": [],
|
336 |
+
"source": [
|
337 |
+
"tokenizer.batch_decode([[1, 2, 3]])"
|
338 |
+
]
|
339 |
+
},
|
340 |
+
{
|
341 |
+
"cell_type": "markdown",
|
342 |
+
"id": "ec11e43f-1eb8-4399-9a93-6f1427782661",
|
343 |
+
"metadata": {
|
344 |
+
"jp-MarkdownHeadingCollapsed": true
|
345 |
+
},
|
346 |
+
"source": [
|
347 |
+
"## Accelerating GPT 2 Inference"
|
348 |
+
]
|
349 |
+
},
|
350 |
+
{
|
351 |
+
"cell_type": "code",
|
352 |
+
"execution_count": null,
|
353 |
+
"id": "5489cb4e-3213-4931-abe1-4c96d1a7ba56",
|
354 |
+
"metadata": {},
|
355 |
+
"outputs": [],
|
356 |
+
"source": [
|
357 |
+
"\"\"\"\n",
|
358 |
+
"- change tensorrt.tensorrt to tensorrt\n",
|
359 |
+
"- remove cpu quantization lines\n",
|
360 |
+
"- output_names [\"logits\"]\n",
|
361 |
+
"\"\"\""
|
362 |
+
]
|
363 |
+
},
|
364 |
+
{
|
365 |
+
"cell_type": "code",
|
366 |
+
"execution_count": null,
|
367 |
+
"id": "7e7e6ea6-7319-4e57-af33-5d917d26abc6",
|
368 |
+
"metadata": {},
|
369 |
+
"outputs": [],
|
370 |
+
"source": [
|
371 |
+
"import logging\n",
|
372 |
+
"import time\n",
|
373 |
+
"from typing import Callable, Dict\n",
|
374 |
+
"\n",
|
375 |
+
"import numpy as np\n",
|
376 |
+
"import tensorrt as trt\n",
|
377 |
+
"import torch\n",
|
378 |
+
"from tensorrt import ICudaEngine\n",
|
379 |
+
"from tensorrt import Logger, Runtime\n",
|
380 |
+
"from transformers import AutoTokenizer, BatchEncoding, GPT2LMHeadModel, AutoModelForCausalLM\n",
|
381 |
+
"from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions\n",
|
382 |
+
"from transformer_deploy.utils.generative_model import GPTModelWrapper\n",
|
383 |
+
"import inspect\n",
|
384 |
+
"from transformers import TensorType\n",
|
385 |
+
"\n",
|
386 |
+
"from transformer_deploy.backends.ort_utils import create_model_for_provider, inference_onnx_binding, optimize_onnx\n",
|
387 |
+
"from transformer_deploy.backends.pytorch_utils import convert_to_onnx, get_model_size\n",
|
388 |
+
"from transformer_deploy.backends.trt_utils import build_engine, load_engine, save_engine"
|
389 |
+
]
|
390 |
+
},
|
391 |
+
{
|
392 |
+
"cell_type": "code",
|
393 |
+
"execution_count": null,
|
394 |
+
"id": "21681412-7747-4824-894a-6006eb12a821",
|
395 |
+
"metadata": {},
|
396 |
+
"outputs": [],
|
397 |
+
"source": [
|
398 |
+
"model_name = \"gpt2\"\n",
|
399 |
+
"\n",
|
400 |
+
"model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(model_name)\n",
|
401 |
+
"model.eval()\n",
|
402 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
403 |
+
"model.config.pad_token_id = tokenizer.eos_token_id"
|
404 |
+
]
|
405 |
+
},
|
406 |
+
{
|
407 |
+
"cell_type": "code",
|
408 |
+
"execution_count": null,
|
409 |
+
"id": "46783acd-c404-44b4-904b-d8fb687afc34",
|
410 |
+
"metadata": {},
|
411 |
+
"outputs": [],
|
412 |
+
"source": [
|
413 |
+
"inputs = tokenizer(\"Here is some text to encode Hello World\", return_tensors=\"pt\")\n",
|
414 |
+
"print(\"input tensors\")\n",
|
415 |
+
"print(inputs)\n",
|
416 |
+
"print(\"input tensor shape\")\n",
|
417 |
+
"print(inputs[\"input_ids\"].size())\n",
|
418 |
+
"\n",
|
419 |
+
"with torch.no_grad():\n",
|
420 |
+
" outputs = model(**inputs)\n",
|
421 |
+
"\n",
|
422 |
+
"logits = outputs.logits\n",
|
423 |
+
"print(\"output tensor\")\n",
|
424 |
+
"print(logits)\n",
|
425 |
+
"print(\"output shape\")\n",
|
426 |
+
"print(logits.shape)"
|
427 |
+
]
|
428 |
+
},
|
429 |
+
{
|
430 |
+
"cell_type": "code",
|
431 |
+
"execution_count": null,
|
432 |
+
"id": "2f6cc7bd-5e2d-4d4e-a7e6-73a6b2ecd7af",
|
433 |
+
"metadata": {},
|
434 |
+
"outputs": [],
|
435 |
+
"source": [
|
436 |
+
"size = 0\n",
|
437 |
+
"for i in range(8, 256, 1):\n",
|
438 |
+
" # input sequence (input_ids) made of int-32 (4 bytes)\n",
|
439 |
+
" size += np.prod([1, i]) * 4\n",
|
440 |
+
" # output tensor made of float-32 (4 bytes)\n",
|
441 |
+
" size += np.prod([1, i, 50257]) * 4\n",
|
442 |
+
"print(f\"total size (input+output): {size / 1024**3:.2f} Gb\")\n",
|
443 |
+
"\n",
|
444 |
+
"# to manually check actual tensor size:\n",
|
445 |
+
"# np.prod(logits.shape)*32/8/1024**2:.2f}\n",
|
446 |
+
"# or\n",
|
447 |
+
"# sys.getsizeof(logits.storage())/1024**2"
|
448 |
+
]
|
449 |
+
},
|
450 |
+
{
|
451 |
+
"cell_type": "code",
|
452 |
+
"execution_count": null,
|
453 |
+
"id": "7debb40e-9941-45e4-9db8-4bb021ce44ab",
|
454 |
+
"metadata": {},
|
455 |
+
"outputs": [],
|
456 |
+
"source": [
|
457 |
+
"input_ids: BatchEncoding = tokenizer(\n",
|
458 |
+
" \"Here is some text to encode Hello World\", add_special_tokens=True, return_attention_mask=False, return_tensors=\"pt\"\n",
|
459 |
+
")\n",
|
460 |
+
"# some inference engines don't support int64 tensor as inputs, we convert all input tensors to int32 type\n",
|
461 |
+
"for k, v in input_ids.items(): # type: str, torch.Tensor\n",
|
462 |
+
" input_ids[k] = v.type(dtype=torch.int32)\n",
|
463 |
+
"\n",
|
464 |
+
"convert_to_onnx(\n",
|
465 |
+
" model_pytorch=model,\n",
|
466 |
+
" output_path=\"test-gpt2.onnx\",\n",
|
467 |
+
" inputs_pytorch=dict(input_ids),\n",
|
468 |
+
" quantization=False,\n",
|
469 |
+
" var_output_seq=True, # we inform ONNX export tool that the output shape will vary with the input shape\n",
|
470 |
+
" output_names = [\"logits\"]\n",
|
471 |
+
")\n",
|
472 |
+
"# model may switch to train mode for some unknown reasons, we force the eval mode.\n",
|
473 |
+
"_ = model.eval()"
|
474 |
+
]
|
475 |
+
},
|
476 |
+
{
|
477 |
+
"cell_type": "code",
|
478 |
+
"execution_count": null,
|
479 |
+
"id": "956c3007-2c18-4d92-af4f-6cef474d86b5",
|
480 |
+
"metadata": {},
|
481 |
+
"outputs": [],
|
482 |
+
"source": [
|
483 |
+
"logging.basicConfig()\n",
|
484 |
+
"logging.getLogger().setLevel(logging.INFO)\n",
|
485 |
+
"num_attention_heads, hidden_size = get_model_size(path=model_name)\n",
|
486 |
+
"optimize_onnx(\n",
|
487 |
+
" onnx_path=\"test-gpt2.onnx\",\n",
|
488 |
+
" onnx_optim_model_path=\"test-gpt2-opt.onnx\",\n",
|
489 |
+
" fp16=False,\n",
|
490 |
+
" use_cuda=True,\n",
|
491 |
+
" num_attention_heads=num_attention_heads,\n",
|
492 |
+
" hidden_size=hidden_size,\n",
|
493 |
+
" architecture=\"gpt2\",\n",
|
494 |
+
")"
|
495 |
+
]
|
496 |
+
},
|
497 |
+
{
|
498 |
+
"cell_type": "code",
|
499 |
+
"execution_count": null,
|
500 |
+
"id": "85f30ed9-2802-46c9-9201-a70e200b6860",
|
501 |
+
"metadata": {},
|
502 |
+
"outputs": [],
|
503 |
+
"source": [
|
504 |
+
"from pathlib import Path\n",
|
505 |
+
"\n",
|
506 |
+
"trt_logger: Logger = trt.Logger(trt.Logger.ERROR)\n",
|
507 |
+
"runtime: Runtime = trt.Runtime(trt_logger)\n",
|
508 |
+
"trt_model_name = \"test-gpt2.plan\"\n",
|
509 |
+
"\n",
|
510 |
+
"# create only of does not exist because it's slow to run...\n",
|
511 |
+
"\n",
|
512 |
+
"engine: ICudaEngine = build_engine(\n",
|
513 |
+
" runtime=runtime,\n",
|
514 |
+
" onnx_file_path=\"test-gpt2.onnx\",\n",
|
515 |
+
" logger=trt_logger,\n",
|
516 |
+
" min_shape=(1, 1),\n",
|
517 |
+
" optimal_shape=(1, 128), # num beam, batch size\n",
|
518 |
+
" max_shape=(1, 384), # num beam, batch size\n",
|
519 |
+
" workspace_size=10000 * 1024**2,\n",
|
520 |
+
" fp16=True,\n",
|
521 |
+
" int8=False,\n",
|
522 |
+
")\n",
|
523 |
+
"save_engine(engine, trt_model_name)"
|
524 |
+
]
|
525 |
+
},
|
526 |
+
{
|
527 |
+
"cell_type": "code",
|
528 |
+
"execution_count": null,
|
529 |
+
"id": "908fe664-800e-4c5f-a1d5-adfd31fd1c64",
|
530 |
+
"metadata": {},
|
531 |
+
"outputs": [],
|
532 |
+
"source": [
|
533 |
+
"engine.num_bindings"
|
534 |
+
]
|
535 |
+
},
|
536 |
+
{
|
537 |
+
"cell_type": "code",
|
538 |
+
"execution_count": null,
|
539 |
+
"id": "4626926b-fa94-4633-95d5-0d515f8db5f6",
|
540 |
+
"metadata": {},
|
541 |
+
"outputs": [],
|
542 |
+
"source": [
|
543 |
+
"print(inspect.getsource(GPTModelWrapper))"
|
544 |
+
]
|
545 |
+
},
|
546 |
+
{
|
547 |
+
"cell_type": "code",
|
548 |
+
"execution_count": null,
|
549 |
+
"id": "d5bd1de1-a949-46a3-8d15-457d51db4e40",
|
550 |
+
"metadata": {},
|
551 |
+
"outputs": [],
|
552 |
+
"source": [
|
553 |
+
"inputs = tokenizer(\n",
|
554 |
+
" \"Here is some text to encode Hello World\", # Nvidia example prompt\n",
|
555 |
+
" add_special_tokens=True,\n",
|
556 |
+
" return_attention_mask=False, # Not used\n",
|
557 |
+
" return_tensors=TensorType.PYTORCH,\n",
|
558 |
+
")\n",
|
559 |
+
"inputs"
|
560 |
+
]
|
561 |
+
},
|
562 |
+
{
|
563 |
+
"cell_type": "code",
|
564 |
+
"execution_count": null,
|
565 |
+
"id": "815b548f-fa00-4183-b72c-10ecdd4b11c7",
|
566 |
+
"metadata": {},
|
567 |
+
"outputs": [],
|
568 |
+
"source": [
|
569 |
+
"from transformers.generation import GenerationConfig\n",
|
570 |
+
"\n",
|
571 |
+
"class GPTWrapper(GPTModelWrapper):\n",
|
572 |
+
" def __init__(self, *args, **kwargs):\n",
|
573 |
+
" super().__init__(*args, **kwargs)\n",
|
574 |
+
"\n",
|
575 |
+
" self.generation_config = GenerationConfig.from_model_config(self.config) if self.can_generate() else None\n",
|
576 |
+
"\n",
|
577 |
+
" @classmethod\n",
|
578 |
+
" def can_generate(cls) -> bool:\n",
|
579 |
+
" \"\"\"\n",
|
580 |
+
" Returns whether this model can generate sequences with `.generate()`.\n",
|
581 |
+
"\n",
|
582 |
+
" Returns:\n",
|
583 |
+
" `bool`: Whether this model can generate sequences with `.generate()`.\n",
|
584 |
+
" \"\"\"\n",
|
585 |
+
" # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.\n",
|
586 |
+
" # Alternativelly, the model can also have a custom `generate` function.\n",
|
587 |
+
" if \"GenerationMixin\" in str(cls.prepare_inputs_for_generation) and \"GenerationMixin\" in str(cls.generate):\n",
|
588 |
+
" return False\n",
|
589 |
+
" return True"
|
590 |
+
]
|
591 |
+
},
|
592 |
+
{
|
593 |
+
"cell_type": "code",
|
594 |
+
"execution_count": null,
|
595 |
+
"id": "ca57ed1e-0bbe-48dd-ae0f-f3d8ecd7fd04",
|
596 |
+
"metadata": {},
|
597 |
+
"outputs": [],
|
598 |
+
"source": [
|
599 |
+
"def inference_torch(input_ids: torch.Tensor) -> torch.Tensor:\n",
|
600 |
+
" transformer_outputs: BaseModelOutputWithPastAndCrossAttentions = model.transformer(input_ids=input_ids)\n",
|
601 |
+
" return model.lm_head(transformer_outputs.last_hidden_state)\n",
|
602 |
+
"\n",
|
603 |
+
"\n",
|
604 |
+
"model.cuda()\n",
|
605 |
+
"model.eval()\n",
|
606 |
+
"inputs.to(\"cuda\")\n",
|
607 |
+
"with torch.inference_mode():\n",
|
608 |
+
" gpt2_model = GPTWrapper(config=model.config, device=model.device, inference=inference_torch)\n",
|
609 |
+
" sample_output = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
|
610 |
+
" print(tokenizer.decode(sample_output[0], skip_special_tokens=False))\n",
|
611 |
+
" for _ in range(2):\n",
|
612 |
+
" _ = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
|
613 |
+
" torch.cuda.synchronize()\n",
|
614 |
+
" start = time.time()\n",
|
615 |
+
" for _ in range(10):\n",
|
616 |
+
" _ = gpt2_model.generate(inputs.input_ids, max_length=256)\n",
|
617 |
+
" torch.cuda.synchronize()\n",
|
618 |
+
" print(f\"----\\nPytorch: {(time.time() - start)/10:.2f}s/sequence\")\n",
|
619 |
+
"_ = model.cpu()"
|
620 |
+
]
|
621 |
+
},
|
622 |
+
{
|
623 |
+
"cell_type": "code",
|
624 |
+
"execution_count": null,
|
625 |
+
"id": "f0849aae-876e-47bc-b045-14a594170947",
|
626 |
+
"metadata": {},
|
627 |
+
"outputs": [],
|
628 |
+
"source": [
|
629 |
+
"model_onnx = create_model_for_provider(path=\"test-gpt2-opt.onnx\", provider_to_use=\"CUDAExecutionProvider\")\n",
|
630 |
+
"\n",
|
631 |
+
"\n",
|
632 |
+
"def inference_onnx_naive(input_ids: torch.Tensor) -> torch.Tensor:\n",
|
633 |
+
" data = {\"input_ids\": input_ids.detach().cpu().numpy().astype(np.int32)}\n",
|
634 |
+
" logit = model_onnx.run(None, data)\n",
|
635 |
+
" np_logit = np.array(logit) # convert list of numpy arrays to a numpy array\n",
|
636 |
+
" # we convert numpy tensor to Pytorch tensor as it's the type expected by HF decoding algorithm\n",
|
637 |
+
" return torch.squeeze(torch.from_numpy(np_logit), dim=0)\n",
|
638 |
+
"\n",
|
639 |
+
"\n",
|
640 |
+
"gpt2_model = GPTWrapper(config=model.config, device=torch.device(\"cpu\"), inference=inference_onnx_naive)\n",
|
641 |
+
"inputs.to(\"cpu\")\n",
|
642 |
+
"sample_output = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
|
643 |
+
"print(tokenizer.decode(sample_output[0], skip_special_tokens=True))\n",
|
644 |
+
"for _ in range(2):\n",
|
645 |
+
" _ = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
|
646 |
+
"start = time.time()\n",
|
647 |
+
"for _ in range(10):\n",
|
648 |
+
" _ = gpt2_model.generate(inputs.input_ids, max_length=256)\n",
|
649 |
+
"print(f\"----\\nONNX Runtime (standard API): {(time.time() - start)/10:.2f}s/sequence\")\n",
|
650 |
+
"\n",
|
651 |
+
"del model_onnx"
|
652 |
+
]
|
653 |
+
},
|
654 |
+
{
|
655 |
+
"cell_type": "code",
|
656 |
+
"execution_count": null,
|
657 |
+
"id": "96114897-894b-4997-bc61-8ac0682e0e55",
|
658 |
+
"metadata": {},
|
659 |
+
"outputs": [],
|
660 |
+
"source": [
|
661 |
+
"model_onnx = create_model_for_provider(path=\"test-gpt2-opt.onnx\", provider_to_use=\"CUDAExecutionProvider\")\n",
|
662 |
+
"\n",
|
663 |
+
"\n",
|
664 |
+
"def inference_onnx_optimized(input_ids: torch.Tensor) -> torch.Tensor:\n",
|
665 |
+
" data = {\"input_ids\": input_ids}\n",
|
666 |
+
" return inference_onnx_binding(model_onnx=model_onnx, inputs=data, device=\"cuda\")[\"output\"]\n",
|
667 |
+
"\n",
|
668 |
+
"\n",
|
669 |
+
"gpt2_model = GPTWrapper(config=model.config, device=torch.device(\"cuda\"), inference=inference_onnx_optimized)\n",
|
670 |
+
"inputs.to(\"cuda\")\n",
|
671 |
+
"sample_output = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
|
672 |
+
"print(tokenizer.decode(sample_output[0], skip_special_tokens=True))\n",
|
673 |
+
"for _ in range(2):\n",
|
674 |
+
" _ = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
|
675 |
+
"start = time.time()\n",
|
676 |
+
"for _ in range(10):\n",
|
677 |
+
" _ = gpt2_model.generate(inputs.input_ids, max_length=256)\n",
|
678 |
+
"print(f\"----\\nONNX Runtime (binding io API): {(time.time() - start)/10:.2f}/sequence\")\n",
|
679 |
+
"del model_onnx"
|
680 |
+
]
|
681 |
+
},
|
682 |
+
{
|
683 |
+
"cell_type": "code",
|
684 |
+
"execution_count": null,
|
685 |
+
"id": "0b5b5427-fd6b-4f70-b307-9c579f0f842a",
|
686 |
+
"metadata": {},
|
687 |
+
"outputs": [],
|
688 |
+
"source": [
|
689 |
+
"tensorrt_model: Callable[[Dict[str, torch.Tensor]], torch.Tensor] = load_engine(\n",
|
690 |
+
" engine_file_path=\"test-gpt2.plan\", runtime=runtime\n",
|
691 |
+
")\n",
|
692 |
+
"\n",
|
693 |
+
"\n",
|
694 |
+
"def inference_tensorrt(input_ids: torch.Tensor) -> torch.Tensor:\n",
|
695 |
+
" data = {\"input_ids\": input_ids}\n",
|
696 |
+
" return tensorrt_model(data)\n",
|
697 |
+
"\n",
|
698 |
+
"\n",
|
699 |
+
"gpt2_model = GPTWrapper(config=model.config, device=torch.device(\"cuda\"), inference=inference_tensorrt)\n",
|
700 |
+
"inputs.to(\"cuda\")\n",
|
701 |
+
"sample_output = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
|
702 |
+
"print(tokenizer.decode(sample_output[0], skip_special_tokens=True))\n",
|
703 |
+
"for _ in range(2):\n",
|
704 |
+
" _ = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
|
705 |
+
"start = time.time()\n",
|
706 |
+
"for _ in range(10):\n",
|
707 |
+
" _ = gpt2_model.generate(inputs.input_ids, max_length=256)\n",
|
708 |
+
"print(f\"----\\nTensorRT + CUDA tensors: {(time.time() - start)/10:.2f}/sequence\")\n",
|
709 |
+
"\n",
|
710 |
+
"del tensorrt_model"
|
711 |
+
]
|
712 |
+
},
|
713 |
+
{
|
714 |
+
"cell_type": "markdown",
|
715 |
+
"id": "f547239d-4f7a-433b-8ef6-9e5110a61f4b",
|
716 |
+
"metadata": {
|
717 |
+
"jp-MarkdownHeadingCollapsed": true
|
718 |
+
},
|
719 |
+
"source": [
|
720 |
+
"## Using CUDAExecution Provider"
|
721 |
+
]
|
722 |
+
},
|
723 |
+
{
|
724 |
+
"cell_type": "code",
|
725 |
+
"execution_count": null,
|
726 |
+
"id": "6e34c682-85fc-4e8d-b13c-7c1c9ea39ead",
|
727 |
+
"metadata": {},
|
728 |
+
"outputs": [],
|
729 |
+
"source": [
|
730 |
+
"from optimum.onnxruntime import ORTModelForCausalLM\n",
|
731 |
+
"from optimum.pipelines import pipeline\n",
|
732 |
+
"from transformers import AutoTokenizer\n",
|
733 |
+
"\n",
|
734 |
+
"model_id = \"openai-community/gpt2\"\n",
|
735 |
+
"\n",
|
736 |
+
"ort_model = ORTModelForCausalLM.from_pretrained(\n",
|
737 |
+
" model_id,\n",
|
738 |
+
" export=True,\n",
|
739 |
+
" provider=\"CUDAExecutionProvider\",\n",
|
740 |
+
" use_io_binding=True\n",
|
741 |
+
")\n",
|
742 |
+
"\n",
|
743 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
|
744 |
+
"tokenizer.pad_token = tokenizer.eos_token\n",
|
745 |
+
"\n",
|
746 |
+
"pipe = pipeline(task=\"text-generation\", model=ort_model, tokenizer=tokenizer, device=\"cuda:0\")"
|
747 |
+
]
|
748 |
+
},
|
749 |
+
{
|
750 |
+
"cell_type": "code",
|
751 |
+
"execution_count": null,
|
752 |
+
"id": "17d28184-26db-4dd3-b24b-0c5a12b10d6d",
|
753 |
+
"metadata": {},
|
754 |
+
"outputs": [],
|
755 |
+
"source": [
|
756 |
+
"import time\n",
|
757 |
+
"\n",
|
758 |
+
"start_time = time.time()\n",
|
759 |
+
"\n",
|
760 |
+
"generations = pipe(\"Both the music and visual were astounding, not to mention the actors performance.\")\n",
|
761 |
+
"generations[0][\"generated_text\"]\n",
|
762 |
+
"\n",
|
763 |
+
"finish_time = time.time()\n",
|
764 |
+
"\n",
|
765 |
+
"print(\"End to End Latency: \", (finish_time - start_time) * 1000, \"ms\")"
|
766 |
+
]
|
767 |
+
},
|
768 |
+
{
|
769 |
+
"cell_type": "markdown",
|
770 |
+
"id": "19c4230a-3244-4dce-b5ef-d9927dec5c45",
|
771 |
+
"metadata": {},
|
772 |
+
"source": [
|
773 |
+
"## ASR LM with CUDAExcecution Provider"
|
774 |
+
]
|
775 |
+
},
|
776 |
+
{
|
777 |
+
"cell_type": "code",
|
778 |
+
"execution_count": null,
|
779 |
+
"id": "0f0f1cdc-bfcd-46c5-80a4-60bc76366cf5",
|
780 |
+
"metadata": {},
|
781 |
+
"outputs": [],
|
782 |
+
"source": [
|
783 |
+
"from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer\n",
|
784 |
+
"from datasets import DatasetDict\n",
|
785 |
+
"import torch\n",
|
786 |
+
"\n",
|
787 |
+
"device = \"cuda:0\"\n",
|
788 |
+
"dtype = torch.float16\n",
|
789 |
+
"\n",
|
790 |
+
"dataset = DatasetDict.load_from_disk(\"./../librispeech_tokenized.hf\")\n",
|
791 |
+
"\n",
|
792 |
+
"from optimum.onnxruntime import ORTModelForCausalLM\n",
|
793 |
+
"from optimum.pipelines import pipeline\n",
|
794 |
+
"from transformers import AutoTokenizer\n",
|
795 |
+
"\n",
|
796 |
+
"model_id = \"./../out/checkpoint-10000\"\n",
|
797 |
+
"\n",
|
798 |
+
"ort_model = ORTModelForCausalLM.from_pretrained(\n",
|
799 |
+
" model_id,\n",
|
800 |
+
" export=True,\n",
|
801 |
+
" provider=\"CUDAExecutionProvider\",\n",
|
802 |
+
" use_io_binding=True\n",
|
803 |
+
")\n",
|
804 |
+
"\n",
|
805 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"./tokenizer\")\n",
|
806 |
+
"\n",
|
807 |
+
"pipe = pipeline(task=\"text-generation\", model=ort_model, tokenizer=tokenizer, device=\"cuda:0\")"
|
808 |
+
]
|
809 |
+
},
|
810 |
+
{
|
811 |
+
"cell_type": "code",
|
812 |
+
"execution_count": null,
|
813 |
+
"id": "9d32098c-b0ec-4c36-95ac-775a3a865512",
|
814 |
+
"metadata": {},
|
815 |
+
"outputs": [],
|
816 |
+
"source": [
|
817 |
+
"ort_model.config.eos_token_id = tokenizer.encode(\"<|endoftranscript|>\")[0]\n",
|
818 |
+
"ort_model.config.bos_token_id = tokenizer.encode(\"<|startoftranscript|>\")[0]"
|
819 |
+
]
|
820 |
+
},
|
821 |
+
{
|
822 |
+
"cell_type": "code",
|
823 |
+
"execution_count": null,
|
824 |
+
"id": "1fd0a1fb-9349-4c7a-af03-21e29334f420",
|
825 |
+
"metadata": {},
|
826 |
+
"outputs": [],
|
827 |
+
"source": [
|
828 |
+
"dataset[split][idx].keys()"
|
829 |
+
]
|
830 |
+
},
|
831 |
+
{
|
832 |
+
"cell_type": "code",
|
833 |
+
"execution_count": null,
|
834 |
+
"id": "15d8b989-6460-4555-b6e2-2f9e219d7034",
|
835 |
+
"metadata": {},
|
836 |
+
"outputs": [],
|
837 |
+
"source": [
|
838 |
+
"split = \"train.clean.100\"\n",
|
839 |
+
"idx = 0\n",
|
840 |
+
"\n",
|
841 |
+
"text = \"\".join([ f\"<|audio:{tkn}|>\"for tkn in dataset[split][idx][\"audio_tokens\"]]) + \"<|startoftranscript|>\"\n",
|
842 |
+
"\n",
|
843 |
+
"import time\n",
|
844 |
+
"\n",
|
845 |
+
"start_time = time.time()\n",
|
846 |
+
"\n",
|
847 |
+
"generations = pipe(text, max_new_tokens=10, skip_special_tokens=True)\n",
|
848 |
+
"\n",
|
849 |
+
"finish_time = time.time()\n",
|
850 |
+
"\n",
|
851 |
+
"print(generations[0][\"generated_text\"])\n",
|
852 |
+
"\n",
|
853 |
+
"print(\"End to End Latency: \", (finish_time - start_time) * 1000, \"ms\")"
|
854 |
+
]
|
855 |
+
}
|
856 |
+
],
|
857 |
+
"metadata": {
|
858 |
+
"kernelspec": {
|
859 |
+
"display_name": "Python 3 (ipykernel)",
|
860 |
+
"language": "python",
|
861 |
+
"name": "python3"
|
862 |
+
},
|
863 |
+
"language_info": {
|
864 |
+
"codemirror_mode": {
|
865 |
+
"name": "ipython",
|
866 |
+
"version": 3
|
867 |
+
},
|
868 |
+
"file_extension": ".py",
|
869 |
+
"mimetype": "text/x-python",
|
870 |
+
"name": "python",
|
871 |
+
"nbconvert_exporter": "python",
|
872 |
+
"pygments_lexer": "ipython3",
|
873 |
+
"version": "3.8.10"
|
874 |
+
}
|
875 |
+
},
|
876 |
+
"nbformat": 4,
|
877 |
+
"nbformat_minor": 5
|
878 |
+
}
|
ASR/demo.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request
|
2 |
+
# import speech_recognition as sr
|
3 |
+
|
4 |
+
app = Flask(__name__)
|
5 |
+
# recognizer = sr.Recognizer()
|
6 |
+
|
7 |
+
@app.route("/darshan/microphone", methods=['POST'])
|
8 |
+
def handle_audio():
|
9 |
+
audio_data = request.data
|
10 |
+
print(audio_data)
|
11 |
+
# audio = sr.AudioData(audio_data, sample_rate=44100, sample_width=2) # Adjust sample rate and sample width as needed
|
12 |
+
# try:
|
13 |
+
# text = recognizer.recognize_google(audio)
|
14 |
+
# print(f"Transcription: {text}")
|
15 |
+
# return {'transcription': text}, 200
|
16 |
+
# except sr.UnknownValueError:
|
17 |
+
# print("Could not understand audio")
|
18 |
+
# return '', 400
|
19 |
+
# except sr.RequestError as e:
|
20 |
+
# print(f"Error from Google Speech Recognition service; {e}")
|
21 |
+
# return '', 500
|
22 |
+
|
23 |
+
if __name__ == '__main__':
|
24 |
+
app.run(host='0.0.0.0', port=8723) # Replace with your desired host and port
|
ASR/repcodec/.ipynb_checkpoints/RepCodec-checkpoint.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from repcodec.modules.decoder import Decoder
|
11 |
+
from repcodec.modules.encoder import Encoder
|
12 |
+
from repcodec.modules.projector import Projector
|
13 |
+
from repcodec.modules.quantizer import Quantizer
|
14 |
+
|
15 |
+
|
16 |
+
class RepCodec(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
input_channels=768,
|
20 |
+
output_channels=768,
|
21 |
+
encode_channels=768,
|
22 |
+
decode_channels=768,
|
23 |
+
code_dim=768,
|
24 |
+
codebook_num=1,
|
25 |
+
codebook_size=1024,
|
26 |
+
bias=True,
|
27 |
+
enc_ratios=(1, 1),
|
28 |
+
dec_ratios=(1, 1),
|
29 |
+
enc_strides=(1, 1),
|
30 |
+
dec_strides=(1, 1),
|
31 |
+
enc_kernel_size=3,
|
32 |
+
dec_kernel_size=3,
|
33 |
+
enc_block_dilations=(1, 1),
|
34 |
+
enc_block_kernel_size=3,
|
35 |
+
dec_block_dilations=(1, 1),
|
36 |
+
dec_block_kernel_size=3
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
self.input_channels = input_channels
|
41 |
+
|
42 |
+
self.encoder = Encoder(
|
43 |
+
input_channels=input_channels,
|
44 |
+
encode_channels=encode_channels,
|
45 |
+
channel_ratios=enc_ratios,
|
46 |
+
strides=enc_strides,
|
47 |
+
kernel_size=enc_kernel_size,
|
48 |
+
bias=bias,
|
49 |
+
block_dilations=enc_block_dilations,
|
50 |
+
unit_kernel_size=enc_block_kernel_size
|
51 |
+
)
|
52 |
+
|
53 |
+
self.decoder = Decoder(
|
54 |
+
code_dim=code_dim,
|
55 |
+
output_channels=output_channels,
|
56 |
+
decode_channels=decode_channels,
|
57 |
+
channel_ratios=dec_ratios,
|
58 |
+
strides=dec_strides,
|
59 |
+
kernel_size=dec_kernel_size,
|
60 |
+
bias=bias,
|
61 |
+
block_dilations=dec_block_dilations,
|
62 |
+
unit_kernel_size=dec_block_kernel_size
|
63 |
+
)
|
64 |
+
|
65 |
+
self.projector = Projector(
|
66 |
+
input_channels=self.encoder.out_channels,
|
67 |
+
code_dim=code_dim,
|
68 |
+
kernel_size=3,
|
69 |
+
stride=1,
|
70 |
+
bias=False
|
71 |
+
)
|
72 |
+
|
73 |
+
self.quantizer = Quantizer(
|
74 |
+
code_dim=code_dim,
|
75 |
+
codebook_num=codebook_num,
|
76 |
+
codebook_size=codebook_size
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
x = self.encoder(x)
|
81 |
+
z = self.projector(x)
|
82 |
+
zq, vqloss, perplexity = self.quantizer(z)
|
83 |
+
y = self.decoder(zq)
|
84 |
+
return y, zq, z, vqloss, perplexity
|
ASR/repcodec/RepCodec.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from repcodec.modules.decoder import Decoder
|
11 |
+
from repcodec.modules.encoder import Encoder
|
12 |
+
from repcodec.modules.projector import Projector
|
13 |
+
from repcodec.modules.quantizer import Quantizer
|
14 |
+
|
15 |
+
|
16 |
+
class RepCodec(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
input_channels=768,
|
20 |
+
output_channels=768,
|
21 |
+
encode_channels=768,
|
22 |
+
decode_channels=768,
|
23 |
+
code_dim=768,
|
24 |
+
codebook_num=1,
|
25 |
+
codebook_size=1024,
|
26 |
+
bias=True,
|
27 |
+
enc_ratios=(1, 1),
|
28 |
+
dec_ratios=(1, 1),
|
29 |
+
enc_strides=(1, 1),
|
30 |
+
dec_strides=(1, 1),
|
31 |
+
enc_kernel_size=3,
|
32 |
+
dec_kernel_size=3,
|
33 |
+
enc_block_dilations=(1, 1),
|
34 |
+
enc_block_kernel_size=3,
|
35 |
+
dec_block_dilations=(1, 1),
|
36 |
+
dec_block_kernel_size=3
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
self.input_channels = input_channels
|
41 |
+
|
42 |
+
self.encoder = Encoder(
|
43 |
+
input_channels=input_channels,
|
44 |
+
encode_channels=encode_channels,
|
45 |
+
channel_ratios=enc_ratios,
|
46 |
+
strides=enc_strides,
|
47 |
+
kernel_size=enc_kernel_size,
|
48 |
+
bias=bias,
|
49 |
+
block_dilations=enc_block_dilations,
|
50 |
+
unit_kernel_size=enc_block_kernel_size
|
51 |
+
)
|
52 |
+
|
53 |
+
self.decoder = Decoder(
|
54 |
+
code_dim=code_dim,
|
55 |
+
output_channels=output_channels,
|
56 |
+
decode_channels=decode_channels,
|
57 |
+
channel_ratios=dec_ratios,
|
58 |
+
strides=dec_strides,
|
59 |
+
kernel_size=dec_kernel_size,
|
60 |
+
bias=bias,
|
61 |
+
block_dilations=dec_block_dilations,
|
62 |
+
unit_kernel_size=dec_block_kernel_size
|
63 |
+
)
|
64 |
+
|
65 |
+
self.projector = Projector(
|
66 |
+
input_channels=self.encoder.out_channels,
|
67 |
+
code_dim=code_dim,
|
68 |
+
kernel_size=3,
|
69 |
+
stride=1,
|
70 |
+
bias=False
|
71 |
+
)
|
72 |
+
|
73 |
+
self.quantizer = Quantizer(
|
74 |
+
code_dim=code_dim,
|
75 |
+
codebook_num=codebook_num,
|
76 |
+
codebook_size=codebook_size
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
x = self.encoder(x)
|
81 |
+
z = self.projector(x)
|
82 |
+
zq, vqloss, perplexity = self.quantizer(z)
|
83 |
+
y = self.decoder(zq)
|
84 |
+
return y, zq, z, vqloss, perplexity
|
ASR/repcodec/__pycache__/RepCodec.cpython-38.pyc
ADDED
Binary file (1.87 kB). View file
|
|
ASR/repcodec/configs/repcodec_dim1024.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_channels: 1024
|
2 |
+
output_channels: 1024
|
3 |
+
encode_channels: 1024
|
4 |
+
decode_channels: 1024
|
5 |
+
code_dim: 1024
|
6 |
+
codebook_num: 1
|
7 |
+
codebook_size: 1024
|
8 |
+
bias: true
|
9 |
+
enc_ratios: [ 1, 1 ]
|
10 |
+
dec_ratios: [ 1, 1 ]
|
11 |
+
enc_strides: [ 1, 1 ] # no downsampling
|
12 |
+
dec_strides: [ 1, 1 ]
|
13 |
+
enc_kernel_size: 3
|
14 |
+
dec_kernel_size: 3
|
15 |
+
enc_block_dilations: [ 1, 1 ]
|
16 |
+
enc_block_kernel_size: 3
|
17 |
+
dec_block_dilations: [ 1, 1 ]
|
18 |
+
dec_block_kernel_size: 3
|
ASR/repcodec/configs/repcodec_dim1280.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_channels: 1280
|
2 |
+
output_channels: 1280
|
3 |
+
encode_channels: 1280
|
4 |
+
decode_channels: 1280
|
5 |
+
code_dim: 1280
|
6 |
+
codebook_num: 1
|
7 |
+
codebook_size: 1024
|
8 |
+
bias: true
|
9 |
+
enc_ratios: [ 1, 1 ]
|
10 |
+
dec_ratios: [ 1, 1 ]
|
11 |
+
enc_strides: [ 1, 1 ] # no downsampling
|
12 |
+
dec_strides: [ 1, 1 ]
|
13 |
+
enc_kernel_size: 3
|
14 |
+
dec_kernel_size: 3
|
15 |
+
enc_block_dilations: [ 1, 1 ]
|
16 |
+
enc_block_kernel_size: 3
|
17 |
+
dec_block_dilations: [ 1, 1 ]
|
18 |
+
dec_block_kernel_size: 3
|
ASR/repcodec/configs/repcodec_dim768.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_channels: 768
|
2 |
+
output_channels: 768
|
3 |
+
encode_channels: 768
|
4 |
+
decode_channels: 768
|
5 |
+
code_dim: 768
|
6 |
+
codebook_num: 1
|
7 |
+
codebook_size: 1024
|
8 |
+
bias: true
|
9 |
+
enc_ratios: [ 1, 1 ]
|
10 |
+
dec_ratios: [ 1, 1 ]
|
11 |
+
enc_strides: [ 1, 1 ] # no downsampling
|
12 |
+
dec_strides: [ 1, 1 ]
|
13 |
+
enc_kernel_size: 3
|
14 |
+
dec_kernel_size: 3
|
15 |
+
enc_block_dilations: [ 1, 1 ]
|
16 |
+
enc_block_kernel_size: 3
|
17 |
+
dec_block_dilations: [ 1, 1 ]
|
18 |
+
dec_block_kernel_size: 3
|
ASR/repcodec/layers/__pycache__/conv_layer.cpython-38.pyc
ADDED
Binary file (2.52 kB). View file
|
|
ASR/repcodec/layers/__pycache__/vq_module.cpython-38.pyc
ADDED
Binary file (5.14 kB). View file
|
|
ASR/repcodec/layers/conv_layer.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
|
11 |
+
class Conv1d1x1(nn.Conv1d):
|
12 |
+
"""1x1 Conv1d."""
|
13 |
+
|
14 |
+
def __init__(self, in_channels, out_channels, bias=True):
|
15 |
+
super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias)
|
16 |
+
|
17 |
+
|
18 |
+
class Conv1d(nn.Module):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
in_channels: int,
|
22 |
+
out_channels: int,
|
23 |
+
kernel_size: int,
|
24 |
+
stride: int = 1,
|
25 |
+
padding: int = -1,
|
26 |
+
dilation: int = 1,
|
27 |
+
groups: int = 1,
|
28 |
+
bias: bool = True
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
self.in_channels = in_channels
|
32 |
+
self.out_channels = out_channels
|
33 |
+
self.kernel_size = kernel_size
|
34 |
+
if padding < 0:
|
35 |
+
padding = (kernel_size - 1) // 2 * dilation
|
36 |
+
self.dilation = dilation
|
37 |
+
self.conv = nn.Conv1d(
|
38 |
+
in_channels=in_channels,
|
39 |
+
out_channels=out_channels,
|
40 |
+
kernel_size=kernel_size,
|
41 |
+
stride=stride,
|
42 |
+
padding=padding,
|
43 |
+
dilation=dilation,
|
44 |
+
groups=groups,
|
45 |
+
bias=bias,
|
46 |
+
)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
x (Tensor): Float tensor variable with the shape (B, C, T).
|
52 |
+
Returns:
|
53 |
+
Tensor: Float tensor variable with the shape (B, C, T).
|
54 |
+
"""
|
55 |
+
x = self.conv(x)
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
class ConvTranspose1d(nn.Module):
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
in_channels: int,
|
63 |
+
out_channels: int,
|
64 |
+
kernel_size: int,
|
65 |
+
stride: int,
|
66 |
+
padding=-1,
|
67 |
+
output_padding=-1,
|
68 |
+
groups=1,
|
69 |
+
bias=True,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
if padding < 0:
|
73 |
+
padding = (stride + 1) // 2
|
74 |
+
if output_padding < 0:
|
75 |
+
output_padding = 1 if stride % 2 else 0
|
76 |
+
self.deconv = nn.ConvTranspose1d(
|
77 |
+
in_channels=in_channels,
|
78 |
+
out_channels=out_channels,
|
79 |
+
kernel_size=kernel_size,
|
80 |
+
stride=stride,
|
81 |
+
padding=padding,
|
82 |
+
output_padding=output_padding,
|
83 |
+
groups=groups,
|
84 |
+
bias=bias,
|
85 |
+
)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
"""
|
89 |
+
Args:
|
90 |
+
x (Tensor): Float tensor variable with the shape (B, C, T).
|
91 |
+
Returns:
|
92 |
+
Tensor: Float tensor variable with the shape (B, C', T').
|
93 |
+
"""
|
94 |
+
x = self.deconv(x)
|
95 |
+
return x
|
ASR/repcodec/layers/vq_module.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
class VectorQuantize(nn.Module):
|
14 |
+
"""Vector quantization w/ exponential moving averages (EMA)"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
dim: int,
|
19 |
+
codebook_size: int,
|
20 |
+
decay=0.8,
|
21 |
+
commitment=1.,
|
22 |
+
eps=1e-5,
|
23 |
+
n_embed=None,
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
n_embed = self.default(n_embed, codebook_size)
|
27 |
+
|
28 |
+
self.dim = dim
|
29 |
+
self.n_embed = n_embed
|
30 |
+
self.decay = decay
|
31 |
+
self.eps = eps
|
32 |
+
self.commitment = commitment
|
33 |
+
|
34 |
+
embed = torch.randn(dim, n_embed)
|
35 |
+
self.register_buffer('embed', embed)
|
36 |
+
self.register_buffer('cluster_size', torch.zeros(n_embed))
|
37 |
+
self.register_buffer('embed_avg', embed.clone())
|
38 |
+
|
39 |
+
@property
|
40 |
+
def codebook(self):
|
41 |
+
return self.embed.transpose(0, 1)
|
42 |
+
|
43 |
+
def exists(self, val):
|
44 |
+
return val is not None
|
45 |
+
|
46 |
+
def default(self, val, d):
|
47 |
+
return val if self.exists(val) else d
|
48 |
+
|
49 |
+
def ema_inplace(self, moving_avg, new, decay):
|
50 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
51 |
+
|
52 |
+
def laplace_smoothing(self, x, n_categories, eps=1e-5):
|
53 |
+
return (x + eps) / (x.sum() + n_categories * eps)
|
54 |
+
|
55 |
+
def forward(self, input):
|
56 |
+
dtype = input.dtype
|
57 |
+
flatten = input.reshape(-1, self.dim)
|
58 |
+
dist = (
|
59 |
+
flatten.pow(2).sum(1, keepdim=True)
|
60 |
+
- 2 * flatten @ self.embed
|
61 |
+
+ self.embed.pow(2).sum(0, keepdim=True)
|
62 |
+
)
|
63 |
+
_, embed_ind = (-dist).max(1)
|
64 |
+
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
|
65 |
+
embed_ind = embed_ind.view(*input.shape[:-1])
|
66 |
+
quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
|
67 |
+
|
68 |
+
if self.training:
|
69 |
+
self.ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
70 |
+
embed_sum = flatten.transpose(0, 1) @ embed_onehot
|
71 |
+
self.ema_inplace(self.embed_avg, embed_sum, self.decay)
|
72 |
+
cluster_size = self.laplace_smoothing(self.cluster_size, self.n_embed, self.eps) * self.cluster_size.sum()
|
73 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
|
74 |
+
self.embed.data.copy_(embed_normalized)
|
75 |
+
|
76 |
+
loss = F.mse_loss(quantize.detach(), input) * self.commitment
|
77 |
+
quantize = input + (quantize - input).detach()
|
78 |
+
|
79 |
+
avg_probs = torch.mean(embed_onehot, dim=0)
|
80 |
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
81 |
+
|
82 |
+
return quantize, loss, perplexity
|
83 |
+
|
84 |
+
def forward_index(self, input):
|
85 |
+
dtype = input.dtype
|
86 |
+
flatten = input.reshape(-1, self.dim)
|
87 |
+
dist = (
|
88 |
+
flatten.pow(2).sum(1, keepdim=True)
|
89 |
+
- 2 * flatten @ self.embed
|
90 |
+
+ self.embed.pow(2).sum(0, keepdim=True)
|
91 |
+
)
|
92 |
+
_, embed_ind = (-dist).max(1)
|
93 |
+
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
|
94 |
+
embed_ind = embed_ind.view(*input.shape[:-1])
|
95 |
+
quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
|
96 |
+
quantize = input + (quantize - input).detach()
|
97 |
+
|
98 |
+
return quantize, embed_ind
|
99 |
+
|
100 |
+
|
101 |
+
class ResidualVQ(nn.Module):
|
102 |
+
""" Residual VQ following algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
|
103 |
+
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
*,
|
107 |
+
num_quantizers,
|
108 |
+
**kwargs
|
109 |
+
):
|
110 |
+
super().__init__()
|
111 |
+
self.layers = nn.ModuleList([VectorQuantize(**kwargs) for _ in range(num_quantizers)])
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
quantized_out = 0.
|
115 |
+
residual = x
|
116 |
+
all_losses = []
|
117 |
+
all_perplexities = []
|
118 |
+
for layer in self.layers:
|
119 |
+
quantized, loss, perplexity = layer(residual)
|
120 |
+
# Issue: https://github.com/lucidrains/vector-quantize-pytorch/issues/33
|
121 |
+
# We found considering only the 1st layer VQ's graident results in better performance
|
122 |
+
# residual = residual - quantized.detach() # considering all layers' graidents
|
123 |
+
residual = residual - quantized # considering only the first layer's graident
|
124 |
+
quantized_out = quantized_out + quantized
|
125 |
+
all_losses.append(loss)
|
126 |
+
all_perplexities.append(perplexity)
|
127 |
+
all_losses, all_perplexities = map(torch.stack, (all_losses, all_perplexities))
|
128 |
+
return quantized_out, all_losses, all_perplexities
|
129 |
+
|
130 |
+
def forward_index(self, x, flatten_idx=False):
|
131 |
+
quantized_out = 0.
|
132 |
+
residual = x
|
133 |
+
all_indices = []
|
134 |
+
for i, layer in enumerate(self.layers):
|
135 |
+
quantized, indices = layer.forward_index(residual)
|
136 |
+
# residual = residual - quantized.detach()
|
137 |
+
residual = residual - quantized
|
138 |
+
quantized_out = quantized_out + quantized
|
139 |
+
if flatten_idx:
|
140 |
+
indices += (self.codebook_size * i)
|
141 |
+
all_indices.append(indices)
|
142 |
+
all_indices = torch.stack(all_indices)
|
143 |
+
return quantized_out, all_indices.squeeze(1)
|
144 |
+
|
145 |
+
def initial(self):
|
146 |
+
self.codebook = []
|
147 |
+
for layer in self.layers:
|
148 |
+
self.codebook.append(layer.codebook)
|
149 |
+
self.codebook_size = self.codebook[0].size(0)
|
150 |
+
self.codebook = torch.stack(self.codebook)
|
151 |
+
self.codebook = self.codebook.reshape(-1, self.codebook.size(-1))
|
152 |
+
|
153 |
+
def lookup(self, indices):
|
154 |
+
quantized_out = F.embedding(indices, self.codebook) # Num x T x C
|
155 |
+
return torch.sum(quantized_out, dim=0, keepdim=True)
|
ASR/repcodec/modules/__pycache__/decoder.cpython-38.pyc
ADDED
Binary file (2.51 kB). View file
|
|
ASR/repcodec/modules/__pycache__/encoder.cpython-38.pyc
ADDED
Binary file (2.23 kB). View file
|
|
ASR/repcodec/modules/__pycache__/projector.cpython-38.pyc
ADDED
Binary file (903 Bytes). View file
|
|
ASR/repcodec/modules/__pycache__/quantizer.cpython-38.pyc
ADDED
Binary file (1.63 kB). View file
|
|
ASR/repcodec/modules/__pycache__/residual_unit.cpython-38.pyc
ADDED
Binary file (1.14 kB). View file
|
|
ASR/repcodec/modules/decoder.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from repcodec.layers.conv_layer import Conv1d, ConvTranspose1d
|
12 |
+
from repcodec.modules.residual_unit import ResidualUnit
|
13 |
+
|
14 |
+
|
15 |
+
class DecoderBlock(nn.Module):
|
16 |
+
""" Decoder block (no up-sampling) """
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
in_channels: int,
|
21 |
+
out_channels: int,
|
22 |
+
stride: int,
|
23 |
+
dilations=(1, 1),
|
24 |
+
unit_kernel_size=3,
|
25 |
+
bias=True
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
if stride == 1:
|
30 |
+
self.conv = Conv1d(
|
31 |
+
in_channels=in_channels,
|
32 |
+
out_channels=out_channels,
|
33 |
+
kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape
|
34 |
+
stride=stride,
|
35 |
+
bias=bias,
|
36 |
+
)
|
37 |
+
else:
|
38 |
+
self.conv = ConvTranspose1d(
|
39 |
+
in_channels=in_channels,
|
40 |
+
out_channels=out_channels,
|
41 |
+
kernel_size=(2 * stride),
|
42 |
+
stride=stride,
|
43 |
+
bias=bias,
|
44 |
+
)
|
45 |
+
|
46 |
+
self.res_units = torch.nn.ModuleList()
|
47 |
+
for idx, dilation in enumerate(dilations):
|
48 |
+
self.res_units += [
|
49 |
+
ResidualUnit(out_channels, out_channels,
|
50 |
+
kernel_size=unit_kernel_size,
|
51 |
+
dilation=dilation)
|
52 |
+
]
|
53 |
+
self.num_res = len(self.res_units)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
x = self.conv(x)
|
57 |
+
for idx in range(self.num_res):
|
58 |
+
x = self.res_units[idx](x)
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
class Decoder(nn.Module):
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
code_dim: int,
|
66 |
+
output_channels: int,
|
67 |
+
decode_channels: int,
|
68 |
+
channel_ratios=(1, 1),
|
69 |
+
strides=(1, 1),
|
70 |
+
kernel_size=3,
|
71 |
+
bias=True,
|
72 |
+
block_dilations=(1, 1),
|
73 |
+
unit_kernel_size=3,
|
74 |
+
):
|
75 |
+
super().__init__()
|
76 |
+
assert len(channel_ratios) == len(strides)
|
77 |
+
|
78 |
+
self.conv1 = Conv1d(
|
79 |
+
in_channels=code_dim,
|
80 |
+
out_channels=int(decode_channels * channel_ratios[0]),
|
81 |
+
kernel_size=kernel_size,
|
82 |
+
stride=1,
|
83 |
+
bias=False
|
84 |
+
)
|
85 |
+
|
86 |
+
self.conv_blocks = torch.nn.ModuleList()
|
87 |
+
for idx, stride in enumerate(strides):
|
88 |
+
in_channels = int(decode_channels * channel_ratios[idx])
|
89 |
+
if idx < (len(channel_ratios) - 1):
|
90 |
+
out_channels = int(decode_channels * channel_ratios[idx + 1])
|
91 |
+
else:
|
92 |
+
out_channels = decode_channels
|
93 |
+
self.conv_blocks += [
|
94 |
+
DecoderBlock(
|
95 |
+
in_channels, out_channels, stride,
|
96 |
+
dilations=block_dilations, unit_kernel_size=unit_kernel_size,
|
97 |
+
bias=bias
|
98 |
+
)
|
99 |
+
]
|
100 |
+
self.num_blocks = len(self.conv_blocks)
|
101 |
+
|
102 |
+
self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False)
|
103 |
+
|
104 |
+
def forward(self, z):
|
105 |
+
x = self.conv1(z)
|
106 |
+
for i in range(self.num_blocks):
|
107 |
+
x = self.conv_blocks[i](x)
|
108 |
+
x = self.conv2(x)
|
109 |
+
return x
|
ASR/repcodec/modules/encoder.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from repcodec.layers.conv_layer import Conv1d
|
12 |
+
from repcodec.modules.residual_unit import ResidualUnit
|
13 |
+
|
14 |
+
|
15 |
+
class EncoderBlock(nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
in_channels: int,
|
19 |
+
out_channels: int,
|
20 |
+
stride: int,
|
21 |
+
dilations=(1, 1),
|
22 |
+
unit_kernel_size=3,
|
23 |
+
bias=True
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
self.res_units = torch.nn.ModuleList()
|
27 |
+
for dilation in dilations:
|
28 |
+
self.res_units += [
|
29 |
+
ResidualUnit(in_channels, in_channels,
|
30 |
+
kernel_size=unit_kernel_size,
|
31 |
+
dilation=dilation)
|
32 |
+
]
|
33 |
+
self.num_res = len(self.res_units)
|
34 |
+
|
35 |
+
self.conv = Conv1d(
|
36 |
+
in_channels=in_channels,
|
37 |
+
out_channels=out_channels,
|
38 |
+
kernel_size=3 if stride == 1 else (2 * stride), # special case: stride=1, do not use kernel=2
|
39 |
+
stride=stride,
|
40 |
+
bias=bias,
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
for idx in range(self.num_res):
|
45 |
+
x = self.res_units[idx](x)
|
46 |
+
x = self.conv(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
|
50 |
+
class Encoder(nn.Module):
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
input_channels: int,
|
54 |
+
encode_channels: int,
|
55 |
+
channel_ratios=(1, 1),
|
56 |
+
strides=(1, 1),
|
57 |
+
kernel_size=3,
|
58 |
+
bias=True,
|
59 |
+
block_dilations=(1, 1),
|
60 |
+
unit_kernel_size=3
|
61 |
+
):
|
62 |
+
super().__init__()
|
63 |
+
assert len(channel_ratios) == len(strides)
|
64 |
+
|
65 |
+
self.conv = Conv1d(
|
66 |
+
in_channels=input_channels,
|
67 |
+
out_channels=encode_channels,
|
68 |
+
kernel_size=kernel_size,
|
69 |
+
stride=1,
|
70 |
+
bias=False
|
71 |
+
)
|
72 |
+
self.conv_blocks = torch.nn.ModuleList()
|
73 |
+
in_channels = encode_channels
|
74 |
+
for idx, stride in enumerate(strides):
|
75 |
+
out_channels = int(encode_channels * channel_ratios[idx]) # could be float
|
76 |
+
self.conv_blocks += [
|
77 |
+
EncoderBlock(in_channels, out_channels, stride,
|
78 |
+
dilations=block_dilations, unit_kernel_size=unit_kernel_size,
|
79 |
+
bias=bias)
|
80 |
+
]
|
81 |
+
in_channels = out_channels
|
82 |
+
self.num_blocks = len(self.conv_blocks)
|
83 |
+
self.out_channels = out_channels
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
x = self.conv(x)
|
87 |
+
for i in range(self.num_blocks):
|
88 |
+
x = self.conv_blocks[i](x)
|
89 |
+
return x
|
ASR/repcodec/modules/projector.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from repcodec.layers.conv_layer import Conv1d
|
11 |
+
|
12 |
+
|
13 |
+
class Projector(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
input_channels: int,
|
17 |
+
code_dim: int,
|
18 |
+
kernel_size=3,
|
19 |
+
stride=1,
|
20 |
+
bias=False
|
21 |
+
):
|
22 |
+
super().__init__()
|
23 |
+
self.project = Conv1d(
|
24 |
+
input_channels,
|
25 |
+
code_dim,
|
26 |
+
kernel_size=kernel_size,
|
27 |
+
stride=stride,
|
28 |
+
bias=bias
|
29 |
+
)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
return self.project(x)
|
ASR/repcodec/modules/quantizer.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from repcodec.layers.vq_module import ResidualVQ
|
11 |
+
|
12 |
+
|
13 |
+
class Quantizer(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
code_dim: int,
|
17 |
+
codebook_num: int,
|
18 |
+
codebook_size: int,
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
self.codebook = ResidualVQ(
|
22 |
+
dim=code_dim,
|
23 |
+
num_quantizers=codebook_num,
|
24 |
+
codebook_size=codebook_size
|
25 |
+
)
|
26 |
+
|
27 |
+
def initial(self):
|
28 |
+
self.codebook.initial()
|
29 |
+
|
30 |
+
def forward(self, z):
|
31 |
+
zq, vqloss, perplexity = self.codebook(z.transpose(2, 1))
|
32 |
+
zq = zq.transpose(2, 1)
|
33 |
+
return zq, vqloss, perplexity
|
34 |
+
|
35 |
+
def inference(self, z):
|
36 |
+
zq, indices = self.codebook.forward_index(z.transpose(2, 1))
|
37 |
+
zq = zq.transpose(2, 1)
|
38 |
+
return zq, indices
|
39 |
+
|
40 |
+
def encode(self, z):
|
41 |
+
zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True)
|
42 |
+
return zq, indices
|
43 |
+
|
44 |
+
def decode(self, indices):
|
45 |
+
z = self.codebook.lookup(indices)
|
46 |
+
return z
|
ASR/repcodec/modules/residual_unit.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from repcodec.layers.conv_layer import Conv1d, Conv1d1x1
|
11 |
+
|
12 |
+
|
13 |
+
class ResidualUnit(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
in_channels: int,
|
17 |
+
out_channels: int,
|
18 |
+
kernel_size=3,
|
19 |
+
dilation=1,
|
20 |
+
bias=False,
|
21 |
+
nonlinear_activation="ELU",
|
22 |
+
nonlinear_activation_params={},
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
|
26 |
+
self.conv1 = Conv1d(
|
27 |
+
in_channels=in_channels,
|
28 |
+
out_channels=out_channels,
|
29 |
+
kernel_size=kernel_size,
|
30 |
+
stride=1,
|
31 |
+
dilation=dilation,
|
32 |
+
bias=bias,
|
33 |
+
)
|
34 |
+
self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
y = self.conv1(self.activation(x))
|
38 |
+
y = self.conv2(self.activation(y))
|
39 |
+
return x + y
|
ASR/repcodec/tokenize.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Tuple, List, Optional
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import yaml
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
from repcodec.RepCodec import RepCodec
|
18 |
+
|
19 |
+
ALL_MODELS = {
|
20 |
+
"data2vec_base_l6": 768,
|
21 |
+
"data2vec_large_l18": 1024,
|
22 |
+
"hubert_base_l9": 768,
|
23 |
+
"hubert_large_l18": 1024,
|
24 |
+
"whisper_medium_l24": 1024,
|
25 |
+
"whisper_large_l32": 1280
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
def parse_args():
|
30 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
31 |
+
parser.add_argument(
|
32 |
+
"in_dir",
|
33 |
+
type=str,
|
34 |
+
help="directory of representations to be tokenized."
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--model",
|
38 |
+
required=True,
|
39 |
+
type=str,
|
40 |
+
help="path of the RepCodec model."
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--tsv_path",
|
44 |
+
required=True,
|
45 |
+
type=str,
|
46 |
+
help="path of the tsv file."
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--model_config_path",
|
50 |
+
default=None,
|
51 |
+
type=str,
|
52 |
+
help="please provide this training config if you are using the model you trained yourself."
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"--n_shard",
|
56 |
+
required=False,
|
57 |
+
type=int,
|
58 |
+
default=1,
|
59 |
+
help="number of shards of representations."
|
60 |
+
)
|
61 |
+
parser.add_argument(
|
62 |
+
"--use_gpu",
|
63 |
+
default=False,
|
64 |
+
action="store_true",
|
65 |
+
help="whether use gpu for inference."
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--batch_size",
|
69 |
+
default=1,
|
70 |
+
type=int,
|
71 |
+
help="number of utterances for each mini batch."
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--out_dir",
|
75 |
+
type=str,
|
76 |
+
default=".",
|
77 |
+
help="the directory to save the output."
|
78 |
+
)
|
79 |
+
return parser.parse_args()
|
80 |
+
|
81 |
+
|
82 |
+
def load_model(model_path: str, config_path: Optional[str] = None):
|
83 |
+
if config_path is None:
|
84 |
+
name = os.path.basename(model_path).strip(".pkl")
|
85 |
+
assert name in ALL_MODELS.keys(), f"Cannot find configs for {model_path}. " \
|
86 |
+
f"Please provide the config file you used for training."
|
87 |
+
config = os.path.join(os.path.dirname(__file__), "configs", f"repcodec_dim{ALL_MODELS[name]}.yaml")
|
88 |
+
with open(config) as fp:
|
89 |
+
conf = yaml.load(fp, Loader=yaml.FullLoader)
|
90 |
+
else:
|
91 |
+
with open(config_path) as fp:
|
92 |
+
conf = yaml.load(fp, Loader=yaml.FullLoader)["model_params"]
|
93 |
+
|
94 |
+
model = RepCodec(**conf)
|
95 |
+
model.load_state_dict(torch.load(model_path, map_location="cpu")["model"]["repcodec"])
|
96 |
+
model.quantizer.initial()
|
97 |
+
model.eval()
|
98 |
+
return model
|
99 |
+
|
100 |
+
|
101 |
+
def load_shard(in_dir: Path, rank: int, n_shard: int) -> Tuple[np.ndarray, List[int]]:
|
102 |
+
feat_path = in_dir / f"{rank}_{n_shard}.npy"
|
103 |
+
len_path = in_dir / f"{rank}_{n_shard}.len"
|
104 |
+
|
105 |
+
with open(len_path) as fp:
|
106 |
+
lengths = [int(line.strip()) for line in fp]
|
107 |
+
|
108 |
+
return np.load(feat_path.as_posix(), mmap_mode="r"), lengths
|
109 |
+
|
110 |
+
|
111 |
+
def pad_data(data: List[np.ndarray]) -> List[np.ndarray]:
|
112 |
+
max_len = max([d.shape[0] for d in data])
|
113 |
+
data = [
|
114 |
+
np.pad(d, [(0, max_len - d.shape[0]), (0, 0)], "constant", constant_values=0.0)
|
115 |
+
for d in data
|
116 |
+
]
|
117 |
+
return data
|
118 |
+
|
119 |
+
|
120 |
+
def make_batch_data(data: np.ndarray, shard_lengths: List[int], batch_size: int):
|
121 |
+
batch_data = []
|
122 |
+
batch_lens = []
|
123 |
+
offsets = np.cumsum([0] + shard_lengths)
|
124 |
+
assert len(data) == offsets[-1], f"{len(data)} {offsets[-1]}"
|
125 |
+
|
126 |
+
# from longest to shortest
|
127 |
+
for i in range(len(shard_lengths)):
|
128 |
+
if batch_size > len(batch_data):
|
129 |
+
batch_data.append(data[offsets[i]: offsets[i + 1]])
|
130 |
+
batch_lens.append(shard_lengths[i])
|
131 |
+
else:
|
132 |
+
yield {
|
133 |
+
"data": torch.tensor(np.stack(pad_data(batch_data)), dtype=torch.float), # (bsz, seq len, hidden dim)
|
134 |
+
"lengths": batch_lens
|
135 |
+
}
|
136 |
+
batch_data = [data[offsets[i]: offsets[i + 1]]]
|
137 |
+
batch_lens = [shard_lengths[i]]
|
138 |
+
if len(batch_data) > 0:
|
139 |
+
yield {
|
140 |
+
"data": torch.tensor(np.stack(pad_data(batch_data)), dtype=torch.float),
|
141 |
+
"lengths": batch_lens
|
142 |
+
}
|
143 |
+
|
144 |
+
|
145 |
+
def tokenize_batch(model: RepCodec, batch: dict, device: str) -> List[List[int]]:
|
146 |
+
with torch.no_grad():
|
147 |
+
data = batch["data"].transpose(1, 2).to(device) # (bsz, hidden dim, seq len)
|
148 |
+
x = model.encoder(data)
|
149 |
+
z = model.projector(x)
|
150 |
+
_, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))
|
151 |
+
|
152 |
+
# when bsz=1: (1, seq len)
|
153 |
+
if idx.dim() == 2:
|
154 |
+
return idx.cpu().data.numpy().tolist()
|
155 |
+
# when bsz>1: (1, bsz, seq len)
|
156 |
+
tokens = idx.cpu().data.numpy().tolist()[0]
|
157 |
+
res = []
|
158 |
+
batch_lens = batch["lengths"]
|
159 |
+
for i in range(len(tokens)):
|
160 |
+
n_tokens = batch_lens[i]
|
161 |
+
res.append(tokens[i][:n_tokens])
|
162 |
+
return res
|
163 |
+
|
164 |
+
|
165 |
+
def load_tsv(path: str):
|
166 |
+
with open(path) as fp:
|
167 |
+
root = fp.readline().strip()
|
168 |
+
names = []
|
169 |
+
for line in fp:
|
170 |
+
names.append(line.strip().split("\t")[0])
|
171 |
+
return root, names
|
172 |
+
|
173 |
+
|
174 |
+
def cli():
|
175 |
+
args = parse_args()
|
176 |
+
device = "cuda" if args.use_gpu else "cpu"
|
177 |
+
|
178 |
+
model = load_model(model_path=args.model, config_path=args.model_config_path)
|
179 |
+
model.to(device)
|
180 |
+
|
181 |
+
in_dir = Path(args.in_dir)
|
182 |
+
n_shard = args.n_shard
|
183 |
+
batch_size = args.batch_size
|
184 |
+
|
185 |
+
root_dir, file_names = load_tsv(args.tsv_path)
|
186 |
+
|
187 |
+
output_dir = args.out_dir
|
188 |
+
os.makedirs(output_dir, exist_ok=True)
|
189 |
+
|
190 |
+
processed_cnt = 0
|
191 |
+
pbar = tqdm(total=len(file_names))
|
192 |
+
with open(os.path.join(output_dir, "tokens"), mode="w+") as fp:
|
193 |
+
fp.write(f"{root_dir}\n")
|
194 |
+
|
195 |
+
for rank in range(n_shard):
|
196 |
+
shard_data, shard_lengths = load_shard(in_dir, rank, n_shard)
|
197 |
+
for batch in make_batch_data(shard_data, shard_lengths, batch_size=batch_size):
|
198 |
+
batch_tokens = tokenize_batch(model, batch, device)
|
199 |
+
|
200 |
+
for tokens in batch_tokens:
|
201 |
+
fp.write(f"{file_names[processed_cnt]}\t{' '.join(map(str, tokens))}\n")
|
202 |
+
processed_cnt += 1
|
203 |
+
|
204 |
+
pbar.update(len(batch_tokens))
|
205 |
+
assert processed_cnt == len(file_names), f"# lines of tsv do not match # of representations!"
|
206 |
+
|
207 |
+
pbar.close()
|
208 |
+
print("Tokenize successfully!")
|
209 |
+
|
210 |
+
|
211 |
+
if __name__ == '__main__':
|
212 |
+
cli()
|
ASR/test-gpt2-opt.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c1717d8158b524053e9cc92fdbc9942bb4abc9f119576680f159f2f3177b378d
|
3 |
+
size 653546067
|
ASR/test-gpt2.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a3e0852bac1a63c63262029142c06158c40c81e90ff5f92ef45b0924ad80f6de
|
3 |
+
size 653828879
|
ASR/test-gpt2.plan
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:abdb8a96621f01394e8a783644c782815c25c507cafbc1cb86d6b4a2ccb5a6cc
|
3 |
+
size 328704308
|
ASR/tokenized_librispeech/dataset_dict.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"splits": ["train.clean.100", "train.clean.360", "train.other.500", "validation.clean", "validation.other", "test.clean", "test.other"]}
|
ASR/tokenized_librispeech/test.clean/data-00000-of-00001.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:267be16197f2a380f021dd5b1687ffd341825474ca00875098a5582252b9b901
|
3 |
+
size 29539856
|
ASR/tokenized_librispeech/test.clean/dataset_info.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"citation": "",
|
3 |
+
"description": "",
|
4 |
+
"features": {
|
5 |
+
"input_ids": {
|
6 |
+
"feature": {
|
7 |
+
"dtype": "int32",
|
8 |
+
"_type": "Value"
|
9 |
+
},
|
10 |
+
"_type": "Sequence"
|
11 |
+
},
|
12 |
+
"token_type_ids": {
|
13 |
+
"feature": {
|
14 |
+
"dtype": "int8",
|
15 |
+
"_type": "Value"
|
16 |
+
},
|
17 |
+
"_type": "Sequence"
|
18 |
+
},
|
19 |
+
"attention_mask": {
|
20 |
+
"feature": {
|
21 |
+
"dtype": "int8",
|
22 |
+
"_type": "Value"
|
23 |
+
},
|
24 |
+
"_type": "Sequence"
|
25 |
+
}
|
26 |
+
},
|
27 |
+
"homepage": "",
|
28 |
+
"license": ""
|
29 |
+
}
|
ASR/tokenized_librispeech/test.clean/state.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_data_files": [
|
3 |
+
{
|
4 |
+
"filename": "data-00000-of-00001.arrow"
|
5 |
+
}
|
6 |
+
],
|
7 |
+
"_fingerprint": "5d19fcebbf9d8932",
|
8 |
+
"_format_columns": null,
|
9 |
+
"_format_kwargs": {},
|
10 |
+
"_format_type": null,
|
11 |
+
"_output_all_columns": false,
|
12 |
+
"_split": null
|
13 |
+
}
|
ASR/tokenized_librispeech/test.other/data-00000-of-00001.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:663e0f0f58ac0ac562f319e51219064caaba5102b9bccd1e30322a780b1237ad
|
3 |
+
size 33136248
|
ASR/tokenized_librispeech/test.other/dataset_info.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"citation": "",
|
3 |
+
"description": "",
|
4 |
+
"features": {
|
5 |
+
"input_ids": {
|
6 |
+
"feature": {
|
7 |
+
"dtype": "int32",
|
8 |
+
"_type": "Value"
|
9 |
+
},
|
10 |
+
"_type": "Sequence"
|
11 |
+
},
|
12 |
+
"token_type_ids": {
|
13 |
+
"feature": {
|
14 |
+
"dtype": "int8",
|
15 |
+
"_type": "Value"
|
16 |
+
},
|
17 |
+
"_type": "Sequence"
|
18 |
+
},
|
19 |
+
"attention_mask": {
|
20 |
+
"feature": {
|
21 |
+
"dtype": "int8",
|
22 |
+
"_type": "Value"
|
23 |
+
},
|
24 |
+
"_type": "Sequence"
|
25 |
+
}
|
26 |
+
},
|
27 |
+
"homepage": "",
|
28 |
+
"license": ""
|
29 |
+
}
|
ASR/tokenized_librispeech/test.other/state.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_data_files": [
|
3 |
+
{
|
4 |
+
"filename": "data-00000-of-00001.arrow"
|
5 |
+
}
|
6 |
+
],
|
7 |
+
"_fingerprint": "6d0c91ac40d55d91",
|
8 |
+
"_format_columns": null,
|
9 |
+
"_format_kwargs": {},
|
10 |
+
"_format_type": null,
|
11 |
+
"_output_all_columns": false,
|
12 |
+
"_split": null
|
13 |
+
}
|
ASR/tokenized_librispeech/train.clean.100/data-00000-of-00001.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9b797259a17fc4ddaca3c0f4a09622cdde099dbf46685d5096405d4916e3428b
|
3 |
+
size 321761256
|
ASR/tokenized_librispeech/train.clean.100/dataset_info.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"citation": "",
|
3 |
+
"description": "",
|
4 |
+
"features": {
|
5 |
+
"input_ids": {
|
6 |
+
"feature": {
|
7 |
+
"dtype": "int32",
|
8 |
+
"_type": "Value"
|
9 |
+
},
|
10 |
+
"_type": "Sequence"
|
11 |
+
},
|
12 |
+
"token_type_ids": {
|
13 |
+
"feature": {
|
14 |
+
"dtype": "int8",
|
15 |
+
"_type": "Value"
|
16 |
+
},
|
17 |
+
"_type": "Sequence"
|
18 |
+
},
|
19 |
+
"attention_mask": {
|
20 |
+
"feature": {
|
21 |
+
"dtype": "int8",
|
22 |
+
"_type": "Value"
|
23 |
+
},
|
24 |
+
"_type": "Sequence"
|
25 |
+
}
|
26 |
+
},
|
27 |
+
"homepage": "",
|
28 |
+
"license": ""
|
29 |
+
}
|
ASR/tokenized_librispeech/train.clean.100/state.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_data_files": [
|
3 |
+
{
|
4 |
+
"filename": "data-00000-of-00001.arrow"
|
5 |
+
}
|
6 |
+
],
|
7 |
+
"_fingerprint": "00eedcd713e3fb08",
|
8 |
+
"_format_columns": null,
|
9 |
+
"_format_kwargs": {},
|
10 |
+
"_format_type": null,
|
11 |
+
"_output_all_columns": false,
|
12 |
+
"_split": null
|
13 |
+
}
|
ASR/tokenized_librispeech/train.clean.360/data-00000-of-00003.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e9da54e7eceac6fdf5c90656bca2655b28f079f52b3648978e228326ab84f334
|
3 |
+
size 390907152
|
ASR/tokenized_librispeech/train.clean.360/data-00001-of-00003.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:af64dfb82934022b279d8022102c321991f99d1e53ab102e9e84a1488565d1fd
|
3 |
+
size 390895880
|
ASR/tokenized_librispeech/train.clean.360/data-00002-of-00003.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a40a4c69c058b01e7eff4068babf1655998beb5904bc32a620393f41018fa49a
|
3 |
+
size 390895880
|
ASR/tokenized_librispeech/train.clean.360/dataset_info.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"citation": "",
|
3 |
+
"description": "",
|
4 |
+
"features": {
|
5 |
+
"input_ids": {
|
6 |
+
"feature": {
|
7 |
+
"dtype": "int32",
|
8 |
+
"_type": "Value"
|
9 |
+
},
|
10 |
+
"_type": "Sequence"
|
11 |
+
},
|
12 |
+
"token_type_ids": {
|
13 |
+
"feature": {
|
14 |
+
"dtype": "int8",
|
15 |
+
"_type": "Value"
|
16 |
+
},
|
17 |
+
"_type": "Sequence"
|
18 |
+
},
|
19 |
+
"attention_mask": {
|
20 |
+
"feature": {
|
21 |
+
"dtype": "int8",
|
22 |
+
"_type": "Value"
|
23 |
+
},
|
24 |
+
"_type": "Sequence"
|
25 |
+
}
|
26 |
+
},
|
27 |
+
"homepage": "",
|
28 |
+
"license": ""
|
29 |
+
}
|
ASR/tokenized_librispeech/train.clean.360/state.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_data_files": [
|
3 |
+
{
|
4 |
+
"filename": "data-00000-of-00003.arrow"
|
5 |
+
},
|
6 |
+
{
|
7 |
+
"filename": "data-00001-of-00003.arrow"
|
8 |
+
},
|
9 |
+
{
|
10 |
+
"filename": "data-00002-of-00003.arrow"
|
11 |
+
}
|
12 |
+
],
|
13 |
+
"_fingerprint": "44573b66f4895b44",
|
14 |
+
"_format_columns": null,
|
15 |
+
"_format_kwargs": {},
|
16 |
+
"_format_type": null,
|
17 |
+
"_output_all_columns": false,
|
18 |
+
"_split": null
|
19 |
+
}
|
ASR/tokenized_librispeech/train.other.500/data-00000-of-00004.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8310b88f596d7f62ce3c8a9801cefdb016400702e48c316c145bf4abb8539447
|
3 |
+
size 419093384
|