darshanmakwana commited on
Commit
e0c2d04
·
verified ·
1 Parent(s): 2cddd11

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. ASR/.ipynb_checkpoints/audio_tokenizer-checkpoint.py +611 -0
  3. ASR/.ipynb_checkpoints/demo-checkpoint.ipynb +849 -0
  4. ASR/.ipynb_checkpoints/demo-checkpoint.py +24 -0
  5. ASR/.ipynb_checkpoints/tokenizer_training-checkpoint.ipynb +203 -0
  6. ASR/__pycache__/audio_tokenizer.cpython-38.pyc +0 -0
  7. ASR/__pycache__/tokenizer.cpython-38.pyc +0 -0
  8. ASR/audio_tokenizer.py +611 -0
  9. ASR/demo.ipynb +878 -0
  10. ASR/demo.py +24 -0
  11. ASR/repcodec/.ipynb_checkpoints/RepCodec-checkpoint.py +84 -0
  12. ASR/repcodec/RepCodec.py +84 -0
  13. ASR/repcodec/__pycache__/RepCodec.cpython-38.pyc +0 -0
  14. ASR/repcodec/configs/repcodec_dim1024.yaml +18 -0
  15. ASR/repcodec/configs/repcodec_dim1280.yaml +18 -0
  16. ASR/repcodec/configs/repcodec_dim768.yaml +18 -0
  17. ASR/repcodec/layers/__pycache__/conv_layer.cpython-38.pyc +0 -0
  18. ASR/repcodec/layers/__pycache__/vq_module.cpython-38.pyc +0 -0
  19. ASR/repcodec/layers/conv_layer.py +95 -0
  20. ASR/repcodec/layers/vq_module.py +155 -0
  21. ASR/repcodec/modules/__pycache__/decoder.cpython-38.pyc +0 -0
  22. ASR/repcodec/modules/__pycache__/encoder.cpython-38.pyc +0 -0
  23. ASR/repcodec/modules/__pycache__/projector.cpython-38.pyc +0 -0
  24. ASR/repcodec/modules/__pycache__/quantizer.cpython-38.pyc +0 -0
  25. ASR/repcodec/modules/__pycache__/residual_unit.cpython-38.pyc +0 -0
  26. ASR/repcodec/modules/decoder.py +109 -0
  27. ASR/repcodec/modules/encoder.py +89 -0
  28. ASR/repcodec/modules/projector.py +32 -0
  29. ASR/repcodec/modules/quantizer.py +46 -0
  30. ASR/repcodec/modules/residual_unit.py +39 -0
  31. ASR/repcodec/tokenize.py +212 -0
  32. ASR/test-gpt2-opt.onnx +3 -0
  33. ASR/test-gpt2.onnx +3 -0
  34. ASR/test-gpt2.plan +3 -0
  35. ASR/tokenized_librispeech/dataset_dict.json +1 -0
  36. ASR/tokenized_librispeech/test.clean/data-00000-of-00001.arrow +3 -0
  37. ASR/tokenized_librispeech/test.clean/dataset_info.json +29 -0
  38. ASR/tokenized_librispeech/test.clean/state.json +13 -0
  39. ASR/tokenized_librispeech/test.other/data-00000-of-00001.arrow +3 -0
  40. ASR/tokenized_librispeech/test.other/dataset_info.json +29 -0
  41. ASR/tokenized_librispeech/test.other/state.json +13 -0
  42. ASR/tokenized_librispeech/train.clean.100/data-00000-of-00001.arrow +3 -0
  43. ASR/tokenized_librispeech/train.clean.100/dataset_info.json +29 -0
  44. ASR/tokenized_librispeech/train.clean.100/state.json +13 -0
  45. ASR/tokenized_librispeech/train.clean.360/data-00000-of-00003.arrow +3 -0
  46. ASR/tokenized_librispeech/train.clean.360/data-00001-of-00003.arrow +3 -0
  47. ASR/tokenized_librispeech/train.clean.360/data-00002-of-00003.arrow +3 -0
  48. ASR/tokenized_librispeech/train.clean.360/dataset_info.json +29 -0
  49. ASR/tokenized_librispeech/train.clean.360/state.json +19 -0
  50. ASR/tokenized_librispeech/train.other.500/data-00000-of-00004.arrow +3 -0
.gitattributes CHANGED
@@ -40,3 +40,7 @@ prompting/train_data/train.other.500.json filter=lfs diff=lfs merge=lfs -text
40
  prompting/transcripts/train.clean.360.txt filter=lfs diff=lfs merge=lfs -text
41
  prompting/transcripts/train.other.500.txt filter=lfs diff=lfs merge=lfs -text
42
  prompting/wandb/run-20240615_114519-wfpe2teb/run-wfpe2teb.wandb filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
40
  prompting/transcripts/train.clean.360.txt filter=lfs diff=lfs merge=lfs -text
41
  prompting/transcripts/train.other.500.txt filter=lfs diff=lfs merge=lfs -text
42
  prompting/wandb/run-20240615_114519-wfpe2teb/run-wfpe2teb.wandb filter=lfs diff=lfs merge=lfs -text
43
+ ASR/test-gpt2.plan filter=lfs diff=lfs merge=lfs -text
44
+ ASR/transformer-deploy/docs/infinity/infinity.xcf filter=lfs diff=lfs merge=lfs -text
45
+ ASR/transformer-deploy/resources/img/export_process.png filter=lfs diff=lfs merge=lfs -text
46
+ ASR/transformer-deploy/resources/img/gpt2.png filter=lfs diff=lfs merge=lfs -text
ASR/.ipynb_checkpoints/audio_tokenizer-checkpoint.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+
6
+ from omegaconf import II
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.distributed as dist
12
+
13
+ from fairseq.modules import EMAModule, EMAModuleConfig
14
+ from fairseq.data.data_utils import compute_mask_indices
15
+ from fairseq.models import BaseFairseqModel, register_model
16
+ from fairseq.models.wav2vec import (
17
+ ConvFeatureExtractionModel,
18
+ Wav2Vec2Config,
19
+ TransformerEncoder,
20
+ )
21
+ from fairseq.modules import (
22
+ GradMultiply,
23
+ LayerNorm,
24
+ )
25
+ from fairseq.utils import index_put
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ @dataclass
32
+ class Data2VecAudioConfig(Wav2Vec2Config):
33
+
34
+ loss_beta: float = field(
35
+ default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
36
+ )
37
+ loss_scale: Optional[float] = field(
38
+ default=None,
39
+ metadata={
40
+ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
41
+ },
42
+ )
43
+ average_top_k_layers: int = field(
44
+ default=8, metadata={"help": "how many layers to average"}
45
+ )
46
+
47
+ layer_norm_target_layer: bool = False
48
+ instance_norm_target_layer: bool = False
49
+ instance_norm_targets: bool = False
50
+ layer_norm_targets: bool = False
51
+ batch_norm_target_layer: bool = False
52
+ group_norm_target_layer: bool = False
53
+
54
+ ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
55
+ ema_end_decay: float = field(
56
+ default=0.9999, metadata={"help": "final ema decay rate"}
57
+ )
58
+
59
+ # when to finish annealing ema decay rate
60
+ ema_anneal_end_step: int = II("optimization.max_update")
61
+
62
+ ema_transformer_only: bool = field(
63
+ default=True,
64
+ metadata={"help": "whether to momentum update only the transformer"},
65
+ )
66
+ ema_layers_only: bool = field(
67
+ default=True,
68
+ metadata={"help": "whether to momentum update only the transformer layers"},
69
+ )
70
+
71
+ max_update: int = II("optimization.max_update")
72
+
73
+ min_target_var: float = field(
74
+ default=0.1, metadata={"help": "stop training if target var falls below this"}
75
+ )
76
+ min_pred_var: float = field(
77
+ default=0.01,
78
+ metadata={"help": "stop training if prediction var falls below this"},
79
+ )
80
+
81
+
82
+ def get_annealed_rate(start, end, curr_step, total_steps):
83
+ r = end - start
84
+ pct_remaining = 1 - curr_step / total_steps
85
+ return end - r * pct_remaining
86
+
87
+
88
+ @register_model("data2vec_audio", dataclass=Data2VecAudioConfig)
89
+ class Data2VecAudioModel(BaseFairseqModel):
90
+ def __init__(self, cfg: Data2VecAudioConfig):
91
+ super().__init__()
92
+ self.cfg = cfg
93
+
94
+ feature_enc_layers = eval(cfg.conv_feature_layers)
95
+ self.extractor_embed = feature_enc_layers[-1][0]
96
+
97
+ self.ema = None
98
+ self.embed = cfg.encoder_embed_dim
99
+
100
+ self.average_top_k_layers = cfg.average_top_k_layers
101
+ self.loss_beta = cfg.loss_beta
102
+ self.loss_scale = cfg.loss_scale
103
+
104
+ self.feature_extractor = ConvFeatureExtractionModel(
105
+ conv_layers=feature_enc_layers,
106
+ dropout=0.0,
107
+ mode=cfg.extractor_mode,
108
+ conv_bias=cfg.conv_bias,
109
+ )
110
+
111
+ self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim)
112
+
113
+ self.mask_prob = cfg.mask_prob
114
+ self.mask_selection = cfg.mask_selection
115
+ self.mask_other = cfg.mask_other
116
+ self.mask_length = cfg.mask_length
117
+ self.no_mask_overlap = cfg.no_mask_overlap
118
+ self.mask_min_space = cfg.mask_min_space
119
+
120
+ self.mask_channel_prob = cfg.mask_channel_prob
121
+ self.mask_channel_before = cfg.mask_channel_before
122
+ self.mask_channel_selection = cfg.mask_channel_selection
123
+ self.mask_channel_other = cfg.mask_channel_other
124
+ self.mask_channel_length = cfg.mask_channel_length
125
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
126
+ self.mask_channel_min_space = cfg.mask_channel_min_space
127
+
128
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
129
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
130
+
131
+ self.feature_grad_mult = cfg.feature_grad_mult
132
+
133
+ self.mask_emb = nn.Parameter(
134
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
135
+ )
136
+
137
+ self.encoder = TransformerEncoder(cfg)
138
+ self.layer_norm = LayerNorm(self.extractor_embed)
139
+
140
+ self.final_proj = nn.Linear(self.embed, self.embed)
141
+
142
+ self.num_updates = 0
143
+
144
+ def make_ema_teacher(self):
145
+ ema_config = EMAModuleConfig(
146
+ ema_decay=self.cfg.ema_decay,
147
+ ema_fp32=True,
148
+ )
149
+ skip_keys = set()
150
+ if self.cfg.ema_layers_only:
151
+ self.cfg.ema_transformer_only = True
152
+ for k, _ in self.encoder.pos_conv.named_parameters():
153
+ skip_keys.add(f"pos_conv.{k}")
154
+
155
+ self.ema = EMAModule(
156
+ self.encoder if self.cfg.ema_transformer_only else self,
157
+ ema_config,
158
+ skip_keys=skip_keys,
159
+ )
160
+
161
+ def set_num_updates(self, num_updates):
162
+ super().set_num_updates(num_updates)
163
+
164
+ if self.ema is None and self.final_proj is not None:
165
+ logger.info(f"making ema teacher")
166
+ self.make_ema_teacher()
167
+ elif self.training and self.ema is not None:
168
+ if self.cfg.ema_decay != self.cfg.ema_end_decay:
169
+ if num_updates >= self.cfg.ema_anneal_end_step:
170
+ decay = self.cfg.ema_end_decay
171
+ else:
172
+ decay = get_annealed_rate(
173
+ self.cfg.ema_decay,
174
+ self.cfg.ema_end_decay,
175
+ num_updates,
176
+ self.cfg.ema_anneal_end_step,
177
+ )
178
+ self.ema.set_decay(decay)
179
+ if self.ema.get_decay() < 1:
180
+ self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
181
+
182
+ self.num_updates = num_updates
183
+
184
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
185
+ state = super().state_dict(destination, prefix, keep_vars)
186
+
187
+ if self.ema is not None:
188
+ state[prefix + "_ema"] = self.ema.fp32_params
189
+
190
+ return state
191
+
192
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
193
+ if self.ema is not None:
194
+ k = prefix + "_ema"
195
+ assert k in state_dict
196
+ self.ema.restore(state_dict[k], True)
197
+ del state_dict[k]
198
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
199
+
200
+ @classmethod
201
+ def build_model(cls, cfg: Data2VecAudioConfig, task=None):
202
+ """Build a new model instance."""
203
+
204
+ return cls(cfg)
205
+
206
+ def apply_mask(
207
+ self,
208
+ x,
209
+ padding_mask,
210
+ mask_indices=None,
211
+ mask_channel_indices=None,
212
+ ):
213
+ B, T, C = x.shape
214
+
215
+ if self.mask_channel_prob > 0 and self.mask_channel_before:
216
+ mask_channel_indices = compute_mask_indices(
217
+ (B, C),
218
+ None,
219
+ self.mask_channel_prob,
220
+ self.mask_channel_length,
221
+ self.mask_channel_selection,
222
+ self.mask_channel_other,
223
+ no_overlap=self.no_mask_channel_overlap,
224
+ min_space=self.mask_channel_min_space,
225
+ )
226
+ mask_channel_indices = (
227
+ torch.from_numpy(mask_channel_indices)
228
+ .to(x.device)
229
+ .unsqueeze(1)
230
+ .expand(-1, T, -1)
231
+ )
232
+ x[mask_channel_indices] = 0
233
+
234
+ if self.mask_prob > 0:
235
+ if mask_indices is None:
236
+ mask_indices = compute_mask_indices(
237
+ (B, T),
238
+ padding_mask,
239
+ self.mask_prob,
240
+ self.mask_length,
241
+ self.mask_selection,
242
+ self.mask_other,
243
+ min_masks=1,
244
+ no_overlap=self.no_mask_overlap,
245
+ min_space=self.mask_min_space,
246
+ require_same_masks=self.cfg.require_same_masks,
247
+ mask_dropout=self.cfg.mask_dropout,
248
+ )
249
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
250
+ x = index_put(x, mask_indices, self.mask_emb)
251
+ else:
252
+ mask_indices = None
253
+
254
+ if self.mask_channel_prob > 0 and not self.mask_channel_before:
255
+ if mask_channel_indices is None:
256
+ mask_channel_indices = compute_mask_indices(
257
+ (B, C),
258
+ None,
259
+ self.mask_channel_prob,
260
+ self.mask_channel_length,
261
+ self.mask_channel_selection,
262
+ self.mask_channel_other,
263
+ no_overlap=self.no_mask_channel_overlap,
264
+ min_space=self.mask_channel_min_space,
265
+ )
266
+ mask_channel_indices = (
267
+ torch.from_numpy(mask_channel_indices)
268
+ .to(x.device)
269
+ .unsqueeze(1)
270
+ .expand(-1, T, -1)
271
+ )
272
+ x = index_put(x, mask_channel_indices, 0)
273
+
274
+ return x, mask_indices
275
+
276
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
277
+ """
278
+ Computes the output length of the convolutional layers
279
+ """
280
+
281
+ def _conv_out_length(input_length, kernel_size, stride):
282
+ return torch.floor((input_length - kernel_size) / stride + 1)
283
+
284
+ conv_cfg_list = eval(self.cfg.conv_feature_layers)
285
+
286
+ for i in range(len(conv_cfg_list)):
287
+ input_lengths = _conv_out_length(
288
+ input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
289
+ )
290
+
291
+ return input_lengths.to(torch.long)
292
+
293
+ def forward(
294
+ self,
295
+ source,
296
+ padding_mask=None,
297
+ mask=True,
298
+ features_only=False,
299
+ layer=None,
300
+ mask_indices=None,
301
+ mask_channel_indices=None,
302
+ padding_count=None,
303
+ ):
304
+ features = source
305
+
306
+ if self.feature_grad_mult > 0:
307
+ features = self.feature_extractor(features)
308
+ if self.feature_grad_mult != 1.0:
309
+ features = GradMultiply.apply(features, self.feature_grad_mult)
310
+ else:
311
+ with torch.no_grad():
312
+ features = self.feature_extractor(features)
313
+
314
+ features = features.transpose(1, 2)
315
+
316
+ features = self.layer_norm(features)
317
+
318
+ orig_padding_mask = padding_mask
319
+
320
+ if padding_mask is not None and padding_mask.any():
321
+ input_lengths = (1 - padding_mask.long()).sum(-1)
322
+ # apply conv formula to get real output_lengths
323
+ output_lengths = self._get_feat_extract_output_lengths(input_lengths)
324
+
325
+ padding_mask = torch.zeros(
326
+ features.shape[:2], dtype=features.dtype, device=features.device
327
+ )
328
+
329
+ # these two operations makes sure that all values
330
+ # before the output lengths indices are attended to
331
+ padding_mask[
332
+ (
333
+ torch.arange(padding_mask.shape[0], device=padding_mask.device),
334
+ output_lengths - 1,
335
+ )
336
+ ] = 1
337
+ padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
338
+ else:
339
+ padding_mask = None
340
+
341
+ if self.post_extract_proj is not None:
342
+ features = self.post_extract_proj(features)
343
+
344
+ pre_encoder_features = None
345
+ if self.cfg.ema_transformer_only:
346
+ pre_encoder_features = features.clone()
347
+
348
+ features = self.dropout_input(features)
349
+
350
+ if mask:
351
+ x, mask_indices = self.apply_mask(
352
+ features,
353
+ padding_mask,
354
+ mask_indices=mask_indices,
355
+ mask_channel_indices=mask_channel_indices,
356
+ )
357
+ else:
358
+ x = features
359
+ mask_indices = None
360
+
361
+ x, layer_results = self.encoder(
362
+ x,
363
+ padding_mask=padding_mask,
364
+ layer=layer,
365
+ )
366
+
367
+ if features_only:
368
+ return {
369
+ "x": x,
370
+ "padding_mask": padding_mask,
371
+ "layer_results": layer_results,
372
+ }
373
+
374
+ result = {
375
+ "losses": {},
376
+ }
377
+
378
+ with torch.no_grad():
379
+ self.ema.model.eval()
380
+
381
+ if self.cfg.ema_transformer_only:
382
+ y, layer_results = self.ema.model.extract_features(
383
+ pre_encoder_features,
384
+ padding_mask=padding_mask,
385
+ min_layer=self.cfg.encoder_layers - self.average_top_k_layers,
386
+ )
387
+ y = {
388
+ "x": y,
389
+ "padding_mask": padding_mask,
390
+ "layer_results": layer_results,
391
+ }
392
+ else:
393
+ y = self.ema.model.extract_features(
394
+ source=source,
395
+ padding_mask=orig_padding_mask,
396
+ mask=False,
397
+ )
398
+
399
+ target_layer_results = [l[2] for l in y["layer_results"]]
400
+
401
+ permuted = False
402
+ if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
403
+ target_layer_results = [
404
+ tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
405
+ ]
406
+ permuted = True
407
+
408
+ if self.cfg.batch_norm_target_layer:
409
+ target_layer_results = [
410
+ F.batch_norm(
411
+ tl.float(), running_mean=None, running_var=None, training=True
412
+ )
413
+ for tl in target_layer_results
414
+ ]
415
+
416
+ if self.cfg.instance_norm_target_layer:
417
+ target_layer_results = [
418
+ F.instance_norm(tl.float()) for tl in target_layer_results
419
+ ]
420
+
421
+ if permuted:
422
+ target_layer_results = [
423
+ tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
424
+ ]
425
+
426
+ if self.cfg.group_norm_target_layer:
427
+ target_layer_results = [
428
+ F.layer_norm(tl.float(), tl.shape[-2:])
429
+ for tl in target_layer_results
430
+ ]
431
+
432
+ if self.cfg.layer_norm_target_layer:
433
+ target_layer_results = [
434
+ F.layer_norm(tl.float(), tl.shape[-1:])
435
+ for tl in target_layer_results
436
+ ]
437
+
438
+ y = sum(target_layer_results) / len(target_layer_results)
439
+
440
+ if self.cfg.layer_norm_targets:
441
+ y = F.layer_norm(y.float(), y.shape[-1:])
442
+
443
+ if self.cfg.instance_norm_targets:
444
+ y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
445
+
446
+ if not permuted:
447
+ y = y.transpose(0, 1)
448
+
449
+ y = y[mask_indices]
450
+
451
+ x = x[mask_indices]
452
+ x = self.final_proj(x)
453
+
454
+ sz = x.size(-1)
455
+
456
+ if self.loss_beta == 0:
457
+ loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
458
+ else:
459
+ loss = F.smooth_l1_loss(
460
+ x.float(), y.float(), reduction="none", beta=self.loss_beta
461
+ ).sum(dim=-1)
462
+
463
+ if self.loss_scale is not None:
464
+ scale = self.loss_scale
465
+ else:
466
+ scale = 1 / math.sqrt(sz)
467
+
468
+ result["losses"]["regression"] = loss.sum() * scale
469
+
470
+ if "sample_size" not in result:
471
+ result["sample_size"] = loss.numel()
472
+
473
+ with torch.no_grad():
474
+ result["target_var"] = self.compute_var(y)
475
+ result["pred_var"] = self.compute_var(x.float())
476
+
477
+ if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var:
478
+ logger.error(
479
+ f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
480
+ )
481
+ raise Exception(
482
+ f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
483
+ )
484
+ if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var:
485
+ logger.error(
486
+ f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
487
+ )
488
+ raise Exception(
489
+ f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
490
+ )
491
+
492
+ if self.ema is not None:
493
+ result["ema_decay"] = self.ema.get_decay() * 1000
494
+
495
+ return result
496
+
497
+ @staticmethod
498
+ def compute_var(y):
499
+ y = y.view(-1, y.size(-1))
500
+ if dist.is_initialized():
501
+ zc = torch.tensor(y.size(0)).cuda()
502
+ zs = y.sum(dim=0)
503
+ zss = (y ** 2).sum(dim=0)
504
+
505
+ dist.all_reduce(zc)
506
+ dist.all_reduce(zs)
507
+ dist.all_reduce(zss)
508
+
509
+ var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
510
+ return torch.sqrt(var + 1e-6).mean()
511
+ else:
512
+ return torch.sqrt(y.var(dim=0) + 1e-6).mean()
513
+
514
+ def extract_features(
515
+ self, source, padding_mask, mask=False, layer=None
516
+ ):
517
+ res = self.forward(
518
+ source,
519
+ padding_mask,
520
+ mask=mask,
521
+ features_only=True,
522
+ layer=layer,
523
+ )
524
+ return res
525
+
526
+ def remove_pretraining_modules(self, last_layer=None):
527
+ self.final_proj = None
528
+ self.ema = None
529
+ if last_layer is not None:
530
+ self.encoder.layers = nn.ModuleList(
531
+ l for i, l in enumerate(self.encoder.layers) if i <= last_layer
532
+ )
533
+
534
+ import logging
535
+
536
+ import torch
537
+ import torch.nn.functional as F
538
+ from fairseq import tasks
539
+ from fairseq.checkpoint_utils import load_checkpoint_to_cpu
540
+ from fairseq.data.audio.audio_utils import get_features_or_waveform
541
+ from omegaconf import OmegaConf
542
+
543
+ logger = logging.getLogger("dump_feature")
544
+
545
+
546
+ class Data2vecFeatureReader(object):
547
+ def __init__(self, ckpt_path: str, layer: int, device: str, max_chunk=1600000):
548
+ state = load_checkpoint_to_cpu(ckpt_path)
549
+ cfg = state["cfg"]
550
+ # load task
551
+ task = tasks.setup_task(cfg.task, from_checkpoint=True)
552
+ task.load_state_dict(state["task_state"])
553
+ # load model config
554
+ if "layer_type" not in cfg.model:
555
+ # fix a missing key
556
+ model_config = {k: v for k, v in cfg.model.items()}
557
+ model_config["layer_type"] = "transformer"
558
+ model_config = OmegaConf.create(model_config)
559
+ else:
560
+ model_config = cfg.model
561
+
562
+ # fix param name in the state
563
+ state["model"]["final_proj.weight"] = state["model"].pop("final_proj.0.weight")
564
+ state["model"]["final_proj.bias"] = state["model"].pop("final_proj.0.bias")
565
+ del state["model"]["_ema"]
566
+
567
+ # load model
568
+ model = Data2VecAudioModel.build_model(model_config)
569
+ model.load_state_dict(
570
+ state["model"], strict=True, model_cfg=model_config
571
+ )
572
+
573
+ self.device = device
574
+ logger.info(f"device = {self.device}")
575
+
576
+ self.model = model.eval().to(self.device)
577
+ self.task = task
578
+ self.layer = layer - 1 # make it 1-based
579
+ self.max_chunk = max_chunk
580
+ logger.info(f"TASK CONFIG:\n{self.task.cfg}")
581
+ logger.info(f" max_chunk = {self.max_chunk}")
582
+
583
+ def read_audio(self, path, ref_len=None):
584
+ wav = get_features_or_waveform(path, need_waveform=True, use_sample_rate=self.task.cfg.sample_rate)
585
+ if wav.ndim == 2:
586
+ wav = wav.mean(-1)
587
+ assert wav.ndim == 1, wav.ndim
588
+ if ref_len is not None and abs(ref_len - len(wav)) > 160:
589
+ logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
590
+ return wav
591
+
592
+ def get_feats(self, path, ref_len=None):
593
+ x = self.read_audio(path, ref_len=ref_len)
594
+ with torch.no_grad():
595
+ x = torch.from_numpy(x).float().to(self.device)
596
+ if self.task.cfg.normalize:
597
+ x = F.layer_norm(x, x.shape)
598
+ x = x.view(1, -1)
599
+
600
+ feat = []
601
+ for start in range(0, x.size(1), self.max_chunk):
602
+ x_chunk = x[:, start: start + self.max_chunk]
603
+ res = self.model.extract_features(
604
+ source=x_chunk,
605
+ padding_mask=None,
606
+ mask=False,
607
+ layer=self.layer,
608
+ )
609
+ feat_chunk = res["x"]
610
+ feat.append(feat_chunk)
611
+ return torch.cat(feat, 1).squeeze(0)
ASR/.ipynb_checkpoints/demo-checkpoint.ipynb ADDED
@@ -0,0 +1,849 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "715a402a-44b9-4fa2-abf0-b0cfd2f3d80b",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Recording voice in Real Time"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "dbdf6bab-7418-4a6f-8b75-c31f98a6ada5",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "\"\"\"\n",
19
+ "Sprints:\n",
20
+ "- [ ] Do Inference optimization of ASR LM\n",
21
+ "- [ ] Train on train.other.500\n",
22
+ "- [ ] Generate dataset for prompting\n",
23
+ "\n",
24
+ "Evaluation Dates: 20th - 21st June, 2023, 3:30 - 5:30pm\n",
25
+ "Sharpen PPT Skills: 20th June, 3:30pm - 4:45pm\n",
26
+ "Flow of the PPT:\n",
27
+ "Demo -> Datasets -> Techniques -> Evaluation -> Q&A\n",
28
+ "- [ Done ] Update the one pager deck slide\n",
29
+ "https://sprinklr-my.sharepoint.com/:p:/r/personal/sricharan_narayanam_sprinklr_com/_layouts/15/Doc.aspx?sourcedoc=%7B84811f56-5fc7-4eaa-87d2-db4a3588d18c%7D&action=edit&wdPreviousSession=948ccc35-dc05-f1f9-612d-9a22300e25ba\n",
30
+ "My PPT:\n",
31
+ "https://sprinklr-my.sharepoint.com/:p:/p/darshan_makwana/Ec4jCiyMWhxMproH625msc8BClFVceNQ8o4kS3EhZBO9MA?e=YCSDxm&wdOrigin=TEAMS-MAGLEV.p2p_ns.rwc&wdExp=TEAMS-TREATMENT&wdhostclicktime=1718703689001&web=1\n",
32
+ "Intern Tracker:\n",
33
+ "https://sprinklr.sharepoint.com/:x:/s/AIIntuition/EbRhHPIAIw9MlZ5PpXbztmABde1LFbaSoSHJAo9qU8ggDg?e=xiLkRt&wdOrigin=TEAMS-MAGLEV.p2p_ns.rwc&wdExp=TEAMS-TREATMENT&wdhostclicktime=1718692666812&web=1\n",
34
+ "\"\"\""
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "markdown",
39
+ "id": "150aca01-4098-4ab2-809a-25775ec52069",
40
+ "metadata": {},
41
+ "source": [
42
+ "## ASR LM Inference"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "id": "804a58af-beb2-48c1-9530-98024e27c0d6",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "from audio_tokenizer import Data2vecFeatureReader\n",
53
+ "from repcodec.RepCodec import RepCodec\n",
54
+ "import torch.nn.functional as F\n",
55
+ "import torch\n",
56
+ "import yaml\n",
57
+ "\n",
58
+ "reader = Data2vecFeatureReader(\"./../prompting/models/vox_pretrained.pt\", 18, device=\"cuda:0\", max_chunk=1600000)\n",
59
+ "\n",
60
+ "config = \"./repcodec/configs/repcodec_dim1024.yaml\"\n",
61
+ "with open(config) as fp:\n",
62
+ " conf = yaml.load(fp, Loader=yaml.FullLoader)\n",
63
+ "\n",
64
+ "audio_model = RepCodec(**conf)\n",
65
+ "audio_model.load_state_dict(torch.load(\"./../prompting/models/data2vec_large_l18.pkl\", map_location=\"cuda:0\")[\"model\"][\"repcodec\"])\n",
66
+ "audio_model.quantizer.initial()\n",
67
+ "audio_model.to(\"cuda:0\")\n",
68
+ "audio_model.eval()\n",
69
+ "\n",
70
+ "print(\"Successfully Loaded Audio Tokenizer\")"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "7d8da397-2030-4b36-9a42-97862488797b",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "from datasets import load_dataset\n",
81
+ "\n",
82
+ "cache_dir = \"./../cache\"\n",
83
+ "dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir, trust_remote_code=True)"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": 2,
89
+ "id": "bb8016b2-fc9d-4c23-9e85-b6e1c5ca164c",
90
+ "metadata": {},
91
+ "outputs": [
92
+ {
93
+ "ename": "ImportError",
94
+ "evalue": "FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.",
95
+ "output_type": "error",
96
+ "traceback": [
97
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
98
+ "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
99
+ "Cell \u001b[0;32mIn[2], line 33\u001b[0m\n\u001b[1;32m 30\u001b[0m eot_token \u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mencode(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m<|endoftranscript|>\u001b[39m\u001b[38;5;124m\"\u001b[39m)[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 31\u001b[0m pad_token \u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mencode(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m<|padding|>\u001b[39m\u001b[38;5;124m\"\u001b[39m)[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m---> 33\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mGPT2LMHeadModel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m./../out/checkpoint-10000\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattn_implementation\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mflash_attention_2\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 34\u001b[0m model\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mpad_token_id \u001b[38;5;241m=\u001b[39m pad_token\n\u001b[1;32m 35\u001b[0m model\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39meos_token_id \u001b[38;5;241m=\u001b[39m eot_token\n",
100
+ "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py:3620\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 3617\u001b[0m init_contexts\u001b[38;5;241m.\u001b[39mappend(init_empty_weights())\n\u001b[1;32m 3619\u001b[0m config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(config) \u001b[38;5;66;03m# We do not want to modify the config inplace in from_pretrained.\u001b[39;00m\n\u001b[0;32m-> 3620\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_autoset_attn_implementation\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3621\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muse_flash_attention_2\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_flash_attention_2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch_dtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_map\u001b[49m\n\u001b[1;32m 3622\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3624\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ContextManagers(init_contexts):\n\u001b[1;32m 3625\u001b[0m \u001b[38;5;66;03m# Let's make sure we don't run the init function of buffer modules\u001b[39;00m\n\u001b[1;32m 3626\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m(config, \u001b[38;5;241m*\u001b[39mmodel_args, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs)\n",
101
+ "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py:1469\u001b[0m, in \u001b[0;36mPreTrainedModel._autoset_attn_implementation\u001b[0;34m(cls, config, use_flash_attention_2, torch_dtype, device_map, check_device_map)\u001b[0m\n\u001b[1;32m 1466\u001b[0m config\u001b[38;5;241m.\u001b[39m_attn_implementation \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attention_2\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1468\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config\u001b[38;5;241m.\u001b[39m_attn_implementation \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attention_2\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m-> 1469\u001b[0m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_check_and_enable_flash_attn_2\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1470\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1471\u001b[0m \u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch_dtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1472\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_map\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1473\u001b[0m \u001b[43m \u001b[49m\u001b[43mhard_check_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1474\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_device_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcheck_device_map\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1475\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1476\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m requested_attn_implementation \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msdpa\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available():\n\u001b[1;32m 1477\u001b[0m \u001b[38;5;66;03m# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.\u001b[39;00m\n\u001b[1;32m 1478\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_check_and_enable_sdpa(\n\u001b[1;32m 1479\u001b[0m config,\n\u001b[1;32m 1480\u001b[0m hard_check_only\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m \u001b[38;5;28;01mif\u001b[39;00m requested_attn_implementation \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 1481\u001b[0m )\n",
102
+ "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py:1571\u001b[0m, in \u001b[0;36mPreTrainedModel._check_and_enable_flash_attn_2\u001b[0;34m(cls, config, torch_dtype, device_map, check_device_map, hard_check_only)\u001b[0m\n\u001b[1;32m 1568\u001b[0m install_message \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1570\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m importlib\u001b[38;5;241m.\u001b[39mutil\u001b[38;5;241m.\u001b[39mfind_spec(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attn\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1571\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpreface\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m the package flash_attn seems to be not installed. \u001b[39m\u001b[38;5;132;01m{\u001b[39;00minstall_message\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1573\u001b[0m flash_attention_version \u001b[38;5;241m=\u001b[39m version\u001b[38;5;241m.\u001b[39mparse(importlib\u001b[38;5;241m.\u001b[39mmetadata\u001b[38;5;241m.\u001b[39mversion(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attn\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 1574\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mversion\u001b[38;5;241m.\u001b[39mcuda:\n",
103
+ "\u001b[0;31mImportError\u001b[0m: FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
104
+ ]
105
+ }
106
+ ],
107
+ "source": [
108
+ "from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer\n",
109
+ "import torch\n",
110
+ "import string\n",
111
+ "\n",
112
+ "def process(text):\n",
113
+ "\n",
114
+ " # Lower case every letter\n",
115
+ " text = text.lower()\n",
116
+ "\n",
117
+ " # Remove punctuation\n",
118
+ " punctuation_to_remove = string.punctuation.replace(\"'\", \"\")\n",
119
+ " translation_table = str.maketrans('', '', punctuation_to_remove)\n",
120
+ " text = text.translate(translation_table)\n",
121
+ "\n",
122
+ " # Remove whitespaces from front and behind\n",
123
+ " while text[0] == ' ' or text[-1] == ' ':\n",
124
+ " if text[0] == ' ':\n",
125
+ " text = text[1:]\n",
126
+ " if text[-1] == ' ':\n",
127
+ " text = text[:-1]\n",
128
+ " \n",
129
+ " return text\n",
130
+ "\n",
131
+ "device = \"cuda:0\"\n",
132
+ "dtype = torch.float16\n",
133
+ "context_length = 1877\n",
134
+ "\n",
135
+ "# Load tokenizer and add audio tokens\n",
136
+ "tokenizer = AutoTokenizer.from_pretrained(\"./tokenizer\")\n",
137
+ "eot_token = tokenizer.encode(\"<|endoftranscript|>\")[0]\n",
138
+ "pad_token = tokenizer.encode(\"<|padding|>\")[0]\n",
139
+ "\n",
140
+ "model = GPT2LMHeadModel.from_pretrained(\"./../out/checkpoint-10000\", attn_implementation=\"flash_attention_2\", device_map=device, torch_dtype=dtype).eval()\n",
141
+ "model.config.pad_token_id = pad_token\n",
142
+ "model.config.eos_token_id = eot_token\n",
143
+ "# model = torch.compile(model)"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "id": "7cabe9dc-bbbf-41b4-918f-3f60ee5582f2",
150
+ "metadata": {},
151
+ "outputs": [],
152
+ "source": [
153
+ "from tqdm import tqdm\n",
154
+ "from math import ceil\n",
155
+ "import torch\n",
156
+ "import time\n",
157
+ "\n",
158
+ "sample = dataset[\"train.clean.100\"][5]\n",
159
+ "\n",
160
+ "x = sample[\"audio\"][\"array\"]\n",
161
+ "\n",
162
+ "start_time = time.time()\n",
163
+ "\n",
164
+ "with torch.no_grad():\n",
165
+ " x = torch.from_numpy(x).float().to(reader.device)\n",
166
+ " if reader.task.cfg.normalize:\n",
167
+ " x = F.layer_norm(x, x.shape)\n",
168
+ " x = x.view(1, -1)\n",
169
+ "\n",
170
+ " feat = []\n",
171
+ " for start in range(0, x.size(1), reader.max_chunk):\n",
172
+ " x_chunk = x[:, start: start + reader.max_chunk]\n",
173
+ " res = reader.model.extract_features(\n",
174
+ " source=x_chunk,\n",
175
+ " padding_mask=None,\n",
176
+ " mask=False,\n",
177
+ " layer=reader.layer,\n",
178
+ " )\n",
179
+ " feat_chunk = res[\"x\"]\n",
180
+ " feat.append(feat_chunk)\n",
181
+ " \n",
182
+ " features = torch.cat(feat, 1).permute(0, 2, 1)\n",
183
+ "\n",
184
+ " x = audio_model.encoder(features)\n",
185
+ " z = audio_model.projector(x)\n",
186
+ " _, idx = audio_model.quantizer.codebook.forward_index(z.transpose(2, 1))\n",
187
+ " tokens = idx.cpu().data.numpy().tolist()[0]\n",
188
+ " \n",
189
+ "text = \"\".join([f\"<|audio:{token}|>\" for token in tokens]) + \"<|startoftranscript|>\"\n",
190
+ "input_ids = tokenizer(text, return_tensors=\"pt\").to(device)[\"input_ids\"]\n",
191
+ "\n",
192
+ "input_time = time.time()\n",
193
+ "\n",
194
+ "generations = model.generate(\n",
195
+ " input_ids,\n",
196
+ " pad_token_id = pad_token,\n",
197
+ " eos_token_id = eot_token,\n",
198
+ " max_new_tokens = context_length,\n",
199
+ " use_cache=True\n",
200
+ ")\n",
201
+ "\n",
202
+ "finish_time = time.time()\n",
203
+ "\n",
204
+ "tokenizer.batch_decode(generations, skip_special_tokens=True)\n",
205
+ "print(\"First Token Latency: \", (input_time - start_time) * 1000, \"ms\")\n",
206
+ "# print(\"Throughput: \", (1 + num_tokens)/total_time, \"tokens/s\")\n",
207
+ "print(\"End to End Inference Time: \", (finish_time - start_time) * 1000, \"ms\")\n",
208
+ "print(\"Refer Text: \", process(sample[\"text\"]))\n",
209
+ "print(\"Transcript: \", tokenizer.batch_decode(generations, skip_special_tokens=True)[0])"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": null,
215
+ "id": "baa8d79b-7cf5-4435-838c-1f3d4e043d60",
216
+ "metadata": {},
217
+ "outputs": [],
218
+ "source": [
219
+ "import time\n",
220
+ "\n",
221
+ "sample = dataset[\"train.clean.100\"][0]\n",
222
+ "\n",
223
+ "x = sample[\"audio\"][\"array\"]\n",
224
+ "\n",
225
+ "start_time = time.time()\n",
226
+ "\n",
227
+ "with torch.no_grad():\n",
228
+ " x = torch.from_numpy(x).float().to(reader.device)\n",
229
+ " if reader.task.cfg.normalize:\n",
230
+ " x = F.layer_norm(x, x.shape)\n",
231
+ " x = x.view(1, -1)\n",
232
+ "\n",
233
+ " feat = []\n",
234
+ " for start in range(0, x.size(1), reader.max_chunk):\n",
235
+ " x_chunk = x[:, start: start + reader.max_chunk]\n",
236
+ " res = reader.model.extract_features(\n",
237
+ " source=x_chunk,\n",
238
+ " padding_mask=None,\n",
239
+ " mask=False,\n",
240
+ " layer=reader.layer,\n",
241
+ " )\n",
242
+ " feat_chunk = res[\"x\"]\n",
243
+ " feat.append(feat_chunk)\n",
244
+ " \n",
245
+ " features = torch.cat(feat, 1).permute(0, 2, 1)\n",
246
+ "\n",
247
+ " x = audio_model.encoder(features)\n",
248
+ " z = audio_model.projector(x)\n",
249
+ " _, idx = audio_model.quantizer.codebook.forward_index(z.transpose(2, 1))\n",
250
+ " tokens = idx.cpu().data.numpy().tolist()[0]\n",
251
+ "\n",
252
+ "from tqdm import tqdm\n",
253
+ "from math import ceil\n",
254
+ "import torch\n",
255
+ "\n",
256
+ "context_length = 1877\n",
257
+ "eot_token = tokenizer.encode(\"<|endoftranscript|>\")[0]\n",
258
+ "pad_token = tokenizer.encode(\"<|padding|>\")[0]\n",
259
+ " \n",
260
+ "text = \"\".join([f\"<|audio:{token}|>\" for token in tokens]) + \"<|startoftranscript|>\"\n",
261
+ "input_ids = tokenizer(text, return_tensors=\"pt\").to(device)[\"input_ids\"]\n",
262
+ "\n",
263
+ "max_new_tokens = context_length\n",
264
+ "num_tokens = 0\n",
265
+ "first_token = True\n",
266
+ "\n",
267
+ "while max_new_tokens > 0 and input_ids.shape[-1] < context_length:\n",
268
+ "\n",
269
+ " with torch.no_grad():\n",
270
+ " outputs = model(input_ids = input_ids)\n",
271
+ "\n",
272
+ " logits = outputs[\"logits\"][:, -1]\n",
273
+ "\n",
274
+ " # Greedy Sampling\n",
275
+ " probas = torch.softmax(logits, dim=-1)\n",
276
+ " pred_idx = torch.argmax(probas, dim=-1, keepdim=True)\n",
277
+ " next_idx = pred_idx.item()\n",
278
+ "\n",
279
+ " if first_token:\n",
280
+ " first_token_latency = time.time() - start_time\n",
281
+ " first_token = False\n",
282
+ " start_time = time.time()\n",
283
+ "\n",
284
+ " if next_idx == eot_token:\n",
285
+ " break\n",
286
+ "\n",
287
+ " input_ids = torch.cat((input_ids, pred_idx), dim=-1)\n",
288
+ "\n",
289
+ " max_new_tokens -= 1\n",
290
+ " num_tokens += 1\n",
291
+ "\n",
292
+ "total_time = time.time() - start_time\n",
293
+ "\n",
294
+ "print(\"First Token Latency: \", first_token_latency * 1000, \"ms\")\n",
295
+ "print(\"Throughput: \", (1 + num_tokens)/total_time, \"tokens/s\")\n",
296
+ "print(\"End to End Inference Time: \", (total_time + first_token_latency) * 1000, \"ms\")\n",
297
+ "print(tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0])\n",
298
+ "print(process(sample[\"text\"]))"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": null,
304
+ "id": "014ed999-3293-4d68-8f9c-017584adc642",
305
+ "metadata": {},
306
+ "outputs": [],
307
+ "source": [
308
+ "tokenizer.batch_decode([[1, 2, 3]])"
309
+ ]
310
+ },
311
+ {
312
+ "cell_type": "markdown",
313
+ "id": "ec11e43f-1eb8-4399-9a93-6f1427782661",
314
+ "metadata": {
315
+ "jp-MarkdownHeadingCollapsed": true
316
+ },
317
+ "source": [
318
+ "## Accelerating GPT 2 Inference"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": null,
324
+ "id": "5489cb4e-3213-4931-abe1-4c96d1a7ba56",
325
+ "metadata": {},
326
+ "outputs": [],
327
+ "source": [
328
+ "\"\"\"\n",
329
+ "- change tensorrt.tensorrt to tensorrt\n",
330
+ "- remove cpu quantization lines\n",
331
+ "- output_names [\"logits\"]\n",
332
+ "\"\"\""
333
+ ]
334
+ },
335
+ {
336
+ "cell_type": "code",
337
+ "execution_count": null,
338
+ "id": "7e7e6ea6-7319-4e57-af33-5d917d26abc6",
339
+ "metadata": {},
340
+ "outputs": [],
341
+ "source": [
342
+ "import logging\n",
343
+ "import time\n",
344
+ "from typing import Callable, Dict\n",
345
+ "\n",
346
+ "import numpy as np\n",
347
+ "import tensorrt as trt\n",
348
+ "import torch\n",
349
+ "from tensorrt import ICudaEngine\n",
350
+ "from tensorrt import Logger, Runtime\n",
351
+ "from transformers import AutoTokenizer, BatchEncoding, GPT2LMHeadModel, AutoModelForCausalLM\n",
352
+ "from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions\n",
353
+ "from transformer_deploy.utils.generative_model import GPTModelWrapper\n",
354
+ "import inspect\n",
355
+ "from transformers import TensorType\n",
356
+ "\n",
357
+ "from transformer_deploy.backends.ort_utils import create_model_for_provider, inference_onnx_binding, optimize_onnx\n",
358
+ "from transformer_deploy.backends.pytorch_utils import convert_to_onnx, get_model_size\n",
359
+ "from transformer_deploy.backends.trt_utils import build_engine, load_engine, save_engine"
360
+ ]
361
+ },
362
+ {
363
+ "cell_type": "code",
364
+ "execution_count": null,
365
+ "id": "21681412-7747-4824-894a-6006eb12a821",
366
+ "metadata": {},
367
+ "outputs": [],
368
+ "source": [
369
+ "model_name = \"gpt2\"\n",
370
+ "\n",
371
+ "model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(model_name)\n",
372
+ "model.eval()\n",
373
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
374
+ "model.config.pad_token_id = tokenizer.eos_token_id"
375
+ ]
376
+ },
377
+ {
378
+ "cell_type": "code",
379
+ "execution_count": null,
380
+ "id": "46783acd-c404-44b4-904b-d8fb687afc34",
381
+ "metadata": {},
382
+ "outputs": [],
383
+ "source": [
384
+ "inputs = tokenizer(\"Here is some text to encode Hello World\", return_tensors=\"pt\")\n",
385
+ "print(\"input tensors\")\n",
386
+ "print(inputs)\n",
387
+ "print(\"input tensor shape\")\n",
388
+ "print(inputs[\"input_ids\"].size())\n",
389
+ "\n",
390
+ "with torch.no_grad():\n",
391
+ " outputs = model(**inputs)\n",
392
+ "\n",
393
+ "logits = outputs.logits\n",
394
+ "print(\"output tensor\")\n",
395
+ "print(logits)\n",
396
+ "print(\"output shape\")\n",
397
+ "print(logits.shape)"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": null,
403
+ "id": "2f6cc7bd-5e2d-4d4e-a7e6-73a6b2ecd7af",
404
+ "metadata": {},
405
+ "outputs": [],
406
+ "source": [
407
+ "size = 0\n",
408
+ "for i in range(8, 256, 1):\n",
409
+ " # input sequence (input_ids) made of int-32 (4 bytes)\n",
410
+ " size += np.prod([1, i]) * 4\n",
411
+ " # output tensor made of float-32 (4 bytes)\n",
412
+ " size += np.prod([1, i, 50257]) * 4\n",
413
+ "print(f\"total size (input+output): {size / 1024**3:.2f} Gb\")\n",
414
+ "\n",
415
+ "# to manually check actual tensor size:\n",
416
+ "# np.prod(logits.shape)*32/8/1024**2:.2f}\n",
417
+ "# or\n",
418
+ "# sys.getsizeof(logits.storage())/1024**2"
419
+ ]
420
+ },
421
+ {
422
+ "cell_type": "code",
423
+ "execution_count": null,
424
+ "id": "7debb40e-9941-45e4-9db8-4bb021ce44ab",
425
+ "metadata": {},
426
+ "outputs": [],
427
+ "source": [
428
+ "input_ids: BatchEncoding = tokenizer(\n",
429
+ " \"Here is some text to encode Hello World\", add_special_tokens=True, return_attention_mask=False, return_tensors=\"pt\"\n",
430
+ ")\n",
431
+ "# some inference engines don't support int64 tensor as inputs, we convert all input tensors to int32 type\n",
432
+ "for k, v in input_ids.items(): # type: str, torch.Tensor\n",
433
+ " input_ids[k] = v.type(dtype=torch.int32)\n",
434
+ "\n",
435
+ "convert_to_onnx(\n",
436
+ " model_pytorch=model,\n",
437
+ " output_path=\"test-gpt2.onnx\",\n",
438
+ " inputs_pytorch=dict(input_ids),\n",
439
+ " quantization=False,\n",
440
+ " var_output_seq=True, # we inform ONNX export tool that the output shape will vary with the input shape\n",
441
+ " output_names = [\"logits\"]\n",
442
+ ")\n",
443
+ "# model may switch to train mode for some unknown reasons, we force the eval mode.\n",
444
+ "_ = model.eval()"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "code",
449
+ "execution_count": null,
450
+ "id": "956c3007-2c18-4d92-af4f-6cef474d86b5",
451
+ "metadata": {},
452
+ "outputs": [],
453
+ "source": [
454
+ "logging.basicConfig()\n",
455
+ "logging.getLogger().setLevel(logging.INFO)\n",
456
+ "num_attention_heads, hidden_size = get_model_size(path=model_name)\n",
457
+ "optimize_onnx(\n",
458
+ " onnx_path=\"test-gpt2.onnx\",\n",
459
+ " onnx_optim_model_path=\"test-gpt2-opt.onnx\",\n",
460
+ " fp16=False,\n",
461
+ " use_cuda=True,\n",
462
+ " num_attention_heads=num_attention_heads,\n",
463
+ " hidden_size=hidden_size,\n",
464
+ " architecture=\"gpt2\",\n",
465
+ ")"
466
+ ]
467
+ },
468
+ {
469
+ "cell_type": "code",
470
+ "execution_count": null,
471
+ "id": "85f30ed9-2802-46c9-9201-a70e200b6860",
472
+ "metadata": {},
473
+ "outputs": [],
474
+ "source": [
475
+ "from pathlib import Path\n",
476
+ "\n",
477
+ "trt_logger: Logger = trt.Logger(trt.Logger.ERROR)\n",
478
+ "runtime: Runtime = trt.Runtime(trt_logger)\n",
479
+ "trt_model_name = \"test-gpt2.plan\"\n",
480
+ "\n",
481
+ "# create only of does not exist because it's slow to run...\n",
482
+ "\n",
483
+ "engine: ICudaEngine = build_engine(\n",
484
+ " runtime=runtime,\n",
485
+ " onnx_file_path=\"test-gpt2.onnx\",\n",
486
+ " logger=trt_logger,\n",
487
+ " min_shape=(1, 1),\n",
488
+ " optimal_shape=(1, 128), # num beam, batch size\n",
489
+ " max_shape=(1, 384), # num beam, batch size\n",
490
+ " workspace_size=10000 * 1024**2,\n",
491
+ " fp16=True,\n",
492
+ " int8=False,\n",
493
+ ")\n",
494
+ "save_engine(engine, trt_model_name)"
495
+ ]
496
+ },
497
+ {
498
+ "cell_type": "code",
499
+ "execution_count": null,
500
+ "id": "908fe664-800e-4c5f-a1d5-adfd31fd1c64",
501
+ "metadata": {},
502
+ "outputs": [],
503
+ "source": [
504
+ "engine.num_bindings"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": null,
510
+ "id": "4626926b-fa94-4633-95d5-0d515f8db5f6",
511
+ "metadata": {},
512
+ "outputs": [],
513
+ "source": [
514
+ "print(inspect.getsource(GPTModelWrapper))"
515
+ ]
516
+ },
517
+ {
518
+ "cell_type": "code",
519
+ "execution_count": null,
520
+ "id": "d5bd1de1-a949-46a3-8d15-457d51db4e40",
521
+ "metadata": {},
522
+ "outputs": [],
523
+ "source": [
524
+ "inputs = tokenizer(\n",
525
+ " \"Here is some text to encode Hello World\", # Nvidia example prompt\n",
526
+ " add_special_tokens=True,\n",
527
+ " return_attention_mask=False, # Not used\n",
528
+ " return_tensors=TensorType.PYTORCH,\n",
529
+ ")\n",
530
+ "inputs"
531
+ ]
532
+ },
533
+ {
534
+ "cell_type": "code",
535
+ "execution_count": null,
536
+ "id": "815b548f-fa00-4183-b72c-10ecdd4b11c7",
537
+ "metadata": {},
538
+ "outputs": [],
539
+ "source": [
540
+ "from transformers.generation import GenerationConfig\n",
541
+ "\n",
542
+ "class GPTWrapper(GPTModelWrapper):\n",
543
+ " def __init__(self, *args, **kwargs):\n",
544
+ " super().__init__(*args, **kwargs)\n",
545
+ "\n",
546
+ " self.generation_config = GenerationConfig.from_model_config(self.config) if self.can_generate() else None\n",
547
+ "\n",
548
+ " @classmethod\n",
549
+ " def can_generate(cls) -> bool:\n",
550
+ " \"\"\"\n",
551
+ " Returns whether this model can generate sequences with `.generate()`.\n",
552
+ "\n",
553
+ " Returns:\n",
554
+ " `bool`: Whether this model can generate sequences with `.generate()`.\n",
555
+ " \"\"\"\n",
556
+ " # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.\n",
557
+ " # Alternativelly, the model can also have a custom `generate` function.\n",
558
+ " if \"GenerationMixin\" in str(cls.prepare_inputs_for_generation) and \"GenerationMixin\" in str(cls.generate):\n",
559
+ " return False\n",
560
+ " return True"
561
+ ]
562
+ },
563
+ {
564
+ "cell_type": "code",
565
+ "execution_count": null,
566
+ "id": "ca57ed1e-0bbe-48dd-ae0f-f3d8ecd7fd04",
567
+ "metadata": {},
568
+ "outputs": [],
569
+ "source": [
570
+ "def inference_torch(input_ids: torch.Tensor) -> torch.Tensor:\n",
571
+ " transformer_outputs: BaseModelOutputWithPastAndCrossAttentions = model.transformer(input_ids=input_ids)\n",
572
+ " return model.lm_head(transformer_outputs.last_hidden_state)\n",
573
+ "\n",
574
+ "\n",
575
+ "model.cuda()\n",
576
+ "model.eval()\n",
577
+ "inputs.to(\"cuda\")\n",
578
+ "with torch.inference_mode():\n",
579
+ " gpt2_model = GPTWrapper(config=model.config, device=model.device, inference=inference_torch)\n",
580
+ " sample_output = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
581
+ " print(tokenizer.decode(sample_output[0], skip_special_tokens=False))\n",
582
+ " for _ in range(2):\n",
583
+ " _ = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
584
+ " torch.cuda.synchronize()\n",
585
+ " start = time.time()\n",
586
+ " for _ in range(10):\n",
587
+ " _ = gpt2_model.generate(inputs.input_ids, max_length=256)\n",
588
+ " torch.cuda.synchronize()\n",
589
+ " print(f\"----\\nPytorch: {(time.time() - start)/10:.2f}s/sequence\")\n",
590
+ "_ = model.cpu()"
591
+ ]
592
+ },
593
+ {
594
+ "cell_type": "code",
595
+ "execution_count": null,
596
+ "id": "f0849aae-876e-47bc-b045-14a594170947",
597
+ "metadata": {},
598
+ "outputs": [],
599
+ "source": [
600
+ "model_onnx = create_model_for_provider(path=\"test-gpt2-opt.onnx\", provider_to_use=\"CUDAExecutionProvider\")\n",
601
+ "\n",
602
+ "\n",
603
+ "def inference_onnx_naive(input_ids: torch.Tensor) -> torch.Tensor:\n",
604
+ " data = {\"input_ids\": input_ids.detach().cpu().numpy().astype(np.int32)}\n",
605
+ " logit = model_onnx.run(None, data)\n",
606
+ " np_logit = np.array(logit) # convert list of numpy arrays to a numpy array\n",
607
+ " # we convert numpy tensor to Pytorch tensor as it's the type expected by HF decoding algorithm\n",
608
+ " return torch.squeeze(torch.from_numpy(np_logit), dim=0)\n",
609
+ "\n",
610
+ "\n",
611
+ "gpt2_model = GPTWrapper(config=model.config, device=torch.device(\"cpu\"), inference=inference_onnx_naive)\n",
612
+ "inputs.to(\"cpu\")\n",
613
+ "sample_output = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
614
+ "print(tokenizer.decode(sample_output[0], skip_special_tokens=True))\n",
615
+ "for _ in range(2):\n",
616
+ " _ = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
617
+ "start = time.time()\n",
618
+ "for _ in range(10):\n",
619
+ " _ = gpt2_model.generate(inputs.input_ids, max_length=256)\n",
620
+ "print(f\"----\\nONNX Runtime (standard API): {(time.time() - start)/10:.2f}s/sequence\")\n",
621
+ "\n",
622
+ "del model_onnx"
623
+ ]
624
+ },
625
+ {
626
+ "cell_type": "code",
627
+ "execution_count": null,
628
+ "id": "96114897-894b-4997-bc61-8ac0682e0e55",
629
+ "metadata": {},
630
+ "outputs": [],
631
+ "source": [
632
+ "model_onnx = create_model_for_provider(path=\"test-gpt2-opt.onnx\", provider_to_use=\"CUDAExecutionProvider\")\n",
633
+ "\n",
634
+ "\n",
635
+ "def inference_onnx_optimized(input_ids: torch.Tensor) -> torch.Tensor:\n",
636
+ " data = {\"input_ids\": input_ids}\n",
637
+ " return inference_onnx_binding(model_onnx=model_onnx, inputs=data, device=\"cuda\")[\"output\"]\n",
638
+ "\n",
639
+ "\n",
640
+ "gpt2_model = GPTWrapper(config=model.config, device=torch.device(\"cuda\"), inference=inference_onnx_optimized)\n",
641
+ "inputs.to(\"cuda\")\n",
642
+ "sample_output = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
643
+ "print(tokenizer.decode(sample_output[0], skip_special_tokens=True))\n",
644
+ "for _ in range(2):\n",
645
+ " _ = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
646
+ "start = time.time()\n",
647
+ "for _ in range(10):\n",
648
+ " _ = gpt2_model.generate(inputs.input_ids, max_length=256)\n",
649
+ "print(f\"----\\nONNX Runtime (binding io API): {(time.time() - start)/10:.2f}/sequence\")\n",
650
+ "del model_onnx"
651
+ ]
652
+ },
653
+ {
654
+ "cell_type": "code",
655
+ "execution_count": null,
656
+ "id": "0b5b5427-fd6b-4f70-b307-9c579f0f842a",
657
+ "metadata": {},
658
+ "outputs": [],
659
+ "source": [
660
+ "tensorrt_model: Callable[[Dict[str, torch.Tensor]], torch.Tensor] = load_engine(\n",
661
+ " engine_file_path=\"test-gpt2.plan\", runtime=runtime\n",
662
+ ")\n",
663
+ "\n",
664
+ "\n",
665
+ "def inference_tensorrt(input_ids: torch.Tensor) -> torch.Tensor:\n",
666
+ " data = {\"input_ids\": input_ids}\n",
667
+ " return tensorrt_model(data)\n",
668
+ "\n",
669
+ "\n",
670
+ "gpt2_model = GPTWrapper(config=model.config, device=torch.device(\"cuda\"), inference=inference_tensorrt)\n",
671
+ "inputs.to(\"cuda\")\n",
672
+ "sample_output = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
673
+ "print(tokenizer.decode(sample_output[0], skip_special_tokens=True))\n",
674
+ "for _ in range(2):\n",
675
+ " _ = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
676
+ "start = time.time()\n",
677
+ "for _ in range(10):\n",
678
+ " _ = gpt2_model.generate(inputs.input_ids, max_length=256)\n",
679
+ "print(f\"----\\nTensorRT + CUDA tensors: {(time.time() - start)/10:.2f}/sequence\")\n",
680
+ "\n",
681
+ "del tensorrt_model"
682
+ ]
683
+ },
684
+ {
685
+ "cell_type": "markdown",
686
+ "id": "f547239d-4f7a-433b-8ef6-9e5110a61f4b",
687
+ "metadata": {
688
+ "jp-MarkdownHeadingCollapsed": true
689
+ },
690
+ "source": [
691
+ "## Using CUDAExecution Provider"
692
+ ]
693
+ },
694
+ {
695
+ "cell_type": "code",
696
+ "execution_count": null,
697
+ "id": "6e34c682-85fc-4e8d-b13c-7c1c9ea39ead",
698
+ "metadata": {},
699
+ "outputs": [],
700
+ "source": [
701
+ "from optimum.onnxruntime import ORTModelForCausalLM\n",
702
+ "from optimum.pipelines import pipeline\n",
703
+ "from transformers import AutoTokenizer\n",
704
+ "\n",
705
+ "model_id = \"openai-community/gpt2\"\n",
706
+ "\n",
707
+ "ort_model = ORTModelForCausalLM.from_pretrained(\n",
708
+ " model_id,\n",
709
+ " export=True,\n",
710
+ " provider=\"CUDAExecutionProvider\",\n",
711
+ " use_io_binding=True\n",
712
+ ")\n",
713
+ "\n",
714
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
715
+ "tokenizer.pad_token = tokenizer.eos_token\n",
716
+ "\n",
717
+ "pipe = pipeline(task=\"text-generation\", model=ort_model, tokenizer=tokenizer, device=\"cuda:0\")"
718
+ ]
719
+ },
720
+ {
721
+ "cell_type": "code",
722
+ "execution_count": null,
723
+ "id": "17d28184-26db-4dd3-b24b-0c5a12b10d6d",
724
+ "metadata": {},
725
+ "outputs": [],
726
+ "source": [
727
+ "import time\n",
728
+ "\n",
729
+ "start_time = time.time()\n",
730
+ "\n",
731
+ "generations = pipe(\"Both the music and visual were astounding, not to mention the actors performance.\")\n",
732
+ "generations[0][\"generated_text\"]\n",
733
+ "\n",
734
+ "finish_time = time.time()\n",
735
+ "\n",
736
+ "print(\"End to End Latency: \", (finish_time - start_time) * 1000, \"ms\")"
737
+ ]
738
+ },
739
+ {
740
+ "cell_type": "markdown",
741
+ "id": "19c4230a-3244-4dce-b5ef-d9927dec5c45",
742
+ "metadata": {},
743
+ "source": [
744
+ "## ASR LM with CUDAExcecution Provider"
745
+ ]
746
+ },
747
+ {
748
+ "cell_type": "code",
749
+ "execution_count": null,
750
+ "id": "0f0f1cdc-bfcd-46c5-80a4-60bc76366cf5",
751
+ "metadata": {},
752
+ "outputs": [],
753
+ "source": [
754
+ "from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer\n",
755
+ "from datasets import DatasetDict\n",
756
+ "import torch\n",
757
+ "\n",
758
+ "device = \"cuda:0\"\n",
759
+ "dtype = torch.float16\n",
760
+ "\n",
761
+ "dataset = DatasetDict.load_from_disk(\"./../librispeech_tokenized.hf\")\n",
762
+ "\n",
763
+ "from optimum.onnxruntime import ORTModelForCausalLM\n",
764
+ "from optimum.pipelines import pipeline\n",
765
+ "from transformers import AutoTokenizer\n",
766
+ "\n",
767
+ "model_id = \"./../out/checkpoint-10000\"\n",
768
+ "\n",
769
+ "ort_model = ORTModelForCausalLM.from_pretrained(\n",
770
+ " model_id,\n",
771
+ " export=True,\n",
772
+ " provider=\"CUDAExecutionProvider\",\n",
773
+ " use_io_binding=True\n",
774
+ ")\n",
775
+ "\n",
776
+ "tokenizer = AutoTokenizer.from_pretrained(\"./tokenizer\")\n",
777
+ "\n",
778
+ "pipe = pipeline(task=\"text-generation\", model=ort_model, tokenizer=tokenizer, device=\"cuda:0\")"
779
+ ]
780
+ },
781
+ {
782
+ "cell_type": "code",
783
+ "execution_count": null,
784
+ "id": "9d32098c-b0ec-4c36-95ac-775a3a865512",
785
+ "metadata": {},
786
+ "outputs": [],
787
+ "source": [
788
+ "ort_model.config.eos_token_id = tokenizer.encode(\"<|endoftranscript|>\")[0]\n",
789
+ "ort_model.config.bos_token_id = tokenizer.encode(\"<|startoftranscript|>\")[0]"
790
+ ]
791
+ },
792
+ {
793
+ "cell_type": "code",
794
+ "execution_count": null,
795
+ "id": "1fd0a1fb-9349-4c7a-af03-21e29334f420",
796
+ "metadata": {},
797
+ "outputs": [],
798
+ "source": [
799
+ "dataset[split][idx].keys()"
800
+ ]
801
+ },
802
+ {
803
+ "cell_type": "code",
804
+ "execution_count": null,
805
+ "id": "15d8b989-6460-4555-b6e2-2f9e219d7034",
806
+ "metadata": {},
807
+ "outputs": [],
808
+ "source": [
809
+ "split = \"train.clean.100\"\n",
810
+ "idx = 0\n",
811
+ "\n",
812
+ "text = \"\".join([ f\"<|audio:{tkn}|>\"for tkn in dataset[split][idx][\"audio_tokens\"]]) + \"<|startoftranscript|>\"\n",
813
+ "\n",
814
+ "import time\n",
815
+ "\n",
816
+ "start_time = time.time()\n",
817
+ "\n",
818
+ "generations = pipe(text, max_new_tokens=10, skip_special_tokens=True)\n",
819
+ "\n",
820
+ "finish_time = time.time()\n",
821
+ "\n",
822
+ "print(generations[0][\"generated_text\"])\n",
823
+ "\n",
824
+ "print(\"End to End Latency: \", (finish_time - start_time) * 1000, \"ms\")"
825
+ ]
826
+ }
827
+ ],
828
+ "metadata": {
829
+ "kernelspec": {
830
+ "display_name": "Python 3 (ipykernel)",
831
+ "language": "python",
832
+ "name": "python3"
833
+ },
834
+ "language_info": {
835
+ "codemirror_mode": {
836
+ "name": "ipython",
837
+ "version": 3
838
+ },
839
+ "file_extension": ".py",
840
+ "mimetype": "text/x-python",
841
+ "name": "python",
842
+ "nbconvert_exporter": "python",
843
+ "pygments_lexer": "ipython3",
844
+ "version": "3.8.10"
845
+ }
846
+ },
847
+ "nbformat": 4,
848
+ "nbformat_minor": 5
849
+ }
ASR/.ipynb_checkpoints/demo-checkpoint.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request
2
+ # import speech_recognition as sr
3
+
4
+ app = Flask(__name__)
5
+ # recognizer = sr.Recognizer()
6
+
7
+ @app.route("/darshan/microphone", methods=['POST'])
8
+ def handle_audio():
9
+ audio_data = request.data
10
+ print(audio_data)
11
+ # audio = sr.AudioData(audio_data, sample_rate=44100, sample_width=2) # Adjust sample rate and sample width as needed
12
+ # try:
13
+ # text = recognizer.recognize_google(audio)
14
+ # print(f"Transcription: {text}")
15
+ # return {'transcription': text}, 200
16
+ # except sr.UnknownValueError:
17
+ # print("Could not understand audio")
18
+ # return '', 400
19
+ # except sr.RequestError as e:
20
+ # print(f"Error from Google Speech Recognition service; {e}")
21
+ # return '', 500
22
+
23
+ if __name__ == '__main__':
24
+ app.run(host='0.0.0.0', port=8723) # Replace with your desired host and port
ASR/.ipynb_checkpoints/tokenizer_training-checkpoint.ipynb ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "8f95f1d6-be90-4900-9116-c27b82bd7836",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from tokenizers import SentencePieceBPETokenizer\n",
11
+ "import transformers\n",
12
+ "from transformers import GPT2Tokenizer, AutoModelForCausalLM\n",
13
+ "from datasets import Dataset, DatasetDict\n",
14
+ "\n",
15
+ "cache_dir = \"./cache\"\n",
16
+ "\n",
17
+ "dataset = DatasetDict.load_from_disk(\"./../librispeech_tokenized.hf\")\n",
18
+ "\n",
19
+ "text = []\n",
20
+ "for split in dataset.keys():\n",
21
+ " text += list(dataset[split][\"text\"])\n",
22
+ "\n",
23
+ "model_max_length = 1877\n",
24
+ "special_tokens = [ f\"<|audio:{idx}|>\" for idx in range(1024)] + [\"<|startoftranscript|>\", \"<|endoftranscript|>\", \"<|padding|>\"]\n",
25
+ "\n",
26
+ "bpe_tokenizer = SentencePieceBPETokenizer()\n",
27
+ "bpe_tokenizer.train_from_iterator(\n",
28
+ " text,\n",
29
+ " vocab_size = 5000 + len(special_tokens),\n",
30
+ " min_frequency = 2,\n",
31
+ " show_progress = True,\n",
32
+ " special_tokens = special_tokens\n",
33
+ ")\n",
34
+ "\n",
35
+ "tokenizer = transformers.PreTrainedTokenizerFast(\n",
36
+ " tokenizer_object = bpe_tokenizer,\n",
37
+ " model_max_length = model_max_length,\n",
38
+ " special_tokens = special_tokens\n",
39
+ ")\n",
40
+ "\n",
41
+ "tokenizer.pad_token = \"<|padding|>\"\n",
42
+ "tokenizer.pad_token_id = bpe_tokenizer.token_to_id(\"<|padding|>\")\n",
43
+ "\n",
44
+ "tokenizer.save_pretrained(\"./tokenizer\")"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "id": "d259b76d-1c8d-4c74-9d04-d711a4b3f395",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "from transformers import GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer\n",
55
+ "from datasets import Dataset, DatasetDict\n",
56
+ "\n",
57
+ "max_length = 1877\n",
58
+ "\n",
59
+ "dataset = DatasetDict.load_from_disk(\"./../librispeech_tokenized.hf\")\n",
60
+ "\n",
61
+ "tokenizer = AutoTokenizer.from_pretrained(\"./tokenizer\")\n",
62
+ "\n",
63
+ "def tokenize(row):\n",
64
+ " text = \"\".join([f\"<|audio:{token}|>\" for token in row[\"audio_tokens\"]]) + \"<|startoftranscript|>\" + row[\"text\"] + \"<|endoftranscript|>\"\n",
65
+ " input_ids = tokenizer(\n",
66
+ " text,\n",
67
+ " padding=\"max_length\",\n",
68
+ " max_length=max_length,\n",
69
+ " )\n",
70
+ " return input_ids\n",
71
+ "\n",
72
+ "dataset = dataset.map(tokenize, remove_columns=[\"text\", \"audio_tokens\"])\n",
73
+ "\n",
74
+ "dataset.save_to_disk(\"tokenized_librispeech\")"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": 1,
80
+ "id": "9bc4d1db-1390-4ba1-b296-d52a8993c87f",
81
+ "metadata": {},
82
+ "outputs": [
83
+ {
84
+ "name": "stderr",
85
+ "output_type": "stream",
86
+ "text": [
87
+ "/usr/local/lib/python3.8/dist-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'.\n",
88
+ " table = cls._concat_blocks(blocks, axis=0)\n"
89
+ ]
90
+ }
91
+ ],
92
+ "source": [
93
+ "from transformers import GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer\n",
94
+ "from datasets import Dataset, DatasetDict\n",
95
+ "\n",
96
+ "max_length = 1877\n",
97
+ "\n",
98
+ "dataset = DatasetDict.load_from_disk(\"./../librispeech_tokenized.hf\")\n",
99
+ "\n",
100
+ "tokenizer = AutoTokenizer.from_pretrained(\"./tokenizer\")"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": 2,
106
+ "id": "a83d83c7-2e56-4b9e-9f63-11f7eea22d6d",
107
+ "metadata": {},
108
+ "outputs": [
109
+ {
110
+ "data": {
111
+ "text/plain": [
112
+ "[1024]"
113
+ ]
114
+ },
115
+ "execution_count": 2,
116
+ "metadata": {},
117
+ "output_type": "execute_result"
118
+ }
119
+ ],
120
+ "source": [
121
+ "tokenizer.encode(\"<|startoftranscript|>\")"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": null,
127
+ "id": "5b9f0fa9-384b-4453-bf4c-0d49f0c1e4a5",
128
+ "metadata": {},
129
+ "outputs": [],
130
+ "source": [
131
+ "tokenizer.pad_token_id"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "id": "1964689d-687d-4ab9-967d-80d9eb95b159",
138
+ "metadata": {},
139
+ "outputs": [],
140
+ "source": [
141
+ "from tqdm import tqdm\n",
142
+ "lens = []\n",
143
+ "\n",
144
+ "for split in dataset.keys():\n",
145
+ " for idx in tqdm(range(len(dataset[split]))):\n",
146
+ " sample = dataset[split][idx]\n",
147
+ " max_len = len(tokenizer.encode(sample[\"text\"])) + len(sample[\"audio_tokens\"])\n",
148
+ " lens.append(max_len)"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": null,
154
+ "id": "e49c1a27-c3e4-43eb-a4e0-68a0d739f27b",
155
+ "metadata": {},
156
+ "outputs": [],
157
+ "source": [
158
+ "max(lens)"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "id": "96dbd94d-5455-49bd-8893-4aae5dfd9b7f",
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": [
168
+ "min(lens)"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "id": "b9aaca12-286d-416c-a2f7-0cdf689eeb2e",
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "tokenizer.encode(\"<|audio:0|>\")"
179
+ ]
180
+ }
181
+ ],
182
+ "metadata": {
183
+ "kernelspec": {
184
+ "display_name": "Python 3 (ipykernel)",
185
+ "language": "python",
186
+ "name": "python3"
187
+ },
188
+ "language_info": {
189
+ "codemirror_mode": {
190
+ "name": "ipython",
191
+ "version": 3
192
+ },
193
+ "file_extension": ".py",
194
+ "mimetype": "text/x-python",
195
+ "name": "python",
196
+ "nbconvert_exporter": "python",
197
+ "pygments_lexer": "ipython3",
198
+ "version": "3.8.10"
199
+ }
200
+ },
201
+ "nbformat": 4,
202
+ "nbformat_minor": 5
203
+ }
ASR/__pycache__/audio_tokenizer.cpython-38.pyc ADDED
Binary file (14.8 kB). View file
 
ASR/__pycache__/tokenizer.cpython-38.pyc ADDED
Binary file (14.8 kB). View file
 
ASR/audio_tokenizer.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+
6
+ from omegaconf import II
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.distributed as dist
12
+
13
+ from fairseq.modules import EMAModule, EMAModuleConfig
14
+ from fairseq.data.data_utils import compute_mask_indices
15
+ from fairseq.models import BaseFairseqModel, register_model
16
+ from fairseq.models.wav2vec import (
17
+ ConvFeatureExtractionModel,
18
+ Wav2Vec2Config,
19
+ TransformerEncoder,
20
+ )
21
+ from fairseq.modules import (
22
+ GradMultiply,
23
+ LayerNorm,
24
+ )
25
+ from fairseq.utils import index_put
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ @dataclass
32
+ class Data2VecAudioConfig(Wav2Vec2Config):
33
+
34
+ loss_beta: float = field(
35
+ default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
36
+ )
37
+ loss_scale: Optional[float] = field(
38
+ default=None,
39
+ metadata={
40
+ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
41
+ },
42
+ )
43
+ average_top_k_layers: int = field(
44
+ default=8, metadata={"help": "how many layers to average"}
45
+ )
46
+
47
+ layer_norm_target_layer: bool = False
48
+ instance_norm_target_layer: bool = False
49
+ instance_norm_targets: bool = False
50
+ layer_norm_targets: bool = False
51
+ batch_norm_target_layer: bool = False
52
+ group_norm_target_layer: bool = False
53
+
54
+ ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
55
+ ema_end_decay: float = field(
56
+ default=0.9999, metadata={"help": "final ema decay rate"}
57
+ )
58
+
59
+ # when to finish annealing ema decay rate
60
+ ema_anneal_end_step: int = II("optimization.max_update")
61
+
62
+ ema_transformer_only: bool = field(
63
+ default=True,
64
+ metadata={"help": "whether to momentum update only the transformer"},
65
+ )
66
+ ema_layers_only: bool = field(
67
+ default=True,
68
+ metadata={"help": "whether to momentum update only the transformer layers"},
69
+ )
70
+
71
+ max_update: int = II("optimization.max_update")
72
+
73
+ min_target_var: float = field(
74
+ default=0.1, metadata={"help": "stop training if target var falls below this"}
75
+ )
76
+ min_pred_var: float = field(
77
+ default=0.01,
78
+ metadata={"help": "stop training if prediction var falls below this"},
79
+ )
80
+
81
+
82
+ def get_annealed_rate(start, end, curr_step, total_steps):
83
+ r = end - start
84
+ pct_remaining = 1 - curr_step / total_steps
85
+ return end - r * pct_remaining
86
+
87
+
88
+ @register_model("data2vec_audio", dataclass=Data2VecAudioConfig)
89
+ class Data2VecAudioModel(BaseFairseqModel):
90
+ def __init__(self, cfg: Data2VecAudioConfig):
91
+ super().__init__()
92
+ self.cfg = cfg
93
+
94
+ feature_enc_layers = eval(cfg.conv_feature_layers)
95
+ self.extractor_embed = feature_enc_layers[-1][0]
96
+
97
+ self.ema = None
98
+ self.embed = cfg.encoder_embed_dim
99
+
100
+ self.average_top_k_layers = cfg.average_top_k_layers
101
+ self.loss_beta = cfg.loss_beta
102
+ self.loss_scale = cfg.loss_scale
103
+
104
+ self.feature_extractor = ConvFeatureExtractionModel(
105
+ conv_layers=feature_enc_layers,
106
+ dropout=0.0,
107
+ mode=cfg.extractor_mode,
108
+ conv_bias=cfg.conv_bias,
109
+ )
110
+
111
+ self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim)
112
+
113
+ self.mask_prob = cfg.mask_prob
114
+ self.mask_selection = cfg.mask_selection
115
+ self.mask_other = cfg.mask_other
116
+ self.mask_length = cfg.mask_length
117
+ self.no_mask_overlap = cfg.no_mask_overlap
118
+ self.mask_min_space = cfg.mask_min_space
119
+
120
+ self.mask_channel_prob = cfg.mask_channel_prob
121
+ self.mask_channel_before = cfg.mask_channel_before
122
+ self.mask_channel_selection = cfg.mask_channel_selection
123
+ self.mask_channel_other = cfg.mask_channel_other
124
+ self.mask_channel_length = cfg.mask_channel_length
125
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
126
+ self.mask_channel_min_space = cfg.mask_channel_min_space
127
+
128
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
129
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
130
+
131
+ self.feature_grad_mult = cfg.feature_grad_mult
132
+
133
+ self.mask_emb = nn.Parameter(
134
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
135
+ )
136
+
137
+ self.encoder = TransformerEncoder(cfg)
138
+ self.layer_norm = LayerNorm(self.extractor_embed)
139
+
140
+ self.final_proj = nn.Linear(self.embed, self.embed)
141
+
142
+ self.num_updates = 0
143
+
144
+ def make_ema_teacher(self):
145
+ ema_config = EMAModuleConfig(
146
+ ema_decay=self.cfg.ema_decay,
147
+ ema_fp32=True,
148
+ )
149
+ skip_keys = set()
150
+ if self.cfg.ema_layers_only:
151
+ self.cfg.ema_transformer_only = True
152
+ for k, _ in self.encoder.pos_conv.named_parameters():
153
+ skip_keys.add(f"pos_conv.{k}")
154
+
155
+ self.ema = EMAModule(
156
+ self.encoder if self.cfg.ema_transformer_only else self,
157
+ ema_config,
158
+ skip_keys=skip_keys,
159
+ )
160
+
161
+ def set_num_updates(self, num_updates):
162
+ super().set_num_updates(num_updates)
163
+
164
+ if self.ema is None and self.final_proj is not None:
165
+ logger.info(f"making ema teacher")
166
+ self.make_ema_teacher()
167
+ elif self.training and self.ema is not None:
168
+ if self.cfg.ema_decay != self.cfg.ema_end_decay:
169
+ if num_updates >= self.cfg.ema_anneal_end_step:
170
+ decay = self.cfg.ema_end_decay
171
+ else:
172
+ decay = get_annealed_rate(
173
+ self.cfg.ema_decay,
174
+ self.cfg.ema_end_decay,
175
+ num_updates,
176
+ self.cfg.ema_anneal_end_step,
177
+ )
178
+ self.ema.set_decay(decay)
179
+ if self.ema.get_decay() < 1:
180
+ self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
181
+
182
+ self.num_updates = num_updates
183
+
184
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
185
+ state = super().state_dict(destination, prefix, keep_vars)
186
+
187
+ if self.ema is not None:
188
+ state[prefix + "_ema"] = self.ema.fp32_params
189
+
190
+ return state
191
+
192
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
193
+ if self.ema is not None:
194
+ k = prefix + "_ema"
195
+ assert k in state_dict
196
+ self.ema.restore(state_dict[k], True)
197
+ del state_dict[k]
198
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
199
+
200
+ @classmethod
201
+ def build_model(cls, cfg: Data2VecAudioConfig, task=None):
202
+ """Build a new model instance."""
203
+
204
+ return cls(cfg)
205
+
206
+ def apply_mask(
207
+ self,
208
+ x,
209
+ padding_mask,
210
+ mask_indices=None,
211
+ mask_channel_indices=None,
212
+ ):
213
+ B, T, C = x.shape
214
+
215
+ if self.mask_channel_prob > 0 and self.mask_channel_before:
216
+ mask_channel_indices = compute_mask_indices(
217
+ (B, C),
218
+ None,
219
+ self.mask_channel_prob,
220
+ self.mask_channel_length,
221
+ self.mask_channel_selection,
222
+ self.mask_channel_other,
223
+ no_overlap=self.no_mask_channel_overlap,
224
+ min_space=self.mask_channel_min_space,
225
+ )
226
+ mask_channel_indices = (
227
+ torch.from_numpy(mask_channel_indices)
228
+ .to(x.device)
229
+ .unsqueeze(1)
230
+ .expand(-1, T, -1)
231
+ )
232
+ x[mask_channel_indices] = 0
233
+
234
+ if self.mask_prob > 0:
235
+ if mask_indices is None:
236
+ mask_indices = compute_mask_indices(
237
+ (B, T),
238
+ padding_mask,
239
+ self.mask_prob,
240
+ self.mask_length,
241
+ self.mask_selection,
242
+ self.mask_other,
243
+ min_masks=1,
244
+ no_overlap=self.no_mask_overlap,
245
+ min_space=self.mask_min_space,
246
+ require_same_masks=self.cfg.require_same_masks,
247
+ mask_dropout=self.cfg.mask_dropout,
248
+ )
249
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
250
+ x = index_put(x, mask_indices, self.mask_emb)
251
+ else:
252
+ mask_indices = None
253
+
254
+ if self.mask_channel_prob > 0 and not self.mask_channel_before:
255
+ if mask_channel_indices is None:
256
+ mask_channel_indices = compute_mask_indices(
257
+ (B, C),
258
+ None,
259
+ self.mask_channel_prob,
260
+ self.mask_channel_length,
261
+ self.mask_channel_selection,
262
+ self.mask_channel_other,
263
+ no_overlap=self.no_mask_channel_overlap,
264
+ min_space=self.mask_channel_min_space,
265
+ )
266
+ mask_channel_indices = (
267
+ torch.from_numpy(mask_channel_indices)
268
+ .to(x.device)
269
+ .unsqueeze(1)
270
+ .expand(-1, T, -1)
271
+ )
272
+ x = index_put(x, mask_channel_indices, 0)
273
+
274
+ return x, mask_indices
275
+
276
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
277
+ """
278
+ Computes the output length of the convolutional layers
279
+ """
280
+
281
+ def _conv_out_length(input_length, kernel_size, stride):
282
+ return torch.floor((input_length - kernel_size) / stride + 1)
283
+
284
+ conv_cfg_list = eval(self.cfg.conv_feature_layers)
285
+
286
+ for i in range(len(conv_cfg_list)):
287
+ input_lengths = _conv_out_length(
288
+ input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
289
+ )
290
+
291
+ return input_lengths.to(torch.long)
292
+
293
+ def forward(
294
+ self,
295
+ source,
296
+ padding_mask=None,
297
+ mask=True,
298
+ features_only=False,
299
+ layer=None,
300
+ mask_indices=None,
301
+ mask_channel_indices=None,
302
+ padding_count=None,
303
+ ):
304
+ features = source
305
+
306
+ if self.feature_grad_mult > 0:
307
+ features = self.feature_extractor(features)
308
+ if self.feature_grad_mult != 1.0:
309
+ features = GradMultiply.apply(features, self.feature_grad_mult)
310
+ else:
311
+ with torch.no_grad():
312
+ features = self.feature_extractor(features)
313
+
314
+ features = features.transpose(1, 2)
315
+
316
+ features = self.layer_norm(features)
317
+
318
+ orig_padding_mask = padding_mask
319
+
320
+ if padding_mask is not None and padding_mask.any():
321
+ input_lengths = (1 - padding_mask.long()).sum(-1)
322
+ # apply conv formula to get real output_lengths
323
+ output_lengths = self._get_feat_extract_output_lengths(input_lengths)
324
+
325
+ padding_mask = torch.zeros(
326
+ features.shape[:2], dtype=features.dtype, device=features.device
327
+ )
328
+
329
+ # these two operations makes sure that all values
330
+ # before the output lengths indices are attended to
331
+ padding_mask[
332
+ (
333
+ torch.arange(padding_mask.shape[0], device=padding_mask.device),
334
+ output_lengths - 1,
335
+ )
336
+ ] = 1
337
+ padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
338
+ else:
339
+ padding_mask = None
340
+
341
+ if self.post_extract_proj is not None:
342
+ features = self.post_extract_proj(features)
343
+
344
+ pre_encoder_features = None
345
+ if self.cfg.ema_transformer_only:
346
+ pre_encoder_features = features.clone()
347
+
348
+ features = self.dropout_input(features)
349
+
350
+ if mask:
351
+ x, mask_indices = self.apply_mask(
352
+ features,
353
+ padding_mask,
354
+ mask_indices=mask_indices,
355
+ mask_channel_indices=mask_channel_indices,
356
+ )
357
+ else:
358
+ x = features
359
+ mask_indices = None
360
+
361
+ x, layer_results = self.encoder(
362
+ x,
363
+ padding_mask=padding_mask,
364
+ layer=layer,
365
+ )
366
+
367
+ if features_only:
368
+ return {
369
+ "x": x,
370
+ "padding_mask": padding_mask,
371
+ "layer_results": layer_results,
372
+ }
373
+
374
+ result = {
375
+ "losses": {},
376
+ }
377
+
378
+ with torch.no_grad():
379
+ self.ema.model.eval()
380
+
381
+ if self.cfg.ema_transformer_only:
382
+ y, layer_results = self.ema.model.extract_features(
383
+ pre_encoder_features,
384
+ padding_mask=padding_mask,
385
+ min_layer=self.cfg.encoder_layers - self.average_top_k_layers,
386
+ )
387
+ y = {
388
+ "x": y,
389
+ "padding_mask": padding_mask,
390
+ "layer_results": layer_results,
391
+ }
392
+ else:
393
+ y = self.ema.model.extract_features(
394
+ source=source,
395
+ padding_mask=orig_padding_mask,
396
+ mask=False,
397
+ )
398
+
399
+ target_layer_results = [l[2] for l in y["layer_results"]]
400
+
401
+ permuted = False
402
+ if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
403
+ target_layer_results = [
404
+ tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
405
+ ]
406
+ permuted = True
407
+
408
+ if self.cfg.batch_norm_target_layer:
409
+ target_layer_results = [
410
+ F.batch_norm(
411
+ tl.float(), running_mean=None, running_var=None, training=True
412
+ )
413
+ for tl in target_layer_results
414
+ ]
415
+
416
+ if self.cfg.instance_norm_target_layer:
417
+ target_layer_results = [
418
+ F.instance_norm(tl.float()) for tl in target_layer_results
419
+ ]
420
+
421
+ if permuted:
422
+ target_layer_results = [
423
+ tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
424
+ ]
425
+
426
+ if self.cfg.group_norm_target_layer:
427
+ target_layer_results = [
428
+ F.layer_norm(tl.float(), tl.shape[-2:])
429
+ for tl in target_layer_results
430
+ ]
431
+
432
+ if self.cfg.layer_norm_target_layer:
433
+ target_layer_results = [
434
+ F.layer_norm(tl.float(), tl.shape[-1:])
435
+ for tl in target_layer_results
436
+ ]
437
+
438
+ y = sum(target_layer_results) / len(target_layer_results)
439
+
440
+ if self.cfg.layer_norm_targets:
441
+ y = F.layer_norm(y.float(), y.shape[-1:])
442
+
443
+ if self.cfg.instance_norm_targets:
444
+ y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
445
+
446
+ if not permuted:
447
+ y = y.transpose(0, 1)
448
+
449
+ y = y[mask_indices]
450
+
451
+ x = x[mask_indices]
452
+ x = self.final_proj(x)
453
+
454
+ sz = x.size(-1)
455
+
456
+ if self.loss_beta == 0:
457
+ loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
458
+ else:
459
+ loss = F.smooth_l1_loss(
460
+ x.float(), y.float(), reduction="none", beta=self.loss_beta
461
+ ).sum(dim=-1)
462
+
463
+ if self.loss_scale is not None:
464
+ scale = self.loss_scale
465
+ else:
466
+ scale = 1 / math.sqrt(sz)
467
+
468
+ result["losses"]["regression"] = loss.sum() * scale
469
+
470
+ if "sample_size" not in result:
471
+ result["sample_size"] = loss.numel()
472
+
473
+ with torch.no_grad():
474
+ result["target_var"] = self.compute_var(y)
475
+ result["pred_var"] = self.compute_var(x.float())
476
+
477
+ if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var:
478
+ logger.error(
479
+ f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
480
+ )
481
+ raise Exception(
482
+ f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
483
+ )
484
+ if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var:
485
+ logger.error(
486
+ f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
487
+ )
488
+ raise Exception(
489
+ f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
490
+ )
491
+
492
+ if self.ema is not None:
493
+ result["ema_decay"] = self.ema.get_decay() * 1000
494
+
495
+ return result
496
+
497
+ @staticmethod
498
+ def compute_var(y):
499
+ y = y.view(-1, y.size(-1))
500
+ if dist.is_initialized():
501
+ zc = torch.tensor(y.size(0)).cuda()
502
+ zs = y.sum(dim=0)
503
+ zss = (y ** 2).sum(dim=0)
504
+
505
+ dist.all_reduce(zc)
506
+ dist.all_reduce(zs)
507
+ dist.all_reduce(zss)
508
+
509
+ var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
510
+ return torch.sqrt(var + 1e-6).mean()
511
+ else:
512
+ return torch.sqrt(y.var(dim=0) + 1e-6).mean()
513
+
514
+ def extract_features(
515
+ self, source, padding_mask, mask=False, layer=None
516
+ ):
517
+ res = self.forward(
518
+ source,
519
+ padding_mask,
520
+ mask=mask,
521
+ features_only=True,
522
+ layer=layer,
523
+ )
524
+ return res
525
+
526
+ def remove_pretraining_modules(self, last_layer=None):
527
+ self.final_proj = None
528
+ self.ema = None
529
+ if last_layer is not None:
530
+ self.encoder.layers = nn.ModuleList(
531
+ l for i, l in enumerate(self.encoder.layers) if i <= last_layer
532
+ )
533
+
534
+ import logging
535
+
536
+ import torch
537
+ import torch.nn.functional as F
538
+ from fairseq import tasks
539
+ from fairseq.checkpoint_utils import load_checkpoint_to_cpu
540
+ from fairseq.data.audio.audio_utils import get_features_or_waveform
541
+ from omegaconf import OmegaConf
542
+
543
+ logger = logging.getLogger("dump_feature")
544
+
545
+
546
+ class Data2vecFeatureReader(object):
547
+ def __init__(self, ckpt_path: str, layer: int, device: str, max_chunk=1600000):
548
+ state = load_checkpoint_to_cpu(ckpt_path)
549
+ cfg = state["cfg"]
550
+ # load task
551
+ task = tasks.setup_task(cfg.task, from_checkpoint=True)
552
+ task.load_state_dict(state["task_state"])
553
+ # load model config
554
+ if "layer_type" not in cfg.model:
555
+ # fix a missing key
556
+ model_config = {k: v for k, v in cfg.model.items()}
557
+ model_config["layer_type"] = "transformer"
558
+ model_config = OmegaConf.create(model_config)
559
+ else:
560
+ model_config = cfg.model
561
+
562
+ # fix param name in the state
563
+ state["model"]["final_proj.weight"] = state["model"].pop("final_proj.0.weight")
564
+ state["model"]["final_proj.bias"] = state["model"].pop("final_proj.0.bias")
565
+ del state["model"]["_ema"]
566
+
567
+ # load model
568
+ model = Data2VecAudioModel.build_model(model_config)
569
+ model.load_state_dict(
570
+ state["model"], strict=True, model_cfg=model_config
571
+ )
572
+
573
+ self.device = device
574
+ logger.info(f"device = {self.device}")
575
+
576
+ self.model = model.eval().to(self.device)
577
+ self.task = task
578
+ self.layer = layer - 1 # make it 1-based
579
+ self.max_chunk = max_chunk
580
+ logger.info(f"TASK CONFIG:\n{self.task.cfg}")
581
+ logger.info(f" max_chunk = {self.max_chunk}")
582
+
583
+ def read_audio(self, path, ref_len=None):
584
+ wav = get_features_or_waveform(path, need_waveform=True, use_sample_rate=self.task.cfg.sample_rate)
585
+ if wav.ndim == 2:
586
+ wav = wav.mean(-1)
587
+ assert wav.ndim == 1, wav.ndim
588
+ if ref_len is not None and abs(ref_len - len(wav)) > 160:
589
+ logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
590
+ return wav
591
+
592
+ def get_feats(self, path, ref_len=None):
593
+ x = self.read_audio(path, ref_len=ref_len)
594
+ with torch.no_grad():
595
+ x = torch.from_numpy(x).float().to(self.device)
596
+ if self.task.cfg.normalize:
597
+ x = F.layer_norm(x, x.shape)
598
+ x = x.view(1, -1)
599
+
600
+ feat = []
601
+ for start in range(0, x.size(1), self.max_chunk):
602
+ x_chunk = x[:, start: start + self.max_chunk]
603
+ res = self.model.extract_features(
604
+ source=x_chunk,
605
+ padding_mask=None,
606
+ mask=False,
607
+ layer=self.layer,
608
+ )
609
+ feat_chunk = res["x"]
610
+ feat.append(feat_chunk)
611
+ return torch.cat(feat, 1).squeeze(0)
ASR/demo.ipynb ADDED
@@ -0,0 +1,878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "715a402a-44b9-4fa2-abf0-b0cfd2f3d80b",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Recording voice in Real Time"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "dbdf6bab-7418-4a6f-8b75-c31f98a6ada5",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "\"\"\"\n",
19
+ "Sprints:\n",
20
+ "- [ ] Do Inference optimization of ASR LM\n",
21
+ "- [ ] Train on train.other.500\n",
22
+ "- [ ] Generate dataset for prompting\n",
23
+ "\n",
24
+ "Evaluation Dates: 20th - 21st June, 2023, 3:30 - 5:30pm\n",
25
+ "Sharpen PPT Skills: 20th June, 3:30pm - 4:45pm\n",
26
+ "Flow of the PPT:\n",
27
+ "Demo -> Datasets -> Techniques -> Evaluation -> Q&A\n",
28
+ "- [ Done ] Update the one pager deck slide\n",
29
+ "https://sprinklr-my.sharepoint.com/:p:/r/personal/sricharan_narayanam_sprinklr_com/_layouts/15/Doc.aspx?sourcedoc=%7B84811f56-5fc7-4eaa-87d2-db4a3588d18c%7D&action=edit&wdPreviousSession=948ccc35-dc05-f1f9-612d-9a22300e25ba\n",
30
+ "My PPT:\n",
31
+ "https://sprinklr-my.sharepoint.com/:p:/p/darshan_makwana/Ec4jCiyMWhxMproH625msc8BClFVceNQ8o4kS3EhZBO9MA?e=YCSDxm&wdOrigin=TEAMS-MAGLEV.p2p_ns.rwc&wdExp=TEAMS-TREATMENT&wdhostclicktime=1718703689001&web=1\n",
32
+ "Intern Tracker:\n",
33
+ "https://sprinklr.sharepoint.com/:x:/s/AIIntuition/EbRhHPIAIw9MlZ5PpXbztmABde1LFbaSoSHJAo9qU8ggDg?e=xiLkRt&wdOrigin=TEAMS-MAGLEV.p2p_ns.rwc&wdExp=TEAMS-TREATMENT&wdhostclicktime=1718692666812&web=1\n",
34
+ "\"\"\""
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "markdown",
39
+ "id": "150aca01-4098-4ab2-809a-25775ec52069",
40
+ "metadata": {},
41
+ "source": [
42
+ "## ASR LM Inference"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "id": "804a58af-beb2-48c1-9530-98024e27c0d6",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "from audio_tokenizer import Data2vecFeatureReader\n",
53
+ "from repcodec.RepCodec import RepCodec\n",
54
+ "import torch.nn.functional as F\n",
55
+ "import torch\n",
56
+ "import yaml\n",
57
+ "\n",
58
+ "reader = Data2vecFeatureReader(\"./../prompting/models/vox_pretrained.pt\", 18, device=\"cuda:0\", max_chunk=1600000)\n",
59
+ "\n",
60
+ "config = \"./repcodec/configs/repcodec_dim1024.yaml\"\n",
61
+ "with open(config) as fp:\n",
62
+ " conf = yaml.load(fp, Loader=yaml.FullLoader)\n",
63
+ "\n",
64
+ "audio_model = RepCodec(**conf)\n",
65
+ "audio_model.load_state_dict(torch.load(\"./../prompting/models/data2vec_large_l18.pkl\", map_location=\"cuda:0\")[\"model\"][\"repcodec\"])\n",
66
+ "audio_model.quantizer.initial()\n",
67
+ "audio_model.to(\"cuda:0\")\n",
68
+ "audio_model.eval()\n",
69
+ "\n",
70
+ "print(\"Successfully Loaded Audio Tokenizer\")"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "7d8da397-2030-4b36-9a42-97862488797b",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "from datasets import load_dataset\n",
81
+ "\n",
82
+ "cache_dir = \"./../cache\"\n",
83
+ "dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir, trust_remote_code=True)"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": 1,
89
+ "id": "bb8016b2-fc9d-4c23-9e85-b6e1c5ca164c",
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer\n",
94
+ "import torch\n",
95
+ "import string\n",
96
+ "\n",
97
+ "def process(text):\n",
98
+ "\n",
99
+ " # Lower case every letter\n",
100
+ " text = text.lower()\n",
101
+ "\n",
102
+ " # Remove punctuation\n",
103
+ " punctuation_to_remove = string.punctuation.replace(\"'\", \"\")\n",
104
+ " translation_table = str.maketrans('', '', punctuation_to_remove)\n",
105
+ " text = text.translate(translation_table)\n",
106
+ "\n",
107
+ " # Remove whitespaces from front and behind\n",
108
+ " while text[0] == ' ' or text[-1] == ' ':\n",
109
+ " if text[0] == ' ':\n",
110
+ " text = text[1:]\n",
111
+ " if text[-1] == ' ':\n",
112
+ " text = text[:-1]\n",
113
+ " \n",
114
+ " return text\n",
115
+ "\n",
116
+ "device = \"cuda:0\"\n",
117
+ "dtype = torch.float16\n",
118
+ "context_length = 1877\n",
119
+ "\n",
120
+ "# Load tokenizer and add audio tokens\n",
121
+ "tokenizer = AutoTokenizer.from_pretrained(\"./tokenizer\")\n",
122
+ "eot_token = tokenizer.encode(\"<|endoftranscript|>\")[0]\n",
123
+ "pad_token = tokenizer.encode(\"<|padding|>\")[0]\n",
124
+ "\n",
125
+ "model = GPT2LMHeadModel.from_pretrained(\"./../out/checkpoint-19000\", attn_implementation=\"flash_attention_2\", device_map=device, torch_dtype=dtype).eval()\n",
126
+ "model.config.pad_token_id = pad_token\n",
127
+ "model.config.eos_token_id = eot_token\n",
128
+ "# model = torch.compile(model)"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": 3,
134
+ "id": "693db182-92ac-4e36-b848-989fafd10e73",
135
+ "metadata": {},
136
+ "outputs": [
137
+ {
138
+ "data": {
139
+ "text/plain": [
140
+ "GPT2Model(\n",
141
+ " (wte): Embedding(6027, 768)\n",
142
+ " (wpe): Embedding(1877, 768)\n",
143
+ " (drop): Dropout(p=0.1, inplace=False)\n",
144
+ " (h): ModuleList(\n",
145
+ " (0-11): 12 x GPT2Block(\n",
146
+ " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
147
+ " (attn): GPT2FlashAttention2(\n",
148
+ " (c_attn): Conv1D()\n",
149
+ " (c_proj): Conv1D()\n",
150
+ " (attn_dropout): Dropout(p=0.1, inplace=False)\n",
151
+ " (resid_dropout): Dropout(p=0.1, inplace=False)\n",
152
+ " )\n",
153
+ " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
154
+ " (mlp): GPT2MLP(\n",
155
+ " (c_fc): Conv1D()\n",
156
+ " (c_proj): Conv1D()\n",
157
+ " (act): NewGELUActivation()\n",
158
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
159
+ " )\n",
160
+ " )\n",
161
+ " )\n",
162
+ " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
163
+ ")"
164
+ ]
165
+ },
166
+ "execution_count": 3,
167
+ "metadata": {},
168
+ "output_type": "execute_result"
169
+ }
170
+ ],
171
+ "source": [
172
+ "model.transformer"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": null,
178
+ "id": "7cabe9dc-bbbf-41b4-918f-3f60ee5582f2",
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": [
182
+ "from tqdm import tqdm\n",
183
+ "from math import ceil\n",
184
+ "import torch\n",
185
+ "import time\n",
186
+ "\n",
187
+ "sample = dataset[\"train.clean.100\"][5]\n",
188
+ "\n",
189
+ "x = sample[\"audio\"][\"array\"]\n",
190
+ "\n",
191
+ "start_time = time.time()\n",
192
+ "\n",
193
+ "with torch.no_grad():\n",
194
+ " x = torch.from_numpy(x).float().to(reader.device)\n",
195
+ " if reader.task.cfg.normalize:\n",
196
+ " x = F.layer_norm(x, x.shape)\n",
197
+ " x = x.view(1, -1)\n",
198
+ "\n",
199
+ " feat = []\n",
200
+ " for start in range(0, x.size(1), reader.max_chunk):\n",
201
+ " x_chunk = x[:, start: start + reader.max_chunk]\n",
202
+ " res = reader.model.extract_features(\n",
203
+ " source=x_chunk,\n",
204
+ " padding_mask=None,\n",
205
+ " mask=False,\n",
206
+ " layer=reader.layer,\n",
207
+ " )\n",
208
+ " feat_chunk = res[\"x\"]\n",
209
+ " feat.append(feat_chunk)\n",
210
+ " \n",
211
+ " features = torch.cat(feat, 1).permute(0, 2, 1)\n",
212
+ "\n",
213
+ " x = audio_model.encoder(features)\n",
214
+ " z = audio_model.projector(x)\n",
215
+ " _, idx = audio_model.quantizer.codebook.forward_index(z.transpose(2, 1))\n",
216
+ " tokens = idx.cpu().data.numpy().tolist()[0]\n",
217
+ " \n",
218
+ "text = \"\".join([f\"<|audio:{token}|>\" for token in tokens]) + \"<|startoftranscript|>\"\n",
219
+ "input_ids = tokenizer(text, return_tensors=\"pt\").to(device)[\"input_ids\"]\n",
220
+ "\n",
221
+ "input_time = time.time()\n",
222
+ "\n",
223
+ "generations = model.generate(\n",
224
+ " input_ids,\n",
225
+ " pad_token_id = pad_token,\n",
226
+ " eos_token_id = eot_token,\n",
227
+ " max_new_tokens = context_length,\n",
228
+ " use_cache=True\n",
229
+ ")\n",
230
+ "\n",
231
+ "finish_time = time.time()\n",
232
+ "\n",
233
+ "tokenizer.batch_decode(generations, skip_special_tokens=True)\n",
234
+ "print(\"First Token Latency: \", (input_time - start_time) * 1000, \"ms\")\n",
235
+ "# print(\"Throughput: \", (1 + num_tokens)/total_time, \"tokens/s\")\n",
236
+ "print(\"End to End Inference Time: \", (finish_time - start_time) * 1000, \"ms\")\n",
237
+ "print(\"Refer Text: \", process(sample[\"text\"]))\n",
238
+ "print(\"Transcript: \", tokenizer.batch_decode(generations, skip_special_tokens=True)[0])"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": null,
244
+ "id": "baa8d79b-7cf5-4435-838c-1f3d4e043d60",
245
+ "metadata": {},
246
+ "outputs": [],
247
+ "source": [
248
+ "import time\n",
249
+ "\n",
250
+ "sample = dataset[\"train.clean.100\"][0]\n",
251
+ "\n",
252
+ "x = sample[\"audio\"][\"array\"]\n",
253
+ "\n",
254
+ "start_time = time.time()\n",
255
+ "\n",
256
+ "with torch.no_grad():\n",
257
+ " x = torch.from_numpy(x).float().to(reader.device)\n",
258
+ " if reader.task.cfg.normalize:\n",
259
+ " x = F.layer_norm(x, x.shape)\n",
260
+ " x = x.view(1, -1)\n",
261
+ "\n",
262
+ " feat = []\n",
263
+ " for start in range(0, x.size(1), reader.max_chunk):\n",
264
+ " x_chunk = x[:, start: start + reader.max_chunk]\n",
265
+ " res = reader.model.extract_features(\n",
266
+ " source=x_chunk,\n",
267
+ " padding_mask=None,\n",
268
+ " mask=False,\n",
269
+ " layer=reader.layer,\n",
270
+ " )\n",
271
+ " feat_chunk = res[\"x\"]\n",
272
+ " feat.append(feat_chunk)\n",
273
+ " \n",
274
+ " features = torch.cat(feat, 1).permute(0, 2, 1)\n",
275
+ "\n",
276
+ " x = audio_model.encoder(features)\n",
277
+ " z = audio_model.projector(x)\n",
278
+ " _, idx = audio_model.quantizer.codebook.forward_index(z.transpose(2, 1))\n",
279
+ " tokens = idx.cpu().data.numpy().tolist()[0]\n",
280
+ "\n",
281
+ "from tqdm import tqdm\n",
282
+ "from math import ceil\n",
283
+ "import torch\n",
284
+ "\n",
285
+ "context_length = 1877\n",
286
+ "eot_token = tokenizer.encode(\"<|endoftranscript|>\")[0]\n",
287
+ "pad_token = tokenizer.encode(\"<|padding|>\")[0]\n",
288
+ " \n",
289
+ "text = \"\".join([f\"<|audio:{token}|>\" for token in tokens]) + \"<|startoftranscript|>\"\n",
290
+ "input_ids = tokenizer(text, return_tensors=\"pt\").to(device)[\"input_ids\"]\n",
291
+ "\n",
292
+ "max_new_tokens = context_length\n",
293
+ "num_tokens = 0\n",
294
+ "first_token = True\n",
295
+ "\n",
296
+ "while max_new_tokens > 0 and input_ids.shape[-1] < context_length:\n",
297
+ "\n",
298
+ " with torch.no_grad():\n",
299
+ " outputs = model(input_ids = input_ids)\n",
300
+ "\n",
301
+ " logits = outputs[\"logits\"][:, -1]\n",
302
+ "\n",
303
+ " # Greedy Sampling\n",
304
+ " probas = torch.softmax(logits, dim=-1)\n",
305
+ " pred_idx = torch.argmax(probas, dim=-1, keepdim=True)\n",
306
+ " next_idx = pred_idx.item()\n",
307
+ "\n",
308
+ " if first_token:\n",
309
+ " first_token_latency = time.time() - start_time\n",
310
+ " first_token = False\n",
311
+ " start_time = time.time()\n",
312
+ "\n",
313
+ " if next_idx == eot_token:\n",
314
+ " break\n",
315
+ "\n",
316
+ " input_ids = torch.cat((input_ids, pred_idx), dim=-1)\n",
317
+ "\n",
318
+ " max_new_tokens -= 1\n",
319
+ " num_tokens += 1\n",
320
+ "\n",
321
+ "total_time = time.time() - start_time\n",
322
+ "\n",
323
+ "print(\"First Token Latency: \", first_token_latency * 1000, \"ms\")\n",
324
+ "print(\"Throughput: \", (1 + num_tokens)/total_time, \"tokens/s\")\n",
325
+ "print(\"End to End Inference Time: \", (total_time + first_token_latency) * 1000, \"ms\")\n",
326
+ "print(tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0])\n",
327
+ "print(process(sample[\"text\"]))"
328
+ ]
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "execution_count": null,
333
+ "id": "014ed999-3293-4d68-8f9c-017584adc642",
334
+ "metadata": {},
335
+ "outputs": [],
336
+ "source": [
337
+ "tokenizer.batch_decode([[1, 2, 3]])"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "markdown",
342
+ "id": "ec11e43f-1eb8-4399-9a93-6f1427782661",
343
+ "metadata": {
344
+ "jp-MarkdownHeadingCollapsed": true
345
+ },
346
+ "source": [
347
+ "## Accelerating GPT 2 Inference"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "code",
352
+ "execution_count": null,
353
+ "id": "5489cb4e-3213-4931-abe1-4c96d1a7ba56",
354
+ "metadata": {},
355
+ "outputs": [],
356
+ "source": [
357
+ "\"\"\"\n",
358
+ "- change tensorrt.tensorrt to tensorrt\n",
359
+ "- remove cpu quantization lines\n",
360
+ "- output_names [\"logits\"]\n",
361
+ "\"\"\""
362
+ ]
363
+ },
364
+ {
365
+ "cell_type": "code",
366
+ "execution_count": null,
367
+ "id": "7e7e6ea6-7319-4e57-af33-5d917d26abc6",
368
+ "metadata": {},
369
+ "outputs": [],
370
+ "source": [
371
+ "import logging\n",
372
+ "import time\n",
373
+ "from typing import Callable, Dict\n",
374
+ "\n",
375
+ "import numpy as np\n",
376
+ "import tensorrt as trt\n",
377
+ "import torch\n",
378
+ "from tensorrt import ICudaEngine\n",
379
+ "from tensorrt import Logger, Runtime\n",
380
+ "from transformers import AutoTokenizer, BatchEncoding, GPT2LMHeadModel, AutoModelForCausalLM\n",
381
+ "from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions\n",
382
+ "from transformer_deploy.utils.generative_model import GPTModelWrapper\n",
383
+ "import inspect\n",
384
+ "from transformers import TensorType\n",
385
+ "\n",
386
+ "from transformer_deploy.backends.ort_utils import create_model_for_provider, inference_onnx_binding, optimize_onnx\n",
387
+ "from transformer_deploy.backends.pytorch_utils import convert_to_onnx, get_model_size\n",
388
+ "from transformer_deploy.backends.trt_utils import build_engine, load_engine, save_engine"
389
+ ]
390
+ },
391
+ {
392
+ "cell_type": "code",
393
+ "execution_count": null,
394
+ "id": "21681412-7747-4824-894a-6006eb12a821",
395
+ "metadata": {},
396
+ "outputs": [],
397
+ "source": [
398
+ "model_name = \"gpt2\"\n",
399
+ "\n",
400
+ "model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(model_name)\n",
401
+ "model.eval()\n",
402
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
403
+ "model.config.pad_token_id = tokenizer.eos_token_id"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": null,
409
+ "id": "46783acd-c404-44b4-904b-d8fb687afc34",
410
+ "metadata": {},
411
+ "outputs": [],
412
+ "source": [
413
+ "inputs = tokenizer(\"Here is some text to encode Hello World\", return_tensors=\"pt\")\n",
414
+ "print(\"input tensors\")\n",
415
+ "print(inputs)\n",
416
+ "print(\"input tensor shape\")\n",
417
+ "print(inputs[\"input_ids\"].size())\n",
418
+ "\n",
419
+ "with torch.no_grad():\n",
420
+ " outputs = model(**inputs)\n",
421
+ "\n",
422
+ "logits = outputs.logits\n",
423
+ "print(\"output tensor\")\n",
424
+ "print(logits)\n",
425
+ "print(\"output shape\")\n",
426
+ "print(logits.shape)"
427
+ ]
428
+ },
429
+ {
430
+ "cell_type": "code",
431
+ "execution_count": null,
432
+ "id": "2f6cc7bd-5e2d-4d4e-a7e6-73a6b2ecd7af",
433
+ "metadata": {},
434
+ "outputs": [],
435
+ "source": [
436
+ "size = 0\n",
437
+ "for i in range(8, 256, 1):\n",
438
+ " # input sequence (input_ids) made of int-32 (4 bytes)\n",
439
+ " size += np.prod([1, i]) * 4\n",
440
+ " # output tensor made of float-32 (4 bytes)\n",
441
+ " size += np.prod([1, i, 50257]) * 4\n",
442
+ "print(f\"total size (input+output): {size / 1024**3:.2f} Gb\")\n",
443
+ "\n",
444
+ "# to manually check actual tensor size:\n",
445
+ "# np.prod(logits.shape)*32/8/1024**2:.2f}\n",
446
+ "# or\n",
447
+ "# sys.getsizeof(logits.storage())/1024**2"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": null,
453
+ "id": "7debb40e-9941-45e4-9db8-4bb021ce44ab",
454
+ "metadata": {},
455
+ "outputs": [],
456
+ "source": [
457
+ "input_ids: BatchEncoding = tokenizer(\n",
458
+ " \"Here is some text to encode Hello World\", add_special_tokens=True, return_attention_mask=False, return_tensors=\"pt\"\n",
459
+ ")\n",
460
+ "# some inference engines don't support int64 tensor as inputs, we convert all input tensors to int32 type\n",
461
+ "for k, v in input_ids.items(): # type: str, torch.Tensor\n",
462
+ " input_ids[k] = v.type(dtype=torch.int32)\n",
463
+ "\n",
464
+ "convert_to_onnx(\n",
465
+ " model_pytorch=model,\n",
466
+ " output_path=\"test-gpt2.onnx\",\n",
467
+ " inputs_pytorch=dict(input_ids),\n",
468
+ " quantization=False,\n",
469
+ " var_output_seq=True, # we inform ONNX export tool that the output shape will vary with the input shape\n",
470
+ " output_names = [\"logits\"]\n",
471
+ ")\n",
472
+ "# model may switch to train mode for some unknown reasons, we force the eval mode.\n",
473
+ "_ = model.eval()"
474
+ ]
475
+ },
476
+ {
477
+ "cell_type": "code",
478
+ "execution_count": null,
479
+ "id": "956c3007-2c18-4d92-af4f-6cef474d86b5",
480
+ "metadata": {},
481
+ "outputs": [],
482
+ "source": [
483
+ "logging.basicConfig()\n",
484
+ "logging.getLogger().setLevel(logging.INFO)\n",
485
+ "num_attention_heads, hidden_size = get_model_size(path=model_name)\n",
486
+ "optimize_onnx(\n",
487
+ " onnx_path=\"test-gpt2.onnx\",\n",
488
+ " onnx_optim_model_path=\"test-gpt2-opt.onnx\",\n",
489
+ " fp16=False,\n",
490
+ " use_cuda=True,\n",
491
+ " num_attention_heads=num_attention_heads,\n",
492
+ " hidden_size=hidden_size,\n",
493
+ " architecture=\"gpt2\",\n",
494
+ ")"
495
+ ]
496
+ },
497
+ {
498
+ "cell_type": "code",
499
+ "execution_count": null,
500
+ "id": "85f30ed9-2802-46c9-9201-a70e200b6860",
501
+ "metadata": {},
502
+ "outputs": [],
503
+ "source": [
504
+ "from pathlib import Path\n",
505
+ "\n",
506
+ "trt_logger: Logger = trt.Logger(trt.Logger.ERROR)\n",
507
+ "runtime: Runtime = trt.Runtime(trt_logger)\n",
508
+ "trt_model_name = \"test-gpt2.plan\"\n",
509
+ "\n",
510
+ "# create only of does not exist because it's slow to run...\n",
511
+ "\n",
512
+ "engine: ICudaEngine = build_engine(\n",
513
+ " runtime=runtime,\n",
514
+ " onnx_file_path=\"test-gpt2.onnx\",\n",
515
+ " logger=trt_logger,\n",
516
+ " min_shape=(1, 1),\n",
517
+ " optimal_shape=(1, 128), # num beam, batch size\n",
518
+ " max_shape=(1, 384), # num beam, batch size\n",
519
+ " workspace_size=10000 * 1024**2,\n",
520
+ " fp16=True,\n",
521
+ " int8=False,\n",
522
+ ")\n",
523
+ "save_engine(engine, trt_model_name)"
524
+ ]
525
+ },
526
+ {
527
+ "cell_type": "code",
528
+ "execution_count": null,
529
+ "id": "908fe664-800e-4c5f-a1d5-adfd31fd1c64",
530
+ "metadata": {},
531
+ "outputs": [],
532
+ "source": [
533
+ "engine.num_bindings"
534
+ ]
535
+ },
536
+ {
537
+ "cell_type": "code",
538
+ "execution_count": null,
539
+ "id": "4626926b-fa94-4633-95d5-0d515f8db5f6",
540
+ "metadata": {},
541
+ "outputs": [],
542
+ "source": [
543
+ "print(inspect.getsource(GPTModelWrapper))"
544
+ ]
545
+ },
546
+ {
547
+ "cell_type": "code",
548
+ "execution_count": null,
549
+ "id": "d5bd1de1-a949-46a3-8d15-457d51db4e40",
550
+ "metadata": {},
551
+ "outputs": [],
552
+ "source": [
553
+ "inputs = tokenizer(\n",
554
+ " \"Here is some text to encode Hello World\", # Nvidia example prompt\n",
555
+ " add_special_tokens=True,\n",
556
+ " return_attention_mask=False, # Not used\n",
557
+ " return_tensors=TensorType.PYTORCH,\n",
558
+ ")\n",
559
+ "inputs"
560
+ ]
561
+ },
562
+ {
563
+ "cell_type": "code",
564
+ "execution_count": null,
565
+ "id": "815b548f-fa00-4183-b72c-10ecdd4b11c7",
566
+ "metadata": {},
567
+ "outputs": [],
568
+ "source": [
569
+ "from transformers.generation import GenerationConfig\n",
570
+ "\n",
571
+ "class GPTWrapper(GPTModelWrapper):\n",
572
+ " def __init__(self, *args, **kwargs):\n",
573
+ " super().__init__(*args, **kwargs)\n",
574
+ "\n",
575
+ " self.generation_config = GenerationConfig.from_model_config(self.config) if self.can_generate() else None\n",
576
+ "\n",
577
+ " @classmethod\n",
578
+ " def can_generate(cls) -> bool:\n",
579
+ " \"\"\"\n",
580
+ " Returns whether this model can generate sequences with `.generate()`.\n",
581
+ "\n",
582
+ " Returns:\n",
583
+ " `bool`: Whether this model can generate sequences with `.generate()`.\n",
584
+ " \"\"\"\n",
585
+ " # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.\n",
586
+ " # Alternativelly, the model can also have a custom `generate` function.\n",
587
+ " if \"GenerationMixin\" in str(cls.prepare_inputs_for_generation) and \"GenerationMixin\" in str(cls.generate):\n",
588
+ " return False\n",
589
+ " return True"
590
+ ]
591
+ },
592
+ {
593
+ "cell_type": "code",
594
+ "execution_count": null,
595
+ "id": "ca57ed1e-0bbe-48dd-ae0f-f3d8ecd7fd04",
596
+ "metadata": {},
597
+ "outputs": [],
598
+ "source": [
599
+ "def inference_torch(input_ids: torch.Tensor) -> torch.Tensor:\n",
600
+ " transformer_outputs: BaseModelOutputWithPastAndCrossAttentions = model.transformer(input_ids=input_ids)\n",
601
+ " return model.lm_head(transformer_outputs.last_hidden_state)\n",
602
+ "\n",
603
+ "\n",
604
+ "model.cuda()\n",
605
+ "model.eval()\n",
606
+ "inputs.to(\"cuda\")\n",
607
+ "with torch.inference_mode():\n",
608
+ " gpt2_model = GPTWrapper(config=model.config, device=model.device, inference=inference_torch)\n",
609
+ " sample_output = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
610
+ " print(tokenizer.decode(sample_output[0], skip_special_tokens=False))\n",
611
+ " for _ in range(2):\n",
612
+ " _ = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
613
+ " torch.cuda.synchronize()\n",
614
+ " start = time.time()\n",
615
+ " for _ in range(10):\n",
616
+ " _ = gpt2_model.generate(inputs.input_ids, max_length=256)\n",
617
+ " torch.cuda.synchronize()\n",
618
+ " print(f\"----\\nPytorch: {(time.time() - start)/10:.2f}s/sequence\")\n",
619
+ "_ = model.cpu()"
620
+ ]
621
+ },
622
+ {
623
+ "cell_type": "code",
624
+ "execution_count": null,
625
+ "id": "f0849aae-876e-47bc-b045-14a594170947",
626
+ "metadata": {},
627
+ "outputs": [],
628
+ "source": [
629
+ "model_onnx = create_model_for_provider(path=\"test-gpt2-opt.onnx\", provider_to_use=\"CUDAExecutionProvider\")\n",
630
+ "\n",
631
+ "\n",
632
+ "def inference_onnx_naive(input_ids: torch.Tensor) -> torch.Tensor:\n",
633
+ " data = {\"input_ids\": input_ids.detach().cpu().numpy().astype(np.int32)}\n",
634
+ " logit = model_onnx.run(None, data)\n",
635
+ " np_logit = np.array(logit) # convert list of numpy arrays to a numpy array\n",
636
+ " # we convert numpy tensor to Pytorch tensor as it's the type expected by HF decoding algorithm\n",
637
+ " return torch.squeeze(torch.from_numpy(np_logit), dim=0)\n",
638
+ "\n",
639
+ "\n",
640
+ "gpt2_model = GPTWrapper(config=model.config, device=torch.device(\"cpu\"), inference=inference_onnx_naive)\n",
641
+ "inputs.to(\"cpu\")\n",
642
+ "sample_output = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
643
+ "print(tokenizer.decode(sample_output[0], skip_special_tokens=True))\n",
644
+ "for _ in range(2):\n",
645
+ " _ = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
646
+ "start = time.time()\n",
647
+ "for _ in range(10):\n",
648
+ " _ = gpt2_model.generate(inputs.input_ids, max_length=256)\n",
649
+ "print(f\"----\\nONNX Runtime (standard API): {(time.time() - start)/10:.2f}s/sequence\")\n",
650
+ "\n",
651
+ "del model_onnx"
652
+ ]
653
+ },
654
+ {
655
+ "cell_type": "code",
656
+ "execution_count": null,
657
+ "id": "96114897-894b-4997-bc61-8ac0682e0e55",
658
+ "metadata": {},
659
+ "outputs": [],
660
+ "source": [
661
+ "model_onnx = create_model_for_provider(path=\"test-gpt2-opt.onnx\", provider_to_use=\"CUDAExecutionProvider\")\n",
662
+ "\n",
663
+ "\n",
664
+ "def inference_onnx_optimized(input_ids: torch.Tensor) -> torch.Tensor:\n",
665
+ " data = {\"input_ids\": input_ids}\n",
666
+ " return inference_onnx_binding(model_onnx=model_onnx, inputs=data, device=\"cuda\")[\"output\"]\n",
667
+ "\n",
668
+ "\n",
669
+ "gpt2_model = GPTWrapper(config=model.config, device=torch.device(\"cuda\"), inference=inference_onnx_optimized)\n",
670
+ "inputs.to(\"cuda\")\n",
671
+ "sample_output = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
672
+ "print(tokenizer.decode(sample_output[0], skip_special_tokens=True))\n",
673
+ "for _ in range(2):\n",
674
+ " _ = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
675
+ "start = time.time()\n",
676
+ "for _ in range(10):\n",
677
+ " _ = gpt2_model.generate(inputs.input_ids, max_length=256)\n",
678
+ "print(f\"----\\nONNX Runtime (binding io API): {(time.time() - start)/10:.2f}/sequence\")\n",
679
+ "del model_onnx"
680
+ ]
681
+ },
682
+ {
683
+ "cell_type": "code",
684
+ "execution_count": null,
685
+ "id": "0b5b5427-fd6b-4f70-b307-9c579f0f842a",
686
+ "metadata": {},
687
+ "outputs": [],
688
+ "source": [
689
+ "tensorrt_model: Callable[[Dict[str, torch.Tensor]], torch.Tensor] = load_engine(\n",
690
+ " engine_file_path=\"test-gpt2.plan\", runtime=runtime\n",
691
+ ")\n",
692
+ "\n",
693
+ "\n",
694
+ "def inference_tensorrt(input_ids: torch.Tensor) -> torch.Tensor:\n",
695
+ " data = {\"input_ids\": input_ids}\n",
696
+ " return tensorrt_model(data)\n",
697
+ "\n",
698
+ "\n",
699
+ "gpt2_model = GPTWrapper(config=model.config, device=torch.device(\"cuda\"), inference=inference_tensorrt)\n",
700
+ "inputs.to(\"cuda\")\n",
701
+ "sample_output = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
702
+ "print(tokenizer.decode(sample_output[0], skip_special_tokens=True))\n",
703
+ "for _ in range(2):\n",
704
+ " _ = gpt2_model.generate(inputs.input_ids, max_length=64)\n",
705
+ "start = time.time()\n",
706
+ "for _ in range(10):\n",
707
+ " _ = gpt2_model.generate(inputs.input_ids, max_length=256)\n",
708
+ "print(f\"----\\nTensorRT + CUDA tensors: {(time.time() - start)/10:.2f}/sequence\")\n",
709
+ "\n",
710
+ "del tensorrt_model"
711
+ ]
712
+ },
713
+ {
714
+ "cell_type": "markdown",
715
+ "id": "f547239d-4f7a-433b-8ef6-9e5110a61f4b",
716
+ "metadata": {
717
+ "jp-MarkdownHeadingCollapsed": true
718
+ },
719
+ "source": [
720
+ "## Using CUDAExecution Provider"
721
+ ]
722
+ },
723
+ {
724
+ "cell_type": "code",
725
+ "execution_count": null,
726
+ "id": "6e34c682-85fc-4e8d-b13c-7c1c9ea39ead",
727
+ "metadata": {},
728
+ "outputs": [],
729
+ "source": [
730
+ "from optimum.onnxruntime import ORTModelForCausalLM\n",
731
+ "from optimum.pipelines import pipeline\n",
732
+ "from transformers import AutoTokenizer\n",
733
+ "\n",
734
+ "model_id = \"openai-community/gpt2\"\n",
735
+ "\n",
736
+ "ort_model = ORTModelForCausalLM.from_pretrained(\n",
737
+ " model_id,\n",
738
+ " export=True,\n",
739
+ " provider=\"CUDAExecutionProvider\",\n",
740
+ " use_io_binding=True\n",
741
+ ")\n",
742
+ "\n",
743
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
744
+ "tokenizer.pad_token = tokenizer.eos_token\n",
745
+ "\n",
746
+ "pipe = pipeline(task=\"text-generation\", model=ort_model, tokenizer=tokenizer, device=\"cuda:0\")"
747
+ ]
748
+ },
749
+ {
750
+ "cell_type": "code",
751
+ "execution_count": null,
752
+ "id": "17d28184-26db-4dd3-b24b-0c5a12b10d6d",
753
+ "metadata": {},
754
+ "outputs": [],
755
+ "source": [
756
+ "import time\n",
757
+ "\n",
758
+ "start_time = time.time()\n",
759
+ "\n",
760
+ "generations = pipe(\"Both the music and visual were astounding, not to mention the actors performance.\")\n",
761
+ "generations[0][\"generated_text\"]\n",
762
+ "\n",
763
+ "finish_time = time.time()\n",
764
+ "\n",
765
+ "print(\"End to End Latency: \", (finish_time - start_time) * 1000, \"ms\")"
766
+ ]
767
+ },
768
+ {
769
+ "cell_type": "markdown",
770
+ "id": "19c4230a-3244-4dce-b5ef-d9927dec5c45",
771
+ "metadata": {},
772
+ "source": [
773
+ "## ASR LM with CUDAExcecution Provider"
774
+ ]
775
+ },
776
+ {
777
+ "cell_type": "code",
778
+ "execution_count": null,
779
+ "id": "0f0f1cdc-bfcd-46c5-80a4-60bc76366cf5",
780
+ "metadata": {},
781
+ "outputs": [],
782
+ "source": [
783
+ "from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer\n",
784
+ "from datasets import DatasetDict\n",
785
+ "import torch\n",
786
+ "\n",
787
+ "device = \"cuda:0\"\n",
788
+ "dtype = torch.float16\n",
789
+ "\n",
790
+ "dataset = DatasetDict.load_from_disk(\"./../librispeech_tokenized.hf\")\n",
791
+ "\n",
792
+ "from optimum.onnxruntime import ORTModelForCausalLM\n",
793
+ "from optimum.pipelines import pipeline\n",
794
+ "from transformers import AutoTokenizer\n",
795
+ "\n",
796
+ "model_id = \"./../out/checkpoint-10000\"\n",
797
+ "\n",
798
+ "ort_model = ORTModelForCausalLM.from_pretrained(\n",
799
+ " model_id,\n",
800
+ " export=True,\n",
801
+ " provider=\"CUDAExecutionProvider\",\n",
802
+ " use_io_binding=True\n",
803
+ ")\n",
804
+ "\n",
805
+ "tokenizer = AutoTokenizer.from_pretrained(\"./tokenizer\")\n",
806
+ "\n",
807
+ "pipe = pipeline(task=\"text-generation\", model=ort_model, tokenizer=tokenizer, device=\"cuda:0\")"
808
+ ]
809
+ },
810
+ {
811
+ "cell_type": "code",
812
+ "execution_count": null,
813
+ "id": "9d32098c-b0ec-4c36-95ac-775a3a865512",
814
+ "metadata": {},
815
+ "outputs": [],
816
+ "source": [
817
+ "ort_model.config.eos_token_id = tokenizer.encode(\"<|endoftranscript|>\")[0]\n",
818
+ "ort_model.config.bos_token_id = tokenizer.encode(\"<|startoftranscript|>\")[0]"
819
+ ]
820
+ },
821
+ {
822
+ "cell_type": "code",
823
+ "execution_count": null,
824
+ "id": "1fd0a1fb-9349-4c7a-af03-21e29334f420",
825
+ "metadata": {},
826
+ "outputs": [],
827
+ "source": [
828
+ "dataset[split][idx].keys()"
829
+ ]
830
+ },
831
+ {
832
+ "cell_type": "code",
833
+ "execution_count": null,
834
+ "id": "15d8b989-6460-4555-b6e2-2f9e219d7034",
835
+ "metadata": {},
836
+ "outputs": [],
837
+ "source": [
838
+ "split = \"train.clean.100\"\n",
839
+ "idx = 0\n",
840
+ "\n",
841
+ "text = \"\".join([ f\"<|audio:{tkn}|>\"for tkn in dataset[split][idx][\"audio_tokens\"]]) + \"<|startoftranscript|>\"\n",
842
+ "\n",
843
+ "import time\n",
844
+ "\n",
845
+ "start_time = time.time()\n",
846
+ "\n",
847
+ "generations = pipe(text, max_new_tokens=10, skip_special_tokens=True)\n",
848
+ "\n",
849
+ "finish_time = time.time()\n",
850
+ "\n",
851
+ "print(generations[0][\"generated_text\"])\n",
852
+ "\n",
853
+ "print(\"End to End Latency: \", (finish_time - start_time) * 1000, \"ms\")"
854
+ ]
855
+ }
856
+ ],
857
+ "metadata": {
858
+ "kernelspec": {
859
+ "display_name": "Python 3 (ipykernel)",
860
+ "language": "python",
861
+ "name": "python3"
862
+ },
863
+ "language_info": {
864
+ "codemirror_mode": {
865
+ "name": "ipython",
866
+ "version": 3
867
+ },
868
+ "file_extension": ".py",
869
+ "mimetype": "text/x-python",
870
+ "name": "python",
871
+ "nbconvert_exporter": "python",
872
+ "pygments_lexer": "ipython3",
873
+ "version": "3.8.10"
874
+ }
875
+ },
876
+ "nbformat": 4,
877
+ "nbformat_minor": 5
878
+ }
ASR/demo.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request
2
+ # import speech_recognition as sr
3
+
4
+ app = Flask(__name__)
5
+ # recognizer = sr.Recognizer()
6
+
7
+ @app.route("/darshan/microphone", methods=['POST'])
8
+ def handle_audio():
9
+ audio_data = request.data
10
+ print(audio_data)
11
+ # audio = sr.AudioData(audio_data, sample_rate=44100, sample_width=2) # Adjust sample rate and sample width as needed
12
+ # try:
13
+ # text = recognizer.recognize_google(audio)
14
+ # print(f"Transcription: {text}")
15
+ # return {'transcription': text}, 200
16
+ # except sr.UnknownValueError:
17
+ # print("Could not understand audio")
18
+ # return '', 400
19
+ # except sr.RequestError as e:
20
+ # print(f"Error from Google Speech Recognition service; {e}")
21
+ # return '', 500
22
+
23
+ if __name__ == '__main__':
24
+ app.run(host='0.0.0.0', port=8723) # Replace with your desired host and port
ASR/repcodec/.ipynb_checkpoints/RepCodec-checkpoint.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch.nn as nn
9
+
10
+ from repcodec.modules.decoder import Decoder
11
+ from repcodec.modules.encoder import Encoder
12
+ from repcodec.modules.projector import Projector
13
+ from repcodec.modules.quantizer import Quantizer
14
+
15
+
16
+ class RepCodec(nn.Module):
17
+ def __init__(
18
+ self,
19
+ input_channels=768,
20
+ output_channels=768,
21
+ encode_channels=768,
22
+ decode_channels=768,
23
+ code_dim=768,
24
+ codebook_num=1,
25
+ codebook_size=1024,
26
+ bias=True,
27
+ enc_ratios=(1, 1),
28
+ dec_ratios=(1, 1),
29
+ enc_strides=(1, 1),
30
+ dec_strides=(1, 1),
31
+ enc_kernel_size=3,
32
+ dec_kernel_size=3,
33
+ enc_block_dilations=(1, 1),
34
+ enc_block_kernel_size=3,
35
+ dec_block_dilations=(1, 1),
36
+ dec_block_kernel_size=3
37
+ ):
38
+ super().__init__()
39
+
40
+ self.input_channels = input_channels
41
+
42
+ self.encoder = Encoder(
43
+ input_channels=input_channels,
44
+ encode_channels=encode_channels,
45
+ channel_ratios=enc_ratios,
46
+ strides=enc_strides,
47
+ kernel_size=enc_kernel_size,
48
+ bias=bias,
49
+ block_dilations=enc_block_dilations,
50
+ unit_kernel_size=enc_block_kernel_size
51
+ )
52
+
53
+ self.decoder = Decoder(
54
+ code_dim=code_dim,
55
+ output_channels=output_channels,
56
+ decode_channels=decode_channels,
57
+ channel_ratios=dec_ratios,
58
+ strides=dec_strides,
59
+ kernel_size=dec_kernel_size,
60
+ bias=bias,
61
+ block_dilations=dec_block_dilations,
62
+ unit_kernel_size=dec_block_kernel_size
63
+ )
64
+
65
+ self.projector = Projector(
66
+ input_channels=self.encoder.out_channels,
67
+ code_dim=code_dim,
68
+ kernel_size=3,
69
+ stride=1,
70
+ bias=False
71
+ )
72
+
73
+ self.quantizer = Quantizer(
74
+ code_dim=code_dim,
75
+ codebook_num=codebook_num,
76
+ codebook_size=codebook_size
77
+ )
78
+
79
+ def forward(self, x):
80
+ x = self.encoder(x)
81
+ z = self.projector(x)
82
+ zq, vqloss, perplexity = self.quantizer(z)
83
+ y = self.decoder(zq)
84
+ return y, zq, z, vqloss, perplexity
ASR/repcodec/RepCodec.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch.nn as nn
9
+
10
+ from repcodec.modules.decoder import Decoder
11
+ from repcodec.modules.encoder import Encoder
12
+ from repcodec.modules.projector import Projector
13
+ from repcodec.modules.quantizer import Quantizer
14
+
15
+
16
+ class RepCodec(nn.Module):
17
+ def __init__(
18
+ self,
19
+ input_channels=768,
20
+ output_channels=768,
21
+ encode_channels=768,
22
+ decode_channels=768,
23
+ code_dim=768,
24
+ codebook_num=1,
25
+ codebook_size=1024,
26
+ bias=True,
27
+ enc_ratios=(1, 1),
28
+ dec_ratios=(1, 1),
29
+ enc_strides=(1, 1),
30
+ dec_strides=(1, 1),
31
+ enc_kernel_size=3,
32
+ dec_kernel_size=3,
33
+ enc_block_dilations=(1, 1),
34
+ enc_block_kernel_size=3,
35
+ dec_block_dilations=(1, 1),
36
+ dec_block_kernel_size=3
37
+ ):
38
+ super().__init__()
39
+
40
+ self.input_channels = input_channels
41
+
42
+ self.encoder = Encoder(
43
+ input_channels=input_channels,
44
+ encode_channels=encode_channels,
45
+ channel_ratios=enc_ratios,
46
+ strides=enc_strides,
47
+ kernel_size=enc_kernel_size,
48
+ bias=bias,
49
+ block_dilations=enc_block_dilations,
50
+ unit_kernel_size=enc_block_kernel_size
51
+ )
52
+
53
+ self.decoder = Decoder(
54
+ code_dim=code_dim,
55
+ output_channels=output_channels,
56
+ decode_channels=decode_channels,
57
+ channel_ratios=dec_ratios,
58
+ strides=dec_strides,
59
+ kernel_size=dec_kernel_size,
60
+ bias=bias,
61
+ block_dilations=dec_block_dilations,
62
+ unit_kernel_size=dec_block_kernel_size
63
+ )
64
+
65
+ self.projector = Projector(
66
+ input_channels=self.encoder.out_channels,
67
+ code_dim=code_dim,
68
+ kernel_size=3,
69
+ stride=1,
70
+ bias=False
71
+ )
72
+
73
+ self.quantizer = Quantizer(
74
+ code_dim=code_dim,
75
+ codebook_num=codebook_num,
76
+ codebook_size=codebook_size
77
+ )
78
+
79
+ def forward(self, x):
80
+ x = self.encoder(x)
81
+ z = self.projector(x)
82
+ zq, vqloss, perplexity = self.quantizer(z)
83
+ y = self.decoder(zq)
84
+ return y, zq, z, vqloss, perplexity
ASR/repcodec/__pycache__/RepCodec.cpython-38.pyc ADDED
Binary file (1.87 kB). View file
 
ASR/repcodec/configs/repcodec_dim1024.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_channels: 1024
2
+ output_channels: 1024
3
+ encode_channels: 1024
4
+ decode_channels: 1024
5
+ code_dim: 1024
6
+ codebook_num: 1
7
+ codebook_size: 1024
8
+ bias: true
9
+ enc_ratios: [ 1, 1 ]
10
+ dec_ratios: [ 1, 1 ]
11
+ enc_strides: [ 1, 1 ] # no downsampling
12
+ dec_strides: [ 1, 1 ]
13
+ enc_kernel_size: 3
14
+ dec_kernel_size: 3
15
+ enc_block_dilations: [ 1, 1 ]
16
+ enc_block_kernel_size: 3
17
+ dec_block_dilations: [ 1, 1 ]
18
+ dec_block_kernel_size: 3
ASR/repcodec/configs/repcodec_dim1280.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_channels: 1280
2
+ output_channels: 1280
3
+ encode_channels: 1280
4
+ decode_channels: 1280
5
+ code_dim: 1280
6
+ codebook_num: 1
7
+ codebook_size: 1024
8
+ bias: true
9
+ enc_ratios: [ 1, 1 ]
10
+ dec_ratios: [ 1, 1 ]
11
+ enc_strides: [ 1, 1 ] # no downsampling
12
+ dec_strides: [ 1, 1 ]
13
+ enc_kernel_size: 3
14
+ dec_kernel_size: 3
15
+ enc_block_dilations: [ 1, 1 ]
16
+ enc_block_kernel_size: 3
17
+ dec_block_dilations: [ 1, 1 ]
18
+ dec_block_kernel_size: 3
ASR/repcodec/configs/repcodec_dim768.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_channels: 768
2
+ output_channels: 768
3
+ encode_channels: 768
4
+ decode_channels: 768
5
+ code_dim: 768
6
+ codebook_num: 1
7
+ codebook_size: 1024
8
+ bias: true
9
+ enc_ratios: [ 1, 1 ]
10
+ dec_ratios: [ 1, 1 ]
11
+ enc_strides: [ 1, 1 ] # no downsampling
12
+ dec_strides: [ 1, 1 ]
13
+ enc_kernel_size: 3
14
+ dec_kernel_size: 3
15
+ enc_block_dilations: [ 1, 1 ]
16
+ enc_block_kernel_size: 3
17
+ dec_block_dilations: [ 1, 1 ]
18
+ dec_block_kernel_size: 3
ASR/repcodec/layers/__pycache__/conv_layer.cpython-38.pyc ADDED
Binary file (2.52 kB). View file
 
ASR/repcodec/layers/__pycache__/vq_module.cpython-38.pyc ADDED
Binary file (5.14 kB). View file
 
ASR/repcodec/layers/conv_layer.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch.nn as nn
9
+
10
+
11
+ class Conv1d1x1(nn.Conv1d):
12
+ """1x1 Conv1d."""
13
+
14
+ def __init__(self, in_channels, out_channels, bias=True):
15
+ super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias)
16
+
17
+
18
+ class Conv1d(nn.Module):
19
+ def __init__(
20
+ self,
21
+ in_channels: int,
22
+ out_channels: int,
23
+ kernel_size: int,
24
+ stride: int = 1,
25
+ padding: int = -1,
26
+ dilation: int = 1,
27
+ groups: int = 1,
28
+ bias: bool = True
29
+ ):
30
+ super().__init__()
31
+ self.in_channels = in_channels
32
+ self.out_channels = out_channels
33
+ self.kernel_size = kernel_size
34
+ if padding < 0:
35
+ padding = (kernel_size - 1) // 2 * dilation
36
+ self.dilation = dilation
37
+ self.conv = nn.Conv1d(
38
+ in_channels=in_channels,
39
+ out_channels=out_channels,
40
+ kernel_size=kernel_size,
41
+ stride=stride,
42
+ padding=padding,
43
+ dilation=dilation,
44
+ groups=groups,
45
+ bias=bias,
46
+ )
47
+
48
+ def forward(self, x):
49
+ """
50
+ Args:
51
+ x (Tensor): Float tensor variable with the shape (B, C, T).
52
+ Returns:
53
+ Tensor: Float tensor variable with the shape (B, C, T).
54
+ """
55
+ x = self.conv(x)
56
+ return x
57
+
58
+
59
+ class ConvTranspose1d(nn.Module):
60
+ def __init__(
61
+ self,
62
+ in_channels: int,
63
+ out_channels: int,
64
+ kernel_size: int,
65
+ stride: int,
66
+ padding=-1,
67
+ output_padding=-1,
68
+ groups=1,
69
+ bias=True,
70
+ ):
71
+ super().__init__()
72
+ if padding < 0:
73
+ padding = (stride + 1) // 2
74
+ if output_padding < 0:
75
+ output_padding = 1 if stride % 2 else 0
76
+ self.deconv = nn.ConvTranspose1d(
77
+ in_channels=in_channels,
78
+ out_channels=out_channels,
79
+ kernel_size=kernel_size,
80
+ stride=stride,
81
+ padding=padding,
82
+ output_padding=output_padding,
83
+ groups=groups,
84
+ bias=bias,
85
+ )
86
+
87
+ def forward(self, x):
88
+ """
89
+ Args:
90
+ x (Tensor): Float tensor variable with the shape (B, C, T).
91
+ Returns:
92
+ Tensor: Float tensor variable with the shape (B, C', T').
93
+ """
94
+ x = self.deconv(x)
95
+ return x
ASR/repcodec/layers/vq_module.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class VectorQuantize(nn.Module):
14
+ """Vector quantization w/ exponential moving averages (EMA)"""
15
+
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ codebook_size: int,
20
+ decay=0.8,
21
+ commitment=1.,
22
+ eps=1e-5,
23
+ n_embed=None,
24
+ ):
25
+ super().__init__()
26
+ n_embed = self.default(n_embed, codebook_size)
27
+
28
+ self.dim = dim
29
+ self.n_embed = n_embed
30
+ self.decay = decay
31
+ self.eps = eps
32
+ self.commitment = commitment
33
+
34
+ embed = torch.randn(dim, n_embed)
35
+ self.register_buffer('embed', embed)
36
+ self.register_buffer('cluster_size', torch.zeros(n_embed))
37
+ self.register_buffer('embed_avg', embed.clone())
38
+
39
+ @property
40
+ def codebook(self):
41
+ return self.embed.transpose(0, 1)
42
+
43
+ def exists(self, val):
44
+ return val is not None
45
+
46
+ def default(self, val, d):
47
+ return val if self.exists(val) else d
48
+
49
+ def ema_inplace(self, moving_avg, new, decay):
50
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
51
+
52
+ def laplace_smoothing(self, x, n_categories, eps=1e-5):
53
+ return (x + eps) / (x.sum() + n_categories * eps)
54
+
55
+ def forward(self, input):
56
+ dtype = input.dtype
57
+ flatten = input.reshape(-1, self.dim)
58
+ dist = (
59
+ flatten.pow(2).sum(1, keepdim=True)
60
+ - 2 * flatten @ self.embed
61
+ + self.embed.pow(2).sum(0, keepdim=True)
62
+ )
63
+ _, embed_ind = (-dist).max(1)
64
+ embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
65
+ embed_ind = embed_ind.view(*input.shape[:-1])
66
+ quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
67
+
68
+ if self.training:
69
+ self.ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
70
+ embed_sum = flatten.transpose(0, 1) @ embed_onehot
71
+ self.ema_inplace(self.embed_avg, embed_sum, self.decay)
72
+ cluster_size = self.laplace_smoothing(self.cluster_size, self.n_embed, self.eps) * self.cluster_size.sum()
73
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
74
+ self.embed.data.copy_(embed_normalized)
75
+
76
+ loss = F.mse_loss(quantize.detach(), input) * self.commitment
77
+ quantize = input + (quantize - input).detach()
78
+
79
+ avg_probs = torch.mean(embed_onehot, dim=0)
80
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
81
+
82
+ return quantize, loss, perplexity
83
+
84
+ def forward_index(self, input):
85
+ dtype = input.dtype
86
+ flatten = input.reshape(-1, self.dim)
87
+ dist = (
88
+ flatten.pow(2).sum(1, keepdim=True)
89
+ - 2 * flatten @ self.embed
90
+ + self.embed.pow(2).sum(0, keepdim=True)
91
+ )
92
+ _, embed_ind = (-dist).max(1)
93
+ embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
94
+ embed_ind = embed_ind.view(*input.shape[:-1])
95
+ quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
96
+ quantize = input + (quantize - input).detach()
97
+
98
+ return quantize, embed_ind
99
+
100
+
101
+ class ResidualVQ(nn.Module):
102
+ """ Residual VQ following algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
103
+
104
+ def __init__(
105
+ self,
106
+ *,
107
+ num_quantizers,
108
+ **kwargs
109
+ ):
110
+ super().__init__()
111
+ self.layers = nn.ModuleList([VectorQuantize(**kwargs) for _ in range(num_quantizers)])
112
+
113
+ def forward(self, x):
114
+ quantized_out = 0.
115
+ residual = x
116
+ all_losses = []
117
+ all_perplexities = []
118
+ for layer in self.layers:
119
+ quantized, loss, perplexity = layer(residual)
120
+ # Issue: https://github.com/lucidrains/vector-quantize-pytorch/issues/33
121
+ # We found considering only the 1st layer VQ's graident results in better performance
122
+ # residual = residual - quantized.detach() # considering all layers' graidents
123
+ residual = residual - quantized # considering only the first layer's graident
124
+ quantized_out = quantized_out + quantized
125
+ all_losses.append(loss)
126
+ all_perplexities.append(perplexity)
127
+ all_losses, all_perplexities = map(torch.stack, (all_losses, all_perplexities))
128
+ return quantized_out, all_losses, all_perplexities
129
+
130
+ def forward_index(self, x, flatten_idx=False):
131
+ quantized_out = 0.
132
+ residual = x
133
+ all_indices = []
134
+ for i, layer in enumerate(self.layers):
135
+ quantized, indices = layer.forward_index(residual)
136
+ # residual = residual - quantized.detach()
137
+ residual = residual - quantized
138
+ quantized_out = quantized_out + quantized
139
+ if flatten_idx:
140
+ indices += (self.codebook_size * i)
141
+ all_indices.append(indices)
142
+ all_indices = torch.stack(all_indices)
143
+ return quantized_out, all_indices.squeeze(1)
144
+
145
+ def initial(self):
146
+ self.codebook = []
147
+ for layer in self.layers:
148
+ self.codebook.append(layer.codebook)
149
+ self.codebook_size = self.codebook[0].size(0)
150
+ self.codebook = torch.stack(self.codebook)
151
+ self.codebook = self.codebook.reshape(-1, self.codebook.size(-1))
152
+
153
+ def lookup(self, indices):
154
+ quantized_out = F.embedding(indices, self.codebook) # Num x T x C
155
+ return torch.sum(quantized_out, dim=0, keepdim=True)
ASR/repcodec/modules/__pycache__/decoder.cpython-38.pyc ADDED
Binary file (2.51 kB). View file
 
ASR/repcodec/modules/__pycache__/encoder.cpython-38.pyc ADDED
Binary file (2.23 kB). View file
 
ASR/repcodec/modules/__pycache__/projector.cpython-38.pyc ADDED
Binary file (903 Bytes). View file
 
ASR/repcodec/modules/__pycache__/quantizer.cpython-38.pyc ADDED
Binary file (1.63 kB). View file
 
ASR/repcodec/modules/__pycache__/residual_unit.cpython-38.pyc ADDED
Binary file (1.14 kB). View file
 
ASR/repcodec/modules/decoder.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from repcodec.layers.conv_layer import Conv1d, ConvTranspose1d
12
+ from repcodec.modules.residual_unit import ResidualUnit
13
+
14
+
15
+ class DecoderBlock(nn.Module):
16
+ """ Decoder block (no up-sampling) """
17
+
18
+ def __init__(
19
+ self,
20
+ in_channels: int,
21
+ out_channels: int,
22
+ stride: int,
23
+ dilations=(1, 1),
24
+ unit_kernel_size=3,
25
+ bias=True
26
+ ):
27
+ super().__init__()
28
+
29
+ if stride == 1:
30
+ self.conv = Conv1d(
31
+ in_channels=in_channels,
32
+ out_channels=out_channels,
33
+ kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape
34
+ stride=stride,
35
+ bias=bias,
36
+ )
37
+ else:
38
+ self.conv = ConvTranspose1d(
39
+ in_channels=in_channels,
40
+ out_channels=out_channels,
41
+ kernel_size=(2 * stride),
42
+ stride=stride,
43
+ bias=bias,
44
+ )
45
+
46
+ self.res_units = torch.nn.ModuleList()
47
+ for idx, dilation in enumerate(dilations):
48
+ self.res_units += [
49
+ ResidualUnit(out_channels, out_channels,
50
+ kernel_size=unit_kernel_size,
51
+ dilation=dilation)
52
+ ]
53
+ self.num_res = len(self.res_units)
54
+
55
+ def forward(self, x):
56
+ x = self.conv(x)
57
+ for idx in range(self.num_res):
58
+ x = self.res_units[idx](x)
59
+ return x
60
+
61
+
62
+ class Decoder(nn.Module):
63
+ def __init__(
64
+ self,
65
+ code_dim: int,
66
+ output_channels: int,
67
+ decode_channels: int,
68
+ channel_ratios=(1, 1),
69
+ strides=(1, 1),
70
+ kernel_size=3,
71
+ bias=True,
72
+ block_dilations=(1, 1),
73
+ unit_kernel_size=3,
74
+ ):
75
+ super().__init__()
76
+ assert len(channel_ratios) == len(strides)
77
+
78
+ self.conv1 = Conv1d(
79
+ in_channels=code_dim,
80
+ out_channels=int(decode_channels * channel_ratios[0]),
81
+ kernel_size=kernel_size,
82
+ stride=1,
83
+ bias=False
84
+ )
85
+
86
+ self.conv_blocks = torch.nn.ModuleList()
87
+ for idx, stride in enumerate(strides):
88
+ in_channels = int(decode_channels * channel_ratios[idx])
89
+ if idx < (len(channel_ratios) - 1):
90
+ out_channels = int(decode_channels * channel_ratios[idx + 1])
91
+ else:
92
+ out_channels = decode_channels
93
+ self.conv_blocks += [
94
+ DecoderBlock(
95
+ in_channels, out_channels, stride,
96
+ dilations=block_dilations, unit_kernel_size=unit_kernel_size,
97
+ bias=bias
98
+ )
99
+ ]
100
+ self.num_blocks = len(self.conv_blocks)
101
+
102
+ self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False)
103
+
104
+ def forward(self, z):
105
+ x = self.conv1(z)
106
+ for i in range(self.num_blocks):
107
+ x = self.conv_blocks[i](x)
108
+ x = self.conv2(x)
109
+ return x
ASR/repcodec/modules/encoder.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from repcodec.layers.conv_layer import Conv1d
12
+ from repcodec.modules.residual_unit import ResidualUnit
13
+
14
+
15
+ class EncoderBlock(nn.Module):
16
+ def __init__(
17
+ self,
18
+ in_channels: int,
19
+ out_channels: int,
20
+ stride: int,
21
+ dilations=(1, 1),
22
+ unit_kernel_size=3,
23
+ bias=True
24
+ ):
25
+ super().__init__()
26
+ self.res_units = torch.nn.ModuleList()
27
+ for dilation in dilations:
28
+ self.res_units += [
29
+ ResidualUnit(in_channels, in_channels,
30
+ kernel_size=unit_kernel_size,
31
+ dilation=dilation)
32
+ ]
33
+ self.num_res = len(self.res_units)
34
+
35
+ self.conv = Conv1d(
36
+ in_channels=in_channels,
37
+ out_channels=out_channels,
38
+ kernel_size=3 if stride == 1 else (2 * stride), # special case: stride=1, do not use kernel=2
39
+ stride=stride,
40
+ bias=bias,
41
+ )
42
+
43
+ def forward(self, x):
44
+ for idx in range(self.num_res):
45
+ x = self.res_units[idx](x)
46
+ x = self.conv(x)
47
+ return x
48
+
49
+
50
+ class Encoder(nn.Module):
51
+ def __init__(
52
+ self,
53
+ input_channels: int,
54
+ encode_channels: int,
55
+ channel_ratios=(1, 1),
56
+ strides=(1, 1),
57
+ kernel_size=3,
58
+ bias=True,
59
+ block_dilations=(1, 1),
60
+ unit_kernel_size=3
61
+ ):
62
+ super().__init__()
63
+ assert len(channel_ratios) == len(strides)
64
+
65
+ self.conv = Conv1d(
66
+ in_channels=input_channels,
67
+ out_channels=encode_channels,
68
+ kernel_size=kernel_size,
69
+ stride=1,
70
+ bias=False
71
+ )
72
+ self.conv_blocks = torch.nn.ModuleList()
73
+ in_channels = encode_channels
74
+ for idx, stride in enumerate(strides):
75
+ out_channels = int(encode_channels * channel_ratios[idx]) # could be float
76
+ self.conv_blocks += [
77
+ EncoderBlock(in_channels, out_channels, stride,
78
+ dilations=block_dilations, unit_kernel_size=unit_kernel_size,
79
+ bias=bias)
80
+ ]
81
+ in_channels = out_channels
82
+ self.num_blocks = len(self.conv_blocks)
83
+ self.out_channels = out_channels
84
+
85
+ def forward(self, x):
86
+ x = self.conv(x)
87
+ for i in range(self.num_blocks):
88
+ x = self.conv_blocks[i](x)
89
+ return x
ASR/repcodec/modules/projector.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch.nn as nn
9
+
10
+ from repcodec.layers.conv_layer import Conv1d
11
+
12
+
13
+ class Projector(nn.Module):
14
+ def __init__(
15
+ self,
16
+ input_channels: int,
17
+ code_dim: int,
18
+ kernel_size=3,
19
+ stride=1,
20
+ bias=False
21
+ ):
22
+ super().__init__()
23
+ self.project = Conv1d(
24
+ input_channels,
25
+ code_dim,
26
+ kernel_size=kernel_size,
27
+ stride=stride,
28
+ bias=bias
29
+ )
30
+
31
+ def forward(self, x):
32
+ return self.project(x)
ASR/repcodec/modules/quantizer.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch.nn as nn
9
+
10
+ from repcodec.layers.vq_module import ResidualVQ
11
+
12
+
13
+ class Quantizer(nn.Module):
14
+ def __init__(
15
+ self,
16
+ code_dim: int,
17
+ codebook_num: int,
18
+ codebook_size: int,
19
+ ):
20
+ super().__init__()
21
+ self.codebook = ResidualVQ(
22
+ dim=code_dim,
23
+ num_quantizers=codebook_num,
24
+ codebook_size=codebook_size
25
+ )
26
+
27
+ def initial(self):
28
+ self.codebook.initial()
29
+
30
+ def forward(self, z):
31
+ zq, vqloss, perplexity = self.codebook(z.transpose(2, 1))
32
+ zq = zq.transpose(2, 1)
33
+ return zq, vqloss, perplexity
34
+
35
+ def inference(self, z):
36
+ zq, indices = self.codebook.forward_index(z.transpose(2, 1))
37
+ zq = zq.transpose(2, 1)
38
+ return zq, indices
39
+
40
+ def encode(self, z):
41
+ zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True)
42
+ return zq, indices
43
+
44
+ def decode(self, indices):
45
+ z = self.codebook.lookup(indices)
46
+ return z
ASR/repcodec/modules/residual_unit.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch.nn as nn
9
+
10
+ from repcodec.layers.conv_layer import Conv1d, Conv1d1x1
11
+
12
+
13
+ class ResidualUnit(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_channels: int,
17
+ out_channels: int,
18
+ kernel_size=3,
19
+ dilation=1,
20
+ bias=False,
21
+ nonlinear_activation="ELU",
22
+ nonlinear_activation_params={},
23
+ ):
24
+ super().__init__()
25
+ self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
26
+ self.conv1 = Conv1d(
27
+ in_channels=in_channels,
28
+ out_channels=out_channels,
29
+ kernel_size=kernel_size,
30
+ stride=1,
31
+ dilation=dilation,
32
+ bias=bias,
33
+ )
34
+ self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
35
+
36
+ def forward(self, x):
37
+ y = self.conv1(self.activation(x))
38
+ y = self.conv2(self.activation(y))
39
+ return x + y
ASR/repcodec/tokenize.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Tuple, List, Optional
11
+
12
+ import numpy as np
13
+ import torch
14
+ import yaml
15
+ from tqdm import tqdm
16
+
17
+ from repcodec.RepCodec import RepCodec
18
+
19
+ ALL_MODELS = {
20
+ "data2vec_base_l6": 768,
21
+ "data2vec_large_l18": 1024,
22
+ "hubert_base_l9": 768,
23
+ "hubert_large_l18": 1024,
24
+ "whisper_medium_l24": 1024,
25
+ "whisper_large_l32": 1280
26
+ }
27
+
28
+
29
+ def parse_args():
30
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
31
+ parser.add_argument(
32
+ "in_dir",
33
+ type=str,
34
+ help="directory of representations to be tokenized."
35
+ )
36
+ parser.add_argument(
37
+ "--model",
38
+ required=True,
39
+ type=str,
40
+ help="path of the RepCodec model."
41
+ )
42
+ parser.add_argument(
43
+ "--tsv_path",
44
+ required=True,
45
+ type=str,
46
+ help="path of the tsv file."
47
+ )
48
+ parser.add_argument(
49
+ "--model_config_path",
50
+ default=None,
51
+ type=str,
52
+ help="please provide this training config if you are using the model you trained yourself."
53
+ )
54
+ parser.add_argument(
55
+ "--n_shard",
56
+ required=False,
57
+ type=int,
58
+ default=1,
59
+ help="number of shards of representations."
60
+ )
61
+ parser.add_argument(
62
+ "--use_gpu",
63
+ default=False,
64
+ action="store_true",
65
+ help="whether use gpu for inference."
66
+ )
67
+ parser.add_argument(
68
+ "--batch_size",
69
+ default=1,
70
+ type=int,
71
+ help="number of utterances for each mini batch."
72
+ )
73
+ parser.add_argument(
74
+ "--out_dir",
75
+ type=str,
76
+ default=".",
77
+ help="the directory to save the output."
78
+ )
79
+ return parser.parse_args()
80
+
81
+
82
+ def load_model(model_path: str, config_path: Optional[str] = None):
83
+ if config_path is None:
84
+ name = os.path.basename(model_path).strip(".pkl")
85
+ assert name in ALL_MODELS.keys(), f"Cannot find configs for {model_path}. " \
86
+ f"Please provide the config file you used for training."
87
+ config = os.path.join(os.path.dirname(__file__), "configs", f"repcodec_dim{ALL_MODELS[name]}.yaml")
88
+ with open(config) as fp:
89
+ conf = yaml.load(fp, Loader=yaml.FullLoader)
90
+ else:
91
+ with open(config_path) as fp:
92
+ conf = yaml.load(fp, Loader=yaml.FullLoader)["model_params"]
93
+
94
+ model = RepCodec(**conf)
95
+ model.load_state_dict(torch.load(model_path, map_location="cpu")["model"]["repcodec"])
96
+ model.quantizer.initial()
97
+ model.eval()
98
+ return model
99
+
100
+
101
+ def load_shard(in_dir: Path, rank: int, n_shard: int) -> Tuple[np.ndarray, List[int]]:
102
+ feat_path = in_dir / f"{rank}_{n_shard}.npy"
103
+ len_path = in_dir / f"{rank}_{n_shard}.len"
104
+
105
+ with open(len_path) as fp:
106
+ lengths = [int(line.strip()) for line in fp]
107
+
108
+ return np.load(feat_path.as_posix(), mmap_mode="r"), lengths
109
+
110
+
111
+ def pad_data(data: List[np.ndarray]) -> List[np.ndarray]:
112
+ max_len = max([d.shape[0] for d in data])
113
+ data = [
114
+ np.pad(d, [(0, max_len - d.shape[0]), (0, 0)], "constant", constant_values=0.0)
115
+ for d in data
116
+ ]
117
+ return data
118
+
119
+
120
+ def make_batch_data(data: np.ndarray, shard_lengths: List[int], batch_size: int):
121
+ batch_data = []
122
+ batch_lens = []
123
+ offsets = np.cumsum([0] + shard_lengths)
124
+ assert len(data) == offsets[-1], f"{len(data)} {offsets[-1]}"
125
+
126
+ # from longest to shortest
127
+ for i in range(len(shard_lengths)):
128
+ if batch_size > len(batch_data):
129
+ batch_data.append(data[offsets[i]: offsets[i + 1]])
130
+ batch_lens.append(shard_lengths[i])
131
+ else:
132
+ yield {
133
+ "data": torch.tensor(np.stack(pad_data(batch_data)), dtype=torch.float), # (bsz, seq len, hidden dim)
134
+ "lengths": batch_lens
135
+ }
136
+ batch_data = [data[offsets[i]: offsets[i + 1]]]
137
+ batch_lens = [shard_lengths[i]]
138
+ if len(batch_data) > 0:
139
+ yield {
140
+ "data": torch.tensor(np.stack(pad_data(batch_data)), dtype=torch.float),
141
+ "lengths": batch_lens
142
+ }
143
+
144
+
145
+ def tokenize_batch(model: RepCodec, batch: dict, device: str) -> List[List[int]]:
146
+ with torch.no_grad():
147
+ data = batch["data"].transpose(1, 2).to(device) # (bsz, hidden dim, seq len)
148
+ x = model.encoder(data)
149
+ z = model.projector(x)
150
+ _, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))
151
+
152
+ # when bsz=1: (1, seq len)
153
+ if idx.dim() == 2:
154
+ return idx.cpu().data.numpy().tolist()
155
+ # when bsz>1: (1, bsz, seq len)
156
+ tokens = idx.cpu().data.numpy().tolist()[0]
157
+ res = []
158
+ batch_lens = batch["lengths"]
159
+ for i in range(len(tokens)):
160
+ n_tokens = batch_lens[i]
161
+ res.append(tokens[i][:n_tokens])
162
+ return res
163
+
164
+
165
+ def load_tsv(path: str):
166
+ with open(path) as fp:
167
+ root = fp.readline().strip()
168
+ names = []
169
+ for line in fp:
170
+ names.append(line.strip().split("\t")[0])
171
+ return root, names
172
+
173
+
174
+ def cli():
175
+ args = parse_args()
176
+ device = "cuda" if args.use_gpu else "cpu"
177
+
178
+ model = load_model(model_path=args.model, config_path=args.model_config_path)
179
+ model.to(device)
180
+
181
+ in_dir = Path(args.in_dir)
182
+ n_shard = args.n_shard
183
+ batch_size = args.batch_size
184
+
185
+ root_dir, file_names = load_tsv(args.tsv_path)
186
+
187
+ output_dir = args.out_dir
188
+ os.makedirs(output_dir, exist_ok=True)
189
+
190
+ processed_cnt = 0
191
+ pbar = tqdm(total=len(file_names))
192
+ with open(os.path.join(output_dir, "tokens"), mode="w+") as fp:
193
+ fp.write(f"{root_dir}\n")
194
+
195
+ for rank in range(n_shard):
196
+ shard_data, shard_lengths = load_shard(in_dir, rank, n_shard)
197
+ for batch in make_batch_data(shard_data, shard_lengths, batch_size=batch_size):
198
+ batch_tokens = tokenize_batch(model, batch, device)
199
+
200
+ for tokens in batch_tokens:
201
+ fp.write(f"{file_names[processed_cnt]}\t{' '.join(map(str, tokens))}\n")
202
+ processed_cnt += 1
203
+
204
+ pbar.update(len(batch_tokens))
205
+ assert processed_cnt == len(file_names), f"# lines of tsv do not match # of representations!"
206
+
207
+ pbar.close()
208
+ print("Tokenize successfully!")
209
+
210
+
211
+ if __name__ == '__main__':
212
+ cli()
ASR/test-gpt2-opt.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1717d8158b524053e9cc92fdbc9942bb4abc9f119576680f159f2f3177b378d
3
+ size 653546067
ASR/test-gpt2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3e0852bac1a63c63262029142c06158c40c81e90ff5f92ef45b0924ad80f6de
3
+ size 653828879
ASR/test-gpt2.plan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abdb8a96621f01394e8a783644c782815c25c507cafbc1cb86d6b4a2ccb5a6cc
3
+ size 328704308
ASR/tokenized_librispeech/dataset_dict.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"splits": ["train.clean.100", "train.clean.360", "train.other.500", "validation.clean", "validation.other", "test.clean", "test.other"]}
ASR/tokenized_librispeech/test.clean/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:267be16197f2a380f021dd5b1687ffd341825474ca00875098a5582252b9b901
3
+ size 29539856
ASR/tokenized_librispeech/test.clean/dataset_info.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids": {
6
+ "feature": {
7
+ "dtype": "int32",
8
+ "_type": "Value"
9
+ },
10
+ "_type": "Sequence"
11
+ },
12
+ "token_type_ids": {
13
+ "feature": {
14
+ "dtype": "int8",
15
+ "_type": "Value"
16
+ },
17
+ "_type": "Sequence"
18
+ },
19
+ "attention_mask": {
20
+ "feature": {
21
+ "dtype": "int8",
22
+ "_type": "Value"
23
+ },
24
+ "_type": "Sequence"
25
+ }
26
+ },
27
+ "homepage": "",
28
+ "license": ""
29
+ }
ASR/tokenized_librispeech/test.clean/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "5d19fcebbf9d8932",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
ASR/tokenized_librispeech/test.other/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:663e0f0f58ac0ac562f319e51219064caaba5102b9bccd1e30322a780b1237ad
3
+ size 33136248
ASR/tokenized_librispeech/test.other/dataset_info.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids": {
6
+ "feature": {
7
+ "dtype": "int32",
8
+ "_type": "Value"
9
+ },
10
+ "_type": "Sequence"
11
+ },
12
+ "token_type_ids": {
13
+ "feature": {
14
+ "dtype": "int8",
15
+ "_type": "Value"
16
+ },
17
+ "_type": "Sequence"
18
+ },
19
+ "attention_mask": {
20
+ "feature": {
21
+ "dtype": "int8",
22
+ "_type": "Value"
23
+ },
24
+ "_type": "Sequence"
25
+ }
26
+ },
27
+ "homepage": "",
28
+ "license": ""
29
+ }
ASR/tokenized_librispeech/test.other/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "6d0c91ac40d55d91",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
ASR/tokenized_librispeech/train.clean.100/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b797259a17fc4ddaca3c0f4a09622cdde099dbf46685d5096405d4916e3428b
3
+ size 321761256
ASR/tokenized_librispeech/train.clean.100/dataset_info.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids": {
6
+ "feature": {
7
+ "dtype": "int32",
8
+ "_type": "Value"
9
+ },
10
+ "_type": "Sequence"
11
+ },
12
+ "token_type_ids": {
13
+ "feature": {
14
+ "dtype": "int8",
15
+ "_type": "Value"
16
+ },
17
+ "_type": "Sequence"
18
+ },
19
+ "attention_mask": {
20
+ "feature": {
21
+ "dtype": "int8",
22
+ "_type": "Value"
23
+ },
24
+ "_type": "Sequence"
25
+ }
26
+ },
27
+ "homepage": "",
28
+ "license": ""
29
+ }
ASR/tokenized_librispeech/train.clean.100/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "00eedcd713e3fb08",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
ASR/tokenized_librispeech/train.clean.360/data-00000-of-00003.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9da54e7eceac6fdf5c90656bca2655b28f079f52b3648978e228326ab84f334
3
+ size 390907152
ASR/tokenized_librispeech/train.clean.360/data-00001-of-00003.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af64dfb82934022b279d8022102c321991f99d1e53ab102e9e84a1488565d1fd
3
+ size 390895880
ASR/tokenized_librispeech/train.clean.360/data-00002-of-00003.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a40a4c69c058b01e7eff4068babf1655998beb5904bc32a620393f41018fa49a
3
+ size 390895880
ASR/tokenized_librispeech/train.clean.360/dataset_info.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids": {
6
+ "feature": {
7
+ "dtype": "int32",
8
+ "_type": "Value"
9
+ },
10
+ "_type": "Sequence"
11
+ },
12
+ "token_type_ids": {
13
+ "feature": {
14
+ "dtype": "int8",
15
+ "_type": "Value"
16
+ },
17
+ "_type": "Sequence"
18
+ },
19
+ "attention_mask": {
20
+ "feature": {
21
+ "dtype": "int8",
22
+ "_type": "Value"
23
+ },
24
+ "_type": "Sequence"
25
+ }
26
+ },
27
+ "homepage": "",
28
+ "license": ""
29
+ }
ASR/tokenized_librispeech/train.clean.360/state.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00003.arrow"
5
+ },
6
+ {
7
+ "filename": "data-00001-of-00003.arrow"
8
+ },
9
+ {
10
+ "filename": "data-00002-of-00003.arrow"
11
+ }
12
+ ],
13
+ "_fingerprint": "44573b66f4895b44",
14
+ "_format_columns": null,
15
+ "_format_kwargs": {},
16
+ "_format_type": null,
17
+ "_output_all_columns": false,
18
+ "_split": null
19
+ }
ASR/tokenized_librispeech/train.other.500/data-00000-of-00004.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8310b88f596d7f62ce3c8a9801cefdb016400702e48c316c145bf4abb8539447
3
+ size 419093384