tomxxie commited on
Commit
568e264
·
1 Parent(s): 66817ed

适配zeroGPU

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +11 -3
  2. wenet/LLM/causallm_model.py +207 -0
  3. wenet/LLM/decoder.py +161 -0
  4. wenet/LLM/sampler.py +43 -0
  5. wenet/__init__.py +1 -0
  6. wenet/bin/alignment.py +268 -0
  7. wenet/bin/average_model.py +125 -0
  8. wenet/bin/export_ipex.py +95 -0
  9. wenet/bin/export_jit.py +71 -0
  10. wenet/bin/export_onnx_bpu.py +1065 -0
  11. wenet/bin/export_onnx_cpu.py +470 -0
  12. wenet/bin/export_onnx_gpu.py +1263 -0
  13. wenet/bin/recognize.py +336 -0
  14. wenet/bin/recognize4llmasr.py +340 -0
  15. wenet/bin/recognize_onnx_gpu.py +297 -0
  16. wenet/bin/train.py +232 -0
  17. wenet/branchformer/__init__.py +0 -0
  18. wenet/branchformer/cgmlp.py +194 -0
  19. wenet/branchformer/encoder.py +177 -0
  20. wenet/branchformer/encoder_layer.py +245 -0
  21. wenet/cli/__init__.py +0 -0
  22. wenet/cli/hub.py +116 -0
  23. wenet/cli/model.py +176 -0
  24. wenet/cli/paraformer_model.py +82 -0
  25. wenet/cli/transcribe.py +87 -0
  26. wenet/ctl_model/asr_model_ctl.py +277 -0
  27. wenet/ctl_model/encoder.py +172 -0
  28. wenet/dataset/__init__.py +0 -0
  29. wenet/dataset/datapipes.py +470 -0
  30. wenet/dataset/dataset.py +234 -0
  31. wenet/dataset/deprecated/dataset.py +202 -0
  32. wenet/dataset/deprecated/processor.py +1023 -0
  33. wenet/dataset/kaldi_io.py +772 -0
  34. wenet/dataset/processor.py +694 -0
  35. wenet/dataset/wav_distortion.py +336 -0
  36. wenet/e_branchformer/encoder.py +165 -0
  37. wenet/e_branchformer/encoder_layer.py +187 -0
  38. wenet/efficient_conformer/__init__.py +0 -0
  39. wenet/efficient_conformer/attention.py +257 -0
  40. wenet/efficient_conformer/convolution.py +154 -0
  41. wenet/efficient_conformer/encoder.py +560 -0
  42. wenet/efficient_conformer/encoder_layer.py +165 -0
  43. wenet/efficient_conformer/subsampling.py +74 -0
  44. wenet/finetune/lora/__init__.py +0 -0
  45. wenet/finetune/lora/config.yaml +13 -0
  46. wenet/finetune/lora/layers.py +350 -0
  47. wenet/finetune/lora/utils.py +334 -0
  48. wenet/k2/__init__.py +0 -0
  49. wenet/k2/model.py +304 -0
  50. wenet/llm_asr/__init__.py +0 -0
app.py CHANGED
@@ -10,9 +10,9 @@ import os
10
 
11
  import sys
12
 
 
13
 
14
  sys.path.insert(0, './')
15
- from gxl_ai_utils.utils import utils_file
16
  from wenet.utils.init_tokenizer import init_tokenizer
17
  from wenet.utils.init_model import init_model
18
  import logging
@@ -20,6 +20,14 @@ import librosa
20
  import torch
21
  import torchaudio
22
  import numpy as np
 
 
 
 
 
 
 
 
23
 
24
  # 将图片转换为 Base64
25
  with open("lab.png", "rb") as image_file:
@@ -53,7 +61,7 @@ def init_model_my():
53
  args = SimpleNamespace(**{
54
  "checkpoint": checkpoint_path,
55
  })
56
- configs = utils_file.load_dict_from_yaml(config_path)
57
  model, configs = init_model(args, configs)
58
  model = model.cuda()
59
  tokenizer = init_tokenizer(configs)
@@ -73,7 +81,7 @@ def do_resample(input_wav_path, output_wav_path):
73
  waveform = torch.mean(waveform, dim=0, keepdim=True)
74
  waveform = torchaudio.transforms.Resample(
75
  orig_freq=sample_rate, new_freq=16000)(waveform)
76
- utils_file.makedir_for_file(output_wav_path)
77
  torchaudio.save(output_wav_path, waveform, 16000)
78
 
79
  def true_decode_fuc(input_wav_path, input_prompt):
 
10
 
11
  import sys
12
 
13
+ import yaml
14
 
15
  sys.path.insert(0, './')
 
16
  from wenet.utils.init_tokenizer import init_tokenizer
17
  from wenet.utils.init_model import init_model
18
  import logging
 
20
  import torch
21
  import torchaudio
22
  import numpy as np
23
+ def makedir_for_file(filepath):
24
+ dirpath = os.path.dirname(filepath)
25
+ if not os.path.exists(dirpath):
26
+ os.makedirs(dirpath)
27
+ def load_dict_from_yaml(file_path: str):
28
+ with open(file_path, 'rt', encoding='utf-8') as f:
29
+ dict_1 = yaml.load(f, Loader=yaml.FullLoader)
30
+ return dict_1
31
 
32
  # 将图片转换为 Base64
33
  with open("lab.png", "rb") as image_file:
 
61
  args = SimpleNamespace(**{
62
  "checkpoint": checkpoint_path,
63
  })
64
+ configs = load_dict_from_yaml(config_path)
65
  model, configs = init_model(args, configs)
66
  model = model.cuda()
67
  tokenizer = init_tokenizer(configs)
 
81
  waveform = torch.mean(waveform, dim=0, keepdim=True)
82
  waveform = torchaudio.transforms.Resample(
83
  orig_freq=sample_rate, new_freq=16000)(waveform)
84
+ makedir_for_file(output_wav_path)
85
  torchaudio.save(output_wav_path, waveform, 16000)
86
 
87
  def true_decode_fuc(input_wav_path, input_prompt):
wenet/LLM/causallm_model.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Union
2
+ import torch
3
+ from wenet.LLM.decoder import DecoderOnly
4
+ from wenet.LLM.sampler import sampler
5
+ from wenet.utils.common import IGNORE_ID, th_accuracy
6
+ from wenet.utils.mask import make_pad_mask, subsequent_mask
7
+
8
+
9
+ class CausalLM(torch.nn.Module):
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size: int,
14
+ decoder: DecoderOnly,
15
+ special_tokens: dict,
16
+ tie_word_embedding: bool = False,
17
+ linear_bias: bool = False,
18
+ ignore_id: int = IGNORE_ID,
19
+ lsm_weight: float = 0.0,
20
+ reduction: str = 'mean',
21
+ ) -> None:
22
+ super().__init__()
23
+ del special_tokens
24
+
25
+ self.embed = torch.nn.Embedding(vocab_size, decoder.hidden_size)
26
+ self.out = torch.nn.Linear(decoder.hidden_size,
27
+ vocab_size,
28
+ bias=linear_bias)
29
+
30
+ self.decoder = decoder
31
+ self.vocab_size = vocab_size
32
+ self.criterion_att = torch.nn.CrossEntropyLoss(
33
+ ignore_index=ignore_id,
34
+ label_smoothing=lsm_weight,
35
+ reduction=reduction,
36
+ )
37
+ self.tie_word_embedding = tie_word_embedding
38
+ self.ignore_id = ignore_id
39
+
40
+ @torch.jit.unused
41
+ def forward(
42
+ self,
43
+ batch: dict,
44
+ device: torch.device,
45
+ ) -> Dict[str, Optional[torch.Tensor]]:
46
+ """ Forward for training
47
+ """
48
+ text = batch['feats'].to(device)
49
+ target = batch['target'].to(device)
50
+ text_length = batch['feats_lengths'].to(device)
51
+
52
+ mask = ~make_pad_mask(text_length, max_len=text.size(1)).unsqueeze(
53
+ 1) # (B,1,L)
54
+ causal_mask = subsequent_mask(
55
+ mask.size(-1), device=mask.device).unsqueeze(0) # (1,L,L)
56
+ att_mask = causal_mask & mask # (B, L, L)
57
+
58
+ embeding = self.embed(text)
59
+ decoder_out = self.out(self.decoder(embeding,
60
+ att_mask)[0]) # (B, L, vocab_size)
61
+ loss = self.criterion_att(decoder_out.view(-1, self.vocab_size),
62
+ target.view(-1))
63
+ acc = th_accuracy(decoder_out.view(-1, self.vocab_size),
64
+ target,
65
+ ignore_label=self.ignore_id)
66
+
67
+ return {
68
+ "loss": loss,
69
+ "ppl": torch.exp(loss.detach()),
70
+ "th_accuracy": acc
71
+ }
72
+
73
+ def tie_or_clone_weights(self, jit_mode: bool):
74
+ if not self.tie_word_embedding:
75
+ return
76
+ if jit_mode:
77
+ self.out.weight = torch.nn.Parameter(self.embed.weight.clone())
78
+ else:
79
+ self.out.weight = self.embed.weight
80
+ # TODO(Mddct): whether to deal bias for other llm model
81
+
82
+ @torch.jit.unused
83
+ @torch.inference_mode()
84
+ def generate(
85
+ self,
86
+ prompts_tokens: List[List[int]],
87
+ device: torch.device,
88
+ stop_tokens: List[int],
89
+ dtype: torch.dtype = torch.float32,
90
+ output_len: int = 100,
91
+ temperature: Union[float, None] = 0.95,
92
+ top_p: float = 1.0,
93
+ top_k: int = 100,
94
+ ) -> List[List[int]]:
95
+ """Generates responses for given prompts using Gemma model."""
96
+ # If a single prompt is provided, treat it as a batch of 1.
97
+ batch_size = len(prompts_tokens)
98
+ min_prompt_len = min(len(p) for p in prompts_tokens)
99
+ max_prompt_len = max(len(p) for p in prompts_tokens)
100
+ max_seq_len = max_prompt_len + output_len
101
+ assert max_seq_len <= self.decoder.pos_enc.max_len
102
+
103
+ # build KV caches
104
+ kv_caches = []
105
+ for _ in range(len(self.decoder.decoders)):
106
+ size = (batch_size, 0, self.decoder.n_kv_head,
107
+ self.decoder.head_dim)
108
+ k_cache = torch.zeros(size=size, dtype=dtype, device=device)
109
+ v_cache = torch.zeros(size=size, dtype=dtype, device=device)
110
+ kv_caches.append((k_cache, v_cache))
111
+
112
+ # prepare inputs
113
+ token_ids_tensor = torch.full((batch_size, max_seq_len),
114
+ IGNORE_ID,
115
+ dtype=torch.int64,
116
+ device=device)
117
+ input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
118
+ IGNORE_ID,
119
+ dtype=torch.int64,
120
+ device=device)
121
+ # right padding
122
+ for i, p in enumerate(prompts_tokens):
123
+ token_ids_tensor[i, :len(p)] = torch.tensor(p)
124
+ input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
125
+ p[:min_prompt_len])
126
+
127
+ prompt_mask_tensor = token_ids_tensor != IGNORE_ID
128
+ input_positions_tensor = torch.arange(0,
129
+ min_prompt_len,
130
+ dtype=torch.int64).to(device)
131
+ mask_tensor = torch.ones((1, 1, max_seq_len, max_seq_len),
132
+ dtype=torch.bool)
133
+ mask_tensor = torch.tril(mask_tensor).to(device)
134
+ curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
135
+ att_mask = curr_mask_tensor.squeeze(
136
+ 1)[:, :min_prompt_len, :min_prompt_len]
137
+ output_positions_tensor = torch.LongTensor([min_prompt_len - 1
138
+ ]).to(device)
139
+ temperatures_tensor = None if not temperature else torch.FloatTensor(
140
+ [temperature] * batch_size).to(device)
141
+ top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
142
+ top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
143
+ output_index = torch.tensor(min_prompt_len,
144
+ dtype=torch.int64).to(device)
145
+
146
+ input_token_embeding = self.embed(input_token_ids_tensor)
147
+ offset = torch.tensor([0] * len(prompts_tokens)).to(device)
148
+ input_offset = offset
149
+
150
+ stop_tokens_tensor = torch.tensor(stop_tokens, device=device)
151
+ # Prefill up to min_prompt_len tokens, then treat other prefill as
152
+ # decode and ignore output.
153
+ for i in range(max_seq_len - min_prompt_len):
154
+ decoder_out, kv_caches, = self.decoder(
155
+ input_token_embeding,
156
+ att_mask,
157
+ input_offset,
158
+ kv_caches,
159
+ )
160
+ decoder_out = self.out(decoder_out)
161
+ decoder_out = decoder_out.index_select(1, output_positions_tensor)
162
+ next_token_ids = sampler(
163
+ decoder_out,
164
+ temperatures_tensor,
165
+ top_ps_tensor,
166
+ top_ks_tensor,
167
+ )
168
+ curr_prompt_mask = prompt_mask_tensor.index_select(
169
+ 1, output_index).squeeze(dim=1)
170
+ curr_token_ids = token_ids_tensor.index_select(
171
+ 1, output_index).squeeze(dim=1)
172
+ output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
173
+ next_token_ids).unsqueeze(dim=1)
174
+ token_ids_tensor.index_copy_(1, output_index, output_token_ids)
175
+
176
+ input_token_ids_tensor = output_token_ids
177
+ input_token_embeding = self.embed(input_token_ids_tensor)
178
+
179
+ input_positions_tensor = output_index.unsqueeze(dim=-1)
180
+ curr_mask_tensor = mask_tensor.index_select(
181
+ 2, input_positions_tensor)
182
+ att_mask = curr_mask_tensor.squeeze(1)[:, :output_index +
183
+ 1, :output_index + 1]
184
+
185
+ output_positions_tensor = torch.tensor(
186
+ 0, dtype=torch.int64).to(device)
187
+ input_offset = offset + output_index.unsqueeze(-1)
188
+ output_index = output_index + 1
189
+
190
+ if all(torch.isin(next_token_ids, stop_tokens_tensor)):
191
+ break
192
+
193
+ token_ids = token_ids_tensor.tolist()
194
+ results = []
195
+ for i, tokens in enumerate(token_ids):
196
+ trimmed_output = tokens[len(prompts_tokens[i]
197
+ ):len(prompts_tokens[i]) + output_len]
198
+ for stop_token in stop_tokens:
199
+ try:
200
+ eos_index = trimmed_output.index(stop_token)
201
+ trimmed_output = trimmed_output[:eos_index]
202
+ break
203
+ except Exception:
204
+ continue
205
+ results.append(trimmed_output)
206
+
207
+ return results
wenet/LLM/decoder.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import List, Optional, Tuple, Union
3
+ import torch
4
+ import torch.utils.checkpoint as ckpt
5
+ from wenet.transformer.attention import T_CACHE
6
+
7
+ from wenet.transformer.encoder_layer import TransformerEncoderLayer
8
+ from wenet.utils.class_utils import (WENET_ACTIVATION_CLASSES,
9
+ WENET_ATTENTION_CLASSES,
10
+ WENET_EMB_CLASSES, WENET_MLP_CLASSES,
11
+ WENET_NORM_CLASSES)
12
+ from wenet.utils.common import mask_to_bias
13
+
14
+
15
+ class DecoderOnly(torch.nn.Module):
16
+
17
+ def __init__(
18
+ self,
19
+ n_kv_head: int,
20
+ head_dim: int,
21
+ hidden_size: int,
22
+ attention_heads: int = 4,
23
+ linear_units: int = 2048,
24
+ num_blocks: int = 6,
25
+ dropout_rate: float = 0.1,
26
+ positional_dropout_rate: float = 0.1,
27
+ attention_dropout_rate: float = 0.0,
28
+ normalize_before: bool = True,
29
+ query_bias: bool = False,
30
+ key_bias: bool = False,
31
+ value_bias: bool = False,
32
+ mlp_bias: bool = False,
33
+ activation_type: str = "gelu",
34
+ gelu_approximate: Union[str, None] = None,
35
+ max_position_embeding: int = 8192,
36
+ mlp_type: str = 'gated',
37
+ layer_norm_type: str = 'rms_norm',
38
+ norm_eps: float = 1e-5,
39
+ rms_norm_offset: bool = True,
40
+ selfattention_layer_type: str = "rope_abs_selfattn",
41
+ use_sdpa: bool = False,
42
+ gradient_checkpointing: bool = False,
43
+ rope_theta: float = 10000.0,
44
+ rope_style: str = 'google',
45
+ scale_embed: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ assert selfattention_layer_type in ['rope_abs_selfattn']
50
+ self.pos_enc = WENET_EMB_CLASSES["rope_pos"](
51
+ hidden_size,
52
+ head_dim,
53
+ max_len=max_position_embeding,
54
+ dropout_rate=positional_dropout_rate,
55
+ rope_theta=rope_theta,
56
+ scale=scale_embed)
57
+ if activation_type == "gelu" and gelu_approximate is not None:
58
+ activation = WENET_ACTIVATION_CLASSES['gelu'](
59
+ approximate=gelu_approximate)
60
+ else:
61
+ activation = WENET_ACTIVATION_CLASSES[activation_type]()
62
+
63
+ mlp_class = WENET_MLP_CLASSES[mlp_type]
64
+ self.num_blocks = num_blocks
65
+ # TODO: support lora & refactor lora
66
+ self.decoders = torch.nn.ModuleList([
67
+ TransformerEncoderLayer(
68
+ hidden_size,
69
+ WENET_ATTENTION_CLASSES[selfattention_layer_type](
70
+ attention_heads,
71
+ hidden_size,
72
+ attention_dropout_rate,
73
+ query_bias,
74
+ key_bias,
75
+ value_bias,
76
+ use_sdpa,
77
+ n_kv_head,
78
+ head_dim,
79
+ style=rope_style),
80
+ mlp_class(hidden_size, linear_units, dropout_rate, activation,
81
+ mlp_bias),
82
+ dropout_rate,
83
+ normalize_before,
84
+ layer_norm_type=layer_norm_type,
85
+ norm_eps=norm_eps,
86
+ rms_norm_offset=rms_norm_offset,
87
+ ) for _ in range(self.num_blocks)
88
+ ])
89
+ self.pre_norm = normalize_before
90
+ self.final_norm: Optional[torch.nn.Module] = None
91
+ if self.pre_norm:
92
+ norm_class = WENET_NORM_CLASSES[layer_norm_type]
93
+ if layer_norm_type == "rms_norm":
94
+ norm_class = partial(
95
+ norm_class,
96
+ add_unit_offset=rms_norm_offset,
97
+ )
98
+ self.final_norm = norm_class(hidden_size, eps=norm_eps)
99
+
100
+ self.n_kv_head = n_kv_head
101
+ self.head_dim = head_dim
102
+ self._hidden_size = hidden_size
103
+ self.use_sdpa = use_sdpa
104
+ self.gradient_checkpointing = gradient_checkpointing
105
+
106
+ def forward(
107
+ self,
108
+ input: torch.Tensor,
109
+ att_mask: torch.Tensor,
110
+ input_position: Union[int, torch.Tensor] = 0,
111
+ kv_caches: Optional[List[T_CACHE]] = None,
112
+ ) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]:
113
+ xs, pos_emb = self.pos_enc(input, offset=input_position)
114
+ if self.use_sdpa:
115
+ att_mask = mask_to_bias(att_mask, xs.dtype)
116
+
117
+ if self.gradient_checkpointing and self.training:
118
+ xs = self.forward_layers_checkpointed(xs, att_mask, pos_emb)
119
+ else:
120
+ xs, kv_caches = self.forward_layers(xs, att_mask, pos_emb,
121
+ kv_caches)
122
+ if self.pre_norm and self.final_norm is not None:
123
+ xs = self.final_norm(xs)
124
+ return xs, kv_caches
125
+
126
+ def forward_layers(
127
+ self,
128
+ xs: torch.Tensor,
129
+ att_mask: torch.Tensor,
130
+ pos_emb: torch.Tensor,
131
+ kv_caches: Optional[List[T_CACHE]] = None,
132
+ ) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]:
133
+ if self.training:
134
+ for (i, layer) in enumerate(self.decoders):
135
+ xs, _, _, _ = layer(xs, att_mask, pos_emb)
136
+ new_kv_caches = kv_caches
137
+ else:
138
+ assert kv_caches is not None
139
+ new_kv_caches = []
140
+ for (i, layer) in enumerate(self.decoders):
141
+ xs, _, new_kv_cache, _ = layer(xs,
142
+ att_mask,
143
+ pos_emb,
144
+ att_cache=(kv_caches[i][0],
145
+ kv_caches[i][1]))
146
+ new_kv_caches.append(new_kv_cache)
147
+
148
+ return xs, new_kv_caches
149
+
150
+ @torch.jit.ignore(drop=True)
151
+ def forward_layers_checkpointed(self, xs: torch.Tensor,
152
+ att_mask: torch.Tensor,
153
+ pos_emb: torch.Tensor) -> torch.Tensor:
154
+ for layer in self.decoders:
155
+ xs, _, _, _ = ckpt.checkpoint(layer.__call__, xs, att_mask,
156
+ pos_emb)
157
+ return xs
158
+
159
+ @property
160
+ def hidden_size(self):
161
+ return self._hidden_size
wenet/LLM/sampler.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ import torch
3
+
4
+
5
+ # modified from https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L26
6
+ @torch.no_grad()
7
+ def sampler(
8
+ logits: torch.Tensor,
9
+ temperatures: Union[torch.Tensor, None],
10
+ top_ps: torch.Tensor,
11
+ top_ks: torch.Tensor,
12
+ ) -> torch.Tensor:
13
+ assert logits.size(1) == 1
14
+ logits = logits.squeeze(1) # (batch_size, vocab_size)
15
+ if temperatures is None:
16
+ return torch.argmax(logits, dim=-1).squeeze(dim=-1)
17
+
18
+ # Apply temperature scaling.
19
+ logits.div_(temperatures.unsqueeze(dim=1))
20
+
21
+ # Calculate probabilities with softmax.
22
+ probs = torch.softmax(logits, dim=-1, dtype=torch.float)
23
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
24
+
25
+ # Apply top-p, top-k.
26
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
27
+ top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1)
28
+ probs_sort = torch.where(top_ps_mask, 0, probs_sort)
29
+
30
+ top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device)
31
+ top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
32
+ top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1)
33
+ probs_sort = torch.where(top_ks_mask, 0, probs_sort)
34
+
35
+ # Re-normalization.
36
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
37
+ probs = torch.gather(probs_sort,
38
+ dim=-1,
39
+ index=torch.argsort(probs_idx, dim=-1))
40
+
41
+ next_token_ids = torch.multinomial(probs, num_samples=1,
42
+ replacement=True).squeeze(dim=-1)
43
+ return next_token_ids
wenet/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from wenet.cli.model import load_model # noqa
wenet/bin/alignment.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Di Wu)
2
+ # 2022 Tinnove Inc (authors: Wei Ren)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import print_function
17
+
18
+ import argparse
19
+ import copy
20
+ import logging
21
+ import os
22
+ import sys
23
+
24
+ import torch
25
+ import yaml
26
+ from torch.utils.data import DataLoader
27
+ from textgrid import TextGrid, IntervalTier
28
+ import math
29
+
30
+ from wenet.dataset.dataset import Dataset
31
+ from wenet.utils.ctc_utils import force_align
32
+ from wenet.utils.common import get_subsample
33
+ from wenet.utils.init_model import init_model
34
+ from wenet.utils.init_tokenizer import init_tokenizer
35
+
36
+
37
+ def generator_textgrid(maxtime, lines, output):
38
+ # Download Praat: https://www.fon.hum.uva.nl/praat/
39
+ interval = maxtime / (len(lines) + 1)
40
+ margin = 0.0001
41
+
42
+ tg = TextGrid(maxTime=maxtime)
43
+ linetier = IntervalTier(name="line", maxTime=maxtime)
44
+
45
+ i = 0
46
+ for l in lines:
47
+ s, e, w = l.split()
48
+ linetier.add(minTime=float(s) + margin, maxTime=float(e), mark=w)
49
+
50
+ tg.append(linetier)
51
+ print("successfully generator {}".format(output))
52
+ tg.write(output)
53
+
54
+
55
+ def get_frames_timestamp(alignment,
56
+ prob,
57
+ blank_thres=0.999,
58
+ thres=0.0000000001):
59
+ # convert alignment to a praat format, which is a doing phonetics
60
+ # by computer and helps analyzing alignment
61
+ timestamp = []
62
+ # get frames level duration for each token
63
+ start = 0
64
+ end = 0
65
+ local_start = 0
66
+ while end < len(alignment):
67
+ while end < len(alignment) and alignment[end] == 0:
68
+ end += 1
69
+ if end == len(alignment):
70
+ timestamp[-1] += alignment[start:]
71
+ break
72
+ end += 1
73
+ while end < len(alignment) and alignment[end - 1] == alignment[end]:
74
+ end += 1
75
+ local_start = end - 1
76
+ # find the possible front border for current token
77
+ while local_start >= start and (
78
+ prob[local_start][0] < math.log(blank_thres)
79
+ or prob[local_start][alignment[end - 1]] > math.log(thres)):
80
+ alignment[local_start] = alignment[end - 1]
81
+ local_start -= 1
82
+ cur_alignment = alignment[start:end]
83
+ timestamp.append(cur_alignment)
84
+ start = end
85
+ return timestamp
86
+
87
+
88
+ def get_labformat(timestamp, subsample):
89
+ begin = 0
90
+ begin_time = 0
91
+ duration = 0
92
+ labformat = []
93
+ for idx, t in enumerate(timestamp):
94
+ # 25ms frame_length,10ms hop_length, 1/subsample
95
+ subsample = get_subsample(configs)
96
+ # time duration
97
+ i = 0
98
+ while t[i] == 0:
99
+ i += 1
100
+ begin = i
101
+ dur = 0
102
+ while i < len(t) and t[i] != 0:
103
+ i += 1
104
+ dur += 1
105
+ begin = begin_time + begin * 0.01 * subsample
106
+ duration = dur * 0.01 * subsample
107
+ if idx < len(timestamp) - 1:
108
+ print("{:.2f} {:.2f} {}".format(begin, begin + duration,
109
+ char_dict[t[-1]]))
110
+ labformat.append("{:.2f} {:.2f} {}\n".format(
111
+ begin, begin + duration, char_dict[t[-1]]))
112
+ else: # last token
113
+ non_blank = 0
114
+ for i in t:
115
+ if i != 0:
116
+ token = i
117
+ break
118
+ print("{:.2f} {:.2f} {}".format(begin, begin + duration,
119
+ char_dict[token]))
120
+ labformat.append("{:.2f} {:.2f} {}\n".format(
121
+ begin, begin + duration, char_dict[token]))
122
+ begin_time += len(t) * 0.01 * subsample
123
+ return labformat
124
+
125
+
126
+ if __name__ == '__main__':
127
+ parser = argparse.ArgumentParser(
128
+ description='use ctc to generate alignment')
129
+ parser.add_argument('--config', required=True, help='config file')
130
+ parser.add_argument('--input_file', required=True, help='format data file')
131
+ parser.add_argument('--data_type',
132
+ default='raw',
133
+ choices=['raw', 'shard'],
134
+ help='train and cv data type')
135
+ parser.add_argument('--gpu',
136
+ type=int,
137
+ default=-1,
138
+ help='gpu id for this rank, -1 for cpu')
139
+ parser.add_argument('--device',
140
+ type=str,
141
+ default="cpu",
142
+ choices=["cpu", "npu", "cuda"],
143
+ help='accelerator to use')
144
+ parser.add_argument('--blank_thres',
145
+ default=0.999999,
146
+ type=float,
147
+ help='ctc blank thes')
148
+ parser.add_argument('--thres',
149
+ default=0.000001,
150
+ type=float,
151
+ help='ctc non blank thes')
152
+ parser.add_argument('--checkpoint', required=True, help='checkpoint model')
153
+ parser.add_argument('--dict', required=True, help='dict file')
154
+ parser.add_argument(
155
+ '--non_lang_syms',
156
+ help="non-linguistic symbol file. One symbol per line.")
157
+ parser.add_argument('--result_file',
158
+ required=True,
159
+ help='alignment result file')
160
+ parser.add_argument('--batch_size', type=int, default=1, help='batch size')
161
+ parser.add_argument('--gen_praat',
162
+ action='store_true',
163
+ help='convert alignment to a praat format')
164
+ parser.add_argument('--bpe_model',
165
+ default=None,
166
+ type=str,
167
+ help='bpe model for english part')
168
+
169
+ args = parser.parse_args()
170
+ print(args)
171
+ logging.basicConfig(level=logging.DEBUG,
172
+ format='%(asctime)s %(levelname)s %(message)s')
173
+ if args.gpu != -1:
174
+ # remain the original usage of gpu
175
+ args.device = "cuda"
176
+ if "cuda" in args.device:
177
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
178
+
179
+ if args.batch_size > 1:
180
+ logging.fatal('alignment mode must be running with batch_size == 1')
181
+ sys.exit(1)
182
+
183
+ with open(args.config, 'r') as fin:
184
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
185
+
186
+ # Load dict
187
+ char_dict = {}
188
+ with open(args.dict, 'r') as fin:
189
+ for line in fin:
190
+ arr = line.strip().split()
191
+ assert len(arr) == 2
192
+ char_dict[int(arr[1])] = arr[0]
193
+ eos = len(char_dict) - 1
194
+
195
+ # Init dataset and data loader
196
+ ali_conf = copy.deepcopy(configs['dataset_conf'])
197
+
198
+ ali_conf['filter_conf']['max_length'] = 102400
199
+ ali_conf['filter_conf']['min_length'] = 0
200
+ ali_conf['filter_conf']['token_max_length'] = 102400
201
+ ali_conf['filter_conf']['token_min_length'] = 0
202
+ ali_conf['filter_conf']['max_output_input_ratio'] = 102400
203
+ ali_conf['filter_conf']['min_output_input_ratio'] = 0
204
+ ali_conf['speed_perturb'] = False
205
+ ali_conf['spec_aug'] = False
206
+ ali_conf['spec_trim'] = False
207
+ ali_conf['shuffle'] = False
208
+ ali_conf['sort'] = False
209
+ ali_conf['fbank_conf']['dither'] = 0.0
210
+ ali_conf['batch_conf']['batch_type'] = "static"
211
+ ali_conf['batch_conf']['batch_size'] = args.batch_size
212
+
213
+ tokenizer = init_tokenizer(configs)
214
+ ali_dataset = Dataset(args.data_type,
215
+ args.input_file,
216
+ tokenizer,
217
+ ali_conf,
218
+ partition=False)
219
+
220
+ ali_data_loader = DataLoader(ali_dataset, batch_size=None, num_workers=0)
221
+
222
+ # Init asr model from configs
223
+ model, configs = init_model(args, configs)
224
+
225
+ device = torch.device(args.device)
226
+ model = model.to(device)
227
+
228
+ model.eval()
229
+ with torch.no_grad(), open(args.result_file, 'w',
230
+ encoding='utf-8') as fout:
231
+ for batch_idx, batch in enumerate(ali_data_loader):
232
+ print("#" * 80)
233
+ key, feat, target, feats_length, target_length = batch
234
+
235
+ feat = feat.to(device)
236
+ target = target.to(device)
237
+ feats_length = feats_length.to(device)
238
+ target_length = target_length.to(device)
239
+ # Let's assume B = batch_size and N = beam_size
240
+ # 1. Encoder
241
+ encoder_out, encoder_mask = model._forward_encoder(
242
+ feat, feats_length) # (B, maxlen, encoder_dim)
243
+ maxlen = encoder_out.size(1)
244
+ ctc_probs = model.ctc.log_softmax(
245
+ encoder_out) # (1, maxlen, vocab_size)
246
+ # print(ctc_probs.size(1))
247
+ ctc_probs = ctc_probs.squeeze(0)
248
+ target = target.squeeze(0)
249
+ alignment = force_align(ctc_probs, target)
250
+ fout.write('{} {}\n'.format(key[0], alignment))
251
+
252
+ if args.gen_praat:
253
+ timestamp = get_frames_timestamp(alignment, ctc_probs,
254
+ args.blank_thres, args.thres)
255
+ subsample = get_subsample(configs)
256
+ labformat = get_labformat(timestamp, subsample)
257
+
258
+ lab_path = os.path.join(os.path.dirname(args.result_file),
259
+ key[0] + ".lab")
260
+ with open(lab_path, 'w', encoding='utf-8') as f:
261
+ f.writelines(labformat)
262
+
263
+ textgrid_path = os.path.join(os.path.dirname(args.result_file),
264
+ key[0] + ".TextGrid")
265
+ generator_textgrid(maxtime=(len(alignment) + 1) * 0.01 *
266
+ subsample,
267
+ lines=labformat,
268
+ output=textgrid_path)
wenet/bin/average_model.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc (Di Wu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import argparse
17
+ import glob
18
+ import sys
19
+
20
+ import yaml
21
+ import torch
22
+
23
+
24
+ def get_args():
25
+ parser = argparse.ArgumentParser(description='average model')
26
+ parser.add_argument('--dst_model', required=True, help='averaged model')
27
+ parser.add_argument('--src_path',
28
+ required=True,
29
+ help='src model path for average')
30
+ parser.add_argument('--val_best',
31
+ action="store_true",
32
+ help='averaged model')
33
+ parser.add_argument('--num',
34
+ default=5,
35
+ type=int,
36
+ help='nums for averaged model')
37
+ parser.add_argument('--min_epoch',
38
+ default=0,
39
+ type=int,
40
+ help='min epoch used for averaging model')
41
+ parser.add_argument('--max_epoch',
42
+ default=sys.maxsize,
43
+ type=int,
44
+ help='max epoch used for averaging model')
45
+ parser.add_argument('--min_step',
46
+ default=0,
47
+ type=int,
48
+ help='min step used for averaging model')
49
+ parser.add_argument('--max_step',
50
+ default=sys.maxsize,
51
+ type=int,
52
+ help='max step used for averaging model')
53
+ parser.add_argument('--mode',
54
+ default="hybrid",
55
+ choices=["hybrid", "epoch", "step"],
56
+ type=str,
57
+ help='average mode')
58
+
59
+ args = parser.parse_args()
60
+ print(args)
61
+ return args
62
+
63
+
64
+ def main():
65
+ args = get_args()
66
+ checkpoints = []
67
+ val_scores = []
68
+ if args.val_best:
69
+ if args.mode == "hybrid":
70
+ yamls = glob.glob('{}/*.yaml'.format(args.src_path))
71
+ yamls = [
72
+ f for f in yamls
73
+ if not (os.path.basename(f).startswith('train')
74
+ or os.path.basename(f).startswith('init'))
75
+ ]
76
+ elif args.mode == "step":
77
+ yamls = glob.glob('{}/step_*.yaml'.format(args.src_path))
78
+ else:
79
+ yamls = glob.glob('{}/epoch_*.yaml'.format(args.src_path))
80
+ for y in yamls:
81
+ with open(y, 'r') as f:
82
+ dic_yaml = yaml.load(f, Loader=yaml.FullLoader)
83
+ loss = dic_yaml['loss_dict']['loss']
84
+ epoch = dic_yaml['epoch']
85
+ step = dic_yaml['step']
86
+ tag = dic_yaml['tag']
87
+ if epoch >= args.min_epoch and epoch <= args.max_epoch \
88
+ and step >= args.min_step and step <= args.max_step:
89
+ val_scores += [[epoch, step, loss, tag]]
90
+ sorted_val_scores = sorted(val_scores,
91
+ key=lambda x: x[2],
92
+ reverse=False)
93
+ print("best val (epoch, step, loss, tag) = " +
94
+ str(sorted_val_scores[:args.num]))
95
+ path_list = [
96
+ args.src_path + '/{}.pt'.format(score[-1])
97
+ for score in sorted_val_scores[:args.num]
98
+ ]
99
+ else:
100
+ path_list = glob.glob('{}/[!init]*.pt'.format(args.src_path))
101
+ path_list = sorted(path_list, key=os.path.getmtime)
102
+ path_list = path_list[-args.num:]
103
+ print(path_list)
104
+ avg = {}
105
+ num = args.num
106
+ assert num == len(path_list)
107
+ for path in path_list:
108
+ print('Processing {}'.format(path))
109
+ states = torch.load(path, map_location=torch.device('cpu'))
110
+ for k in states.keys():
111
+ if k not in avg.keys():
112
+ avg[k] = states[k].clone()
113
+ else:
114
+ avg[k] += states[k]
115
+ # average
116
+ for k in avg.keys():
117
+ if avg[k] is not None:
118
+ # pytorch 1.6 use true_divide instead of /=
119
+ avg[k] = torch.true_divide(avg[k], num)
120
+ print('Saving to {}'.format(args.dst_model))
121
+ torch.save(avg, args.dst_model)
122
+
123
+
124
+ if __name__ == '__main__':
125
+ main()
wenet/bin/export_ipex.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2023 Intel Corporation
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import print_function
5
+
6
+ import argparse
7
+ import logging
8
+ import os
9
+
10
+ import torch
11
+ import yaml
12
+
13
+ from wenet.utils.init_model import init_model
14
+ import intel_extension_for_pytorch as ipex
15
+ from intel_extension_for_pytorch.quantization import prepare, convert
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser(description='export your script model')
20
+ parser.add_argument('--config', required=True, help='config file')
21
+ parser.add_argument('--checkpoint', required=True, help='checkpoint model')
22
+ parser.add_argument('--output_file', default=None, help='output file')
23
+ parser.add_argument('--dtype',
24
+ default="fp32",
25
+ help='choose the dtype to run:[fp32,bf16]')
26
+ parser.add_argument('--output_quant_file',
27
+ default=None,
28
+ help='output quantized model file')
29
+ args = parser.parse_args()
30
+ return args
31
+
32
+
33
+ def scripting(model):
34
+ with torch.inference_mode():
35
+ script_model = torch.jit.script(model)
36
+ script_model = torch.jit.freeze(
37
+ script_model,
38
+ preserved_attrs=[
39
+ "forward_encoder_chunk", "ctc_activation",
40
+ "forward_attention_decoder", "subsampling_rate",
41
+ "right_context", "sos_symbol", "eos_symbol",
42
+ "is_bidirectional_decoder"
43
+ ])
44
+ return script_model
45
+
46
+
47
+ def main():
48
+ args = get_args()
49
+ logging.basicConfig(level=logging.DEBUG,
50
+ format='%(asctime)s %(levelname)s %(message)s')
51
+ # No need gpu for model export
52
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
53
+
54
+ with open(args.config, 'r') as fin:
55
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
56
+ model, configs = init_model(args, configs)
57
+ print(model)
58
+
59
+ # Apply IPEX optimization
60
+ model.eval()
61
+ torch._C._jit_set_texpr_fuser_enabled(False)
62
+ model.to(memory_format=torch.channels_last)
63
+ if args.dtype == "fp32":
64
+ ipex_model = ipex.optimize(model)
65
+ elif args.dtype == "bf16": # For Intel 4th generation Xeon (SPR)
66
+ ipex_model = ipex.optimize(model,
67
+ dtype=torch.bfloat16,
68
+ weights_prepack=False)
69
+
70
+ # Export jit torch script model
71
+ if args.output_file:
72
+ if args.dtype == "fp32":
73
+ script_model = scripting(ipex_model)
74
+ elif args.dtype == "bf16":
75
+ torch._C._jit_set_autocast_mode(True)
76
+ with torch.cpu.amp.autocast():
77
+ script_model = scripting(ipex_model)
78
+ script_model.save(args.output_file)
79
+ print('Export model successfully, see {}'.format(args.output_file))
80
+
81
+ # Export quantized jit torch script model
82
+ if args.output_quant_file:
83
+ dynamic_qconfig = ipex.quantization.default_dynamic_qconfig
84
+ dummy_data = (torch.zeros(1, 67, 80), 16, -16,
85
+ torch.zeros(12, 4, 32, 128), torch.zeros(12, 1, 256, 7))
86
+ model = prepare(model, dynamic_qconfig, dummy_data)
87
+ model = convert(model)
88
+ script_quant_model = scripting(model)
89
+ script_quant_model.save(args.output_quant_file)
90
+ print('Export quantized model successfully, '
91
+ 'see {}'.format(args.output_quant_file))
92
+
93
+
94
+ if __name__ == '__main__':
95
+ main()
wenet/bin/export_jit.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ import os
20
+
21
+ import torch
22
+ import yaml
23
+
24
+ from wenet.utils.init_model import init_model
25
+
26
+
27
+ def get_args():
28
+ parser = argparse.ArgumentParser(description='export your script model')
29
+ parser.add_argument('--config', required=True, help='config file')
30
+ parser.add_argument('--checkpoint', required=True, help='checkpoint model')
31
+ parser.add_argument('--output_file', default=None, help='output file')
32
+ parser.add_argument('--output_quant_file',
33
+ default=None,
34
+ help='output quantized model file')
35
+ args = parser.parse_args()
36
+ return args
37
+
38
+
39
+ def main():
40
+ args = get_args()
41
+ args.jit = True
42
+ logging.basicConfig(level=logging.DEBUG,
43
+ format='%(asctime)s %(levelname)s %(message)s')
44
+ # No need gpu for model export
45
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
46
+
47
+ with open(args.config, 'r') as fin:
48
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
49
+ model, configs = init_model(args, configs)
50
+ model.eval()
51
+ print(model)
52
+ # Export jit torch script model
53
+
54
+ if args.output_file:
55
+ script_model = torch.jit.script(model)
56
+ script_model.save(args.output_file)
57
+ print('Export model successfully, see {}'.format(args.output_file))
58
+
59
+ # Export quantized jit torch script model
60
+ if args.output_quant_file:
61
+ quantized_model = torch.quantization.quantize_dynamic(
62
+ model, {torch.nn.Linear}, dtype=torch.qint8)
63
+ print(quantized_model)
64
+ script_quant_model = torch.jit.script(quantized_model)
65
+ script_quant_model.save(args.output_quant_file)
66
+ print('Export quantized model successfully, '
67
+ 'see {}'.format(args.output_quant_file))
68
+
69
+
70
+ if __name__ == '__main__':
71
+ main()
wenet/bin/export_onnx_bpu.py ADDED
@@ -0,0 +1,1065 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Horizon Inc. Xingchen Song ([email protected])
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """NOTE(xcsong): Currently, we only support
15
+ 1. specific conformer encoder architecture, see:
16
+ encoder: conformer
17
+ encoder_conf:
18
+ activation_type: **must be** relu
19
+ attention_heads: 2 or 4 or 8 or any number divisible by output_size
20
+ causal: **must be** true
21
+ cnn_module_kernel: 1 ~ 7
22
+ cnn_module_norm: **must be** batch_norm
23
+ input_layer: **must be** conv2d8
24
+ linear_units: 1 ~ 2048
25
+ normalize_before: **must be** true
26
+ num_blocks: 1 ~ 12
27
+ output_size: 1 ~ 512
28
+ pos_enc_layer_type: **must be** no_pos
29
+ selfattention_layer_type: **must be** selfattn
30
+ use_cnn_module: **must be** true
31
+ use_dynamic_chunk: **must be** true
32
+ use_dynamic_left_chunk: **must be** true
33
+
34
+ 2. specific decoding method: ctc_greedy_search
35
+ """
36
+
37
+ from __future__ import print_function
38
+
39
+ import os
40
+ import sys
41
+ import copy
42
+ import math
43
+ import yaml
44
+ import logging
45
+ from typing import Tuple
46
+
47
+ import torch
48
+ import numpy as np
49
+
50
+ from wenet.transformer.embedding import NoPositionalEncoding
51
+ from wenet.utils.init_model import init_model
52
+ from wenet.bin.export_onnx_cpu import (get_args, to_numpy,
53
+ print_input_output_info)
54
+
55
+ try:
56
+ import onnx
57
+ import onnxruntime
58
+ except ImportError:
59
+ print('Please install onnx and onnxruntime!')
60
+ sys.exit(1)
61
+
62
+ logger = logging.getLogger(__file__)
63
+ logger.setLevel(logging.INFO)
64
+
65
+
66
+ class BPULayerNorm(torch.nn.Module):
67
+ """Refactor torch.nn.LayerNorm to meet 4-D dataflow."""
68
+
69
+ def __init__(self, module, chunk_size=8, run_on_bpu=False):
70
+ super().__init__()
71
+ original = copy.deepcopy(module)
72
+ self.hidden = module.weight.size(0)
73
+ self.chunk_size = chunk_size
74
+ self.run_on_bpu = run_on_bpu
75
+
76
+ if self.run_on_bpu:
77
+ self.weight = torch.nn.Parameter(
78
+ module.weight.reshape(1, self.hidden, 1,
79
+ 1).repeat(1, 1, 1, chunk_size))
80
+ self.bias = torch.nn.Parameter(
81
+ module.bias.reshape(1, self.hidden, 1,
82
+ 1).repeat(1, 1, 1, chunk_size))
83
+ self.negtive = torch.nn.Parameter(
84
+ torch.ones((1, self.hidden, 1, chunk_size)) * -1.0)
85
+ self.eps = torch.nn.Parameter(
86
+ torch.zeros((1, self.hidden, 1, chunk_size)) + module.eps)
87
+ self.mean_conv_1 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False)
88
+ self.mean_conv_1.weight = torch.nn.Parameter(
89
+ torch.ones(self.hidden, self.hidden, 1, 1) /
90
+ (1.0 * self.hidden))
91
+ self.mean_conv_2 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False)
92
+ self.mean_conv_2.weight = torch.nn.Parameter(
93
+ torch.ones(self.hidden, self.hidden, 1, 1) /
94
+ (1.0 * self.hidden))
95
+ else:
96
+ self.norm = module
97
+
98
+ self.check_equal(original)
99
+
100
+ def check_equal(self, module):
101
+ random_data = torch.randn(1, self.chunk_size, self.hidden)
102
+ orig_out = module(random_data)
103
+ new_out = self.forward(random_data.transpose(1, 2).unsqueeze(2))
104
+ np.testing.assert_allclose(to_numpy(orig_out),
105
+ to_numpy(
106
+ new_out.squeeze(2).transpose(1, 2)),
107
+ rtol=1e-02,
108
+ atol=1e-03)
109
+
110
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
111
+ if self.run_on_bpu:
112
+ u = self.mean_conv_1(x) # (1, h, 1, c)
113
+ numerator = x + u * self.negtive # (1, h, 1, c)
114
+ s = torch.pow(numerator, 2) # (1, h, 1, c)
115
+ s = self.mean_conv_2(s) # (1, h, 1, c)
116
+ denominator = torch.sqrt(s + self.eps) # (1, h, 1, c)
117
+ x = torch.div(numerator, denominator) # (1, h, 1, c)
118
+ x = x * self.weight + self.bias
119
+ else:
120
+ x = x.squeeze(2).transpose(1, 2).contiguous()
121
+ x = self.norm(x)
122
+ x = x.transpose(1, 2).contiguous().unsqueeze(2)
123
+ return x
124
+
125
+
126
+ class BPUIdentity(torch.nn.Module):
127
+ """Refactor torch.nn.Identity().
128
+ For inserting BPU node whose input == output.
129
+ """
130
+
131
+ def __init__(self, channels):
132
+ super().__init__()
133
+ self.channels = channels
134
+ self.identity_conv = torch.nn.Conv2d(channels,
135
+ channels,
136
+ 1,
137
+ groups=channels,
138
+ bias=False)
139
+ torch.nn.init.dirac_(self.identity_conv.weight.data, groups=channels)
140
+
141
+ self.check_equal()
142
+
143
+ def check_equal(self):
144
+ random_data = torch.randn(1, self.channels, 1, 10)
145
+ result = self.forward(random_data)
146
+ np.testing.assert_allclose(to_numpy(random_data),
147
+ to_numpy(result),
148
+ rtol=1e-02,
149
+ atol=1e-03)
150
+
151
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
152
+ """Identity with 4-D dataflow, input == output.
153
+ Args:
154
+ x (torch.Tensor): (batch, in_channel, 1, time)
155
+
156
+ Returns:
157
+ (torch.Tensor): (batch, in_channel, 1, time).
158
+ """
159
+ return self.identity_conv(x)
160
+
161
+
162
+ class BPULinear(torch.nn.Module):
163
+ """Refactor torch.nn.Linear or pointwise_conv"""
164
+
165
+ def __init__(self, module, is_pointwise_conv=False):
166
+ super().__init__()
167
+ # Unchanged submodules and attributes
168
+ original = copy.deepcopy(module)
169
+ self.idim = module.weight.size(1)
170
+ self.odim = module.weight.size(0)
171
+ self.is_pointwise_conv = is_pointwise_conv
172
+
173
+ # Modify weight & bias
174
+ self.linear = torch.nn.Conv2d(self.idim, self.odim, 1, 1)
175
+ if is_pointwise_conv:
176
+ # (odim, idim, kernel=1) -> (odim, idim, 1, 1)
177
+ self.linear.weight = torch.nn.Parameter(
178
+ module.weight.unsqueeze(-1))
179
+ else:
180
+ # (odim, idim) -> (odim, idim, 1, 1)
181
+ self.linear.weight = torch.nn.Parameter(
182
+ module.weight.unsqueeze(2).unsqueeze(3))
183
+ self.linear.bias = module.bias
184
+
185
+ self.check_equal(original)
186
+
187
+ def check_equal(self, module):
188
+ random_data = torch.randn(1, 8, self.idim)
189
+ if self.is_pointwise_conv:
190
+ random_data = random_data.transpose(1, 2)
191
+ original_result = module(random_data)
192
+ if self.is_pointwise_conv:
193
+ random_data = random_data.transpose(1, 2)
194
+ original_result = original_result.transpose(1, 2)
195
+ random_data = random_data.transpose(1, 2).unsqueeze(2)
196
+ new_result = self.forward(random_data)
197
+ np.testing.assert_allclose(to_numpy(original_result),
198
+ to_numpy(
199
+ new_result.squeeze(2).transpose(1, 2)),
200
+ rtol=1e-02,
201
+ atol=1e-03)
202
+
203
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
204
+ """Linear with 4-D dataflow.
205
+ Args:
206
+ x (torch.Tensor): (batch, in_channel, 1, time)
207
+ Returns:
208
+ (torch.Tensor): (batch, out_channel, 1, time).
209
+ """
210
+ return self.linear(x)
211
+
212
+
213
+ class BPUGlobalCMVN(torch.nn.Module):
214
+ """Refactor wenet/transformer/cmvn.py::GlobalCMVN"""
215
+
216
+ def __init__(self, module):
217
+ super().__init__()
218
+ # Unchanged submodules and attributes
219
+ self.norm_var = module.norm_var
220
+
221
+ # NOTE(xcsong): Expand to 4-D tensor, (mel_dim) -> (1, 1, mel_dim, 1)
222
+ self.mean = module.mean.unsqueeze(-1).unsqueeze(0).unsqueeze(0)
223
+ self.istd = module.istd.unsqueeze(-1).unsqueeze(0).unsqueeze(0)
224
+
225
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
226
+ """CMVN with 4-D dataflow.
227
+ Args:
228
+ x (torch.Tensor): (batch, 1, mel_dim, time)
229
+ Returns:
230
+ (torch.Tensor): normalized feature with same shape.
231
+ """
232
+ x = x - self.mean
233
+ if self.norm_var:
234
+ x = x * self.istd
235
+ return x
236
+
237
+
238
+ class BPUConv2dSubsampling8(torch.nn.Module):
239
+ """Refactor wenet/transformer/subsampling.py::Conv2dSubsampling8
240
+
241
+ NOTE(xcsong): Only support pos_enc_class == NoPositionalEncoding
242
+ """
243
+
244
+ def __init__(self, module):
245
+ super().__init__()
246
+ # Unchanged submodules and attributes
247
+ original = copy.deepcopy(module)
248
+ self.right_context = module.right_context
249
+ self.subsampling_rate = module.subsampling_rate
250
+ assert isinstance(module.pos_enc, NoPositionalEncoding)
251
+
252
+ # 1. Modify self.conv
253
+ # NOTE(xcsong): We change input shape from (1, 1, frames, mel_dim)
254
+ # to (1, 1, mel_dim, frames) for more efficient computation.
255
+ self.conv = module.conv
256
+ for idx in [0, 2, 4]:
257
+ self.conv[idx].weight = torch.nn.Parameter(
258
+ module.conv[idx].weight.transpose(2, 3))
259
+
260
+ # 2. Modify self.linear
261
+ # NOTE(xcsong): Split final projection to meet the requirment of
262
+ # maximum kernel_size (7 for XJ3)
263
+ self.linear = torch.nn.ModuleList()
264
+ odim = module.linear.weight.size(0) # 512, in this case
265
+ freq = module.linear.weight.size(1) // odim # 4608 // 512 == 9
266
+ self.odim, self.freq = odim, freq
267
+ weight = module.linear.weight.reshape(
268
+ odim, odim, freq,
269
+ 1) # (odim, odim * freq) -> (odim, odim, freq, 1)
270
+ self.split_size = []
271
+ num_split = (freq - 1) // 7 + 1 # XJ3 requires kernel_size <= 7
272
+ slice_begin = 0
273
+ for idx in range(num_split):
274
+ kernel_size = min(freq, (idx + 1) * 7) - idx * 7
275
+ conv_ele = torch.nn.Conv2d(odim, odim, (kernel_size, 1),
276
+ (kernel_size, 1))
277
+ conv_ele.weight = torch.nn.Parameter(
278
+ weight[:, :, slice_begin:slice_begin + kernel_size, :])
279
+ conv_ele.bias = torch.nn.Parameter(torch.zeros_like(conv_ele.bias))
280
+ self.linear.append(conv_ele)
281
+ self.split_size.append(kernel_size)
282
+ slice_begin += kernel_size
283
+ self.linear[0].bias = torch.nn.Parameter(module.linear.bias)
284
+
285
+ self.check_equal(original)
286
+
287
+ def check_equal(self, module):
288
+ random_data = torch.randn(1, 67, 80)
289
+ mask = torch.zeros(1, 1, 67)
290
+ original_result, _, _ = module(random_data, mask) # (1, 8, 512)
291
+ random_data = random_data.transpose(1,
292
+ 2).unsqueeze(0) # (1, 1, 80, 67)
293
+ new_result = self.forward(random_data) # (1, 512, 1, 8)
294
+ np.testing.assert_allclose(to_numpy(original_result),
295
+ to_numpy(
296
+ new_result.squeeze(2).transpose(1, 2)),
297
+ rtol=1e-02,
298
+ atol=1e-03)
299
+
300
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
301
+ """Subsample x with 4-D dataflow.
302
+ Args:
303
+ x (torch.Tensor): Input tensor (#batch, 1, mel_dim, time).
304
+
305
+ Returns:
306
+ torch.Tensor: Subsampled tensor (#batch, odim, 1, time'),
307
+ where time' = time // 8.
308
+ """
309
+ x = self.conv(x) # (1, odim, freq, time')
310
+ x_out = torch.zeros(x.size(0), self.odim, 1, x.size(3))
311
+ x = torch.split(x, self.split_size, dim=2)
312
+ for idx, (x_part, layer) in enumerate(zip(x, self.linear)):
313
+ x_out += layer(x_part)
314
+ return x_out
315
+
316
+
317
+ class BPUMultiHeadedAttention(torch.nn.Module):
318
+ """Refactor wenet/transformer/attention.py::MultiHeadedAttention
319
+
320
+ NOTE(xcsong): Only support attention_class == MultiHeadedAttention,
321
+ we do not consider RelPositionMultiHeadedAttention currently.
322
+ """
323
+
324
+ def __init__(self, module, chunk_size, left_chunks):
325
+ super().__init__()
326
+ # Unchanged submodules and attributes
327
+ original = copy.deepcopy(module)
328
+ self.d_k = module.d_k
329
+ self.h = module.h
330
+ n_feat = self.d_k * self.h
331
+ self.chunk_size = chunk_size
332
+ self.left_chunks = left_chunks
333
+ self.time = chunk_size * (left_chunks + 1)
334
+ self.activation = torch.nn.Softmax(dim=-1)
335
+
336
+ # 1. Modify self.linear_x
337
+ self.linear_q = BPULinear(module.linear_q)
338
+ self.linear_k = BPULinear(module.linear_k)
339
+ self.linear_v = BPULinear(module.linear_v)
340
+ self.linear_out = BPULinear(module.linear_out)
341
+ # 2. denom
342
+ self.register_buffer(
343
+ "denom", torch.full((1, self.h, 1, 1), 1.0 / math.sqrt(self.d_k)))
344
+
345
+ self.check_equal(original)
346
+
347
+ def check_equal(self, module):
348
+ random_data = torch.randn(1, self.chunk_size, self.d_k * self.h)
349
+ mask = torch.ones((1, self.h, self.chunk_size, self.time),
350
+ dtype=torch.bool)
351
+ cache = torch.zeros(1, self.h, self.chunk_size * self.left_chunks,
352
+ self.d_k * 2)
353
+ original_out, original_cache = module(random_data, random_data,
354
+ random_data, mask[:, 0, :, :],
355
+ torch.empty(0), cache)
356
+ random_data = random_data.transpose(1, 2).unsqueeze(2)
357
+ cache = cache.reshape(1, self.h, self.d_k * 2,
358
+ self.chunk_size * self.left_chunks)
359
+ new_out, new_cache = self.forward(random_data, random_data,
360
+ random_data, mask, cache)
361
+ np.testing.assert_allclose(to_numpy(original_out),
362
+ to_numpy(
363
+ new_out.squeeze(2).transpose(1, 2)),
364
+ rtol=1e-02,
365
+ atol=1e-03)
366
+ np.testing.assert_allclose(to_numpy(original_cache),
367
+ to_numpy(new_cache.transpose(2, 3)),
368
+ rtol=1e-02,
369
+ atol=1e-03)
370
+
371
+ def forward(
372
+ self,
373
+ q: torch.Tensor,
374
+ k: torch.Tensor,
375
+ v: torch.Tensor,
376
+ mask: torch.Tensor,
377
+ cache: torch.Tensor,
378
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
379
+ """Compute scaled dot product attention.
380
+
381
+ Args:
382
+ q (torch.Tensor): Query tensor (#batch, size, 1, chunk_size).
383
+ k (torch.Tensor): Key tensor (#batch, size, 1, chunk_size).
384
+ v (torch.Tensor): Value tensor (#batch, size, 1, chunk_size).
385
+ mask (torch.Tensor): Mask tensor,
386
+ (#batch, head, chunk_size, cache_t + chunk_size).
387
+ cache (torch.Tensor): Cache tensor
388
+ (1, head, d_k * 2, cache_t),
389
+ where `cache_t == chunk_size * left_chunks`.
390
+
391
+
392
+ Returns:
393
+ torch.Tensor: Output tensor (#batch, size, 1, chunk_size).
394
+ torch.Tensor: Cache tensor
395
+ (1, head, d_k * 2, cache_t + chunk_size)
396
+ where `cache_t == chunk_size * left_chunks`
397
+ """
398
+ # 1. Forward QKV
399
+ q = self.linear_q(q) # (1, d, 1, c) d == size, c == chunk_size
400
+ k = self.linear_k(k) # (1, d, 1, c)
401
+ v = self.linear_v(v) # (1, d, 1, c)
402
+ q = q.view(1, self.h, self.d_k, self.chunk_size)
403
+ k = k.view(1, self.h, self.d_k, self.chunk_size)
404
+ v = v.view(1, self.h, self.d_k, self.chunk_size)
405
+ q = q.transpose(2, 3) # (batch, head, time1, d_k)
406
+ k_cache, v_cache = torch.split(cache, cache.size(2) // 2, dim=2)
407
+ k = torch.cat((k_cache, k), dim=3)
408
+ v = torch.cat((v_cache, v), dim=3)
409
+ new_cache = torch.cat((k, v), dim=2)
410
+ # 2. (Q^T)K
411
+ scores = torch.matmul(q, k) * self.denom # (#b, n_head, time1, time2)
412
+ # 3. Forward attention
413
+ mask = mask.eq(0)
414
+ scores = scores.masked_fill(mask, -float('inf'))
415
+ attn = self.activation(scores).masked_fill(mask, 0.0)
416
+ attn = attn.transpose(2, 3)
417
+ x = torch.matmul(v, attn)
418
+ x = x.view(1, self.d_k * self.h, 1, self.chunk_size)
419
+ x_out = self.linear_out(x)
420
+ return x_out, new_cache
421
+
422
+
423
+ class BPUConvolution(torch.nn.Module):
424
+ """Refactor wenet/transformer/convolution.py::ConvolutionModule
425
+
426
+ NOTE(xcsong): Only suport use_layer_norm == False
427
+ """
428
+
429
+ def __init__(self, module):
430
+ super().__init__()
431
+ # Unchanged submodules and attributes
432
+ original = copy.deepcopy(module)
433
+ self.lorder = module.lorder
434
+ self.use_layer_norm = False
435
+ self.activation = module.activation
436
+ channels = module.pointwise_conv1.weight.size(1)
437
+ self.channels = channels
438
+ kernel_size = module.depthwise_conv.weight.size(2)
439
+ assert module.use_layer_norm is False
440
+
441
+ # 1. Modify self.pointwise_conv1
442
+ self.pointwise_conv1 = BPULinear(module.pointwise_conv1, True)
443
+
444
+ # 2. Modify self.depthwise_conv
445
+ self.depthwise_conv = torch.nn.Conv2d(channels,
446
+ channels, (1, kernel_size),
447
+ stride=1,
448
+ groups=channels)
449
+ self.depthwise_conv.weight = torch.nn.Parameter(
450
+ module.depthwise_conv.weight.unsqueeze(-2))
451
+ self.depthwise_conv.bias = torch.nn.Parameter(
452
+ module.depthwise_conv.bias)
453
+
454
+ # 3. Modify self.norm, Only support batchnorm2d
455
+ self.norm = torch.nn.BatchNorm2d(channels)
456
+ self.norm.training = False
457
+ self.norm.num_features = module.norm.num_features
458
+ self.norm.eps = module.norm.eps
459
+ self.norm.momentum = module.norm.momentum
460
+ self.norm.weight = torch.nn.Parameter(module.norm.weight)
461
+ self.norm.bias = torch.nn.Parameter(module.norm.bias)
462
+ self.norm.running_mean = module.norm.running_mean
463
+ self.norm.running_var = module.norm.running_var
464
+
465
+ # 4. Modify self.pointwise_conv2
466
+ self.pointwise_conv2 = BPULinear(module.pointwise_conv2, True)
467
+
468
+ # 5. Identity conv, for running `concat` on BPU
469
+ self.identity = BPUIdentity(channels)
470
+
471
+ self.check_equal(original)
472
+
473
+ def check_equal(self, module):
474
+ random_data = torch.randn(1, 8, self.channels)
475
+ cache = torch.zeros((1, self.channels, self.lorder))
476
+ original_out, original_cache = module(random_data, cache=cache)
477
+ random_data = random_data.transpose(1, 2).unsqueeze(2)
478
+ cache = cache.unsqueeze(2)
479
+ new_out, new_cache = self.forward(random_data, cache)
480
+ np.testing.assert_allclose(to_numpy(original_out),
481
+ to_numpy(
482
+ new_out.squeeze(2).transpose(1, 2)),
483
+ rtol=1e-02,
484
+ atol=1e-03)
485
+ np.testing.assert_allclose(to_numpy(original_cache),
486
+ to_numpy(new_cache.squeeze(2)),
487
+ rtol=1e-02,
488
+ atol=1e-03)
489
+
490
+ def forward(self, x: torch.Tensor,
491
+ cache: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
492
+ """Compute convolution module.
493
+ Args:
494
+ x (torch.Tensor): Input tensor (#batch, channels, 1, chunk_size).
495
+ cache (torch.Tensor): left context cache, it is only
496
+ used in causal convolution (#batch, channels, 1, cache_t).
497
+ Returns:
498
+ torch.Tensor: Output tensor (#batch, channels, 1, chunk_size).
499
+ torch.Tensor: Cache tensor (#batch, channels, 1, cache_t).
500
+ """
501
+ # Concat cache
502
+ x = torch.cat((self.identity(cache), self.identity(x)), dim=3)
503
+ new_cache = x[:, :, :, -self.lorder:]
504
+
505
+ # GLU mechanism
506
+ x = self.pointwise_conv1(x) # (batch, 2*channel, 1, dim)
507
+ x = torch.nn.functional.glu(x, dim=1) # (b, channel, 1, dim)
508
+
509
+ # Depthwise Conv
510
+ x = self.depthwise_conv(x)
511
+ x = self.activation(self.norm(x))
512
+ x = self.pointwise_conv2(x)
513
+ return x, new_cache
514
+
515
+
516
+ class BPUFFN(torch.nn.Module):
517
+ """Refactor wenet/transformer/positionwise_feed_forward.py::PositionwiseFeedForward
518
+ """
519
+
520
+ def __init__(self, module):
521
+ super().__init__()
522
+ # Unchanged submodules and attributes
523
+ original = copy.deepcopy(module)
524
+ self.activation = module.activation
525
+
526
+ # 1. Modify self.w_x
527
+ self.w_1 = BPULinear(module.w_1)
528
+ self.w_2 = BPULinear(module.w_2)
529
+
530
+ self.check_equal(original)
531
+
532
+ def check_equal(self, module):
533
+ random_data = torch.randn(1, 8, self.w_1.idim)
534
+ original_out = module(random_data)
535
+ random_data = random_data.transpose(1, 2).unsqueeze(2)
536
+ new_out = self.forward(random_data)
537
+ np.testing.assert_allclose(to_numpy(original_out),
538
+ to_numpy(
539
+ new_out.squeeze(2).transpose(1, 2)),
540
+ rtol=1e-02,
541
+ atol=1e-03)
542
+
543
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
544
+ """Forward function.
545
+
546
+ Args:
547
+ xs: input tensor (B, D, 1, L)
548
+ Returns:
549
+ output tensor, (B, D, 1, L)
550
+ """
551
+ return self.w_2(self.activation(self.w_1(x)))
552
+
553
+
554
+ class BPUConformerEncoderLayer(torch.nn.Module):
555
+ """Refactor wenet/transformer/encoder_layer.py::ConformerEncoderLayer
556
+ """
557
+
558
+ def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False):
559
+ super().__init__()
560
+ # Unchanged submodules and attributes
561
+ original = copy.deepcopy(module)
562
+ self.size = module.size
563
+ assert module.normalize_before is True
564
+ assert module.concat_after is False
565
+
566
+ # 1. Modify submodules
567
+ self.feed_forward_macaron = BPUFFN(module.feed_forward_macaron)
568
+ self.self_attn = BPUMultiHeadedAttention(module.self_attn, chunk_size,
569
+ left_chunks)
570
+ self.conv_module = BPUConvolution(module.conv_module)
571
+ self.feed_forward = BPUFFN(module.feed_forward)
572
+
573
+ # 2. Modify norms
574
+ self.norm_ff = BPULayerNorm(module.norm_ff, chunk_size, ln_run_on_bpu)
575
+ self.norm_mha = BPULayerNorm(module.norm_mha, chunk_size,
576
+ ln_run_on_bpu)
577
+ self.norm_ff_macron = BPULayerNorm(module.norm_ff_macaron, chunk_size,
578
+ ln_run_on_bpu)
579
+ self.norm_conv = BPULayerNorm(module.norm_conv, chunk_size,
580
+ ln_run_on_bpu)
581
+ self.norm_final = BPULayerNorm(module.norm_final, chunk_size,
582
+ ln_run_on_bpu)
583
+
584
+ # 3. 4-D ff_scale
585
+ self.register_buffer("ff_scale",
586
+ torch.full((1, self.size, 1, 1), module.ff_scale))
587
+
588
+ self.check_equal(original)
589
+
590
+ def check_equal(self, module):
591
+ time1 = self.self_attn.chunk_size
592
+ time2 = self.self_attn.time
593
+ h, d_k = self.self_attn.h, self.self_attn.d_k
594
+ random_x = torch.randn(1, time1, self.size)
595
+ att_mask = torch.ones(1, h, time1, time2)
596
+ att_cache = torch.zeros(1, h, time2 - time1, d_k * 2)
597
+ cnn_cache = torch.zeros(1, self.size, self.conv_module.lorder)
598
+ original_x, _, original_att_cache, original_cnn_cache = module(
599
+ random_x,
600
+ att_mask[:, 0, :, :],
601
+ torch.empty(0),
602
+ att_cache=att_cache,
603
+ cnn_cache=cnn_cache)
604
+ random_x = random_x.transpose(1, 2).unsqueeze(2)
605
+ att_cache = att_cache.reshape(1, h, d_k * 2, time2 - time1)
606
+ cnn_cache = cnn_cache.unsqueeze(2)
607
+ new_x, new_att_cache, new_cnn_cache = self.forward(
608
+ random_x, att_mask, att_cache, cnn_cache)
609
+ np.testing.assert_allclose(to_numpy(original_att_cache),
610
+ to_numpy(new_att_cache.transpose(2, 3)),
611
+ rtol=1e-02,
612
+ atol=1e-03)
613
+ np.testing.assert_allclose(to_numpy(original_x),
614
+ to_numpy(new_x.squeeze(2).transpose(1, 2)),
615
+ rtol=1e-02,
616
+ atol=1e-03)
617
+ np.testing.assert_allclose(to_numpy(original_cnn_cache),
618
+ to_numpy(new_cnn_cache.squeeze(2)),
619
+ rtol=1e-02,
620
+ atol=1e-03)
621
+
622
+ def forward(
623
+ self, x: torch.Tensor, att_mask: torch.Tensor, att_cache: torch.Tensor,
624
+ cnn_cache: torch.Tensor
625
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
626
+ """Compute encoded features.
627
+
628
+ Args:
629
+ x (torch.Tensor): (#batch, size, 1, chunk_size)
630
+ att_mask (torch.Tensor): Mask tensor for the input
631
+ (#batch, head, chunk_size, cache_t1 + chunk_size),
632
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
633
+ (#batch=1, head, d_k * 2, cache_t1), head * d_k == size.
634
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
635
+ (#batch=1, size, 1, cache_t2)
636
+ Returns:
637
+ torch.Tensor: Output tensor (#batch, size, 1, chunk_size).
638
+ torch.Tensor: att_cache tensor,
639
+ (1, head, d_k * 2, cache_t1 + chunk_size).
640
+ torch.Tensor: cnn_cahce tensor (#batch, size, 1, cache_t2).
641
+ """
642
+ # 1. ffn_macaron
643
+ residual = x
644
+ x = self.norm_ff_macron(x)
645
+ x = residual + self.ff_scale * self.feed_forward_macaron(x)
646
+
647
+ # 2. attention
648
+ residual = x
649
+ x = self.norm_mha(x)
650
+ x_att, new_att_cache = self.self_attn(x, x, x, att_mask, att_cache)
651
+ x = residual + x_att
652
+
653
+ # 3. convolution
654
+ residual = x
655
+ x = self.norm_conv(x)
656
+ x, new_cnn_cache = self.conv_module(x, cnn_cache)
657
+ x = residual + x
658
+
659
+ # 4. ffn
660
+ residual = x
661
+ x = self.norm_ff(x)
662
+ x = residual + self.ff_scale * self.feed_forward(x)
663
+
664
+ # 5. final post-norm
665
+ x = self.norm_final(x)
666
+
667
+ return x, new_att_cache, new_cnn_cache
668
+
669
+
670
+ class BPUConformerEncoder(torch.nn.Module):
671
+ """Refactor wenet/transformer/encoder.py::ConformerEncoder
672
+ """
673
+
674
+ def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False):
675
+ super().__init__()
676
+ # Unchanged submodules and attributes
677
+ original = copy.deepcopy(module)
678
+ output_size = module.output_size()
679
+ self._output_size = module.output_size()
680
+ self.after_norm = module.after_norm
681
+ self.chunk_size = chunk_size
682
+ self.left_chunks = left_chunks
683
+ self.head = module.encoders[0].self_attn.h
684
+ self.layers = len(module.encoders)
685
+
686
+ # 1. Modify submodules
687
+ self.global_cmvn = BPUGlobalCMVN(module.global_cmvn)
688
+ self.embed = BPUConv2dSubsampling8(module.embed)
689
+ self.encoders = torch.nn.ModuleList()
690
+ for layer in module.encoders:
691
+ self.encoders.append(
692
+ BPUConformerEncoderLayer(layer, chunk_size, left_chunks,
693
+ ln_run_on_bpu))
694
+
695
+ # 2. Auxiliary conv
696
+ self.identity_cnncache = BPUIdentity(output_size)
697
+
698
+ self.check_equal(original)
699
+
700
+ def check_equal(self, module):
701
+ time1 = self.encoders[0].self_attn.chunk_size
702
+ time2 = self.encoders[0].self_attn.time
703
+ layers = self.layers
704
+ h, d_k = self.head, self.encoders[0].self_attn.d_k
705
+ decoding_window = (self.chunk_size - 1) * \
706
+ module.embed.subsampling_rate + \
707
+ module.embed.right_context + 1
708
+ lorder = self.encoders[0].conv_module.lorder
709
+ random_x = torch.randn(1, decoding_window, 80)
710
+ att_mask = torch.ones(1, h, time1, time2)
711
+ att_cache = torch.zeros(layers, h, time2 - time1, d_k * 2)
712
+ cnn_cache = torch.zeros(layers, 1, self._output_size, lorder)
713
+ orig_x, orig_att_cache, orig_cnn_cache = module.forward_chunk(
714
+ random_x,
715
+ 0,
716
+ time2 - time1,
717
+ att_mask=att_mask[:, 0, :, :],
718
+ att_cache=att_cache,
719
+ cnn_cache=cnn_cache)
720
+ random_x = random_x.unsqueeze(0)
721
+ att_cache = att_cache.reshape(1, h * layers, d_k * 2, time2 - time1)
722
+ cnn_cache = cnn_cache.reshape(1, self._output_size, layers, lorder)
723
+ new_x, new_att_cache, new_cnn_cache = self.forward(
724
+ random_x, att_cache, cnn_cache, att_mask)
725
+ caches = torch.split(new_att_cache, h, dim=1)
726
+ caches = [c.transpose(2, 3) for c in caches]
727
+ np.testing.assert_allclose(to_numpy(orig_att_cache),
728
+ to_numpy(torch.cat(caches, dim=0)),
729
+ rtol=1e-02,
730
+ atol=1e-03)
731
+ np.testing.assert_allclose(to_numpy(orig_x),
732
+ to_numpy(new_x.squeeze(2).transpose(1, 2)),
733
+ rtol=1e-02,
734
+ atol=1e-03)
735
+ np.testing.assert_allclose(
736
+ to_numpy(orig_cnn_cache),
737
+ to_numpy(new_cnn_cache.transpose(0, 2).transpose(1, 2)),
738
+ rtol=1e-02,
739
+ atol=1e-03)
740
+
741
+ def forward(
742
+ self, xs: torch.Tensor, att_cache: torch.Tensor,
743
+ cnn_cache: torch.Tensor, att_mask: torch.Tensor
744
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
745
+ """ Forward just one chunk
746
+
747
+ Args:
748
+ xs (torch.Tensor): chunk input, with shape (b=1, 1, time, mel-dim),
749
+ where `time == (chunk_size - 1) * subsample_rate + \
750
+ subsample.right_context + 1`
751
+ att_cache (torch.Tensor): cache tensor for KEY & VALUE in
752
+ transformer/conformer attention, with shape
753
+ (1, head * elayers, d_k * 2, cache_t1), where
754
+ `head * d_k == hidden-dim` and
755
+ `cache_t1 == chunk_size * left_chunks`.
756
+ cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
757
+ (1, hidden-dim, elayers, cache_t2), where
758
+ `cache_t2 == cnn.lorder - 1`
759
+ att_mask (torch.Tensor): Mask tensor for the input
760
+ (#batch, head, chunk_size, cache_t1 + chunk_size),
761
+
762
+ Returns:
763
+ torch.Tensor: output of current input xs,
764
+ with shape (b=1, hidden-dim, 1, chunk_size).
765
+ torch.Tensor: new attention cache required for next chunk, with
766
+ same shape as the original att_cache.
767
+ torch.Tensor: new conformer cnn cache required for next chunk, with
768
+ same shape as the original cnn_cache.
769
+ """
770
+ # xs: (B, 1, time, mel_dim) -> (B, 1, mel_dim, time)
771
+ xs = xs.transpose(2, 3)
772
+ xs = self.global_cmvn(xs)
773
+ # xs: (B, 1, mel_dim, time) -> (B, hidden_dim, 1, chunk_size)
774
+ xs = self.embed(xs)
775
+
776
+ att_cache = torch.split(att_cache, self.head, dim=1)
777
+ cnn_cache = self.identity_cnncache(cnn_cache)
778
+ cnn_cache = torch.split(cnn_cache, 1, dim=2)
779
+ r_att_cache = []
780
+ r_cnn_cache = []
781
+ for i, layer in enumerate(self.encoders):
782
+ xs, new_att_cache, new_cnn_cache = layer(xs,
783
+ att_mask,
784
+ att_cache=att_cache[i],
785
+ cnn_cache=cnn_cache[i])
786
+ r_att_cache.append(new_att_cache[:, :, :, self.chunk_size:])
787
+ r_cnn_cache.append(new_cnn_cache)
788
+ r_att_cache = torch.cat(r_att_cache, dim=1)
789
+ r_cnn_cache = self.identity_cnncache(torch.cat(r_cnn_cache, dim=2))
790
+
791
+ xs = xs.squeeze(2).transpose(1, 2).contiguous()
792
+ xs = self.after_norm(xs)
793
+ # NOTE(xcsong): 4D in, 4D out to meet the requirment of CTC input.
794
+ xs = xs.transpose(1, 2).contiguous().unsqueeze(2) # (B, C, 1, T)
795
+
796
+ return (xs, r_att_cache, r_cnn_cache)
797
+
798
+
799
+ class BPUCTC(torch.nn.Module):
800
+ """Refactor wenet/transformer/ctc.py::CTC
801
+ """
802
+
803
+ def __init__(self, module):
804
+ super().__init__()
805
+ # Unchanged submodules and attributes
806
+ original = copy.deepcopy(module)
807
+ self.idim = module.ctc_lo.weight.size(1)
808
+ num_class = module.ctc_lo.weight.size(0)
809
+
810
+ # 1. Modify self.ctc_lo, Split final projection to meet the
811
+ # requirment of maximum in/out channels (2048 for XJ3)
812
+ self.ctc_lo = torch.nn.ModuleList()
813
+ self.split_size = []
814
+ num_split = (num_class - 1) // 2048 + 1
815
+ for idx in range(num_split):
816
+ out_channel = min(num_class, (idx + 1) * 2048) - idx * 2048
817
+ conv_ele = torch.nn.Conv2d(self.idim, out_channel, 1, 1)
818
+ self.ctc_lo.append(conv_ele)
819
+ self.split_size.append(out_channel)
820
+ orig_weight = torch.split(module.ctc_lo.weight, self.split_size, dim=0)
821
+ orig_bias = torch.split(module.ctc_lo.bias, self.split_size, dim=0)
822
+ for i, (w, b) in enumerate(zip(orig_weight, orig_bias)):
823
+ w = w.unsqueeze(2).unsqueeze(3)
824
+ self.ctc_lo[i].weight = torch.nn.Parameter(w)
825
+ self.ctc_lo[i].bias = torch.nn.Parameter(b)
826
+
827
+ self.check_equal(original)
828
+
829
+ def check_equal(self, module):
830
+ random_data = torch.randn(1, 100, self.idim)
831
+ original_result = module.ctc_lo(random_data)
832
+ random_data = random_data.transpose(1, 2).unsqueeze(2)
833
+ new_result = self.forward(random_data)
834
+ np.testing.assert_allclose(to_numpy(original_result),
835
+ to_numpy(
836
+ new_result.squeeze(2).transpose(1, 2)),
837
+ rtol=1e-02,
838
+ atol=1e-03)
839
+
840
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
841
+ """frame activations, without softmax.
842
+
843
+ Args:
844
+ Tensor x: 4d tensor (B, hidden_dim, 1, chunk_size)
845
+ Returns:
846
+ torch.Tensor: (B, num_class, 1, chunk_size)
847
+ """
848
+ out = []
849
+ for i, layer in enumerate(self.ctc_lo):
850
+ out.append(layer(x))
851
+ out = torch.cat(out, dim=1)
852
+ return out
853
+
854
+
855
+ def export_encoder(asr_model, args):
856
+ logger.info("Stage-1: export encoder")
857
+ decode_window, mel_dim = args.decoding_window, args.feature_size
858
+ encoder = BPUConformerEncoder(asr_model.encoder, args.chunk_size,
859
+ args.num_decoding_left_chunks,
860
+ args.ln_run_on_bpu)
861
+ encoder.eval()
862
+ encoder_outpath = os.path.join(args.output_dir, 'encoder.onnx')
863
+
864
+ logger.info("Stage-1.1: prepare inputs for encoder")
865
+ chunk = torch.randn((1, 1, decode_window, mel_dim))
866
+ required_cache_size = encoder.chunk_size * encoder.left_chunks
867
+ kv_time = required_cache_size + encoder.chunk_size
868
+ hidden, layers = encoder._output_size, len(encoder.encoders)
869
+ head = encoder.encoders[0].self_attn.h
870
+ d_k = hidden // head
871
+ lorder = encoder.encoders[0].conv_module.lorder
872
+ att_cache = torch.zeros(1, layers * head, d_k * 2, required_cache_size)
873
+ att_mask = torch.ones((1, head, encoder.chunk_size, kv_time))
874
+ att_mask[:, :, :, :required_cache_size] = 0
875
+ cnn_cache = torch.zeros((1, hidden, layers, lorder))
876
+ inputs = (chunk, att_cache, cnn_cache, att_mask)
877
+ logger.info("chunk.size(): {} att_cache.size(): {} "
878
+ "cnn_cache.size(): {} att_mask.size(): {}".format(
879
+ list(chunk.size()), list(att_cache.size()),
880
+ list(cnn_cache.size()), list(att_mask.size())))
881
+
882
+ logger.info("Stage-1.2: torch.onnx.export")
883
+ # NOTE(xcsong): Below attributes will be used in
884
+ # onnx2horizonbin.py::generate_config()
885
+ attributes = {}
886
+ attributes['input_name'] = "chunk;att_cache;cnn_cache;att_mask"
887
+ attributes['output_name'] = "output;r_att_cache;r_cnn_cache"
888
+ attributes['input_type'] = "featuremap;featuremap;featuremap;featuremap"
889
+ attributes['norm_type'] = \
890
+ "no_preprocess;no_preprocess;no_preprocess;no_preprocess"
891
+ attributes['input_layout_train'] = "NCHW;NCHW;NCHW;NCHW"
892
+ attributes['input_layout_rt'] = "NCHW;NCHW;NCHW;NCHW"
893
+ attributes['input_shape'] = \
894
+ "{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{}".format(
895
+ chunk.size(0), chunk.size(1), chunk.size(2), chunk.size(3),
896
+ att_cache.size(0), att_cache.size(1), att_cache.size(2),
897
+ att_cache.size(3), cnn_cache.size(0), cnn_cache.size(1),
898
+ cnn_cache.size(2), cnn_cache.size(3), att_mask.size(0),
899
+ att_mask.size(1), att_mask.size(2), att_mask.size(3)
900
+ )
901
+ torch.onnx.export( # NOTE(xcsong): only support opset==11
902
+ encoder,
903
+ inputs,
904
+ encoder_outpath,
905
+ opset_version=11,
906
+ export_params=True,
907
+ do_constant_folding=True,
908
+ input_names=attributes['input_name'].split(';'),
909
+ output_names=attributes['output_name'].split(';'),
910
+ dynamic_axes=None,
911
+ verbose=False)
912
+ onnx_encoder = onnx.load(encoder_outpath)
913
+ for k in vars(args):
914
+ meta = onnx_encoder.metadata_props.add()
915
+ meta.key, meta.value = str(k), str(getattr(args, k))
916
+ for k in attributes:
917
+ meta = onnx_encoder.metadata_props.add()
918
+ meta.key, meta.value = str(k), str(attributes[k])
919
+ onnx.checker.check_model(onnx_encoder)
920
+ onnx.helper.printable_graph(onnx_encoder.graph)
921
+ onnx.save(onnx_encoder, encoder_outpath)
922
+ print_input_output_info(onnx_encoder, "onnx_encoder")
923
+ logger.info('Export onnx_encoder, done! see {}'.format(encoder_outpath))
924
+
925
+ logger.info("Stage-1.3: check onnx_encoder and torch_encoder")
926
+ torch_output = []
927
+ torch_chunk, torch_att_mask = copy.deepcopy(chunk), copy.deepcopy(att_mask)
928
+ torch_att_cache = copy.deepcopy(att_cache)
929
+ torch_cnn_cache = copy.deepcopy(cnn_cache)
930
+ for i in range(10):
931
+ logger.info("torch chunk-{}: {}, att_cache: {}, cnn_cache: {}"
932
+ ", att_mask: {}".format(i, list(torch_chunk.size()),
933
+ list(torch_att_cache.size()),
934
+ list(torch_cnn_cache.size()),
935
+ list(torch_att_mask.size())))
936
+ torch_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)):] = 1
937
+ out, torch_att_cache, torch_cnn_cache = encoder(
938
+ torch_chunk, torch_att_cache, torch_cnn_cache, torch_att_mask)
939
+ torch_output.append(out)
940
+ torch_output = torch.cat(torch_output, dim=-1)
941
+
942
+ onnx_output = []
943
+ onnx_chunk, onnx_att_mask = to_numpy(chunk), to_numpy(att_mask)
944
+ onnx_att_cache = to_numpy(att_cache)
945
+ onnx_cnn_cache = to_numpy(cnn_cache)
946
+ ort_session = onnxruntime.InferenceSession(encoder_outpath)
947
+ input_names = [node.name for node in onnx_encoder.graph.input]
948
+ for i in range(10):
949
+ logger.info("onnx chunk-{}: {}, att_cache: {}, cnn_cache: {},"
950
+ " att_mask: {}".format(i, onnx_chunk.shape,
951
+ onnx_att_cache.shape,
952
+ onnx_cnn_cache.shape,
953
+ onnx_att_mask.shape))
954
+ onnx_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)):] = 1
955
+ ort_inputs = {
956
+ 'chunk': onnx_chunk,
957
+ 'att_cache': onnx_att_cache,
958
+ 'cnn_cache': onnx_cnn_cache,
959
+ 'att_mask': onnx_att_mask,
960
+ }
961
+ ort_outs = ort_session.run(None, ort_inputs)
962
+ onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2]
963
+ onnx_output.append(ort_outs[0])
964
+ onnx_output = np.concatenate(onnx_output, axis=-1)
965
+
966
+ np.testing.assert_allclose(to_numpy(torch_output),
967
+ onnx_output,
968
+ rtol=1e-03,
969
+ atol=1e-04)
970
+ meta = ort_session.get_modelmeta()
971
+ logger.info("custom_metadata_map={}".format(meta.custom_metadata_map))
972
+ logger.info("Check onnx_encoder, pass!")
973
+ return encoder, ort_session
974
+
975
+
976
+ def export_ctc(asr_model, args):
977
+ logger.info("Stage-2: export ctc")
978
+ ctc = BPUCTC(asr_model.ctc).eval()
979
+ ctc_outpath = os.path.join(args.output_dir, 'ctc.onnx')
980
+
981
+ logger.info("Stage-2.1: prepare inputs for ctc")
982
+ hidden = torch.randn((1, args.output_size, 1, args.chunk_size))
983
+
984
+ logger.info("Stage-2.2: torch.onnx.export")
985
+ # NOTE(xcsong): Below attributes will be used in
986
+ # onnx2horizonbin.py::generate_config()
987
+ attributes = {}
988
+ attributes['input_name'], attributes['input_type'] = "hidden", "featuremap"
989
+ attributes['norm_type'] = "no_preprocess"
990
+ attributes['input_layout_train'] = "NCHW"
991
+ attributes['input_layout_rt'] = "NCHW"
992
+ attributes['input_shape'] = "{}x{}x{}x{}".format(
993
+ hidden.size(0),
994
+ hidden.size(1),
995
+ hidden.size(2),
996
+ hidden.size(3),
997
+ )
998
+ torch.onnx.export(ctc,
999
+ hidden,
1000
+ ctc_outpath,
1001
+ opset_version=11,
1002
+ export_params=True,
1003
+ do_constant_folding=True,
1004
+ input_names=['hidden'],
1005
+ output_names=['probs'],
1006
+ dynamic_axes=None,
1007
+ verbose=False)
1008
+ onnx_ctc = onnx.load(ctc_outpath)
1009
+ for k in vars(args):
1010
+ meta = onnx_ctc.metadata_props.add()
1011
+ meta.key, meta.value = str(k), str(getattr(args, k))
1012
+ for k in attributes:
1013
+ meta = onnx_ctc.metadata_props.add()
1014
+ meta.key, meta.value = str(k), str(attributes[k])
1015
+ onnx.checker.check_model(onnx_ctc)
1016
+ onnx.helper.printable_graph(onnx_ctc.graph)
1017
+ onnx.save(onnx_ctc, ctc_outpath)
1018
+ print_input_output_info(onnx_ctc, "onnx_ctc")
1019
+ logger.info('Export onnx_ctc, done! see {}'.format(ctc_outpath))
1020
+
1021
+ logger.info("Stage-2.3: check onnx_ctc and torch_ctc")
1022
+ torch_output = ctc(hidden)
1023
+ ort_session = onnxruntime.InferenceSession(ctc_outpath)
1024
+ onnx_output = ort_session.run(None, {'hidden': to_numpy(hidden)})
1025
+
1026
+ np.testing.assert_allclose(to_numpy(torch_output),
1027
+ onnx_output[0],
1028
+ rtol=1e-03,
1029
+ atol=1e-04)
1030
+ meta = ort_session.get_modelmeta()
1031
+ logger.info("custom_metadata_map={}".format(meta.custom_metadata_map))
1032
+ logger.info("Check onnx_ctc, pass!")
1033
+ return ctc, ort_session
1034
+
1035
+
1036
+ def export_decoder(asr_model, args):
1037
+ logger.info("Currently, Decoder is not supported.")
1038
+
1039
+
1040
+ if __name__ == '__main__':
1041
+ torch.manual_seed(777)
1042
+ args = get_args()
1043
+ args.ln_run_on_bpu = False
1044
+ # NOTE(xcsong): XJ3 BPU only support static shapes
1045
+ assert args.chunk_size > 0
1046
+ assert args.num_decoding_left_chunks > 0
1047
+ os.system("mkdir -p " + args.output_dir)
1048
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
1049
+
1050
+ with open(args.config, 'r') as fin:
1051
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
1052
+
1053
+ model, configs = init_model(args, configs)
1054
+ model.eval()
1055
+ print(model)
1056
+
1057
+ args.feature_size = configs['input_dim']
1058
+ args.output_size = model.encoder.output_size()
1059
+ args.decoding_window = (args.chunk_size - 1) * \
1060
+ model.encoder.embed.subsampling_rate + \
1061
+ model.encoder.embed.right_context + 1
1062
+
1063
+ export_encoder(model, args)
1064
+ export_ctc(model, args)
1065
+ export_decoder(model, args)
wenet/bin/export_onnx_cpu.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Xingchen Song ([email protected])
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ import os
20
+ import copy
21
+ import sys
22
+
23
+ import torch
24
+ import yaml
25
+ import numpy as np
26
+
27
+ from wenet.utils.init_model import init_model
28
+
29
+ try:
30
+ import onnx
31
+ import onnxruntime
32
+ from onnxruntime.quantization import quantize_dynamic, QuantType
33
+ except ImportError:
34
+ print('Please install onnx and onnxruntime!')
35
+ sys.exit(1)
36
+
37
+
38
+ def get_args():
39
+ parser = argparse.ArgumentParser(description='export your script model')
40
+ parser.add_argument('--config', required=True, help='config file')
41
+ parser.add_argument('--checkpoint', required=True, help='checkpoint model')
42
+ parser.add_argument('--output_dir', required=True, help='output directory')
43
+ parser.add_argument('--chunk_size',
44
+ required=True,
45
+ type=int,
46
+ help='decoding chunk size')
47
+ parser.add_argument('--num_decoding_left_chunks',
48
+ required=True,
49
+ type=int,
50
+ help='cache chunks')
51
+ parser.add_argument('--reverse_weight',
52
+ default=0.5,
53
+ type=float,
54
+ help='reverse_weight in attention_rescoing')
55
+ args = parser.parse_args()
56
+ return args
57
+
58
+
59
+ def to_numpy(tensor):
60
+ if tensor.requires_grad:
61
+ return tensor.detach().cpu().numpy()
62
+ else:
63
+ return tensor.cpu().numpy()
64
+
65
+
66
+ def print_input_output_info(onnx_model, name, prefix="\t\t"):
67
+ input_names = [node.name for node in onnx_model.graph.input]
68
+ input_shapes = [[d.dim_value for d in node.type.tensor_type.shape.dim]
69
+ for node in onnx_model.graph.input]
70
+ output_names = [node.name for node in onnx_model.graph.output]
71
+ output_shapes = [[d.dim_value for d in node.type.tensor_type.shape.dim]
72
+ for node in onnx_model.graph.output]
73
+ print("{}{} inputs : {}".format(prefix, name, input_names))
74
+ print("{}{} input shapes : {}".format(prefix, name, input_shapes))
75
+ print("{}{} outputs: {}".format(prefix, name, output_names))
76
+ print("{}{} output shapes : {}".format(prefix, name, output_shapes))
77
+
78
+
79
+ def export_encoder(asr_model, args):
80
+ print("Stage-1: export encoder")
81
+ encoder = asr_model.encoder
82
+ encoder.forward = encoder.forward_chunk
83
+ encoder_outpath = os.path.join(args['output_dir'], 'encoder.onnx')
84
+
85
+ print("\tStage-1.1: prepare inputs for encoder")
86
+ chunk = torch.randn(
87
+ (args['batch'], args['decoding_window'], args['feature_size']))
88
+ offset = 0
89
+ # NOTE(xcsong): The uncertainty of `next_cache_start` only appears
90
+ # in the first few chunks, this is caused by dynamic att_cache shape, i,e
91
+ # (0, 0, 0, 0) for 1st chunk and (elayers, head, ?, d_k*2) for subsequent
92
+ # chunks. One way to ease the ONNX export is to keep `next_cache_start`
93
+ # as a fixed value. To do this, for the **first** chunk, if
94
+ # left_chunks > 0, we feed real cache & real mask to the model, otherwise
95
+ # fake cache & fake mask. In this way, we get:
96
+ # 1. 16/-1 mode: next_cache_start == 0 for all chunks
97
+ # 2. 16/4 mode: next_cache_start == chunk_size for all chunks
98
+ # 3. 16/0 mode: next_cache_start == chunk_size for all chunks
99
+ # 4. -1/-1 mode: next_cache_start == 0 for all chunks
100
+ # NO MORE DYNAMIC CHANGES!!
101
+ #
102
+ # NOTE(Mddct): We retain the current design for the convenience of supporting some
103
+ # inference frameworks without dynamic shapes. If you're interested in all-in-one
104
+ # model that supports different chunks please see:
105
+ # https://github.com/wenet-e2e/wenet/pull/1174
106
+
107
+ if args['left_chunks'] > 0: # 16/4
108
+ required_cache_size = args['chunk_size'] * args['left_chunks']
109
+ offset = required_cache_size
110
+ # Real cache
111
+ att_cache = torch.zeros(
112
+ (args['num_blocks'], args['head'], required_cache_size,
113
+ args['output_size'] // args['head'] * 2))
114
+ # Real mask
115
+ att_mask = torch.ones(
116
+ (args['batch'], 1, required_cache_size + args['chunk_size']),
117
+ dtype=torch.bool)
118
+ att_mask[:, :, :required_cache_size] = 0
119
+ elif args['left_chunks'] <= 0: # 16/-1, -1/-1, 16/0
120
+ required_cache_size = -1 if args['left_chunks'] < 0 else 0
121
+ # Fake cache
122
+ att_cache = torch.zeros((args['num_blocks'], args['head'], 0,
123
+ args['output_size'] // args['head'] * 2))
124
+ # Fake mask
125
+ att_mask = torch.ones((0, 0, 0), dtype=torch.bool)
126
+ cnn_cache = torch.zeros(
127
+ (args['num_blocks'], args['batch'], args['output_size'],
128
+ args['cnn_module_kernel'] - 1))
129
+ inputs = (chunk, offset, required_cache_size, att_cache, cnn_cache,
130
+ att_mask)
131
+ print("\t\tchunk.size(): {}\n".format(chunk.size()),
132
+ "\t\toffset: {}\n".format(offset),
133
+ "\t\trequired_cache: {}\n".format(required_cache_size),
134
+ "\t\tatt_cache.size(): {}\n".format(att_cache.size()),
135
+ "\t\tcnn_cache.size(): {}\n".format(cnn_cache.size()),
136
+ "\t\tatt_mask.size(): {}\n".format(att_mask.size()))
137
+
138
+ print("\tStage-1.2: torch.onnx.export")
139
+ dynamic_axes = {
140
+ 'chunk': {
141
+ 1: 'T'
142
+ },
143
+ 'att_cache': {
144
+ 2: 'T_CACHE'
145
+ },
146
+ 'att_mask': {
147
+ 2: 'T_ADD_T_CACHE'
148
+ },
149
+ 'output': {
150
+ 1: 'T'
151
+ },
152
+ 'r_att_cache': {
153
+ 2: 'T_CACHE'
154
+ },
155
+ }
156
+ # NOTE(xcsong): We keep dynamic axes even if in 16/4 mode, this is
157
+ # to avoid padding the last chunk (which usually contains less
158
+ # frames than required). For users who want static axes, just pop
159
+ # out specific axis.
160
+ # if args['chunk_size'] > 0: # 16/4, 16/-1, 16/0
161
+ # dynamic_axes.pop('chunk')
162
+ # dynamic_axes.pop('output')
163
+ # if args['left_chunks'] >= 0: # 16/4, 16/0
164
+ # # NOTE(xsong): since we feed real cache & real mask into the
165
+ # # model when left_chunks > 0, the shape of cache will never
166
+ # # be changed.
167
+ # dynamic_axes.pop('att_cache')
168
+ # dynamic_axes.pop('r_att_cache')
169
+ torch.onnx.export(encoder,
170
+ inputs,
171
+ encoder_outpath,
172
+ opset_version=13,
173
+ export_params=True,
174
+ do_constant_folding=True,
175
+ input_names=[
176
+ 'chunk', 'offset', 'required_cache_size',
177
+ 'att_cache', 'cnn_cache', 'att_mask'
178
+ ],
179
+ output_names=['output', 'r_att_cache', 'r_cnn_cache'],
180
+ dynamic_axes=dynamic_axes,
181
+ verbose=False)
182
+ onnx_encoder = onnx.load(encoder_outpath)
183
+ for (k, v) in args.items():
184
+ meta = onnx_encoder.metadata_props.add()
185
+ meta.key, meta.value = str(k), str(v)
186
+ onnx.checker.check_model(onnx_encoder)
187
+ onnx.helper.printable_graph(onnx_encoder.graph)
188
+ # NOTE(xcsong): to add those metadatas we need to reopen
189
+ # the file and resave it.
190
+ onnx.save(onnx_encoder, encoder_outpath)
191
+ print_input_output_info(onnx_encoder, "onnx_encoder")
192
+ # Dynamic quantization
193
+ model_fp32 = encoder_outpath
194
+ model_quant = os.path.join(args['output_dir'], 'encoder.quant.onnx')
195
+ quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)
196
+ print('\t\tExport onnx_encoder, done! see {}'.format(encoder_outpath))
197
+
198
+ print("\tStage-1.3: check onnx_encoder and torch_encoder")
199
+ torch_output = []
200
+ torch_chunk = copy.deepcopy(chunk)
201
+ torch_offset = copy.deepcopy(offset)
202
+ torch_required_cache_size = copy.deepcopy(required_cache_size)
203
+ torch_att_cache = copy.deepcopy(att_cache)
204
+ torch_cnn_cache = copy.deepcopy(cnn_cache)
205
+ torch_att_mask = copy.deepcopy(att_mask)
206
+ for i in range(10):
207
+ print("\t\ttorch chunk-{}: {}, offset: {}, att_cache: {},"
208
+ " cnn_cache: {}, att_mask: {}".format(
209
+ i, list(torch_chunk.size()), torch_offset,
210
+ list(torch_att_cache.size()), list(torch_cnn_cache.size()),
211
+ list(torch_att_mask.size())))
212
+ # NOTE(xsong): att_mask of the first few batches need changes if
213
+ # we use 16/4 mode.
214
+ if args['left_chunks'] > 0: # 16/4
215
+ torch_att_mask[:, :, -(args['chunk_size'] * (i + 1)):] = 1
216
+ out, torch_att_cache, torch_cnn_cache = encoder(
217
+ torch_chunk, torch_offset, torch_required_cache_size,
218
+ torch_att_cache, torch_cnn_cache, torch_att_mask)
219
+ torch_output.append(out)
220
+ torch_offset += out.size(1)
221
+ torch_output = torch.cat(torch_output, dim=1)
222
+
223
+ onnx_output = []
224
+ onnx_chunk = to_numpy(chunk)
225
+ onnx_offset = np.array((offset)).astype(np.int64)
226
+ onnx_required_cache_size = np.array((required_cache_size)).astype(np.int64)
227
+ onnx_att_cache = to_numpy(att_cache)
228
+ onnx_cnn_cache = to_numpy(cnn_cache)
229
+ onnx_att_mask = to_numpy(att_mask)
230
+ ort_session = onnxruntime.InferenceSession(
231
+ encoder_outpath, providers=['CPUExecutionProvider'])
232
+ input_names = [node.name for node in onnx_encoder.graph.input]
233
+ for i in range(10):
234
+ print("\t\tonnx chunk-{}: {}, offset: {}, att_cache: {},"
235
+ " cnn_cache: {}, att_mask: {}".format(i, onnx_chunk.shape,
236
+ onnx_offset,
237
+ onnx_att_cache.shape,
238
+ onnx_cnn_cache.shape,
239
+ onnx_att_mask.shape))
240
+ # NOTE(xsong): att_mask of the first few batches need changes if
241
+ # we use 16/4 mode.
242
+ if args['left_chunks'] > 0: # 16/4
243
+ onnx_att_mask[:, :, -(args['chunk_size'] * (i + 1)):] = 1
244
+ ort_inputs = {
245
+ 'chunk': onnx_chunk,
246
+ 'offset': onnx_offset,
247
+ 'required_cache_size': onnx_required_cache_size,
248
+ 'att_cache': onnx_att_cache,
249
+ 'cnn_cache': onnx_cnn_cache,
250
+ 'att_mask': onnx_att_mask
251
+ }
252
+ # NOTE(xcsong): If we use 16/-1, -1/-1 or 16/0 mode, `next_cache_start`
253
+ # will be hardcoded to 0 or chunk_size by ONNX, thus
254
+ # required_cache_size and att_mask are no more needed and they will
255
+ # be removed by ONNX automatically.
256
+ for k in list(ort_inputs):
257
+ if k not in input_names:
258
+ ort_inputs.pop(k)
259
+ ort_outs = ort_session.run(None, ort_inputs)
260
+ onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2]
261
+ onnx_output.append(ort_outs[0])
262
+ onnx_offset += ort_outs[0].shape[1]
263
+ onnx_output = np.concatenate(onnx_output, axis=1)
264
+
265
+ np.testing.assert_allclose(to_numpy(torch_output),
266
+ onnx_output,
267
+ rtol=1e-03,
268
+ atol=1e-05)
269
+ meta = ort_session.get_modelmeta()
270
+ print("\t\tcustom_metadata_map={}".format(meta.custom_metadata_map))
271
+ print("\t\tCheck onnx_encoder, pass!")
272
+
273
+
274
+ def export_ctc(asr_model, args):
275
+ print("Stage-2: export ctc")
276
+ ctc = asr_model.ctc
277
+ ctc.forward = ctc.log_softmax
278
+ ctc_outpath = os.path.join(args['output_dir'], 'ctc.onnx')
279
+
280
+ print("\tStage-2.1: prepare inputs for ctc")
281
+ hidden = torch.randn(
282
+ (args['batch'], args['chunk_size'] if args['chunk_size'] > 0 else 16,
283
+ args['output_size']))
284
+
285
+ print("\tStage-2.2: torch.onnx.export")
286
+ dynamic_axes = {'hidden': {1: 'T'}, 'probs': {1: 'T'}}
287
+ torch.onnx.export(ctc,
288
+ hidden,
289
+ ctc_outpath,
290
+ opset_version=13,
291
+ export_params=True,
292
+ do_constant_folding=True,
293
+ input_names=['hidden'],
294
+ output_names=['probs'],
295
+ dynamic_axes=dynamic_axes,
296
+ verbose=False)
297
+ onnx_ctc = onnx.load(ctc_outpath)
298
+ for (k, v) in args.items():
299
+ meta = onnx_ctc.metadata_props.add()
300
+ meta.key, meta.value = str(k), str(v)
301
+ onnx.checker.check_model(onnx_ctc)
302
+ onnx.helper.printable_graph(onnx_ctc.graph)
303
+ onnx.save(onnx_ctc, ctc_outpath)
304
+ print_input_output_info(onnx_ctc, "onnx_ctc")
305
+ # Dynamic quantization
306
+ model_fp32 = ctc_outpath
307
+ model_quant = os.path.join(args['output_dir'], 'ctc.quant.onnx')
308
+ quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)
309
+ print('\t\tExport onnx_ctc, done! see {}'.format(ctc_outpath))
310
+
311
+ print("\tStage-2.3: check onnx_ctc and torch_ctc")
312
+ torch_output = ctc(hidden)
313
+ ort_session = onnxruntime.InferenceSession(
314
+ ctc_outpath, providers=['CPUExecutionProvider'])
315
+ onnx_output = ort_session.run(None, {'hidden': to_numpy(hidden)})
316
+
317
+ np.testing.assert_allclose(to_numpy(torch_output),
318
+ onnx_output[0],
319
+ rtol=1e-03,
320
+ atol=1e-05)
321
+ print("\t\tCheck onnx_ctc, pass!")
322
+
323
+
324
+ def export_decoder(asr_model, args):
325
+ print("Stage-3: export decoder")
326
+ decoder = asr_model
327
+ # NOTE(lzhin): parameters of encoder will be automatically removed
328
+ # since they are not used during rescoring.
329
+ decoder.forward = decoder.forward_attention_decoder
330
+ decoder_outpath = os.path.join(args['output_dir'], 'decoder.onnx')
331
+
332
+ print("\tStage-3.1: prepare inputs for decoder")
333
+ # hardcode time->200 nbest->10 len->20, they are dynamic axes.
334
+ encoder_out = torch.randn((1, 200, args['output_size']))
335
+ hyps = torch.randint(low=0, high=args['vocab_size'], size=[10, 20])
336
+ hyps[:, 0] = args['vocab_size'] - 1 # <sos>
337
+ hyps_lens = torch.randint(low=15, high=21, size=[10])
338
+
339
+ print("\tStage-3.2: torch.onnx.export")
340
+ dynamic_axes = {
341
+ 'hyps': {
342
+ 0: 'NBEST',
343
+ 1: 'L'
344
+ },
345
+ 'hyps_lens': {
346
+ 0: 'NBEST'
347
+ },
348
+ 'encoder_out': {
349
+ 1: 'T'
350
+ },
351
+ 'score': {
352
+ 0: 'NBEST',
353
+ 1: 'L'
354
+ },
355
+ 'r_score': {
356
+ 0: 'NBEST',
357
+ 1: 'L'
358
+ }
359
+ }
360
+ inputs = (hyps, hyps_lens, encoder_out, args['reverse_weight'])
361
+ torch.onnx.export(
362
+ decoder,
363
+ inputs,
364
+ decoder_outpath,
365
+ opset_version=13,
366
+ export_params=True,
367
+ do_constant_folding=True,
368
+ input_names=['hyps', 'hyps_lens', 'encoder_out', 'reverse_weight'],
369
+ output_names=['score', 'r_score'],
370
+ dynamic_axes=dynamic_axes,
371
+ verbose=False)
372
+ onnx_decoder = onnx.load(decoder_outpath)
373
+ for (k, v) in args.items():
374
+ meta = onnx_decoder.metadata_props.add()
375
+ meta.key, meta.value = str(k), str(v)
376
+ onnx.checker.check_model(onnx_decoder)
377
+ onnx.helper.printable_graph(onnx_decoder.graph)
378
+ onnx.save(onnx_decoder, decoder_outpath)
379
+ print_input_output_info(onnx_decoder, "onnx_decoder")
380
+ model_fp32 = decoder_outpath
381
+ model_quant = os.path.join(args['output_dir'], 'decoder.quant.onnx')
382
+ quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)
383
+ print('\t\tExport onnx_decoder, done! see {}'.format(decoder_outpath))
384
+
385
+ print("\tStage-3.3: check onnx_decoder and torch_decoder")
386
+ torch_score, torch_r_score = decoder(hyps, hyps_lens, encoder_out,
387
+ args['reverse_weight'])
388
+ ort_session = onnxruntime.InferenceSession(
389
+ decoder_outpath, providers=['CPUExecutionProvider'])
390
+ input_names = [node.name for node in onnx_decoder.graph.input]
391
+ ort_inputs = {
392
+ 'hyps': to_numpy(hyps),
393
+ 'hyps_lens': to_numpy(hyps_lens),
394
+ 'encoder_out': to_numpy(encoder_out),
395
+ 'reverse_weight': np.array((args['reverse_weight'])),
396
+ }
397
+ for k in list(ort_inputs):
398
+ if k not in input_names:
399
+ ort_inputs.pop(k)
400
+ onnx_output = ort_session.run(None, ort_inputs)
401
+
402
+ np.testing.assert_allclose(to_numpy(torch_score),
403
+ onnx_output[0],
404
+ rtol=1e-03,
405
+ atol=1e-05)
406
+ if args['is_bidirectional_decoder'] and args['reverse_weight'] > 0.0:
407
+ np.testing.assert_allclose(to_numpy(torch_r_score),
408
+ onnx_output[1],
409
+ rtol=1e-03,
410
+ atol=1e-05)
411
+ print("\t\tCheck onnx_decoder, pass!")
412
+
413
+
414
+ def main():
415
+ torch.manual_seed(777)
416
+ args = get_args()
417
+ logging.basicConfig(level=logging.DEBUG,
418
+ format='%(asctime)s %(levelname)s %(message)s')
419
+ output_dir = args.output_dir
420
+ os.system("mkdir -p " + output_dir)
421
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
422
+
423
+ with open(args.config, 'r') as fin:
424
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
425
+
426
+ model, configs = init_model(args, configs)
427
+ model.eval()
428
+ print(model)
429
+
430
+ arguments = {}
431
+ arguments['output_dir'] = output_dir
432
+ arguments['batch'] = 1
433
+ arguments['chunk_size'] = args.chunk_size
434
+ arguments['left_chunks'] = args.num_decoding_left_chunks
435
+ arguments['reverse_weight'] = args.reverse_weight
436
+ arguments['output_size'] = configs['encoder_conf']['output_size']
437
+ arguments['num_blocks'] = configs['encoder_conf']['num_blocks']
438
+ arguments['cnn_module_kernel'] = configs['encoder_conf'].get(
439
+ 'cnn_module_kernel', 1)
440
+ arguments['head'] = configs['encoder_conf']['attention_heads']
441
+ arguments['feature_size'] = configs['input_dim']
442
+ arguments['vocab_size'] = configs['output_dim']
443
+ # NOTE(xcsong): if chunk_size == -1, hardcode to 67
444
+ arguments['decoding_window'] = (args.chunk_size - 1) * \
445
+ model.encoder.embed.subsampling_rate + \
446
+ model.encoder.embed.right_context + 1 if args.chunk_size > 0 else 67
447
+ arguments['encoder'] = configs['encoder']
448
+ arguments['decoder'] = configs['decoder']
449
+ arguments['subsampling_rate'] = model.subsampling_rate()
450
+ arguments['right_context'] = model.right_context()
451
+ arguments['sos_symbol'] = model.sos_symbol()
452
+ arguments['eos_symbol'] = model.eos_symbol()
453
+ arguments['is_bidirectional_decoder'] = 1 \
454
+ if model.is_bidirectional_decoder() else 0
455
+
456
+ # NOTE(xcsong): Please note that -1/-1 means non-streaming model! It is
457
+ # not a [16/4 16/-1 16/0] all-in-one model and it should not be used in
458
+ # streaming mode (i.e., setting chunk_size=16 in `decoder_main`). If you
459
+ # want to use 16/-1 or any other streaming mode in `decoder_main`,
460
+ # please export onnx in the same config.
461
+ if arguments['left_chunks'] > 0:
462
+ assert arguments['chunk_size'] > 0 # -1/4 not supported
463
+
464
+ export_encoder(model, arguments)
465
+ export_ctc(model, arguments)
466
+ export_decoder(model, arguments)
467
+
468
+
469
+ if __name__ == '__main__':
470
+ main()
wenet/bin/export_onnx_gpu.py ADDED
@@ -0,0 +1,1263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import os
19
+ import sys
20
+
21
+ import torch
22
+ import yaml
23
+ import logging
24
+
25
+ import torch.nn.functional as F
26
+ from wenet.transformer.ctc import CTC
27
+ from wenet.transformer.decoder import TransformerDecoder
28
+ from wenet.transformer.encoder import BaseEncoder
29
+ from wenet.utils.init_model import init_model
30
+ from wenet.utils.mask import make_pad_mask
31
+
32
+ try:
33
+ import onnxruntime
34
+ except ImportError:
35
+ print("Please install onnxruntime-gpu!")
36
+ sys.exit(1)
37
+
38
+ logger = logging.getLogger(__file__)
39
+ logger.setLevel(logging.INFO)
40
+
41
+
42
+ class Encoder(torch.nn.Module):
43
+
44
+ def __init__(self, encoder: BaseEncoder, ctc: CTC, beam_size: int = 10):
45
+ super().__init__()
46
+ self.encoder = encoder
47
+ self.ctc = ctc
48
+ self.beam_size = beam_size
49
+
50
+ def forward(
51
+ self,
52
+ speech: torch.Tensor,
53
+ speech_lengths: torch.Tensor,
54
+ ):
55
+ """Encoder
56
+ Args:
57
+ speech: (Batch, Length, ...)
58
+ speech_lengths: (Batch, )
59
+ Returns:
60
+ encoder_out: B x T x F
61
+ encoder_out_lens: B
62
+ ctc_log_probs: B x T x V
63
+ beam_log_probs: B x T x beam_size
64
+ beam_log_probs_idx: B x T x beam_size
65
+ """
66
+ encoder_out, encoder_mask = self.encoder(speech, speech_lengths, -1,
67
+ -1)
68
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
69
+ ctc_log_probs = self.ctc.log_softmax(encoder_out)
70
+ encoder_out_lens = encoder_out_lens.int()
71
+ beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs,
72
+ self.beam_size,
73
+ dim=2)
74
+ return (
75
+ encoder_out,
76
+ encoder_out_lens,
77
+ ctc_log_probs,
78
+ beam_log_probs,
79
+ beam_log_probs_idx,
80
+ )
81
+
82
+
83
+ class StreamingEncoder(torch.nn.Module):
84
+
85
+ def __init__(
86
+ self,
87
+ model,
88
+ required_cache_size,
89
+ beam_size,
90
+ transformer=False,
91
+ return_ctc_logprobs=False,
92
+ ):
93
+ super().__init__()
94
+ self.ctc = model.ctc
95
+ self.subsampling_rate = model.encoder.embed.subsampling_rate
96
+ self.embed = model.encoder.embed
97
+ self.global_cmvn = model.encoder.global_cmvn
98
+ self.required_cache_size = required_cache_size
99
+ self.beam_size = beam_size
100
+ self.encoder = model.encoder
101
+ self.transformer = transformer
102
+ self.return_ctc_logprobs = return_ctc_logprobs
103
+
104
+ def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache,
105
+ cache_mask):
106
+ """Streaming Encoder
107
+ Args:
108
+ xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
109
+ where `time == (chunk_size - 1) * subsample_rate + \
110
+ subsample.right_context + 1`
111
+ offset (torch.Tensor): offset with shape (b, 1)
112
+ 1 is retained for triton deployment
113
+ required_cache_size (int): cache size required for next chunk
114
+ compuation
115
+ > 0: actual cache size
116
+ <= 0: not allowed in streaming gpu encoder `
117
+ att_cache (torch.Tensor): cache tensor for KEY & VALUE in
118
+ transformer/conformer attention, with shape
119
+ (b, elayers, head, cache_t1, d_k * 2), where
120
+ `head * d_k == hidden-dim` and
121
+ `cache_t1 == chunk_size * num_decoding_left_chunks`.
122
+ cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
123
+ (b, elayers, b, hidden-dim, cache_t2), where
124
+ `cache_t2 == cnn.lorder - 1`
125
+ cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
126
+ in a batch of request, each request may have different
127
+ history cache. Cache mask is used to indidate the effective
128
+ cache for each request
129
+ Returns:
130
+ torch.Tensor: log probabilities of ctc output and cutoff by beam size
131
+ with shape (b, chunk_size, beam)
132
+ torch.Tensor: index of top beam size probabilities for each timestep
133
+ with shape (b, chunk_size, beam)
134
+ torch.Tensor: output of current input xs,
135
+ with shape (b, chunk_size, hidden-dim).
136
+ torch.Tensor: new attention cache required for next chunk, with
137
+ same shape (b, elayers, head, cache_t1, d_k * 2)
138
+ as the original att_cache
139
+ torch.Tensor: new conformer cnn cache required for next chunk, with
140
+ same shape as the original cnn_cache.
141
+ torch.Tensor: new cache mask, with same shape as the original
142
+ cache mask
143
+ """
144
+ offset = offset.squeeze(1)
145
+ T = chunk_xs.size(1)
146
+ chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1)
147
+ # B X 1 X T
148
+ chunk_mask = chunk_mask.to(chunk_xs.dtype)
149
+ # transpose batch & num_layers dim
150
+ att_cache = torch.transpose(att_cache, 0, 1)
151
+ cnn_cache = torch.transpose(cnn_cache, 0, 1)
152
+
153
+ # rewrite encoder.forward_chunk
154
+ # <---------forward_chunk START--------->
155
+ xs = self.global_cmvn(chunk_xs)
156
+ # chunk mask is important for batch inferencing since
157
+ # different sequence in a batch has different length
158
+ xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset)
159
+ cache_size = att_cache.size(3) # required cache size
160
+ masks = torch.cat((cache_mask, chunk_mask), dim=2)
161
+ index = offset - cache_size
162
+
163
+ pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1))
164
+ pos_emb = pos_emb.to(dtype=xs.dtype)
165
+
166
+ next_cache_start = -self.required_cache_size
167
+ r_cache_mask = masks[:, :, next_cache_start:]
168
+
169
+ r_att_cache = []
170
+ r_cnn_cache = []
171
+ for i, layer in enumerate(self.encoder.encoders):
172
+ xs, _, new_att_cache, new_cnn_cache = layer(
173
+ xs,
174
+ masks,
175
+ pos_emb,
176
+ att_cache=att_cache[i],
177
+ cnn_cache=cnn_cache[i],
178
+ )
179
+ # shape(new_att_cache) is (B, head, attention_key_size, d_k * 2),
180
+ # shape(new_cnn_cache) is (B, hidden-dim, cache_t2)
181
+ r_att_cache.append(
182
+ new_att_cache[:, :, next_cache_start:, :].unsqueeze(1))
183
+ if not self.transformer:
184
+ r_cnn_cache.append(new_cnn_cache.unsqueeze(1))
185
+ if self.encoder.normalize_before:
186
+ chunk_out = self.encoder.after_norm(xs)
187
+ else:
188
+ chunk_out = xs
189
+
190
+ r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx
191
+ if not self.transformer:
192
+ r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers
193
+
194
+ # <---------forward_chunk END--------->
195
+
196
+ log_ctc_probs = self.ctc.log_softmax(chunk_out)
197
+ log_probs, log_probs_idx = torch.topk(log_ctc_probs,
198
+ self.beam_size,
199
+ dim=2)
200
+ log_probs = log_probs.to(chunk_xs.dtype)
201
+
202
+ r_offset = offset + chunk_out.shape[1]
203
+ # the below ops not supported in Tensorrt
204
+ # chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
205
+ # rounding_mode='floor')
206
+ chunk_out_lens = chunk_lens // self.subsampling_rate
207
+ r_offset = r_offset.unsqueeze(1)
208
+ if self.return_ctc_logprobs:
209
+ return (
210
+ log_ctc_probs,
211
+ chunk_out,
212
+ chunk_out_lens,
213
+ r_offset,
214
+ r_att_cache,
215
+ r_cnn_cache,
216
+ r_cache_mask,
217
+ )
218
+ else:
219
+ return (
220
+ log_probs,
221
+ log_probs_idx,
222
+ chunk_out,
223
+ chunk_out_lens,
224
+ r_offset,
225
+ r_att_cache,
226
+ r_cnn_cache,
227
+ r_cache_mask,
228
+ )
229
+
230
+
231
+ class StreamingSqueezeformerEncoder(torch.nn.Module):
232
+
233
+ def __init__(self, model, required_cache_size, beam_size):
234
+ super().__init__()
235
+ self.ctc = model.ctc
236
+ self.subsampling_rate = model.encoder.embed.subsampling_rate
237
+ self.embed = model.encoder.embed
238
+ self.global_cmvn = model.encoder.global_cmvn
239
+ self.required_cache_size = required_cache_size
240
+ self.beam_size = beam_size
241
+ self.encoder = model.encoder
242
+ self.reduce_idx = model.encoder.reduce_idx
243
+ self.recover_idx = model.encoder.recover_idx
244
+ if self.reduce_idx is None:
245
+ self.time_reduce = None
246
+ else:
247
+ if self.recover_idx is None:
248
+ self.time_reduce = "normal" # no recovery at the end
249
+ else:
250
+ self.time_reduce = "recover" # recovery at the end
251
+ assert len(self.reduce_idx) == len(self.recover_idx)
252
+
253
+ def calculate_downsampling_factor(self, i: int) -> int:
254
+ if self.reduce_idx is None:
255
+ return 1
256
+ else:
257
+ reduce_exp, recover_exp = 0, 0
258
+ for exp, rd_idx in enumerate(self.reduce_idx):
259
+ if i >= rd_idx:
260
+ reduce_exp = exp + 1
261
+ if self.recover_idx is not None:
262
+ for exp, rc_idx in enumerate(self.recover_idx):
263
+ if i >= rc_idx:
264
+ recover_exp = exp + 1
265
+ return int(2**(reduce_exp - recover_exp))
266
+
267
+ def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache,
268
+ cache_mask):
269
+ """Streaming Encoder
270
+ Args:
271
+ xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
272
+ where `time == (chunk_size - 1) * subsample_rate + \
273
+ subsample.right_context + 1`
274
+ offset (torch.Tensor): offset with shape (b, 1)
275
+ 1 is retained for triton deployment
276
+ required_cache_size (int): cache size required for next chunk
277
+ compuation
278
+ > 0: actual cache size
279
+ <= 0: not allowed in streaming gpu encoder `
280
+ att_cache (torch.Tensor): cache tensor for KEY & VALUE in
281
+ transformer/conformer attention, with shape
282
+ (b, elayers, head, cache_t1, d_k * 2), where
283
+ `head * d_k == hidden-dim` and
284
+ `cache_t1 == chunk_size * num_decoding_left_chunks`.
285
+ cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
286
+ (b, elayers, b, hidden-dim, cache_t2), where
287
+ `cache_t2 == cnn.lorder - 1`
288
+ cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
289
+ in a batch of request, each request may have different
290
+ history cache. Cache mask is used to indidate the effective
291
+ cache for each request
292
+ Returns:
293
+ torch.Tensor: log probabilities of ctc output and cutoff by beam size
294
+ with shape (b, chunk_size, beam)
295
+ torch.Tensor: index of top beam size probabilities for each timestep
296
+ with shape (b, chunk_size, beam)
297
+ torch.Tensor: output of current input xs,
298
+ with shape (b, chunk_size, hidden-dim).
299
+ torch.Tensor: new attention cache required for next chunk, with
300
+ same shape (b, elayers, head, cache_t1, d_k * 2)
301
+ as the original att_cache
302
+ torch.Tensor: new conformer cnn cache required for next chunk, with
303
+ same shape as the original cnn_cache.
304
+ torch.Tensor: new cache mask, with same shape as the original
305
+ cache mask
306
+ """
307
+ offset = offset.squeeze(1)
308
+ T = chunk_xs.size(1)
309
+ chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1)
310
+ # B X 1 X T
311
+ chunk_mask = chunk_mask.to(chunk_xs.dtype)
312
+ # transpose batch & num_layers dim
313
+ att_cache = torch.transpose(att_cache, 0, 1)
314
+ cnn_cache = torch.transpose(cnn_cache, 0, 1)
315
+
316
+ # rewrite encoder.forward_chunk
317
+ # <---------forward_chunk START--------->
318
+ xs = self.global_cmvn(chunk_xs)
319
+ # chunk mask is important for batch inferencing since
320
+ # different sequence in a batch has different length
321
+ xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset)
322
+ elayers, cache_size = att_cache.size(0), att_cache.size(3)
323
+ att_mask = torch.cat((cache_mask, chunk_mask), dim=2)
324
+ index = offset - cache_size
325
+
326
+ pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1))
327
+ pos_emb = pos_emb.to(dtype=xs.dtype)
328
+
329
+ next_cache_start = -self.required_cache_size
330
+ r_cache_mask = att_mask[:, :, next_cache_start:]
331
+
332
+ r_att_cache = []
333
+ r_cnn_cache = []
334
+ mask_pad = torch.ones(1,
335
+ xs.size(1),
336
+ device=xs.device,
337
+ dtype=torch.bool)
338
+ mask_pad = mask_pad.unsqueeze(1)
339
+ max_att_len: int = 0
340
+ recover_activations: List[Tuple[torch.Tensor, torch.Tensor,
341
+ torch.Tensor, torch.Tensor]] = []
342
+ index = 0
343
+ xs_lens = torch.tensor([xs.size(1)], device=xs.device, dtype=torch.int)
344
+ xs = self.encoder.preln(xs)
345
+ for i, layer in enumerate(self.encoder.encoders):
346
+ if self.reduce_idx is not None:
347
+ if self.time_reduce is not None and i in self.reduce_idx:
348
+ recover_activations.append(
349
+ (xs, att_mask, pos_emb, mask_pad))
350
+ (
351
+ xs,
352
+ xs_lens,
353
+ att_mask,
354
+ mask_pad,
355
+ ) = self.encoder.time_reduction_layer(
356
+ xs, xs_lens, att_mask, mask_pad)
357
+ pos_emb = pos_emb[:, ::2, :]
358
+ if self.encoder.pos_enc_layer_type == "rel_pos_repaired":
359
+ pos_emb = pos_emb[:, :xs.size(1) * 2 - 1, :]
360
+ index += 1
361
+
362
+ if self.recover_idx is not None:
363
+ if self.time_reduce == "recover" and i in self.recover_idx:
364
+ index -= 1
365
+ (
366
+ recover_tensor,
367
+ recover_att_mask,
368
+ recover_pos_emb,
369
+ recover_mask_pad,
370
+ ) = recover_activations[index]
371
+ # recover output length for ctc decode
372
+ xs = xs.unsqueeze(2).repeat(1, 1, 2, 1).flatten(1, 2)
373
+ xs = self.encoder.time_recover_layer(xs)
374
+ recoverd_t = recover_tensor.size(1)
375
+ xs = recover_tensor + xs[:, :recoverd_t, :].contiguous()
376
+ att_mask = recover_att_mask
377
+ pos_emb = recover_pos_emb
378
+ mask_pad = recover_mask_pad
379
+
380
+ factor = self.calculate_downsampling_factor(i)
381
+
382
+ xs, _, new_att_cache, new_cnn_cache = layer(
383
+ xs,
384
+ att_mask,
385
+ pos_emb,
386
+ att_cache=att_cache[i][:, :, ::factor, :]
387
+ [:, :, :pos_emb.size(1) - xs.size(1), :]
388
+ if elayers > 0 else att_cache[:, :, ::factor, :],
389
+ cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache,
390
+ )
391
+ cached_att = new_att_cache[:, :, next_cache_start // factor:, :]
392
+ cached_cnn = new_cnn_cache.unsqueeze(1)
393
+ cached_att = (cached_att.unsqueeze(3).repeat(1, 1, 1, factor,
394
+ 1).flatten(2, 3))
395
+ if i == 0:
396
+ # record length for the first block as max length
397
+ max_att_len = cached_att.size(2)
398
+ r_att_cache.append(cached_att[:, :, :max_att_len, :].unsqueeze(1))
399
+ r_cnn_cache.append(cached_cnn)
400
+
401
+ chunk_out = xs
402
+ r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx
403
+ r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers
404
+
405
+ # <---------forward_chunk END--------->
406
+
407
+ log_ctc_probs = self.ctc.log_softmax(chunk_out)
408
+ log_probs, log_probs_idx = torch.topk(log_ctc_probs,
409
+ self.beam_size,
410
+ dim=2)
411
+ log_probs = log_probs.to(chunk_xs.dtype)
412
+
413
+ r_offset = offset + chunk_out.shape[1]
414
+ # the below ops not supported in Tensorrt
415
+ # chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
416
+ # rounding_mode='floor')
417
+ chunk_out_lens = chunk_lens // self.subsampling_rate
418
+ r_offset = r_offset.unsqueeze(1)
419
+
420
+ return (
421
+ log_probs,
422
+ log_probs_idx,
423
+ chunk_out,
424
+ chunk_out_lens,
425
+ r_offset,
426
+ r_att_cache,
427
+ r_cnn_cache,
428
+ r_cache_mask,
429
+ )
430
+
431
+
432
+ class StreamingEfficientConformerEncoder(torch.nn.Module):
433
+
434
+ def __init__(self, model, required_cache_size, beam_size):
435
+ super().__init__()
436
+ self.ctc = model.ctc
437
+ self.subsampling_rate = model.encoder.embed.subsampling_rate
438
+ self.embed = model.encoder.embed
439
+ self.global_cmvn = model.encoder.global_cmvn
440
+ self.required_cache_size = required_cache_size
441
+ self.beam_size = beam_size
442
+ self.encoder = model.encoder
443
+
444
+ # Efficient Conformer
445
+ self.stride_layer_idx = model.encoder.stride_layer_idx
446
+ self.stride = model.encoder.stride
447
+ self.num_blocks = model.encoder.num_blocks
448
+ self.cnn_module_kernel = model.encoder.cnn_module_kernel
449
+
450
+ def calculate_downsampling_factor(self, i: int) -> int:
451
+ factor = 1
452
+ for idx, stride_idx in enumerate(self.stride_layer_idx):
453
+ if i > stride_idx:
454
+ factor *= self.stride[idx]
455
+ return factor
456
+
457
+ def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache,
458
+ cache_mask):
459
+ """Streaming Encoder
460
+ Args:
461
+ chunk_xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
462
+ where `time == (chunk_size - 1) * subsample_rate + \
463
+ subsample.right_context + 1`
464
+ chunk_lens (torch.Tensor):
465
+ offset (torch.Tensor): offset with shape (b, 1)
466
+ 1 is retained for triton deployment
467
+ att_cache (torch.Tensor): cache tensor for KEY & VALUE in
468
+ transformer/conformer attention, with shape
469
+ (b, elayers, head, cache_t1, d_k * 2), where
470
+ `head * d_k == hidden-dim` and
471
+ `cache_t1 == chunk_size * num_decoding_left_chunks`.
472
+ cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
473
+ (b, elayers, hidden-dim, cache_t2), where
474
+ `cache_t2 == cnn.lorder - 1`
475
+ cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
476
+ in a batch of request, each request may have different
477
+ history cache. Cache mask is used to indidate the effective
478
+ cache for each request
479
+ Returns:
480
+ torch.Tensor: log probabilities of ctc output and cutoff by beam size
481
+ with shape (b, chunk_size, beam)
482
+ torch.Tensor: index of top beam size probabilities for each timestep
483
+ with shape (b, chunk_size, beam)
484
+ torch.Tensor: output of current input xs,
485
+ with shape (b, chunk_size, hidden-dim).
486
+ torch.Tensor: new attention cache required for next chunk, with
487
+ same shape (b, elayers, head, cache_t1, d_k * 2)
488
+ as the original att_cache
489
+ torch.Tensor: new conformer cnn cache required for next chunk, with
490
+ same shape as the original cnn_cache.
491
+ torch.Tensor: new cache mask, with same shape as the original
492
+ cache mask
493
+ """
494
+ offset = offset.squeeze(1) # (b, )
495
+ offset *= self.calculate_downsampling_factor(self.num_blocks + 1)
496
+
497
+ T = chunk_xs.size(1)
498
+ chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1) # (b, 1, T)
499
+ # B X 1 X T
500
+ chunk_mask = chunk_mask.to(chunk_xs.dtype)
501
+ # transpose batch & num_layers dim
502
+ # Shape(att_cache): (elayers, b, head, cache_t1, d_k * 2)
503
+ # Shape(cnn_cache): (elayers, b, outsize, cnn_kernel)
504
+ att_cache = torch.transpose(att_cache, 0, 1)
505
+ cnn_cache = torch.transpose(cnn_cache, 0, 1)
506
+
507
+ # rewrite encoder.forward_chunk
508
+ # <---------forward_chunk START--------->
509
+ xs = self.global_cmvn(chunk_xs)
510
+ # chunk mask is important for batch inferencing since
511
+ # different sequence in a batch has different length
512
+ xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset)
513
+ cache_size = att_cache.size(3) # required cache size
514
+ masks = torch.cat((cache_mask, chunk_mask), dim=2)
515
+ att_mask = torch.cat((cache_mask, chunk_mask), dim=2)
516
+ index = offset - cache_size
517
+
518
+ pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1))
519
+ pos_emb = pos_emb.to(dtype=xs.dtype)
520
+
521
+ next_cache_start = -self.required_cache_size
522
+ r_cache_mask = masks[:, :, next_cache_start:]
523
+
524
+ r_att_cache = []
525
+ r_cnn_cache = []
526
+ mask_pad = chunk_mask.to(torch.bool)
527
+ max_att_len, max_cnn_len = (
528
+ 0,
529
+ 0,
530
+ ) # for repeat_interleave of new_att_cache
531
+ for i, layer in enumerate(self.encoder.encoders):
532
+ factor = self.calculate_downsampling_factor(i)
533
+ # NOTE(xcsong): Before layer.forward
534
+ # shape(att_cache[i:i + 1]) is (b, head, cache_t1, d_k * 2),
535
+ # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
536
+ # shape(new_att_cache) = [ batch, head, time2, outdim//head * 2 ]
537
+ att_cache_trunc = 0
538
+ if xs.size(1) + att_cache.size(3) / factor > pos_emb.size(1):
539
+ # The time step is not divisible by the downsampling multiple
540
+ # We propose to double the chunk_size.
541
+ att_cache_trunc = (xs.size(1) + att_cache.size(3) // factor -
542
+ pos_emb.size(1) + 1)
543
+ xs, _, new_att_cache, new_cnn_cache = layer(
544
+ xs,
545
+ att_mask,
546
+ pos_emb,
547
+ mask_pad=mask_pad,
548
+ att_cache=att_cache[i][:, :, ::factor, :][:, :,
549
+ att_cache_trunc:, :],
550
+ cnn_cache=cnn_cache[i, :, :, :]
551
+ if cnn_cache.size(0) > 0 else cnn_cache,
552
+ )
553
+
554
+ if i in self.stride_layer_idx:
555
+ # compute time dimension for next block
556
+ efficient_index = self.stride_layer_idx.index(i)
557
+ att_mask = att_mask[:, ::self.stride[efficient_index], ::self.
558
+ stride[efficient_index], ]
559
+ mask_pad = mask_pad[:, ::self.stride[efficient_index], ::self.
560
+ stride[efficient_index], ]
561
+ pos_emb = pos_emb[:, ::self.stride[efficient_index], :]
562
+
563
+ # shape(new_att_cache) = [batch, head, time2, outdim]
564
+ new_att_cache = new_att_cache[:, :, next_cache_start // factor:, :]
565
+ # shape(new_cnn_cache) = [batch, 1, outdim, cache_t2]
566
+ new_cnn_cache = new_cnn_cache.unsqueeze(1) # shape(1):layerID
567
+
568
+ # use repeat_interleave to new_att_cache
569
+ # new_att_cache = new_att_cache.repeat_interleave(repeats=factor, dim=2)
570
+ new_att_cache = (new_att_cache.unsqueeze(3).repeat(
571
+ 1, 1, 1, factor, 1).flatten(2, 3))
572
+ # padding new_cnn_cache to cnn.lorder for casual convolution
573
+ new_cnn_cache = F.pad(
574
+ new_cnn_cache,
575
+ (self.cnn_module_kernel - 1 - new_cnn_cache.size(3), 0),
576
+ )
577
+
578
+ if i == 0:
579
+ # record length for the first block as max length
580
+ max_att_len = new_att_cache.size(2)
581
+ max_cnn_len = new_cnn_cache.size(3)
582
+
583
+ # update real shape of att_cache and cnn_cache
584
+ r_att_cache.append(new_att_cache[:, :,
585
+ -max_att_len:, :].unsqueeze(1))
586
+ r_cnn_cache.append(new_cnn_cache[:, :, :, -max_cnn_len:])
587
+
588
+ if self.encoder.normalize_before:
589
+ chunk_out = self.encoder.after_norm(xs)
590
+ else:
591
+ chunk_out = xs
592
+
593
+ # shape of r_att_cache: (b, elayers, head, time2, outdim)
594
+ r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx
595
+ # shape of r_cnn_cache: (b, elayers, outdim, cache_t2)
596
+ r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers
597
+
598
+ # <---------forward_chunk END--------->
599
+
600
+ log_ctc_probs = self.ctc.log_softmax(chunk_out)
601
+ log_probs, log_probs_idx = torch.topk(log_ctc_probs,
602
+ self.beam_size,
603
+ dim=2)
604
+ log_probs = log_probs.to(chunk_xs.dtype)
605
+
606
+ r_offset = offset + chunk_out.shape[1]
607
+ # the below ops not supported in Tensorrt
608
+ # chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
609
+ # rounding_mode='floor')
610
+ chunk_out_lens = (
611
+ chunk_lens // self.subsampling_rate //
612
+ self.calculate_downsampling_factor(self.num_blocks + 1))
613
+ chunk_out_lens += 1
614
+ r_offset = r_offset.unsqueeze(1)
615
+
616
+ return (
617
+ log_probs,
618
+ log_probs_idx,
619
+ chunk_out,
620
+ chunk_out_lens,
621
+ r_offset,
622
+ r_att_cache,
623
+ r_cnn_cache,
624
+ r_cache_mask,
625
+ )
626
+
627
+
628
+ class Decoder(torch.nn.Module):
629
+
630
+ def __init__(
631
+ self,
632
+ decoder: TransformerDecoder,
633
+ ctc_weight: float = 0.5,
634
+ reverse_weight: float = 0.0,
635
+ beam_size: int = 10,
636
+ decoder_fastertransformer: bool = False,
637
+ ):
638
+ super().__init__()
639
+ self.decoder = decoder
640
+ self.ctc_weight = ctc_weight
641
+ self.reverse_weight = reverse_weight
642
+ self.beam_size = beam_size
643
+ self.decoder_fastertransformer = decoder_fastertransformer
644
+
645
+ def forward(
646
+ self,
647
+ encoder_out: torch.Tensor,
648
+ encoder_lens: torch.Tensor,
649
+ hyps_pad_sos_eos: torch.Tensor,
650
+ hyps_lens_sos: torch.Tensor,
651
+ r_hyps_pad_sos_eos: torch.Tensor,
652
+ ctc_score: torch.Tensor,
653
+ ):
654
+ """Encoder
655
+ Args:
656
+ encoder_out: B x T x F
657
+ encoder_lens: B
658
+ hyps_pad_sos_eos: B x beam x (T2+1),
659
+ hyps with sos & eos and padded by ignore id
660
+ hyps_lens_sos: B x beam, length for each hyp with sos
661
+ r_hyps_pad_sos_eos: B x beam x (T2+1),
662
+ reversed hyps with sos & eos and padded by ignore id
663
+ ctc_score: B x beam, ctc score for each hyp
664
+ Returns:
665
+ decoder_out: B x beam x T2 x V
666
+ r_decoder_out: B x beam x T2 x V
667
+ best_index: B
668
+ """
669
+ B, T, F = encoder_out.shape
670
+ bz = self.beam_size
671
+ B2 = B * bz
672
+ encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F)
673
+ encoder_mask = ~make_pad_mask(encoder_lens, T).unsqueeze(1)
674
+ encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T)
675
+ T2 = hyps_pad_sos_eos.shape[2] - 1
676
+ hyps_pad = hyps_pad_sos_eos.view(B2, T2 + 1)
677
+ hyps_lens = hyps_lens_sos.view(B2, )
678
+ hyps_pad_sos = hyps_pad[:, :-1].contiguous()
679
+ hyps_pad_eos = hyps_pad[:, 1:].contiguous()
680
+
681
+ r_hyps_pad = r_hyps_pad_sos_eos.view(B2, T2 + 1)
682
+ r_hyps_pad_sos = r_hyps_pad[:, :-1].contiguous()
683
+ r_hyps_pad_eos = r_hyps_pad[:, 1:].contiguous()
684
+
685
+ decoder_out, r_decoder_out, _ = self.decoder(
686
+ encoder_out,
687
+ encoder_mask,
688
+ hyps_pad_sos,
689
+ hyps_lens,
690
+ r_hyps_pad_sos,
691
+ self.reverse_weight,
692
+ )
693
+ decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
694
+ V = decoder_out.shape[-1]
695
+ decoder_out = decoder_out.view(B2, T2, V)
696
+ mask = ~make_pad_mask(hyps_lens, T2) # B2 x T2
697
+ # mask index, remove ignore id
698
+ index = torch.unsqueeze(hyps_pad_eos * mask, 2)
699
+ score = decoder_out.gather(2, index).squeeze(2) # B2 X T2
700
+ # mask padded part
701
+ score = score * mask
702
+ decoder_out = decoder_out.view(B, bz, T2, V)
703
+ if self.reverse_weight > 0:
704
+ r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out,
705
+ dim=-1)
706
+ r_decoder_out = r_decoder_out.view(B2, T2, V)
707
+ index = torch.unsqueeze(r_hyps_pad_eos * mask, 2)
708
+ r_score = r_decoder_out.gather(2, index).squeeze(2)
709
+ r_score = r_score * mask
710
+ score = (score * (1 - self.reverse_weight) +
711
+ self.reverse_weight * r_score)
712
+ r_decoder_out = r_decoder_out.view(B, bz, T2, V)
713
+ score = torch.sum(score, axis=1) # B2
714
+ score = torch.reshape(score, (B, bz)) + self.ctc_weight * ctc_score
715
+ best_index = torch.argmax(score, dim=1)
716
+ if self.decoder_fastertransformer:
717
+ return decoder_out, best_index
718
+ else:
719
+ return best_index
720
+
721
+
722
+ def to_numpy(tensors):
723
+ out = []
724
+ if type(tensors) == torch.tensor:
725
+ tensors = [tensors]
726
+ for tensor in tensors:
727
+ if tensor.requires_grad:
728
+ tensor = tensor.detach().cpu().numpy()
729
+ else:
730
+ tensor = tensor.cpu().numpy()
731
+ out.append(tensor)
732
+ return out
733
+
734
+
735
+ def test(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True):
736
+ for a, b in zip(xlist, blist):
737
+ try:
738
+ torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
739
+ except AssertionError as error:
740
+ if tolerate_small_mismatch:
741
+ print(error)
742
+ else:
743
+ raise
744
+
745
+
746
+ def export_offline_encoder(model, configs, args, logger, encoder_onnx_path):
747
+ bz = 32
748
+ seq_len = 100
749
+ beam_size = args.beam_size
750
+ feature_size = configs["input_dim"]
751
+
752
+ speech = torch.randn(bz, seq_len, feature_size, dtype=torch.float32)
753
+ speech_lens = torch.randint(low=10,
754
+ high=seq_len,
755
+ size=(bz, ),
756
+ dtype=torch.int32)
757
+ encoder = Encoder(model.encoder, model.ctc, beam_size)
758
+ encoder.eval()
759
+
760
+ torch.onnx.export(
761
+ encoder,
762
+ (speech, speech_lens),
763
+ encoder_onnx_path,
764
+ export_params=True,
765
+ opset_version=13,
766
+ do_constant_folding=True,
767
+ input_names=["speech", "speech_lengths"],
768
+ output_names=[
769
+ "encoder_out",
770
+ "encoder_out_lens",
771
+ "ctc_log_probs",
772
+ "beam_log_probs",
773
+ "beam_log_probs_idx",
774
+ ],
775
+ dynamic_axes={
776
+ "speech": {
777
+ 0: "B",
778
+ 1: "T"
779
+ },
780
+ "speech_lengths": {
781
+ 0: "B"
782
+ },
783
+ "encoder_out": {
784
+ 0: "B",
785
+ 1: "T_OUT"
786
+ },
787
+ "encoder_out_lens": {
788
+ 0: "B"
789
+ },
790
+ "ctc_log_probs": {
791
+ 0: "B",
792
+ 1: "T_OUT"
793
+ },
794
+ "beam_log_probs": {
795
+ 0: "B",
796
+ 1: "T_OUT"
797
+ },
798
+ "beam_log_probs_idx": {
799
+ 0: "B",
800
+ 1: "T_OUT"
801
+ },
802
+ },
803
+ verbose=False,
804
+ )
805
+
806
+ with torch.no_grad():
807
+ o0, o1, o2, o3, o4 = encoder(speech, speech_lens)
808
+
809
+ providers = ["CUDAExecutionProvider"]
810
+ ort_session = onnxruntime.InferenceSession(encoder_onnx_path,
811
+ providers=providers)
812
+ ort_inputs = {
813
+ "speech": to_numpy(speech),
814
+ "speech_lengths": to_numpy(speech_lens),
815
+ }
816
+ ort_outs = ort_session.run(None, ort_inputs)
817
+
818
+ # check encoder output
819
+ test(to_numpy([o0, o1, o2, o3, o4]), ort_outs)
820
+ logger.info("export offline onnx encoder succeed!")
821
+ onnx_config = {
822
+ "beam_size": args.beam_size,
823
+ "reverse_weight": args.reverse_weight,
824
+ "ctc_weight": args.ctc_weight,
825
+ "fp16": args.fp16,
826
+ }
827
+ return onnx_config
828
+
829
+
830
+ def export_online_encoder(model, configs, args, logger, encoder_onnx_path):
831
+ decoding_chunk_size = args.decoding_chunk_size
832
+ subsampling = model.encoder.embed.subsampling_rate
833
+ context = model.encoder.embed.right_context + 1
834
+ decoding_window = (decoding_chunk_size - 1) * subsampling + context
835
+ batch_size = 32
836
+ audio_len = decoding_window
837
+ feature_size = configs["input_dim"]
838
+ output_size = configs["encoder_conf"]["output_size"]
839
+ num_layers = configs["encoder_conf"]["num_blocks"]
840
+ # in transformer the cnn module will not be available
841
+ transformer = False
842
+ cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1) - 1
843
+ if not cnn_module_kernel:
844
+ transformer = True
845
+ num_decoding_left_chunks = args.num_decoding_left_chunks
846
+ required_cache_size = decoding_chunk_size * num_decoding_left_chunks
847
+ if configs["encoder"] == "squeezeformer":
848
+ encoder = StreamingSqueezeformerEncoder(model, required_cache_size,
849
+ args.beam_size)
850
+ elif configs["encoder"] == "efficientConformer":
851
+ encoder = StreamingEfficientConformerEncoder(model,
852
+ required_cache_size,
853
+ args.beam_size)
854
+ else:
855
+ encoder = StreamingEncoder(
856
+ model,
857
+ required_cache_size,
858
+ args.beam_size,
859
+ transformer,
860
+ args.return_ctc_logprobs,
861
+ )
862
+ encoder.eval()
863
+
864
+ # begin to export encoder
865
+ chunk_xs = torch.randn(batch_size,
866
+ audio_len,
867
+ feature_size,
868
+ dtype=torch.float32)
869
+ chunk_lens = torch.ones(batch_size, dtype=torch.int32) * audio_len
870
+
871
+ offset = torch.arange(0, batch_size).unsqueeze(1)
872
+ # (elayers, b, head, cache_t1, d_k * 2)
873
+ head = configs["encoder_conf"]["attention_heads"]
874
+ d_k = configs["encoder_conf"]["output_size"] // head
875
+ att_cache = torch.randn(
876
+ batch_size,
877
+ num_layers,
878
+ head,
879
+ required_cache_size,
880
+ d_k * 2,
881
+ dtype=torch.float32,
882
+ )
883
+ cnn_cache = torch.randn(
884
+ batch_size,
885
+ num_layers,
886
+ output_size,
887
+ cnn_module_kernel,
888
+ dtype=torch.float32,
889
+ )
890
+
891
+ cache_mask = torch.ones(batch_size,
892
+ 1,
893
+ required_cache_size,
894
+ dtype=torch.float32)
895
+ input_names = [
896
+ "chunk_xs",
897
+ "chunk_lens",
898
+ "offset",
899
+ "att_cache",
900
+ "cnn_cache",
901
+ "cache_mask",
902
+ ]
903
+ output_names = [
904
+ "log_probs",
905
+ "log_probs_idx",
906
+ "chunk_out",
907
+ "chunk_out_lens",
908
+ "r_offset",
909
+ "r_att_cache",
910
+ "r_cnn_cache",
911
+ "r_cache_mask",
912
+ ]
913
+ if args.return_ctc_logprobs:
914
+ output_names = [
915
+ "ctc_log_probs",
916
+ "chunk_out",
917
+ "chunk_out_lens",
918
+ "r_offset",
919
+ "r_att_cache",
920
+ "r_cnn_cache",
921
+ "r_cache_mask",
922
+ ]
923
+ input_tensors = (
924
+ chunk_xs,
925
+ chunk_lens,
926
+ offset,
927
+ att_cache,
928
+ cnn_cache,
929
+ cache_mask,
930
+ )
931
+ if transformer:
932
+ assert (args.return_ctc_logprobs is
933
+ False), "return_ctc_logprobs is not supported in transformer"
934
+ output_names.pop(6)
935
+
936
+ all_names = input_names + output_names
937
+ dynamic_axes = {}
938
+ for name in all_names:
939
+ # only the first dimension is dynamic
940
+ # all other dimension is fixed
941
+ dynamic_axes[name] = {0: "B"}
942
+
943
+ torch.onnx.export(
944
+ encoder,
945
+ input_tensors,
946
+ encoder_onnx_path,
947
+ export_params=True,
948
+ opset_version=14,
949
+ do_constant_folding=True,
950
+ input_names=input_names,
951
+ output_names=output_names,
952
+ dynamic_axes=dynamic_axes,
953
+ verbose=False,
954
+ )
955
+
956
+ with torch.no_grad():
957
+ torch_outs = encoder(chunk_xs, chunk_lens, offset, att_cache,
958
+ cnn_cache, cache_mask)
959
+ if transformer:
960
+ torch_outs = list(torch_outs).pop(6)
961
+ ort_session = onnxruntime.InferenceSession(
962
+ encoder_onnx_path, providers=["CUDAExecutionProvider"])
963
+ ort_inputs = {}
964
+
965
+ input_tensors = to_numpy(input_tensors)
966
+ for idx, name in enumerate(input_names):
967
+ ort_inputs[name] = input_tensors[idx]
968
+ if transformer:
969
+ del ort_inputs["cnn_cache"]
970
+ ort_outs = ort_session.run(None, ort_inputs)
971
+ test(to_numpy(torch_outs), ort_outs, rtol=1e-03, atol=1e-05)
972
+ logger.info("export to onnx streaming encoder succeed!")
973
+ onnx_config = {
974
+ "subsampling_rate": subsampling,
975
+ "context": context,
976
+ "decoding_chunk_size": decoding_chunk_size,
977
+ "num_decoding_left_chunks": num_decoding_left_chunks,
978
+ "beam_size": args.beam_size,
979
+ "fp16": args.fp16,
980
+ "feat_size": feature_size,
981
+ "decoding_window": decoding_window,
982
+ "cnn_module_kernel_cache": cnn_module_kernel,
983
+ "return_ctc_logprobs": args.return_ctc_logprobs,
984
+ }
985
+ return onnx_config
986
+
987
+
988
+ def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path,
989
+ decoder_fastertransformer):
990
+ bz, seq_len = 32, 100
991
+ beam_size = args.beam_size
992
+ decoder = Decoder(
993
+ model.decoder,
994
+ model.ctc_weight,
995
+ model.reverse_weight,
996
+ beam_size,
997
+ decoder_fastertransformer,
998
+ )
999
+ decoder.eval()
1000
+
1001
+ hyps_pad_sos_eos = torch.randint(low=3,
1002
+ high=1000,
1003
+ size=(bz, beam_size, seq_len))
1004
+ hyps_lens_sos = torch.randint(low=3,
1005
+ high=seq_len,
1006
+ size=(bz, beam_size),
1007
+ dtype=torch.int32)
1008
+ r_hyps_pad_sos_eos = torch.randint(low=3,
1009
+ high=1000,
1010
+ size=(bz, beam_size, seq_len))
1011
+
1012
+ output_size = configs["encoder_conf"]["output_size"]
1013
+ encoder_out = torch.randn(bz, seq_len, output_size, dtype=torch.float32)
1014
+ encoder_out_lens = torch.randint(low=3,
1015
+ high=seq_len,
1016
+ size=(bz, ),
1017
+ dtype=torch.int32)
1018
+ ctc_score = torch.randn(bz, beam_size, dtype=torch.float32)
1019
+
1020
+ input_names = [
1021
+ "encoder_out",
1022
+ "encoder_out_lens",
1023
+ "hyps_pad_sos_eos",
1024
+ "hyps_lens_sos",
1025
+ "r_hyps_pad_sos_eos",
1026
+ "ctc_score",
1027
+ ]
1028
+ output_names = ["best_index"]
1029
+ if decoder_fastertransformer:
1030
+ output_names.insert(0, "decoder_out")
1031
+
1032
+ torch.onnx.export(
1033
+ decoder,
1034
+ (
1035
+ encoder_out,
1036
+ encoder_out_lens,
1037
+ hyps_pad_sos_eos,
1038
+ hyps_lens_sos,
1039
+ r_hyps_pad_sos_eos,
1040
+ ctc_score,
1041
+ ),
1042
+ decoder_onnx_path,
1043
+ export_params=True,
1044
+ opset_version=13,
1045
+ do_constant_folding=True,
1046
+ input_names=input_names,
1047
+ output_names=output_names,
1048
+ dynamic_axes={
1049
+ "encoder_out": {
1050
+ 0: "B",
1051
+ 1: "T"
1052
+ },
1053
+ "encoder_out_lens": {
1054
+ 0: "B"
1055
+ },
1056
+ "hyps_pad_sos_eos": {
1057
+ 0: "B",
1058
+ 2: "T2"
1059
+ },
1060
+ "hyps_lens_sos": {
1061
+ 0: "B"
1062
+ },
1063
+ "r_hyps_pad_sos_eos": {
1064
+ 0: "B",
1065
+ 2: "T2"
1066
+ },
1067
+ "ctc_score": {
1068
+ 0: "B"
1069
+ },
1070
+ "best_index": {
1071
+ 0: "B"
1072
+ },
1073
+ },
1074
+ verbose=False,
1075
+ )
1076
+ with torch.no_grad():
1077
+ o0 = decoder(
1078
+ encoder_out,
1079
+ encoder_out_lens,
1080
+ hyps_pad_sos_eos,
1081
+ hyps_lens_sos,
1082
+ r_hyps_pad_sos_eos,
1083
+ ctc_score,
1084
+ )
1085
+ providers = ["CUDAExecutionProvider"]
1086
+ ort_session = onnxruntime.InferenceSession(decoder_onnx_path,
1087
+ providers=providers)
1088
+
1089
+ input_tensors = [
1090
+ encoder_out,
1091
+ encoder_out_lens,
1092
+ hyps_pad_sos_eos,
1093
+ hyps_lens_sos,
1094
+ r_hyps_pad_sos_eos,
1095
+ ctc_score,
1096
+ ]
1097
+ ort_inputs = {}
1098
+ input_tensors = to_numpy(input_tensors)
1099
+ for idx, name in enumerate(input_names):
1100
+ ort_inputs[name] = input_tensors[idx]
1101
+
1102
+ # if model.reverse weight == 0,
1103
+ # the r_hyps_pad will be removed
1104
+ # from the onnx decoder since it doen't play any role
1105
+ if model.reverse_weight == 0:
1106
+ del ort_inputs["r_hyps_pad_sos_eos"]
1107
+ ort_outs = ort_session.run(None, ort_inputs)
1108
+
1109
+ # check decoder output
1110
+ if decoder_fastertransformer:
1111
+ test(to_numpy(o0), ort_outs, rtol=1e-03, atol=1e-05)
1112
+ else:
1113
+ test(to_numpy([o0]), ort_outs, rtol=1e-03, atol=1e-05)
1114
+ logger.info("export to onnx decoder succeed!")
1115
+
1116
+
1117
+ if __name__ == "__main__":
1118
+ parser = argparse.ArgumentParser(description="export x86_gpu model")
1119
+ parser.add_argument("--config", required=True, help="config file")
1120
+ parser.add_argument("--checkpoint", required=True, help="checkpoint model")
1121
+ parser.add_argument(
1122
+ "--cmvn_file",
1123
+ required=False,
1124
+ default="",
1125
+ type=str,
1126
+ help="global_cmvn file, default path is in config file",
1127
+ )
1128
+ parser.add_argument(
1129
+ "--reverse_weight",
1130
+ default=-1.0,
1131
+ type=float,
1132
+ required=False,
1133
+ help="reverse weight for bitransformer," +
1134
+ "default value is in config file",
1135
+ )
1136
+ parser.add_argument(
1137
+ "--ctc_weight",
1138
+ default=-1.0,
1139
+ type=float,
1140
+ required=False,
1141
+ help="ctc weight, default value is in config file",
1142
+ )
1143
+ parser.add_argument(
1144
+ "--beam_size",
1145
+ default=10,
1146
+ type=int,
1147
+ required=False,
1148
+ help="beam size would be ctc output size",
1149
+ )
1150
+ parser.add_argument(
1151
+ "--output_onnx_dir",
1152
+ default="onnx_model",
1153
+ help="output onnx encoder and decoder directory",
1154
+ )
1155
+ parser.add_argument(
1156
+ "--fp16",
1157
+ action="store_true",
1158
+ help="whether to export fp16 model, default false",
1159
+ )
1160
+ # arguments for streaming encoder
1161
+ parser.add_argument(
1162
+ "--streaming",
1163
+ action="store_true",
1164
+ help="whether to export streaming encoder, default false",
1165
+ )
1166
+ parser.add_argument(
1167
+ "--decoding_chunk_size",
1168
+ default=16,
1169
+ type=int,
1170
+ required=False,
1171
+ help="the decoding chunk size, <=0 is not supported",
1172
+ )
1173
+ parser.add_argument(
1174
+ "--num_decoding_left_chunks",
1175
+ default=5,
1176
+ type=int,
1177
+ required=False,
1178
+ help="number of left chunks, <= 0 is not supported",
1179
+ )
1180
+ parser.add_argument(
1181
+ "--decoder_fastertransformer",
1182
+ action="store_true",
1183
+ help="return decoder_out and best_index for ft",
1184
+ )
1185
+ parser.add_argument(
1186
+ "--return_ctc_logprobs",
1187
+ action="store_true",
1188
+ help="return full ctc_log_probs for TLG streaming encoder",
1189
+ )
1190
+ args = parser.parse_args()
1191
+
1192
+ torch.manual_seed(0)
1193
+ torch.set_printoptions(precision=10)
1194
+
1195
+ with open(args.config, "r") as fin:
1196
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
1197
+ if args.cmvn_file and os.path.exists(args.cmvn_file):
1198
+ if 'cmvn' not in configs:
1199
+ configs['cmvn'] = "global_cmvn"
1200
+ configs['cmvn_conf'] = {}
1201
+ else:
1202
+ assert configs['cmvn'] == "global_cmvn"
1203
+ assert configs['cmvn_conf'] is not None
1204
+ configs['cmvn_conf']["cmvn_file"] = args.cmvn_file
1205
+ if (args.reverse_weight != -1.0
1206
+ and "reverse_weight" in configs["model_conf"]):
1207
+ configs["model_conf"]["reverse_weight"] = args.reverse_weight
1208
+ print("Update reverse weight to", args.reverse_weight)
1209
+ if args.ctc_weight != -1:
1210
+ print("Update ctc weight to ", args.ctc_weight)
1211
+ configs["model_conf"]["ctc_weight"] = args.ctc_weight
1212
+ configs["encoder_conf"]["use_dynamic_chunk"] = False
1213
+
1214
+ model, configs = init_model(args, configs)
1215
+ model.eval()
1216
+
1217
+ if not os.path.exists(args.output_onnx_dir):
1218
+ os.mkdir(args.output_onnx_dir)
1219
+ encoder_onnx_path = os.path.join(args.output_onnx_dir, "encoder.onnx")
1220
+ export_enc_func = None
1221
+ if args.streaming:
1222
+ assert args.decoding_chunk_size > 0
1223
+ assert args.num_decoding_left_chunks > 0
1224
+ export_enc_func = export_online_encoder
1225
+ else:
1226
+ export_enc_func = export_offline_encoder
1227
+
1228
+ onnx_config = export_enc_func(model, configs, args, logger,
1229
+ encoder_onnx_path)
1230
+
1231
+ decoder_onnx_path = os.path.join(args.output_onnx_dir, "decoder.onnx")
1232
+ export_rescoring_decoder(
1233
+ model,
1234
+ configs,
1235
+ args,
1236
+ logger,
1237
+ decoder_onnx_path,
1238
+ args.decoder_fastertransformer,
1239
+ )
1240
+
1241
+ if args.fp16:
1242
+ try:
1243
+ import onnxmltools
1244
+ from onnxmltools.utils.float16_converter import (
1245
+ convert_float_to_float16, )
1246
+ except ImportError:
1247
+ print("Please install onnxmltools!")
1248
+ sys.exit(1)
1249
+ encoder_onnx_model = onnxmltools.utils.load_model(encoder_onnx_path)
1250
+ encoder_onnx_model = convert_float_to_float16(encoder_onnx_model)
1251
+ encoder_onnx_path = os.path.join(args.output_onnx_dir,
1252
+ "encoder_fp16.onnx")
1253
+ onnxmltools.utils.save_model(encoder_onnx_model, encoder_onnx_path)
1254
+ decoder_onnx_model = onnxmltools.utils.load_model(decoder_onnx_path)
1255
+ decoder_onnx_model = convert_float_to_float16(decoder_onnx_model)
1256
+ decoder_onnx_path = os.path.join(args.output_onnx_dir,
1257
+ "decoder_fp16.onnx")
1258
+ onnxmltools.utils.save_model(decoder_onnx_model, decoder_onnx_path)
1259
+ # dump configurations
1260
+
1261
+ config_dir = os.path.join(args.output_onnx_dir, "config.yaml")
1262
+ with open(config_dir, "w") as out:
1263
+ yaml.dump(onnx_config, out)
wenet/bin/recognize.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import copy
19
+ import logging
20
+ import os
21
+
22
+ import torch
23
+ import yaml
24
+ from torch.utils.data import DataLoader
25
+
26
+ from wenet.dataset.dataset import Dataset
27
+ from wenet.utils.config import override_config
28
+ from wenet.utils.init_model import init_model
29
+ from wenet.utils.init_tokenizer import init_tokenizer
30
+ from wenet.utils.context_graph import ContextGraph
31
+ from wenet.utils.ctc_utils import get_blank_id
32
+ from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu
33
+
34
+
35
+ def get_args():
36
+ parser = argparse.ArgumentParser(description='recognize with your model')
37
+ parser.add_argument('--config', required=True, help='config file')
38
+ parser.add_argument('--test_data', required=True, help='test data file')
39
+ parser.add_argument('--data_type',
40
+ default='raw',
41
+ # choices=['raw', 'shard'],
42
+ help='train and cv data type')
43
+ parser.add_argument('--gpu',
44
+ type=int,
45
+ default=-1,
46
+ help='gpu id for this rank, -1 for cpu')
47
+ parser.add_argument('--device',
48
+ type=str,
49
+ default="cpu",
50
+ choices=["cpu", "npu", "cuda"],
51
+ help='accelerator to use')
52
+ parser.add_argument('--dtype',
53
+ type=str,
54
+ default='fp32',
55
+ choices=['fp16', 'fp32', 'bf16'],
56
+ help='model\'s dtype')
57
+ parser.add_argument('--num_workers',
58
+ default=0,
59
+ type=int,
60
+ help='num of subprocess workers for reading')
61
+ parser.add_argument('--checkpoint', required=True, help='checkpoint model')
62
+ parser.add_argument('--beam_size',
63
+ type=int,
64
+ default=10,
65
+ help='beam size for search')
66
+ parser.add_argument('--length_penalty',
67
+ type=float,
68
+ default=0.0,
69
+ help='length penalty')
70
+ parser.add_argument('--blank_penalty',
71
+ type=float,
72
+ default=0.0,
73
+ help='blank penalty')
74
+ parser.add_argument('--result_dir', required=True, help='asr result file')
75
+ parser.add_argument('--batch_size',
76
+ type=int,
77
+ default=16,
78
+ help='asr result file')
79
+ parser.add_argument('--modes',
80
+ nargs='+',
81
+ help="""decoding mode, support the following:
82
+ attention
83
+ ctc_greedy_search
84
+ ctc_prefix_beam_search
85
+ attention_rescoring
86
+ rnnt_greedy_search
87
+ rnnt_beam_search
88
+ rnnt_beam_attn_rescoring
89
+ ctc_beam_td_attn_rescoring
90
+ hlg_onebest
91
+ hlg_rescore
92
+ paraformer_greedy_search
93
+ paraformer_beam_search""")
94
+ parser.add_argument('--search_ctc_weight',
95
+ type=float,
96
+ default=1.0,
97
+ help='ctc weight for nbest generation')
98
+ parser.add_argument('--search_transducer_weight',
99
+ type=float,
100
+ default=0.0,
101
+ help='transducer weight for nbest generation')
102
+ parser.add_argument('--ctc_weight',
103
+ type=float,
104
+ default=0.0,
105
+ help='ctc weight for rescoring weight in \
106
+ attention rescoring decode mode \
107
+ ctc weight for rescoring weight in \
108
+ transducer attention rescore decode mode')
109
+
110
+ parser.add_argument('--transducer_weight',
111
+ type=float,
112
+ default=0.0,
113
+ help='transducer weight for rescoring weight in '
114
+ 'transducer attention rescore mode')
115
+ parser.add_argument('--attn_weight',
116
+ type=float,
117
+ default=0.0,
118
+ help='attention weight for rescoring weight in '
119
+ 'transducer attention rescore mode')
120
+ parser.add_argument('--decoding_chunk_size',
121
+ type=int,
122
+ default=-1,
123
+ help='''decoding chunk size,
124
+ <0: for decoding, use full chunk.
125
+ >0: for decoding, use fixed chunk size as set.
126
+ 0: used for training, it's prohibited here''')
127
+ parser.add_argument('--num_decoding_left_chunks',
128
+ type=int,
129
+ default=-1,
130
+ help='number of left chunks for decoding')
131
+ parser.add_argument('--simulate_streaming',
132
+ action='store_true',
133
+ help='simulate streaming inference')
134
+ parser.add_argument('--reverse_weight',
135
+ type=float,
136
+ default=0.0,
137
+ help='''right to left weight for attention rescoring
138
+ decode mode''')
139
+ parser.add_argument('--override_config',
140
+ action='append',
141
+ default=[],
142
+ help="override yaml config")
143
+
144
+ parser.add_argument('--word',
145
+ default='',
146
+ type=str,
147
+ help='word file, only used for hlg decode')
148
+ parser.add_argument('--hlg',
149
+ default='',
150
+ type=str,
151
+ help='hlg file, only used for hlg decode')
152
+ parser.add_argument('--lm_scale',
153
+ type=float,
154
+ default=0.0,
155
+ help='lm scale for hlg attention rescore decode')
156
+ parser.add_argument('--decoder_scale',
157
+ type=float,
158
+ default=0.0,
159
+ help='lm scale for hlg attention rescore decode')
160
+ parser.add_argument('--r_decoder_scale',
161
+ type=float,
162
+ default=0.0,
163
+ help='lm scale for hlg attention rescore decode')
164
+
165
+ parser.add_argument(
166
+ '--context_bias_mode',
167
+ type=str,
168
+ default='',
169
+ help='''Context bias mode, selectable from the following
170
+ option: decoding-graph, deep-biasing''')
171
+ parser.add_argument('--context_list_path',
172
+ type=str,
173
+ default='',
174
+ help='Context list path')
175
+ parser.add_argument('--context_graph_score',
176
+ type=float,
177
+ default=0.0,
178
+ help='''The higher the score, the greater the degree of
179
+ bias using decoding-graph for biasing''')
180
+
181
+ parser.add_argument('--use_lora',
182
+ type=bool,
183
+ default=False,
184
+ help='''Whether to use lora for biasing''')
185
+ parser.add_argument("--lora_ckpt_path",
186
+ default=None,
187
+ type=str,
188
+ help="lora checkpoint path.")
189
+
190
+ parser.add_argument('--task',
191
+ type=str,
192
+ default='asr',
193
+ help='Context list path')
194
+ parser.add_argument('--lang',
195
+ type=str,
196
+ default='zh',
197
+ help='Context list path')
198
+ args = parser.parse_args()
199
+ print(args)
200
+ return args
201
+
202
+
203
+ def main():
204
+ args = get_args()
205
+ logging.basicConfig(level=logging.DEBUG,
206
+ format='%(asctime)s %(levelname)s %(message)s')
207
+ if args.gpu != -1:
208
+ # remain the original usage of gpu
209
+ args.device = "cuda"
210
+ if "cuda" in args.device:
211
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
212
+
213
+ with open(args.config, 'r') as fin:
214
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
215
+ if len(args.override_config) > 0:
216
+ configs = override_config(configs, args.override_config)
217
+
218
+ test_conf = copy.deepcopy(configs['dataset_conf'])
219
+
220
+ test_conf['filter_conf']['max_length'] = 102400
221
+ test_conf['filter_conf']['min_length'] = 0
222
+ test_conf['filter_conf']['token_max_length'] = 102400
223
+ test_conf['filter_conf']['token_min_length'] = 0
224
+ test_conf['filter_conf']['max_output_input_ratio'] = 102400
225
+ test_conf['filter_conf']['min_output_input_ratio'] = 0
226
+ test_conf['speed_perturb'] = False
227
+ test_conf['spec_aug'] = False
228
+ test_conf['spec_sub'] = False
229
+ test_conf['spec_trim'] = False
230
+ test_conf['shuffle'] = False
231
+ test_conf['sort'] = False
232
+ test_conf['cycle'] = 1
233
+ test_conf['list_shuffle'] = False
234
+ if 'fbank_conf' in test_conf:
235
+ test_conf['fbank_conf']['dither'] = 0.0
236
+ elif 'mfcc_conf' in test_conf:
237
+ test_conf['mfcc_conf']['dither'] = 0.0
238
+ test_conf['batch_conf']['batch_type'] = "static"
239
+ test_conf['batch_conf']['batch_size'] = args.batch_size
240
+
241
+ tokenizer = init_tokenizer(configs)
242
+ test_dataset = Dataset(args.data_type,
243
+ args.test_data,
244
+ tokenizer,
245
+ test_conf,
246
+ partition=False)
247
+
248
+ test_data_loader = DataLoader(test_dataset,
249
+ batch_size=None,
250
+ num_workers=args.num_workers)
251
+
252
+ # Init asr model from configs
253
+ args.jit = False
254
+ model, configs = init_model(args, configs)
255
+
256
+ device = torch.device(args.device)
257
+ model = model.to(device)
258
+ model.eval()
259
+ dtype = torch.float32
260
+ if args.dtype == 'fp16':
261
+ dtype = torch.float16
262
+ elif args.dtype == 'bf16':
263
+ dtype = torch.bfloat16
264
+ logging.info("compute dtype is {}".format(dtype))
265
+
266
+ context_graph = None
267
+ if 'decoding-graph' in args.context_bias_mode:
268
+ context_graph = ContextGraph(args.context_list_path,
269
+ tokenizer.symbol_table,
270
+ configs['tokenizer_conf']['bpe_path'],
271
+ args.context_graph_score)
272
+
273
+ _, blank_id = get_blank_id(configs, tokenizer.symbol_table)
274
+ logging.info("blank_id is {}".format(blank_id))
275
+
276
+ # TODO(Dinghao Zhou): Support RNN-T related decoding
277
+ # TODO(Lv Xiang): Support k2 related decoding
278
+ # TODO(Kaixun Huang): Support context graph
279
+ files = {}
280
+ for mode in args.modes:
281
+ dir_name = os.path.join(args.result_dir, mode)
282
+ os.makedirs(dir_name, exist_ok=True)
283
+ file_name = os.path.join(dir_name, 'text')
284
+ files[mode] = open(file_name, 'w', encoding='utf-8')
285
+ max_format_len = max([len(mode) for mode in args.modes])
286
+
287
+ with torch.cuda.amp.autocast(enabled=True,
288
+ dtype=dtype,
289
+ cache_enabled=False):
290
+ with torch.no_grad():
291
+ utt_num=0
292
+ # logging.info(f'utt_num: {utt_num}')
293
+ for batch_idx, batch in enumerate(test_data_loader):
294
+ keys = batch["keys"]
295
+ feats = batch["feats"].to(device)
296
+ target = batch["target"].to(device)
297
+ feats_lengths = batch["feats_lengths"].to(device)
298
+ target_lengths = batch["target_lengths"].to(device)
299
+ batch_size = feats.size(0)
300
+ # task_list = ["transcribe" for i in range(batch_size)]
301
+ task_list = [args.task for i in range(batch_size)]
302
+ lang_list = [args.lang for i in range(batch_size)]
303
+ infos = {"tasks": task_list, "langs":lang_list}
304
+ results = model.decode(
305
+ args.modes,
306
+ feats,
307
+ feats_lengths,
308
+ args.beam_size,
309
+ decoding_chunk_size=args.decoding_chunk_size,
310
+ num_decoding_left_chunks=args.num_decoding_left_chunks,
311
+ ctc_weight=args.ctc_weight,
312
+ simulate_streaming=args.simulate_streaming,
313
+ reverse_weight=args.reverse_weight,
314
+ context_graph=context_graph,
315
+ blank_id=blank_id,
316
+ blank_penalty=args.blank_penalty,
317
+ length_penalty=args.length_penalty,
318
+ infos=infos)
319
+ for i, key in enumerate(keys):
320
+ utt_num += 1
321
+ for mode, hyps in results.items():
322
+ tokens = hyps[i].tokens
323
+ line = '{} {}'.format(key,
324
+ tokenizer.detokenize(tokens)[0])
325
+ logging.info('{} {}'.format(mode.ljust(max_format_len),
326
+ line))
327
+ files[mode].write(line + '\n')
328
+ # if utt_num % 500 == 0:
329
+ # files[mode].flush()
330
+ for mode, f in files.items():
331
+ f.flush() # 强制将缓冲区内容刷新到文件
332
+ f.close()
333
+
334
+
335
+ if __name__ == '__main__':
336
+ main()
wenet/bin/recognize4llmasr.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import copy
19
+ import logging
20
+ import os
21
+
22
+ import torch
23
+ import yaml
24
+ from gxl_ai_utils.utils.utils_model import set_random_seed
25
+ from torch.utils.data import DataLoader
26
+
27
+ from wenet.dataset.dataset import Dataset
28
+ from wenet.llm_asr.llmasr_model import LLMASR_Model
29
+ from wenet.utils.config import override_config
30
+ from wenet.utils.init_model import init_model
31
+ from wenet.utils.init_tokenizer import init_tokenizer
32
+ from wenet.utils.context_graph import ContextGraph
33
+ from wenet.utils.ctc_utils import get_blank_id
34
+ from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu
35
+
36
+
37
+ def get_args():
38
+ parser = argparse.ArgumentParser(description='recognize with your model')
39
+ parser.add_argument('--config', required=True, help='config file')
40
+ parser.add_argument('--test_data', required=True, help='test data file')
41
+ parser.add_argument('--data_type',
42
+ default='raw',
43
+ # choices=['raw', 'shard'],
44
+ help='train and cv data type')
45
+ parser.add_argument('--gpu',
46
+ type=int,
47
+ default=-1,
48
+ help='gpu id for this rank, -1 for cpu')
49
+ parser.add_argument('--device',
50
+ type=str,
51
+ default="cpu",
52
+ choices=["cpu", "npu", "cuda"],
53
+ help='accelerator to use')
54
+ parser.add_argument('--dtype',
55
+ type=str,
56
+ default='fp32',
57
+ choices=['fp16', 'fp32', 'bf16'],
58
+ help='model\'s dtype')
59
+ parser.add_argument('--num_workers',
60
+ default=0,
61
+ type=int,
62
+ help='num of subprocess workers for reading')
63
+ parser.add_argument('--checkpoint', required=True, help='checkpoint model')
64
+ parser.add_argument('--beam_size',
65
+ type=int,
66
+ default=10,
67
+ help='beam size for search')
68
+ parser.add_argument('--length_penalty',
69
+ type=float,
70
+ default=0.0,
71
+ help='length penalty')
72
+ parser.add_argument('--blank_penalty',
73
+ type=float,
74
+ default=0.0,
75
+ help='blank penalty')
76
+ parser.add_argument('--result_dir', required=True, help='asr result file')
77
+ parser.add_argument('--batch_size',
78
+ type=int,
79
+ default=16,
80
+ help='asr result file')
81
+ parser.add_argument('--modes',
82
+ nargs='+',
83
+ help="""decoding mode, support the following:
84
+ attention
85
+ ctc_greedy_search
86
+ ctc_prefix_beam_search
87
+ attention_rescoring
88
+ rnnt_greedy_search
89
+ rnnt_beam_search
90
+ rnnt_beam_attn_rescoring
91
+ ctc_beam_td_attn_rescoring
92
+ hlg_onebest
93
+ hlg_rescore
94
+ paraformer_greedy_search
95
+ paraformer_beam_search""")
96
+ parser.add_argument('--search_ctc_weight',
97
+ type=float,
98
+ default=1.0,
99
+ help='ctc weight for nbest generation')
100
+ parser.add_argument('--search_transducer_weight',
101
+ type=float,
102
+ default=0.0,
103
+ help='transducer weight for nbest generation')
104
+ parser.add_argument('--ctc_weight',
105
+ type=float,
106
+ default=0.0,
107
+ help='ctc weight for rescoring weight in \
108
+ attention rescoring decode mode \
109
+ ctc weight for rescoring weight in \
110
+ transducer attention rescore decode mode')
111
+
112
+ parser.add_argument('--transducer_weight',
113
+ type=float,
114
+ default=0.0,
115
+ help='transducer weight for rescoring weight in '
116
+ 'transducer attention rescore mode')
117
+ parser.add_argument('--attn_weight',
118
+ type=float,
119
+ default=0.0,
120
+ help='attention weight for rescoring weight in '
121
+ 'transducer attention rescore mode')
122
+ parser.add_argument('--decoding_chunk_size',
123
+ type=int,
124
+ default=-1,
125
+ help='''decoding chunk size,
126
+ <0: for decoding, use full chunk.
127
+ >0: for decoding, use fixed chunk size as set.
128
+ 0: used for training, it's prohibited here''')
129
+ parser.add_argument('--num_decoding_left_chunks',
130
+ type=int,
131
+ default=-1,
132
+ help='number of left chunks for decoding')
133
+ parser.add_argument('--simulate_streaming',
134
+ action='store_true',
135
+ help='simulate streaming inference')
136
+ parser.add_argument('--reverse_weight',
137
+ type=float,
138
+ default=0.0,
139
+ help='''right to left weight for attention rescoring
140
+ decode mode''')
141
+ parser.add_argument('--override_config',
142
+ action='append',
143
+ default=[],
144
+ help="override yaml config")
145
+
146
+ parser.add_argument('--word',
147
+ default='',
148
+ type=str,
149
+ help='word file, only used for hlg decode')
150
+ parser.add_argument('--hlg',
151
+ default='',
152
+ type=str,
153
+ help='hlg file, only used for hlg decode')
154
+ parser.add_argument('--lm_scale',
155
+ type=float,
156
+ default=0.0,
157
+ help='lm scale for hlg attention rescore decode')
158
+ parser.add_argument('--decoder_scale',
159
+ type=float,
160
+ default=0.0,
161
+ help='lm scale for hlg attention rescore decode')
162
+ parser.add_argument('--r_decoder_scale',
163
+ type=float,
164
+ default=0.0,
165
+ help='lm scale for hlg attention rescore decode')
166
+
167
+ parser.add_argument(
168
+ '--context_bias_mode',
169
+ type=str,
170
+ default='',
171
+ help='''Context bias mode, selectable from the following
172
+ option: decoding-graph, deep-biasing''')
173
+ parser.add_argument('--context_list_path',
174
+ type=str,
175
+ default='',
176
+ help='Context list path')
177
+ parser.add_argument('--context_graph_score',
178
+ type=float,
179
+ default=0.0,
180
+ help='''The higher the score, the greater the degree of
181
+ bias using decoding-graph for biasing''')
182
+
183
+ parser.add_argument('--use_lora',
184
+ type=bool,
185
+ default=False,
186
+ help='''Whether to use lora for biasing''')
187
+ parser.add_argument("--lora_ckpt_path",
188
+ default=None,
189
+ type=str,
190
+ help="lora checkpoint path.")
191
+
192
+ parser.add_argument('--task',
193
+ type=str,
194
+ default='asr',
195
+ help='Context list path')
196
+ parser.add_argument('--lang',
197
+ type=str,
198
+ default='zh',
199
+ help='Context list path')
200
+ args = parser.parse_args()
201
+ print(args)
202
+ return args
203
+
204
+
205
+ def main():
206
+ args = get_args()
207
+ logging.basicConfig(level=logging.DEBUG,
208
+ format='%(asctime)s %(levelname)s %(message)s')
209
+
210
+ set_random_seed(777)
211
+
212
+ if args.gpu != -1:
213
+ # remain the original usage of gpu
214
+ args.device = "cuda"
215
+ if "cuda" in args.device:
216
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
217
+
218
+ with open(args.config, 'r') as fin:
219
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
220
+ if len(args.override_config) > 0:
221
+ configs = override_config(configs, args.override_config)
222
+ configs['dataset_conf']['filter_conf']['filter_no_extra_info'] = False
223
+ test_conf = copy.deepcopy(configs['dataset_conf'])
224
+
225
+ test_conf['filter_conf']['max_length'] = 3000 # whisper最长处理30s 102400
226
+ test_conf['filter_conf']['min_length'] = 0
227
+ test_conf['filter_conf']['token_max_length'] = 102400
228
+ test_conf['filter_conf']['token_min_length'] = 0
229
+ test_conf['filter_conf']['max_output_input_ratio'] = 102400
230
+ test_conf['filter_conf']['min_output_input_ratio'] = 0
231
+ test_conf['speed_perturb'] = False
232
+ test_conf['spec_aug'] = False
233
+ test_conf['spec_sub'] = False
234
+ test_conf['spec_trim'] = False
235
+ test_conf['shuffle'] = True
236
+ test_conf['sort'] = False
237
+ test_conf['cycle'] = 1
238
+ test_conf['list_shuffle'] = True
239
+ if 'fbank_conf' in test_conf:
240
+ test_conf['fbank_conf']['dither'] = 0.0
241
+ elif 'mfcc_conf' in test_conf:
242
+ test_conf['mfcc_conf']['dither'] = 0.0
243
+ test_conf['batch_conf']['batch_type'] = "static"
244
+ test_conf['batch_conf']['batch_size'] = 1
245
+ test_conf['split_num'] = 1
246
+
247
+
248
+ tokenizer = init_tokenizer(configs)
249
+ test_dataset = Dataset(args.data_type,
250
+ args.test_data,
251
+ tokenizer,
252
+ test_conf,
253
+ partition=False)
254
+
255
+ test_data_loader = DataLoader(test_dataset,
256
+ batch_size=None,
257
+ num_workers=args.num_workers)
258
+
259
+ # Init asr model from configs
260
+ args.jit = False
261
+ model, configs = init_model(args, configs)
262
+
263
+ device = torch.device(args.device)
264
+ model:LLMASR_Model = model.to(device)
265
+ model.eval()
266
+ dtype = torch.float32
267
+ if args.dtype == 'fp16':
268
+ dtype = torch.float16
269
+ elif args.dtype == 'bf16':
270
+ dtype = torch.bfloat16
271
+ logging.info("compute dtype is {}".format(dtype))
272
+
273
+ context_graph = None
274
+ if 'decoding-graph' in args.context_bias_mode:
275
+ context_graph = ContextGraph(args.context_list_path,
276
+ tokenizer.symbol_table,
277
+ configs['tokenizer_conf']['bpe_path'],
278
+ args.context_graph_score)
279
+
280
+ _, blank_id = get_blank_id(configs, tokenizer.symbol_table)
281
+ logging.info("blank_id is {}".format(blank_id))
282
+
283
+ # TODO(Dinghao Zhou): Support RNN-T related decoding
284
+ # TODO(Lv Xiang): Support k2 related decoding
285
+ # TODO(Kaixun Huang): Support context graph
286
+ files = {}
287
+ modes = ['llmasr_decode']
288
+ for mode in modes:
289
+ dir_name = os.path.join(args.result_dir, mode)
290
+ os.makedirs(dir_name, exist_ok=True)
291
+ file_name = os.path.join(dir_name, 'text')
292
+ files[mode] = open(file_name, 'w', encoding='utf-8')
293
+ max_format_len = max([len(mode) for mode in args.modes])
294
+
295
+ # Get prompt config
296
+ from gxl_ai_utils.utils import utils_file
297
+ global_prompt_dict = utils_file.load_dict_from_yaml('conf/prompt_stage4.yaml')
298
+
299
+ with torch.cuda.amp.autocast(enabled=True,
300
+ dtype=dtype,
301
+ cache_enabled=False):
302
+ with torch.no_grad():
303
+ # logging.info(f'utt_num: {utt_num}')
304
+ for batch_idx, batch in enumerate(test_data_loader):
305
+ keys = batch["keys"]
306
+ feats = batch["feats"].to(device)
307
+ target = batch["target"].to(device)
308
+ feats_lengths = batch["feats_lengths"].to(device)
309
+ target_lengths = batch["target_lengths"].to(device)
310
+ batch_size = feats.size(0)
311
+
312
+ import random
313
+ if '><' in args.task:
314
+ args.task = args.task.replace('><', '> <')
315
+ if args.task == "<TRANSCRIBE>" or args.task == "<transcribe>":
316
+ is_truncation = False
317
+ else:
318
+ is_truncation = True
319
+ random_index = random.randint(0, len(global_prompt_dict[args.task])-1)
320
+ prompt = global_prompt_dict[args.task][random_index]
321
+ # print(args.task, prompt)
322
+
323
+ res_text = model.generate(wavs=feats, wavs_len=feats_lengths, prompt=prompt)
324
+ for mode in modes:
325
+ line = "{}\t{}".format(keys[0], res_text[0])
326
+ files[mode].write(line+'\n')
327
+ utils_file.logging_print( '{} {} {}'.format(batch_idx, keys[0], res_text[0]))
328
+ if batch_idx % 100 == 0:
329
+ for mode, f in files.items():
330
+ f.flush() # 强制将缓冲区内容刷新到文件
331
+ # if batch_idx >= 1000 and is_truncation:
332
+ # utils_file.logging_info('采用截断至3000的策略')
333
+ # break
334
+ for mode, f in files.items():
335
+ f.flush() # 强制将缓冲区内容刷新到文件
336
+ f.close()
337
+
338
+
339
+ if __name__ == '__main__':
340
+ main()
wenet/bin/recognize_onnx_gpu.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
16
+ #
17
+ # Licensed under the Apache License, Version 2.0 (the "License");
18
+ # you may not use this file except in compliance with the License.
19
+ # You may obtain a copy of the License at
20
+ #
21
+ # http://www.apache.org/licenses/LICENSE-2.0
22
+ #
23
+ # Unless required by applicable law or agreed to in writing, software
24
+ # distributed under the License is distributed on an "AS IS" BASIS,
25
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26
+ # See the License for the specific language governing permissions and
27
+ # limitations under the License.
28
+ """
29
+ This script is for testing exported onnx encoder and decoder from
30
+ export_onnx_gpu.py. The exported onnx models only support batch offline ASR inference.
31
+ It requires a python wrapped c++ ctc decoder.
32
+ Please install it by following:
33
+ https://github.com/Slyne/ctc_decoder.git
34
+ """
35
+ from __future__ import print_function
36
+
37
+ import argparse
38
+ import copy
39
+ import logging
40
+ import os
41
+ import sys
42
+
43
+ import torch
44
+ import yaml
45
+ from torch.utils.data import DataLoader
46
+
47
+ from wenet.dataset.dataset import Dataset
48
+ from wenet.utils.common import IGNORE_ID
49
+ from wenet.utils.config import override_config
50
+ from wenet.utils.init_tokenizer import init_tokenizer
51
+
52
+ import onnxruntime as rt
53
+ import multiprocessing
54
+ import numpy as np
55
+
56
+ try:
57
+ from swig_decoders import map_batch, \
58
+ ctc_beam_search_decoder_batch, \
59
+ TrieVector, PathTrie
60
+ except ImportError:
61
+ print('Please install ctc decoders first by refering to\n' +
62
+ 'https://github.com/Slyne/ctc_decoder.git')
63
+ sys.exit(1)
64
+
65
+ def get_args():
66
+ parser = argparse.ArgumentParser(description='recognize with your model')
67
+ parser.add_argument('--config', required=True, help='config file')
68
+ parser.add_argument('--test_data', required=True, help='test data file')
69
+ parser.add_argument('--data_type',
70
+ default='raw',
71
+ choices=['raw', 'shard'],
72
+ help='train and cv data type')
73
+ parser.add_argument('--gpu',
74
+ type=int,
75
+ default=-1,
76
+ help='gpu id for this rank, -1 for cpu')
77
+ parser.add_argument('--dict', required=True, help='dict file')
78
+ parser.add_argument('--encoder_onnx',
79
+ required=True,
80
+ help='encoder onnx file')
81
+ parser.add_argument('--decoder_onnx',
82
+ required=True,
83
+ help='decoder onnx file')
84
+ parser.add_argument('--result_file', required=True, help='asr result file')
85
+ parser.add_argument('--batch_size',
86
+ type=int,
87
+ default=32,
88
+ help='asr result file')
89
+ parser.add_argument('--mode',
90
+ choices=[
91
+ 'ctc_greedy_search', 'ctc_prefix_beam_search',
92
+ 'attention_rescoring'
93
+ ],
94
+ default='attention_rescoring',
95
+ help='decoding mode')
96
+ parser.add_argument('--bpe_model',
97
+ default=None,
98
+ type=str,
99
+ help='bpe model for english part')
100
+ parser.add_argument('--override_config',
101
+ action='append',
102
+ default=[],
103
+ help="override yaml config")
104
+ parser.add_argument('--fp16',
105
+ action='store_true',
106
+ help='whether to export fp16 model, default false')
107
+ args = parser.parse_args()
108
+ return args
109
+
110
+ def main():
111
+ args = get_args()
112
+ logging.basicConfig(level=logging.DEBUG,
113
+ format='%(asctime)s %(levelname)s %(message)s')
114
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
115
+
116
+ with open(args.config, 'r') as fin:
117
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
118
+ if len(args.override_config) > 0:
119
+ configs = override_config(configs, args.override_config)
120
+
121
+ reverse_weight = configs["model_conf"].get("reverse_weight", 0.0)
122
+ special_tokens = configs.get('tokenizer_conf', {}).get('special_tokens', None)
123
+ test_conf = copy.deepcopy(configs['dataset_conf'])
124
+ test_conf['filter_conf']['max_length'] = 102400
125
+ test_conf['filter_conf']['min_length'] = 0
126
+ test_conf['filter_conf']['token_max_length'] = 102400
127
+ test_conf['filter_conf']['token_min_length'] = 0
128
+ test_conf['filter_conf']['max_output_input_ratio'] = 102400
129
+ test_conf['filter_conf']['min_output_input_ratio'] = 0
130
+ test_conf['speed_perturb'] = False
131
+ test_conf['spec_aug'] = False
132
+ test_conf['spec_sub'] = False
133
+ test_conf['spec_trim'] = False
134
+ test_conf['shuffle'] = False
135
+ test_conf['sort'] = False
136
+ test_conf['fbank_conf']['dither'] = 0.0
137
+ test_conf['batch_conf']['batch_type'] = "static"
138
+ test_conf['batch_conf']['batch_size'] = args.batch_size
139
+
140
+ tokenizer = init_tokenizer(configs)
141
+ test_dataset = Dataset(args.data_type,
142
+ args.test_data,
143
+ tokenizer,
144
+ test_conf,
145
+ partition=False)
146
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
147
+
148
+ # Init asr model from configs
149
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
150
+ if use_cuda:
151
+ EP_list = ['CUDAExecutionProvider', 'CPUExecutionProvider']
152
+ else:
153
+ EP_list = ['CPUExecutionProvider']
154
+
155
+ encoder_ort_session = rt.InferenceSession(args.encoder_onnx,
156
+ providers=EP_list)
157
+ decoder_ort_session = None
158
+ if args.mode == "attention_rescoring":
159
+ decoder_ort_session = rt.InferenceSession(args.decoder_onnx,
160
+ providers=EP_list)
161
+
162
+ # Load dict
163
+ vocabulary = []
164
+ char_dict = {}
165
+ with open(args.dict, 'r') as fin:
166
+ for line in fin:
167
+ arr = line.strip().split()
168
+ assert len(arr) == 2
169
+ char_dict[int(arr[1])] = arr[0]
170
+ vocabulary.append(arr[0])
171
+
172
+ vocab_size = len(char_dict)
173
+ sos = (vocab_size - 1 if special_tokens is None else
174
+ special_tokens.get("<sos>", vocab_size - 1))
175
+ eos = (vocab_size - 1 if special_tokens is None else
176
+ special_tokens.get("<eos>", vocab_size - 1))
177
+
178
+ with torch.no_grad(), open(args.result_file, 'w') as fout:
179
+ for _, batch in enumerate(test_data_loader):
180
+ keys = batch['keys']
181
+ feats = batch['feats']
182
+ feats_lengths = batch['feats_lengths']
183
+ feats, feats_lengths = feats.numpy(), feats_lengths.numpy()
184
+ if args.fp16:
185
+ feats = feats.astype(np.float16)
186
+ ort_inputs = {
187
+ encoder_ort_session.get_inputs()[0].name: feats,
188
+ encoder_ort_session.get_inputs()[1].name: feats_lengths
189
+ }
190
+ ort_outs = encoder_ort_session.run(None, ort_inputs)
191
+ encoder_out, encoder_out_lens, ctc_log_probs, \
192
+ beam_log_probs, beam_log_probs_idx = ort_outs
193
+ beam_size = beam_log_probs.shape[-1]
194
+ batch_size = beam_log_probs.shape[0]
195
+ num_processes = min(multiprocessing.cpu_count(), batch_size)
196
+ if args.mode == 'ctc_greedy_search':
197
+ if beam_size != 1:
198
+ log_probs_idx = beam_log_probs_idx[:, :, 0]
199
+ batch_sents = []
200
+ for idx, seq in enumerate(log_probs_idx):
201
+ batch_sents.append(seq[0:encoder_out_lens[idx]].tolist())
202
+ hyps = map_batch(batch_sents, vocabulary, num_processes, True,
203
+ 0)
204
+ elif args.mode in ('ctc_prefix_beam_search',
205
+ "attention_rescoring"):
206
+ batch_log_probs_seq_list = beam_log_probs.tolist()
207
+ batch_log_probs_idx_list = beam_log_probs_idx.tolist()
208
+ batch_len_list = encoder_out_lens.tolist()
209
+ batch_log_probs_seq = []
210
+ batch_log_probs_ids = []
211
+ batch_start = [] # only effective in streaming deployment
212
+ batch_root = TrieVector()
213
+ root_dict = {}
214
+ for i in range(len(batch_len_list)):
215
+ num_sent = batch_len_list[i]
216
+ batch_log_probs_seq.append(
217
+ batch_log_probs_seq_list[i][0:num_sent])
218
+ batch_log_probs_ids.append(
219
+ batch_log_probs_idx_list[i][0:num_sent])
220
+ root_dict[i] = PathTrie()
221
+ batch_root.append(root_dict[i])
222
+ batch_start.append(True)
223
+ score_hyps = ctc_beam_search_decoder_batch(
224
+ batch_log_probs_seq, batch_log_probs_ids, batch_root,
225
+ batch_start, beam_size, num_processes, 0, -2, 0.99999)
226
+ if args.mode == 'ctc_prefix_beam_search':
227
+ hyps = []
228
+ for cand_hyps in score_hyps:
229
+ hyps.append(cand_hyps[0][1])
230
+ hyps = map_batch(hyps, vocabulary, num_processes, False, 0)
231
+ if args.mode == 'attention_rescoring':
232
+ ctc_score, all_hyps = [], []
233
+ max_len = 0
234
+ for hyps in score_hyps:
235
+ cur_len = len(hyps)
236
+ if len(hyps) < beam_size:
237
+ hyps += (beam_size - cur_len) * [(-float("INF"),
238
+ (0, ))]
239
+ cur_ctc_score = []
240
+ for hyp in hyps:
241
+ cur_ctc_score.append(hyp[0])
242
+ all_hyps.append(list(hyp[1]))
243
+ if len(hyp[1]) > max_len:
244
+ max_len = len(hyp[1])
245
+ ctc_score.append(cur_ctc_score)
246
+ if args.fp16:
247
+ ctc_score = np.array(ctc_score, dtype=np.float16)
248
+ else:
249
+ ctc_score = np.array(ctc_score, dtype=np.float32)
250
+ hyps_pad_sos_eos = np.ones(
251
+ (batch_size, beam_size, max_len + 2),
252
+ dtype=np.int64) * IGNORE_ID
253
+ r_hyps_pad_sos_eos = np.ones(
254
+ (batch_size, beam_size, max_len + 2),
255
+ dtype=np.int64) * IGNORE_ID
256
+ hyps_lens_sos = np.ones((batch_size, beam_size),
257
+ dtype=np.int32)
258
+ k = 0
259
+ for i in range(batch_size):
260
+ for j in range(beam_size):
261
+ cand = all_hyps[k]
262
+ l = len(cand) + 2
263
+ hyps_pad_sos_eos[i][j][0:l] = [sos] + cand + [eos]
264
+ r_hyps_pad_sos_eos[i][j][0:l] = [sos] + cand[::-1] + [
265
+ eos
266
+ ]
267
+ hyps_lens_sos[i][j] = len(cand) + 1
268
+ k += 1
269
+ decoder_ort_inputs = {
270
+ decoder_ort_session.get_inputs()[0].name: encoder_out,
271
+ decoder_ort_session.get_inputs()[1].name: encoder_out_lens,
272
+ decoder_ort_session.get_inputs()[2].name: hyps_pad_sos_eos,
273
+ decoder_ort_session.get_inputs()[3].name: hyps_lens_sos,
274
+ decoder_ort_session.get_inputs()[-1].name: ctc_score
275
+ }
276
+ if reverse_weight > 0:
277
+ r_hyps_pad_sos_eos_name = decoder_ort_session.get_inputs(
278
+ )[4].name
279
+ decoder_ort_inputs[
280
+ r_hyps_pad_sos_eos_name] = r_hyps_pad_sos_eos
281
+ best_index = decoder_ort_session.run(None,
282
+ decoder_ort_inputs)[0]
283
+ best_sents = []
284
+ k = 0
285
+ for idx in best_index:
286
+ cur_best_sent = all_hyps[k:k + beam_size][idx]
287
+ best_sents.append(cur_best_sent)
288
+ k += beam_size
289
+ hyps = map_batch(best_sents, vocabulary, num_processes)
290
+
291
+ for i, key in enumerate(keys):
292
+ content = hyps[i]
293
+ logging.info('{} {}'.format(key, content))
294
+ fout.write('{} {}\n'.format(key, content))
295
+
296
+ if __name__ == '__main__':
297
+ main()
wenet/bin/train.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import datetime
19
+ import logging
20
+ import os
21
+ import random
22
+
23
+ import numpy as np
24
+ import yaml
25
+ import torch
26
+
27
+ import torch.distributed as dist
28
+
29
+ from torch.distributed.elastic.multiprocessing.errors import record
30
+ from wenet.utils.common import lrs_to_str, TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu
31
+
32
+ from wenet.utils.executor import Executor
33
+ from wenet.utils.config import override_config
34
+ from wenet.utils.init_model import init_model
35
+ from wenet.utils.init_tokenizer import init_tokenizer
36
+ from wenet.utils.train_utils import (
37
+ add_fsdp_args, add_model_args, add_dataset_args, add_ddp_args,
38
+ add_deepspeed_args, add_trace_args, init_distributed,
39
+ init_dataset_and_dataloader, check_modify_and_save_config,
40
+ init_optimizer_and_scheduler, init_scaler, trace_and_print_model,
41
+ wrap_cuda_model, init_summarywriter, save_model, log_per_epoch,
42
+ add_lora_args, reinit_lora)
43
+ from gxl_ai_utils.utils import utils_file
44
+
45
+ try:
46
+ import torch_npu
47
+
48
+ torch_npu.npu.conv.allow_hf32 = False
49
+ # import deepspeed_npu
50
+ from torch_npu.npu import amp
51
+ from torch_npu.contrib import transfer_to_npu
52
+ except ImportError:
53
+ utils_file.logging_warning(
54
+ "torch_npu is not installed, please install torch_npu first if you want to use torch_npu")
55
+ torch.backends.cudnn.allow_tf32 = False
56
+ torch.backends.cuda.matmul.allow_tf32 = False
57
+
58
+ from msprobe.pytorch import seed_all
59
+ import gc
60
+
61
+ gc.set_threshold(700, 10, 10000) # python gc阈值设置
62
+
63
+
64
+ # import deepspeed_npu
65
+ def get_args():
66
+ parser = argparse.ArgumentParser(description='training your network')
67
+ parser.add_argument('--train_engine',
68
+ default='torch_ddp',
69
+ choices=['torch_ddp', 'torch_fsdp', 'deepspeed'],
70
+ help='Engine for paralleled training')
71
+ # set default value of device to "cuda", avoiding the modify of original scripts
72
+ parser.add_argument('--device',
73
+ type=str,
74
+ default='cuda',
75
+ choices=["cpu", "npu", "cuda"],
76
+ help='accelerator for training')
77
+ # load deepspeed checkpoint
78
+ parser.add_argument('--load_dir',
79
+ type=str,
80
+ default=None)
81
+ parser.add_argument('--ckpt_id',
82
+ type=str,
83
+ default=None)
84
+ parser = add_model_args(parser)
85
+ parser = add_dataset_args(parser)
86
+ parser = add_ddp_args(parser)
87
+ parser = add_lora_args(parser)
88
+ parser = add_deepspeed_args(parser)
89
+ parser = add_fsdp_args(parser)
90
+ parser = add_trace_args(parser)
91
+ args = parser.parse_args()
92
+ if args.train_engine == "deepspeed":
93
+ args.deepspeed = True
94
+ assert args.deepspeed_config is not None
95
+ return args
96
+
97
+
98
+ # NOTE(xcsong): On worker errors, this recod tool will summarize the
99
+ # details of the error (e.g. time, rank, host, pid, traceback, etc).
100
+ @record
101
+ def main():
102
+ args = get_args()
103
+ logging.basicConfig(level=logging.DEBUG,
104
+ format='%(asctime)s %(levelname)s %(message)s')
105
+
106
+ # Set random seed
107
+ torch.manual_seed(777)
108
+ random.seed(777)
109
+ np.random.seed(777)
110
+ utils_file.logging_info('开始严格seed')
111
+ seed_all(777)
112
+ utils_file.logging_info('结束严格seed')
113
+ logging.info('Random seed set to {}'.format(777))
114
+
115
+ # Read config
116
+ with open(args.config, 'r') as fin:
117
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
118
+ if len(args.override_config) > 0:
119
+ configs = override_config(configs, args.override_config)
120
+
121
+ # init tokenizer
122
+ tokenizer = init_tokenizer(configs)
123
+
124
+ # Init env for ddp OR deepspeed
125
+ _, _, rank = init_distributed(args)
126
+
127
+ # Init asr model from configs
128
+ model, configs = init_model(args, configs)
129
+
130
+ # Get dataset & dataloader
131
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
132
+ init_dataset_and_dataloader(args, configs, tokenizer)
133
+
134
+ # Do some sanity checks and save config to arsg.model_dir
135
+ configs = check_modify_and_save_config(args, configs,
136
+ tokenizer.symbol_table)
137
+
138
+ if hasattr(args, 'lora_reinit') and args.lora_reinit:
139
+ reinit_lora(model, args, configs, tokenizer)
140
+
141
+ # Check model is jitable & print model archtectures
142
+ trace_and_print_model(args, model)
143
+
144
+ # Tensorboard summary
145
+ writer = init_summarywriter(args)
146
+
147
+ # Dispatch model from cpu to gpu
148
+ model, device = wrap_cuda_model(args, model, configs)
149
+
150
+ # Get optimizer & scheduler
151
+ model, optimizer, scheduler = init_optimizer_and_scheduler(
152
+ args, configs, model)
153
+
154
+ # Load deepspeed checkpoint
155
+ if args.load_dir is not None and \
156
+ args.ckpt_id is not None:
157
+ _, client_sd = model.load_checkpoint(args.load_dir, args.ckpt_id)
158
+
159
+ # Save checkpoints
160
+ # save_model(model,
161
+ # info_dict={
162
+ # "save_time":
163
+ # datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'),
164
+ # "tag":
165
+ # "init",
166
+ # **configs
167
+ # })
168
+
169
+ # Get executor
170
+ tag = configs["init_infos"].get("tag", "init")
171
+ executor = Executor(global_step=configs["init_infos"].get('step', -1),
172
+ device=device)
173
+
174
+ # Init scaler, used for pytorch amp mixed precision training
175
+ scaler = init_scaler(args)
176
+
177
+ # Start training loop
178
+ start_epoch = configs["init_infos"].get('epoch', 0) + int("epoch_" in tag)
179
+ # if save_interval in configs, steps mode else epoch mode
180
+ end_epoch = configs.get('max_epoch', 100)
181
+ assert start_epoch <= end_epoch
182
+ configs.pop("init_infos", None)
183
+ final_epoch = None
184
+ for epoch in range(start_epoch, end_epoch):
185
+ configs['epoch'] = epoch
186
+
187
+ lrs = [group['lr'] for group in optimizer.param_groups]
188
+ logging.info('Epoch {} Step {} TRAIN info lr {} rank {}'.format(
189
+ epoch, executor.step, lrs_to_str(lrs), rank))
190
+
191
+ dist.barrier(
192
+ ) # NOTE(xcsong): Ensure all ranks start Train at the same time.
193
+ # NOTE(xcsong): Why we need a new group? see `train_utils.py::wenet_join`
194
+ group_join = dist.new_group( # fix by zhaoyi for 多机训练
195
+ backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
196
+ # group_join = None
197
+ executor.train(model, optimizer, scheduler, train_data_loader,
198
+ cv_data_loader, writer, configs, scaler, group_join)
199
+ # dist.destroy_process_group(group_join)
200
+
201
+ dist.barrier(
202
+ ) # NOTE(xcsong): Ensure all ranks start CV at the same time.
203
+ loss_dict = executor.cv(model, cv_data_loader, configs)
204
+ info_dict = {
205
+ 'epoch': epoch,
206
+ 'lrs': [group['lr'] for group in optimizer.param_groups],
207
+ 'step': executor.step,
208
+ "loss_dict": loss_dict,
209
+ 'save_time': datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'),
210
+ 'tag': "epoch_{}".format(epoch),
211
+ 'loss_dict': loss_dict,
212
+ **configs
213
+ }
214
+ # epoch cv: tensorboard && log
215
+ log_per_epoch(writer, info_dict=info_dict)
216
+ save_model(model, info_dict=info_dict)
217
+
218
+ final_epoch = epoch
219
+
220
+ if final_epoch is not None and rank == 0:
221
+ final_model_path = os.path.join(args.model_dir, 'final.pt')
222
+ os.remove(final_model_path) if os.path.exists(
223
+ final_model_path) else None
224
+ os.symlink('{}.pt'.format(final_epoch), final_model_path)
225
+ writer.close()
226
+ dist.barrier(
227
+ ) # NOTE(yktian): Ensure all ranks end Train before destroy process group.
228
+ dist.destroy_process_group()
229
+
230
+
231
+ if __name__ == '__main__':
232
+ main()
wenet/branchformer/__init__.py ADDED
File without changes
wenet/branchformer/cgmlp.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Yifan Peng (Carnegie Mellon University)
2
+ # 2023 Voicecomm Inc (Kai Li)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """MLP with convolutional gating (cgMLP) definition.
17
+
18
+ References:
19
+ https://openreview.net/forum?id=RA-zVvZLYIy
20
+ https://arxiv.org/abs/2105.08050
21
+
22
+ """
23
+
24
+ from typing import Tuple
25
+ import torch
26
+ import torch.nn as nn
27
+ from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES
28
+
29
+
30
+ class ConvolutionalSpatialGatingUnit(torch.nn.Module):
31
+ """Convolutional Spatial Gating Unit (CSGU)."""
32
+
33
+ def __init__(
34
+ self,
35
+ size: int,
36
+ kernel_size: int,
37
+ dropout_rate: float,
38
+ use_linear_after_conv: bool,
39
+ gate_activation: str,
40
+ causal: bool = True,
41
+ ):
42
+ super().__init__()
43
+
44
+ # split input channels
45
+ n_channels = size // 2
46
+ self.norm = nn.LayerNorm(n_channels)
47
+ # self.lorder is used to distinguish if it's a causal convolution,
48
+ # if self.lorder > 0: it's a causal convolution, the input will be
49
+ # padded with self.lorder frames on the left in forward.
50
+ # else: it's a symmetrical convolution
51
+ if causal:
52
+ padding = 0
53
+ self.lorder = kernel_size - 1
54
+ else:
55
+ # kernel_size should be an odd number for none causal convolution
56
+ assert (kernel_size - 1) % 2 == 0
57
+ padding = (kernel_size - 1) // 2
58
+ self.lorder = 0
59
+ self.conv = torch.nn.Conv1d(
60
+ n_channels,
61
+ n_channels,
62
+ kernel_size,
63
+ 1,
64
+ padding,
65
+ groups=n_channels,
66
+ )
67
+ if use_linear_after_conv:
68
+ self.linear = torch.nn.Linear(n_channels, n_channels)
69
+ else:
70
+ self.linear = None
71
+
72
+ if gate_activation == "identity":
73
+ self.act = torch.nn.Identity()
74
+ else:
75
+ self.act = WENET_ACTIVATION_CLASSES[gate_activation]()
76
+
77
+ self.dropout = torch.nn.Dropout(dropout_rate)
78
+
79
+ def espnet_initialization_fn(self):
80
+ torch.nn.init.normal_(self.conv.weight, std=1e-6)
81
+ torch.nn.init.ones_(self.conv.bias)
82
+ if self.linear is not None:
83
+ torch.nn.init.normal_(self.linear.weight, std=1e-6)
84
+ torch.nn.init.ones_(self.linear.bias)
85
+
86
+ def forward(
87
+ self, x: torch.Tensor, cache: torch.Tensor = torch.zeros((0, 0, 0))
88
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
89
+ """Forward method
90
+
91
+ Args:
92
+ x (torch.Tensor): (batch, time, channels)
93
+ cache (torch.Tensor): left context cache, it is only
94
+ used in causal convolution (#batch, channels, cache_t),
95
+ (0, 0, 0) meas fake cache.
96
+
97
+ Returns:
98
+ out (torch.Tensor): (batch, time, channels/2)
99
+ """
100
+
101
+ x_r, x_g = x.chunk(2, dim=-1)
102
+ # exchange the temporal dimension and the feature dimension
103
+ x_g = x_g.transpose(1, 2) # (#batch, channels, time)
104
+
105
+ if self.lorder > 0:
106
+ if cache.size(2) == 0: # cache_t == 0
107
+ x_g = nn.functional.pad(x_g, (self.lorder, 0), 'constant', 0.0)
108
+ else:
109
+ assert cache.size(0) == x_g.size(0) # equal batch
110
+ assert cache.size(1) == x_g.size(1) # equal channel
111
+ x_g = torch.cat((cache, x_g), dim=2)
112
+ assert (x_g.size(2) > self.lorder)
113
+ new_cache = x_g[:, :, -self.lorder:]
114
+ else:
115
+ # It's better we just return None if no cache is required,
116
+ # However, for JIT export, here we just fake one tensor instead of
117
+ # None.
118
+ new_cache = torch.zeros((0, 0, 0),
119
+ dtype=x_g.dtype,
120
+ device=x_g.device)
121
+
122
+ x_g = x_g.transpose(1, 2)
123
+ x_g = self.norm(x_g) # (N, T, D/2)
124
+ x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2) # (N, T, D/2)
125
+ if self.linear is not None:
126
+ x_g = self.linear(x_g)
127
+
128
+ x_g = self.act(x_g)
129
+ out = x_r * x_g # (N, T, D/2)
130
+ out = self.dropout(out)
131
+ return out, new_cache
132
+
133
+
134
+ class ConvolutionalGatingMLP(torch.nn.Module):
135
+ """Convolutional Gating MLP (cgMLP)."""
136
+
137
+ def __init__(
138
+ self,
139
+ size: int,
140
+ linear_units: int,
141
+ kernel_size: int,
142
+ dropout_rate: float,
143
+ use_linear_after_conv: bool,
144
+ gate_activation: str,
145
+ causal: bool = True,
146
+ ):
147
+ super().__init__()
148
+
149
+ self.channel_proj1 = torch.nn.Sequential(
150
+ torch.nn.Linear(size, linear_units), torch.nn.GELU())
151
+ self.csgu = ConvolutionalSpatialGatingUnit(
152
+ size=linear_units,
153
+ kernel_size=kernel_size,
154
+ dropout_rate=dropout_rate,
155
+ use_linear_after_conv=use_linear_after_conv,
156
+ gate_activation=gate_activation,
157
+ causal=causal,
158
+ )
159
+ self.channel_proj2 = torch.nn.Linear(linear_units // 2, size)
160
+
161
+ def forward(
162
+ self,
163
+ x: torch.Tensor,
164
+ mask: torch.Tensor,
165
+ cache: torch.Tensor = torch.zeros((0, 0, 0))
166
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
167
+ """Forward method
168
+
169
+ Args:
170
+ x (torch.Tensor): (batch, time, channels)
171
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
172
+ (0, 0, 0) means fake mask. Not used yet
173
+ cache (torch.Tensor): left context cache, it is only
174
+ used in causal convolution (#batch, channels, cache_t),
175
+ (0, 0, 0) meas fake cache.
176
+
177
+ Returns:
178
+ out (torch.Tensor): (batch, time, channels/2)
179
+ """
180
+
181
+ xs_pad = x
182
+
183
+ # size -> linear_units
184
+ xs_pad = self.channel_proj1(xs_pad)
185
+
186
+ # linear_units -> linear_units/2
187
+ xs_pad, new_cnn_cache = self.csgu(xs_pad, cache)
188
+
189
+ # linear_units/2 -> size
190
+ xs_pad = self.channel_proj2(xs_pad)
191
+
192
+ out = xs_pad
193
+
194
+ return out, new_cnn_cache
wenet/branchformer/encoder.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Yifan Peng (Carnegie Mellon University)
2
+ # 2023 Voicecomm Inc (Kai Li)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Encoder definition."""
17
+
18
+ import torch
19
+
20
+ from typing import List, Optional, Union
21
+
22
+ from wenet.branchformer.encoder_layer import BranchformerEncoderLayer
23
+ from wenet.branchformer.cgmlp import ConvolutionalGatingMLP
24
+ from wenet.transformer.encoder import BaseEncoder
25
+ from wenet.utils.class_utils import (
26
+ WENET_ATTENTION_CLASSES, )
27
+
28
+
29
+ class BranchformerEncoder(BaseEncoder):
30
+ """Branchformer encoder module."""
31
+
32
+ def __init__(
33
+ self,
34
+ input_size: int,
35
+ output_size: int = 256,
36
+ use_attn: bool = True,
37
+ attention_heads: int = 4,
38
+ selfattention_layer_type: str = "rel_selfattn",
39
+ pos_enc_layer_type: str = "rel_pos",
40
+ use_cgmlp: bool = True,
41
+ cgmlp_linear_units: int = 2048,
42
+ cgmlp_conv_kernel: int = 31,
43
+ use_linear_after_conv: bool = False,
44
+ gate_activation: str = "identity",
45
+ merge_method: str = "concat",
46
+ cgmlp_weight: Union[float, List[float]] = 0.5,
47
+ attn_branch_drop_rate: Union[float, List[float]] = 0.0,
48
+ num_blocks: int = 12,
49
+ dropout_rate: float = 0.1,
50
+ positional_dropout_rate: float = 0.1,
51
+ attention_dropout_rate: float = 0.0,
52
+ input_layer: str = "conv2d",
53
+ stochastic_depth_rate: Union[float, List[float]] = 0.0,
54
+ static_chunk_size: int = 0,
55
+ use_dynamic_chunk: bool = False,
56
+ global_cmvn: torch.nn.Module = None,
57
+ use_dynamic_left_chunk: bool = False,
58
+ causal: bool = False,
59
+ query_bias: bool = True,
60
+ key_bias: bool = True,
61
+ value_bias: bool = True,
62
+ gradient_checkpointing: bool = False,
63
+ use_sdpa: bool = False,
64
+ layer_norm_type: str = 'layer_norm',
65
+ norm_eps: float = 1e-5,
66
+ n_kv_head: Optional[int] = None,
67
+ head_dim: Optional[int] = None,
68
+ ):
69
+ super().__init__(input_size, output_size, attention_heads,
70
+ cgmlp_linear_units, num_blocks, dropout_rate,
71
+ positional_dropout_rate, attention_dropout_rate,
72
+ input_layer, pos_enc_layer_type, True,
73
+ static_chunk_size, use_dynamic_chunk, global_cmvn,
74
+ use_dynamic_left_chunk, gradient_checkpointing,
75
+ use_sdpa, layer_norm_type, norm_eps)
76
+
77
+ encoder_selfattn_layer_args = (
78
+ attention_heads,
79
+ output_size,
80
+ attention_dropout_rate,
81
+ query_bias,
82
+ key_bias,
83
+ value_bias,
84
+ use_sdpa,
85
+ n_kv_head,
86
+ head_dim,
87
+ )
88
+
89
+ cgmlp_layer = ConvolutionalGatingMLP
90
+ cgmlp_layer_args = (
91
+ output_size,
92
+ cgmlp_linear_units,
93
+ cgmlp_conv_kernel,
94
+ dropout_rate,
95
+ use_linear_after_conv,
96
+ gate_activation,
97
+ causal,
98
+ )
99
+
100
+ if isinstance(stochastic_depth_rate, float):
101
+ stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
102
+ if len(stochastic_depth_rate) != num_blocks:
103
+ raise ValueError(
104
+ f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
105
+ f"should be equal to num_blocks ({num_blocks})")
106
+
107
+ if isinstance(cgmlp_weight, float):
108
+ cgmlp_weight = [cgmlp_weight] * num_blocks
109
+ if len(cgmlp_weight) != num_blocks:
110
+ raise ValueError(
111
+ f"Length of cgmlp_weight ({len(cgmlp_weight)}) should be equal to "
112
+ f"num_blocks ({num_blocks})")
113
+
114
+ if isinstance(attn_branch_drop_rate, float):
115
+ attn_branch_drop_rate = [attn_branch_drop_rate] * num_blocks
116
+ if len(attn_branch_drop_rate) != num_blocks:
117
+ raise ValueError(
118
+ f"Length of attn_branch_drop_rate ({len(attn_branch_drop_rate)}) "
119
+ f"should be equal to num_blocks ({num_blocks})")
120
+
121
+ self.encoders = LayerDropModuleList(
122
+ p=stochastic_depth_rate,
123
+ modules=[
124
+ BranchformerEncoderLayer(
125
+ output_size,
126
+ WENET_ATTENTION_CLASSES[selfattention_layer_type](
127
+ *encoder_selfattn_layer_args) if use_attn else None,
128
+ cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None,
129
+ dropout_rate,
130
+ merge_method,
131
+ cgmlp_weight[lnum],
132
+ attn_branch_drop_rate[lnum],
133
+ stochastic_depth_rate[lnum],
134
+ ) for lnum in range(num_blocks)
135
+ ])
136
+
137
+
138
+ # modify from : https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/layer_drop.py # noqa
139
+ class LayerDropModuleList(torch.nn.ModuleList):
140
+ """
141
+ A LayerDrop implementation based on :class:`torch.nn.ModuleList`.
142
+
143
+ We refresh the choice of which layers to drop every time we iterate
144
+ over the LayerDropModuleList instance. During evaluation we always
145
+ iterate over all layers.
146
+
147
+ Usage::
148
+
149
+ layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3])
150
+ for layer in layers: # this might iterate over layers 1 and 3
151
+ x = layer(x)
152
+ for layer in layers: # this might iterate over all layers
153
+ x = layer(x)
154
+ for layer in layers: # this might not iterate over any layers
155
+ x = layer(x)
156
+
157
+ Args:
158
+ p (float): probability of dropping out each layer
159
+ modules (iterable, optional): an iterable of modules to add
160
+
161
+ Limitations:
162
+ 1 can work with ddp when layer's gradient checkpoint disabled
163
+ 2 can't work with ddp when layer's gradient checkpoint enables
164
+ 3 can work with fsdp
165
+ 4 can work with deepspeed
166
+ """
167
+
168
+ def __init__(self, p: List[float], modules=None):
169
+ super().__init__(modules)
170
+ assert len(p) == len(self)
171
+ self.p = p
172
+
173
+ def __iter__(self):
174
+ dropout_probs = torch.empty(len(self)).uniform_()
175
+ for i, m in enumerate(super().__iter__()):
176
+ if not self.training or (dropout_probs[i] > self.p[i]):
177
+ yield m
wenet/branchformer/encoder_layer.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Yifan Peng (Carnegie Mellon University)
2
+ # 2023 Voicecomm Inc (Kai Li)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """BranchformerEncoderLayer definition."""
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from typing import Optional, Tuple
21
+
22
+ from wenet.transformer.attention import T_CACHE
23
+
24
+
25
+ class BranchformerEncoderLayer(torch.nn.Module):
26
+ """Branchformer encoder layer module.
27
+
28
+ Args:
29
+ size (int): model dimension
30
+ attn: standard self-attention or efficient attention, optional
31
+ cgmlp: ConvolutionalGatingMLP, optional
32
+ dropout_rate (float): dropout probability
33
+ merge_method (str): concat, learned_ave, fixed_ave
34
+ cgmlp_weight (float): weight of the cgmlp branch, between 0 and 1,
35
+ used if merge_method is fixed_ave
36
+ attn_branch_drop_rate (float): probability of dropping the attn branch,
37
+ used if merge_method is learned_ave
38
+ stochastic_depth_rate (float): stochastic depth probability
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ size: int,
44
+ attn: Optional[torch.nn.Module],
45
+ cgmlp: Optional[torch.nn.Module],
46
+ dropout_rate: float,
47
+ merge_method: str,
48
+ cgmlp_weight: float = 0.5,
49
+ attn_branch_drop_rate: float = 0.0,
50
+ stochastic_depth_rate: float = 0.0,
51
+ ):
52
+ super().__init__()
53
+ assert (attn is not None) or (
54
+ cgmlp is not None), "At least one branch should be valid"
55
+
56
+ self.size = size
57
+ self.attn = attn
58
+ self.cgmlp = cgmlp
59
+ self.merge_method = merge_method
60
+ self.cgmlp_weight = cgmlp_weight
61
+ self.attn_branch_drop_rate = attn_branch_drop_rate
62
+ self.stochastic_depth_rate = stochastic_depth_rate
63
+ self.use_two_branches = (attn is not None) and (cgmlp is not None)
64
+
65
+ if attn is not None:
66
+ self.norm_mha = nn.LayerNorm(size) # for the MHA module
67
+ if cgmlp is not None:
68
+ self.norm_mlp = nn.LayerNorm(size) # for the MLP module
69
+ self.norm_final = nn.LayerNorm(
70
+ size) # for the final output of the block
71
+
72
+ self.dropout = torch.nn.Dropout(dropout_rate)
73
+
74
+ # # attention-based pooling for two branches
75
+ self.pooling_proj1 = torch.nn.Linear(size, 1)
76
+ self.pooling_proj2 = torch.nn.Linear(size, 1)
77
+
78
+ # # linear projections for calculating merging weights
79
+ self.weight_proj1 = torch.nn.Linear(size, 1)
80
+ self.weight_proj2 = torch.nn.Linear(size, 1)
81
+
82
+ if self.use_two_branches:
83
+ if self.merge_method == "concat":
84
+ self.merge_proj = torch.nn.Linear(size + size, size)
85
+
86
+ elif self.merge_method == "learned_ave":
87
+ # linear projection after weighted average
88
+ self.merge_proj = torch.nn.Linear(size, size)
89
+
90
+ elif self.merge_method == "fixed_ave":
91
+ assert (0.0 <= cgmlp_weight <=
92
+ 1.0), "cgmlp weight should be between 0.0 and 1.0"
93
+
94
+ # remove the other branch if only one branch is used
95
+ if cgmlp_weight == 0.0:
96
+ self.use_two_branches = False
97
+ self.cgmlp = None
98
+ self.norm_mlp = None
99
+ elif cgmlp_weight == 1.0:
100
+ self.use_two_branches = False
101
+ self.attn = None
102
+ self.norm_mha = None
103
+
104
+ # linear projection after weighted average
105
+ self.merge_proj = torch.nn.Linear(size, size)
106
+ else:
107
+ raise ValueError(f"unknown merge method: {merge_method}")
108
+ else:
109
+ self.merge_proj = torch.nn.Identity()
110
+
111
+ def _forward(
112
+ self,
113
+ x: torch.Tensor,
114
+ mask: torch.Tensor,
115
+ pos_emb: torch.Tensor,
116
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
117
+ att_cache: T_CACHE = (torch.zeros(
118
+ (0, 0, 0, 0)), torch.zeros(0, 0, 0, 0)),
119
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
120
+ stoch_layer_coeff: float = 1.0
121
+ ) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE, torch.Tensor]:
122
+ # Two branches
123
+ x1 = x
124
+ x2 = x
125
+
126
+ # Branch 1: multi-headed attention module
127
+ if self.attn is not None:
128
+ x1 = self.norm_mha(x1)
129
+ x_att, new_att_cache = self.attn(x1, x1, x1, mask, pos_emb,
130
+ att_cache)
131
+ x1 = self.dropout(x_att)
132
+
133
+ # Branch 2: convolutional gating mlp
134
+ # Fake new cnn cache here, and then change it in conv_module
135
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
136
+ if self.cgmlp is not None:
137
+ x2 = self.norm_mlp(x2)
138
+ x2, new_cnn_cache = self.cgmlp(x2, mask_pad, cnn_cache)
139
+ x2 = self.dropout(x2)
140
+
141
+ # Merge two branches
142
+ if self.use_two_branches:
143
+ if self.merge_method == "concat":
144
+ x = x + stoch_layer_coeff * self.dropout(
145
+ self.merge_proj(torch.cat([x1, x2], dim=-1)))
146
+ elif self.merge_method == "learned_ave":
147
+ if (self.training and self.attn_branch_drop_rate > 0
148
+ and torch.rand(1).item() < self.attn_branch_drop_rate):
149
+ # Drop the attn branch
150
+ w1, w2 = torch.tensor(0.0), torch.tensor(1.0)
151
+ else:
152
+ # branch1
153
+ score1 = (self.pooling_proj1(x1).transpose(1, 2) /
154
+ self.size**0.5)
155
+ score1 = score1.masked_fill(mask_pad.eq(0), -float('inf'))
156
+ score1 = torch.softmax(score1, dim=-1).masked_fill(
157
+ mask_pad.eq(0), 0.0)
158
+
159
+ pooled1 = torch.matmul(score1,
160
+ x1).squeeze(1) # (batch, size)
161
+ weight1 = self.weight_proj1(pooled1) # (batch, 1)
162
+
163
+ # branch2
164
+ score2 = (self.pooling_proj2(x2).transpose(1, 2) /
165
+ self.size**0.5)
166
+ score2 = score2.masked_fill(mask_pad.eq(0), -float('inf'))
167
+ score2 = torch.softmax(score2, dim=-1).masked_fill(
168
+ mask_pad.eq(0), 0.0)
169
+
170
+ pooled2 = torch.matmul(score2,
171
+ x2).squeeze(1) # (batch, size)
172
+ weight2 = self.weight_proj2(pooled2) # (batch, 1)
173
+
174
+ # normalize weights of two branches
175
+ merge_weights = torch.softmax(torch.cat([weight1, weight2],
176
+ dim=-1),
177
+ dim=-1) # (batch, 2)
178
+ merge_weights = merge_weights.unsqueeze(-1).unsqueeze(
179
+ -1) # (batch, 2, 1, 1)
180
+ w1, w2 = merge_weights[:,
181
+ 0], merge_weights[:,
182
+ 1] # (batch, 1, 1)
183
+
184
+ x = x + stoch_layer_coeff * self.dropout(
185
+ self.merge_proj(w1 * x1 + w2 * x2))
186
+ elif self.merge_method == "fixed_ave":
187
+ x = x + stoch_layer_coeff * self.dropout(
188
+ self.merge_proj((1.0 - self.cgmlp_weight) * x1 +
189
+ self.cgmlp_weight * x2))
190
+ else:
191
+ raise RuntimeError(
192
+ f"unknown merge method: {self.merge_method}")
193
+ else:
194
+ if self.attn is None:
195
+ x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x2))
196
+ elif self.cgmlp is None:
197
+ x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x1))
198
+ else:
199
+ # This should not happen
200
+ raise RuntimeError(
201
+ "Both branches are not None, which is unexpected.")
202
+
203
+ x = self.norm_final(x)
204
+
205
+ return x, mask, new_att_cache, new_cnn_cache
206
+
207
+ def forward(
208
+ self,
209
+ x: torch.Tensor,
210
+ mask: torch.Tensor,
211
+ pos_emb: torch.Tensor,
212
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
213
+ att_cache: T_CACHE = (torch.zeros(
214
+ (0, 0, 0, 0)), torch.zeros(0, 0, 0, 0)),
215
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
216
+ ) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE, torch.Tensor]:
217
+ """Compute encoded features.
218
+
219
+ Args:
220
+ x (Union[Tuple, torch.Tensor]): Input tensor (#batch, time, size).
221
+ mask (torch.Tensor): Mask tensor for the input (#batch, time, time).
222
+ pos_emb (torch.Tensor): positional encoding, must not be None
223
+ for BranchformerEncoderLayer.
224
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
225
+ (#batch, 1,time), (0, 0, 0) means fake mask.
226
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
227
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
228
+ cnn_cache (torch.Tensor): Convolution cache in cgmlp layer
229
+ (#batch=1, size, cache_t2)
230
+
231
+ Returns:
232
+ torch.Tensor: Output tensor (#batch, time, size).
233
+ torch.Tensor: Mask tensor (#batch, time, time.
234
+ torch.Tensor: att_cache tensor,
235
+ (#batch=1, head, cache_t1 + time, d_k * 2).
236
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
237
+ """
238
+
239
+ stoch_layer_coeff = 1.0
240
+ # with stochastic depth, residual connection `x + f(x)` becomes
241
+ # `x <- x + 1 / (1 - p) * f(x)` at training time.
242
+ if self.training:
243
+ stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
244
+ return self._forward(x, mask, pos_emb, mask_pad, att_cache, cnn_cache,
245
+ stoch_layer_coeff)
wenet/cli/__init__.py ADDED
File without changes
wenet/cli/hub.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Mddct([email protected])
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import requests
17
+ import sys
18
+ import tarfile
19
+ from pathlib import Path
20
+ from urllib.request import urlretrieve
21
+
22
+ import tqdm
23
+
24
+
25
+ def download(url: str, dest: str, only_child=True):
26
+ """ download from url to dest
27
+ """
28
+ assert os.path.exists(dest)
29
+ print('Downloading {} to {}'.format(url, dest))
30
+
31
+ def progress_hook(t):
32
+ last_b = [0]
33
+
34
+ def update_to(b=1, bsize=1, tsize=None):
35
+ if tsize not in (None, -1):
36
+ t.total = tsize
37
+ displayed = t.update((b - last_b[0]) * bsize)
38
+ last_b[0] = b
39
+ return displayed
40
+
41
+ return update_to
42
+
43
+ # *.tar.gz
44
+ name = url.split('?')[0].split('/')[-1]
45
+ tar_path = os.path.join(dest, name)
46
+ with tqdm.tqdm(unit='B',
47
+ unit_scale=True,
48
+ unit_divisor=1024,
49
+ miniters=1,
50
+ desc=(name)) as t:
51
+ urlretrieve(url,
52
+ filename=tar_path,
53
+ reporthook=progress_hook(t),
54
+ data=None)
55
+ t.total = t.n
56
+
57
+ with tarfile.open(tar_path) as f:
58
+ if not only_child:
59
+ f.extractall(dest)
60
+ else:
61
+ for tarinfo in f:
62
+ if "/" not in tarinfo.name:
63
+ continue
64
+ name = os.path.basename(tarinfo.name)
65
+ fileobj = f.extractfile(tarinfo)
66
+ with open(os.path.join(dest, name), "wb") as writer:
67
+ writer.write(fileobj.read())
68
+
69
+
70
+ class Hub(object):
71
+ """Hub for wenet pretrain runtime model
72
+ """
73
+ # TODO(Mddct): make assets class to support other language
74
+ Assets = {
75
+ # wenetspeech
76
+ "chinese": "wenetspeech_u2pp_conformer_libtorch.tar.gz",
77
+ # gigaspeech
78
+ "english": "gigaspeech_u2pp_conformer_libtorch.tar.gz",
79
+ # paraformer
80
+ "paraformer": "paraformer.tar.gz"
81
+ }
82
+
83
+ def __init__(self) -> None:
84
+ pass
85
+
86
+ @staticmethod
87
+ def get_model_by_lang(lang: str) -> str:
88
+ if lang not in Hub.Assets.keys():
89
+ print('ERROR: Unsupported language {} !!!'.format(lang))
90
+ sys.exit(1)
91
+
92
+ # NOTE(Mddct): model_dir structure
93
+ # Path.Home()/.wenet
94
+ # - chs
95
+ # - units.txt
96
+ # - final.zip
97
+ # - en
98
+ # - units.txt
99
+ # - final.zip
100
+ model = Hub.Assets[lang]
101
+ model_dir = os.path.join(Path.home(), ".wenet", lang)
102
+ if not os.path.exists(model_dir):
103
+ os.makedirs(model_dir)
104
+ # TODO(Mddct): model metadata
105
+ if set(["final.zip",
106
+ "units.txt"]).issubset(set(os.listdir(model_dir))):
107
+ return model_dir
108
+ # If not exist, download
109
+ response = requests.get(
110
+ "https://modelscope.cn/api/v1/datasets/wenet/wenet_pretrained_models/oss/tree" # noqa
111
+ )
112
+ model_info = next(data for data in response.json()["Data"]
113
+ if data["Key"] == model)
114
+ model_url = model_info['Url']
115
+ download(model_url, model_dir, only_child=True)
116
+ return model_dir
wenet/cli/model.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Binbin Zhang ([email protected])
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import torch
18
+ import torchaudio
19
+ import torchaudio.compliance.kaldi as kaldi
20
+
21
+ from wenet.cli.hub import Hub
22
+ from wenet.utils.ctc_utils import (force_align, gen_ctc_peak_time,
23
+ gen_timestamps_from_peak)
24
+ from wenet.utils.file_utils import read_symbol_table
25
+ from wenet.transformer.search import (attention_rescoring,
26
+ ctc_prefix_beam_search, DecodeResult)
27
+ from wenet.utils.context_graph import ContextGraph
28
+ from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu
29
+
30
+
31
+ class Model:
32
+
33
+ def __init__(self,
34
+ model_dir: str,
35
+ gpu: int = -1,
36
+ beam: int = 5,
37
+ context_path: str = None,
38
+ context_score: float = 6.0,
39
+ resample_rate: int = 16000):
40
+ model_path = os.path.join(model_dir, 'final.zip')
41
+ units_path = os.path.join(model_dir, 'units.txt')
42
+ self.model = torch.jit.load(model_path)
43
+ self.resample_rate = resample_rate
44
+ self.model.eval()
45
+ if gpu >= 0:
46
+ device = 'cuda:{}'.format(gpu)
47
+ else:
48
+ device = 'cpu'
49
+ self.device = torch.device(device)
50
+ self.model.to(device)
51
+ self.symbol_table = read_symbol_table(units_path)
52
+ self.char_dict = {v: k for k, v in self.symbol_table.items()}
53
+ self.beam = beam
54
+ if context_path is not None:
55
+ self.context_graph = ContextGraph(context_path,
56
+ self.symbol_table,
57
+ context_score=context_score)
58
+ else:
59
+ self.context_graph = None
60
+
61
+ def compute_feats(self, audio_file: str) -> torch.Tensor:
62
+ waveform, sample_rate = torchaudio.load(audio_file, normalize=False)
63
+ waveform = waveform.to(torch.float)
64
+ if sample_rate != self.resample_rate:
65
+ waveform = torchaudio.transforms.Resample(
66
+ orig_freq=sample_rate, new_freq=self.resample_rate)(waveform)
67
+ # NOTE (MengqingCao): complex dtype not supported in torch_npu.abs() now,
68
+ # thus, delay placing data on NPU after the calculation of fbank.
69
+ # revert me after complex dtype is supported.
70
+ if "npu" not in self.device.__str__():
71
+ waveform = waveform.to(self.device)
72
+ feats = kaldi.fbank(waveform,
73
+ num_mel_bins=80,
74
+ frame_length=25,
75
+ frame_shift=10,
76
+ energy_floor=0.0,
77
+ sample_frequency=self.resample_rate)
78
+ if "npu" in self.device.__str__():
79
+ feats = feats.to(self.device)
80
+ feats = feats.unsqueeze(0)
81
+ return feats
82
+
83
+ @torch.no_grad()
84
+ def _decode(self,
85
+ audio_file: str,
86
+ tokens_info: bool = False,
87
+ label: str = None) -> dict:
88
+ feats = self.compute_feats(audio_file)
89
+ encoder_out, _, _ = self.model.forward_encoder_chunk(feats, 0, -1)
90
+ encoder_lens = torch.tensor([encoder_out.size(1)],
91
+ dtype=torch.long,
92
+ device=encoder_out.device)
93
+ ctc_probs = self.model.ctc_activation(encoder_out)
94
+ if label is None:
95
+ ctc_prefix_results = ctc_prefix_beam_search(
96
+ ctc_probs,
97
+ encoder_lens,
98
+ self.beam,
99
+ context_graph=self.context_graph)
100
+ else: # force align mode, construct ctc prefix result from alignment
101
+ label_t = self.tokenize(label)
102
+ alignment = force_align(ctc_probs.squeeze(0),
103
+ torch.tensor(label_t, dtype=torch.long))
104
+ peaks = gen_ctc_peak_time(alignment)
105
+ ctc_prefix_results = [
106
+ DecodeResult(tokens=label_t,
107
+ score=0.0,
108
+ times=peaks,
109
+ nbest=[label_t],
110
+ nbest_scores=[0.0],
111
+ nbest_times=[peaks])
112
+ ]
113
+ rescoring_results = attention_rescoring(self.model, ctc_prefix_results,
114
+ encoder_out, encoder_lens, 0.3,
115
+ 0.5)
116
+ res = rescoring_results[0]
117
+ result = {}
118
+ result['text'] = ''.join([self.char_dict[x] for x in res.tokens])
119
+ result['confidence'] = res.confidence
120
+
121
+ if tokens_info:
122
+ frame_rate = self.model.subsampling_rate(
123
+ ) * 0.01 # 0.01 seconds per frame
124
+ max_duration = encoder_out.size(1) * frame_rate
125
+ times = gen_timestamps_from_peak(res.times, max_duration,
126
+ frame_rate, 1.0)
127
+ tokens_info = []
128
+ for i, x in enumerate(res.tokens):
129
+ tokens_info.append({
130
+ 'token': self.char_dict[x],
131
+ 'start': round(times[i][0], 3),
132
+ 'end': round(times[i][1], 3),
133
+ 'confidence': round(res.tokens_confidence[i], 2)
134
+ })
135
+ result['tokens'] = tokens_info
136
+ return result
137
+
138
+ def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict:
139
+ return self._decode(audio_file, tokens_info)
140
+
141
+ def tokenize(self, label: str):
142
+ # TODO(Binbin Zhang): Support BPE
143
+ tokens = []
144
+ for c in label:
145
+ if c == ' ':
146
+ c = "▁"
147
+ tokens.append(c)
148
+ token_list = []
149
+ for c in tokens:
150
+ if c in self.symbol_table:
151
+ token_list.append(self.symbol_table[c])
152
+ elif '<unk>' in self.symbol_table:
153
+ token_list.append(self.symbol_table['<unk>'])
154
+ return token_list
155
+
156
+ def align(self, audio_file: str, label: str) -> dict:
157
+ return self._decode(audio_file, True, label)
158
+
159
+
160
+ def load_model(language: str = None,
161
+ model_dir: str = None,
162
+ gpu: int = -1,
163
+ beam: int = 5,
164
+ context_path: str = None,
165
+ context_score: float = 6.0,
166
+ device: str = "cpu") -> Model:
167
+ if model_dir is None:
168
+ model_dir = Hub.get_model_by_lang(language)
169
+
170
+ if gpu != -1:
171
+ # remain the original usage of gpu
172
+ device = "cuda"
173
+ model = Model(model_dir, gpu, beam, context_path, context_score)
174
+ model.device = torch.device(device)
175
+ model.model.to(device)
176
+ return model
wenet/cli/paraformer_model.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torchaudio
5
+ import torchaudio.compliance.kaldi as kaldi
6
+
7
+ from wenet.cli.hub import Hub
8
+ from wenet.paraformer.search import (gen_timestamps_from_peak,
9
+ paraformer_greedy_search)
10
+ from wenet.text.paraformer_tokenizer import ParaformerTokenizer
11
+ from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu
12
+
13
+
14
+ class Paraformer:
15
+
16
+ def __init__(self, model_dir: str, resample_rate: int = 16000) -> None:
17
+
18
+ model_path = os.path.join(model_dir, 'final.zip')
19
+ units_path = os.path.join(model_dir, 'units.txt')
20
+ self.model = torch.jit.load(model_path)
21
+ self.resample_rate = resample_rate
22
+ self.device = torch.device("cpu")
23
+ self.tokenizer = ParaformerTokenizer(symbol_table=units_path)
24
+
25
+ def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict:
26
+ waveform, sample_rate = torchaudio.load(audio_file, normalize=False)
27
+ waveform = waveform.to(torch.float).to(self.device)
28
+ if sample_rate != self.resample_rate:
29
+ waveform = torchaudio.transforms.Resample(
30
+ orig_freq=sample_rate, new_freq=self.resample_rate)(waveform)
31
+ feats = kaldi.fbank(waveform,
32
+ num_mel_bins=80,
33
+ frame_length=25,
34
+ frame_shift=10,
35
+ energy_floor=0.0,
36
+ sample_frequency=self.resample_rate,
37
+ window_type="hamming")
38
+ feats = feats.unsqueeze(0)
39
+ feats_lens = torch.tensor([feats.size(1)],
40
+ dtype=torch.int64,
41
+ device=feats.device)
42
+
43
+ decoder_out, token_num, tp_alphas = self.model.forward_paraformer(
44
+ feats, feats_lens)
45
+ cif_peaks = self.model.forward_cif_peaks(tp_alphas, token_num)
46
+ res = paraformer_greedy_search(decoder_out, token_num, cif_peaks)[0]
47
+ result = {}
48
+ result['confidence'] = res.confidence
49
+ result['text'] = self.tokenizer.detokenize(res.tokens)[0]
50
+ if tokens_info:
51
+ tokens_info = []
52
+ times = gen_timestamps_from_peak(res.times,
53
+ num_frames=tp_alphas.size(1),
54
+ frame_rate=0.02)
55
+
56
+ for i, x in enumerate(res.tokens):
57
+ tokens_info.append({
58
+ 'token': self.tokenizer.char_dict[x],
59
+ 'start': round(times[i][0], 3),
60
+ 'end': round(times[i][1], 3),
61
+ 'confidence': round(res.tokens_confidence[i], 2)
62
+ })
63
+ result['tokens'] = tokens_info
64
+
65
+ return result
66
+
67
+ def align(self, audio_file: str, label: str) -> dict:
68
+ raise NotImplementedError("Align is currently not supported")
69
+
70
+
71
+ def load_model(model_dir: str = None,
72
+ gpu: int = -1,
73
+ device: str = "cpu") -> Paraformer:
74
+ if model_dir is None:
75
+ model_dir = Hub.get_model_by_lang('paraformer')
76
+ if gpu != -1:
77
+ # remain the original usage of gpu
78
+ device = "cuda"
79
+ paraformer = Paraformer(model_dir)
80
+ paraformer.device = torch.device(device)
81
+ paraformer.model.to(device)
82
+ return paraformer
wenet/cli/transcribe.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Binbin Zhang ([email protected])
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+
17
+ from wenet.cli.paraformer_model import load_model as load_paraformer
18
+ from wenet.cli.model import load_model
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser(description='')
23
+ parser.add_argument('audio_file', help='audio file to transcribe')
24
+ parser.add_argument('-l',
25
+ '--language',
26
+ choices=[
27
+ 'chinese',
28
+ 'english',
29
+ ],
30
+ default='chinese',
31
+ help='language type')
32
+ parser.add_argument('-m',
33
+ '--model_dir',
34
+ default=None,
35
+ help='specify your own model dir')
36
+ parser.add_argument('-g',
37
+ '--gpu',
38
+ type=int,
39
+ default='-1',
40
+ help='gpu id to decode, default is cpu.')
41
+ parser.add_argument('--device',
42
+ type=str,
43
+ default='cpu',
44
+ choices=["cpu", "npu", "cuda"],
45
+ help='accelerator to use')
46
+ parser.add_argument('-t',
47
+ '--show_tokens_info',
48
+ action='store_true',
49
+ help='whether to output token(word) level information'
50
+ ', such times/confidence')
51
+ parser.add_argument('--align',
52
+ action='store_true',
53
+ help='force align the input audio and transcript')
54
+ parser.add_argument('--label', type=str, help='the input label to align')
55
+ parser.add_argument('--paraformer',
56
+ action='store_true',
57
+ help='whether to use the best chinese model')
58
+ parser.add_argument('--beam', type=int, default=5, help="beam size")
59
+ parser.add_argument('--context_path',
60
+ type=str,
61
+ default=None,
62
+ help='context list file')
63
+ parser.add_argument('--context_score',
64
+ type=float,
65
+ default=6.0,
66
+ help='context score')
67
+ args = parser.parse_args()
68
+ return args
69
+
70
+
71
+ def main():
72
+ args = get_args()
73
+
74
+ if args.paraformer:
75
+ model = load_paraformer(args.model_dir, args.gpu, args.device)
76
+ else:
77
+ model = load_model(args.language, args.model_dir, args.gpu, args.beam,
78
+ args.context_path, args.context_score, args.device)
79
+ if args.align:
80
+ result = model.align(args.audio_file, args.label)
81
+ else:
82
+ result = model.transcribe(args.audio_file, args.show_tokens_info)
83
+ print(result)
84
+
85
+
86
+ if __name__ == "__main__":
87
+ main()
wenet/ctl_model/asr_model_ctl.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2023 NetEase Inc. (authors: Yuting Yang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet) and
16
+ # fairseq(https://github.com/facebookresearch/fairseq)
17
+
18
+ from typing import Dict, Optional
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from wenet.transformer.ctc import CTC
23
+ from wenet.transformer.decoder import TransformerDecoder
24
+ from wenet.ctl_model.encoder import TransformerEncoder
25
+ from wenet.transformer.asr_model import ASRModel
26
+ from wenet.utils.common import IGNORE_ID
27
+
28
+
29
+ class CTLModel(ASRModel):
30
+ """
31
+ Implementation of Interspeecch 2023 paper:
32
+ 'Enhancing the Unified Streaming and Non-streaming Model
33
+ with Contrastive Learning'
34
+ https://arxiv.org/abs/2306.00755
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ vocab_size: int,
40
+ encoder: TransformerEncoder,
41
+ decoder: TransformerDecoder,
42
+ ctc: CTC,
43
+ ctc_weight: float = 0.5,
44
+ ignore_id: int = IGNORE_ID,
45
+ reverse_weight: float = 0.0,
46
+ lsm_weight: float = 0.0,
47
+ length_normalized_loss: bool = False,
48
+ logit_temp: float = 0.1,
49
+ n_negatives: int = 0,
50
+ ctl_weight: float = 1,
51
+ special_tokens: dict = None,
52
+ ):
53
+ assert 0.0 <= ctc_weight <= 1.0, ctc_weight
54
+ super().__init__(vocab_size,
55
+ encoder,
56
+ decoder,
57
+ ctc,
58
+ ctc_weight,
59
+ ignore_id,
60
+ reverse_weight,
61
+ lsm_weight,
62
+ length_normalized_loss,
63
+ special_tokens=special_tokens)
64
+
65
+ # For CTL Loss
66
+ self.n_negatives = n_negatives
67
+ self.ctl_weight = ctl_weight
68
+ self.logit_temp = logit_temp
69
+
70
+ @torch.jit.unused
71
+ def forward(
72
+ self,
73
+ batch: dict,
74
+ device: torch.device,
75
+ ) -> Dict[str, Optional[torch.Tensor]]:
76
+
77
+ speech = batch['feats'].to(device)
78
+ speech_lengths = batch['feats_lengths'].to(device)
79
+ text = batch['target'].to(device)
80
+ text_lengths = batch['target_lengths'].to(device)
81
+ loss_full, encoder_out_full, _, _ = self.forward_full(
82
+ speech, speech_lengths, text, text_lengths)
83
+ loss_chunk, encoder_out, lens_chunk, encoder_mask = self.forward_chunk(
84
+ speech, speech_lengths, text, text_lengths)
85
+
86
+ ctl_loss = 0.0
87
+ if self.ctl_weight > 0 and self.n_negatives > 0:
88
+ num = encoder_out_full.size(1)
89
+ targets = encoder_out_full
90
+ src = encoder_out
91
+ negs, negs_idxs = self.sample_negatives(targets,
92
+ targets.size(1),
93
+ speech_lengths=lens_chunk)
94
+ ctl_loss = self.CTL(src, targets, negs, encoder_mask)
95
+
96
+ loss = loss_full + loss_chunk + self.ctl_weight * ctl_loss
97
+ return {
98
+ "loss": loss,
99
+ "loss_full": loss_full,
100
+ "loss_chunk": loss_chunk,
101
+ "loss_ctl": ctl_loss
102
+ }
103
+
104
+ def forward_full(
105
+ self,
106
+ speech: torch.Tensor,
107
+ speech_lengths: torch.Tensor,
108
+ text: torch.Tensor,
109
+ text_lengths: torch.Tensor,
110
+ ):
111
+ """Full context mode
112
+ Frontend + Encoder + Decoder + Calc loss
113
+
114
+ Args:
115
+ speech: (Batch, Length, ...)
116
+ speech_lengths: (Batch, )
117
+ text: (Batch, Length)
118
+ text_lengths: (Batch,)
119
+ """
120
+
121
+ assert text_lengths.dim() == 1, text_lengths.shape
122
+ # Check that batch_size is unified
123
+ assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
124
+ text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
125
+ text.shape, text_lengths.shape)
126
+ # 1. Encoder
127
+ encoder_out, encoder_mask = self.encoder.forward_full(
128
+ speech, speech_lengths)
129
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
130
+
131
+ # 2a. Attention-decoder branch
132
+ if self.ctc_weight != 1.0:
133
+ loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
134
+ text, text_lengths)
135
+ else:
136
+ loss_att = None
137
+
138
+ # 2b. CTC branch
139
+ if self.ctc_weight != 0.0:
140
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, text,
141
+ text_lengths)
142
+ else:
143
+ loss_ctc = None
144
+
145
+ if loss_ctc is None:
146
+ loss = loss_att
147
+ elif loss_att is None:
148
+ loss = loss_ctc
149
+ else:
150
+ loss = self.ctc_weight * loss_ctc[0] + (1 -
151
+ self.ctc_weight) * loss_att
152
+ return loss, encoder_out, encoder_out_lens, encoder_mask
153
+
154
+ def forward_chunk(
155
+ self,
156
+ speech: torch.Tensor,
157
+ speech_lengths: torch.Tensor,
158
+ text: torch.Tensor,
159
+ text_lengths: torch.Tensor,
160
+ ):
161
+ """Chunk-based context mode
162
+ Frontend + Encoder + Decoder + Calc loss
163
+
164
+ Args:
165
+ speech: (Batch, Length, ...)
166
+ speech_lengths: (Batch, )
167
+ text: (Batch, Length)
168
+ text_lengths: (Batch,)
169
+ """
170
+
171
+ assert text_lengths.dim() == 1, text_lengths.shape
172
+ # Check that batch_size is unified
173
+ assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
174
+ text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
175
+ text.shape, text_lengths.shape)
176
+ # 1. Encoder
177
+ encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
178
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
179
+
180
+ # 2a. Attention-decoder branch
181
+ if self.ctc_weight != 1.0:
182
+ loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
183
+ text, text_lengths)
184
+ else:
185
+ loss_att = None
186
+
187
+ # 2b. CTC branch
188
+ if self.ctc_weight != 0.0:
189
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, text,
190
+ text_lengths)
191
+ else:
192
+ loss_ctc = None
193
+
194
+ if loss_ctc is None:
195
+ loss = loss_att
196
+ elif loss_att is None:
197
+ loss = loss_ctc
198
+ else:
199
+ loss = self.ctc_weight * loss_ctc[0] + (1 -
200
+ self.ctc_weight) * loss_att
201
+ return loss, encoder_out, encoder_out_lens, encoder_mask
202
+
203
+ def sample_negatives(self, y, num, padding_count=0, speech_lengths=None):
204
+ if self.n_negatives == 0:
205
+ return y.new(0)
206
+ bsz, tsz, fsz = y.shape
207
+ y = y.reshape(-1, fsz) # BTC => (BxT)C
208
+
209
+ # FIXME: what happens if padding_count is specified?
210
+ high = tsz - (padding_count or 0)
211
+ with torch.no_grad():
212
+ assert high > 1, f"{bsz,tsz,fsz}"
213
+
214
+ if self.n_negatives > 0:
215
+ tszs = (torch.arange(num).unsqueeze(-1).expand(
216
+ -1, self.n_negatives).flatten())
217
+ if speech_lengths is not None:
218
+ neg_idxs = [
219
+ torch.randint(low=0,
220
+ high=speech_lengths[i].item() - 1,
221
+ size=(1, self.n_negatives * tsz))
222
+ for i in range(len(speech_lengths))
223
+ ]
224
+ neg_idxs = torch.cat(neg_idxs).reshape(
225
+ bsz, self.n_negatives * tsz)
226
+ else:
227
+ neg_idxs = torch.randint(low=0,
228
+ high=num - 1,
229
+ size=(bsz,
230
+ self.n_negatives * tsz))
231
+ neg_idxs[neg_idxs >= tszs] += 1
232
+
233
+ if self.n_negatives > 0:
234
+ neg_idxs = neg_idxs + (torch.arange(bsz).unsqueeze(1) * high)
235
+
236
+ negs = y[neg_idxs.view(-1)]
237
+ negs = negs.contiguous().view(bsz, num, self.n_negatives,
238
+ fsz).permute(2, 0, 1, 3) # to NxBxTxC
239
+ return negs, neg_idxs
240
+
241
+ def compute_preds(self, x, y, negatives):
242
+ neg_is_pos = (y == negatives).all(-1)
243
+ y = y.unsqueeze(0)
244
+ targets = torch.cat([y, negatives], dim=0)
245
+
246
+ logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1)
247
+ logits = logits / self.logit_temp
248
+ logits = logits.type_as(x)
249
+
250
+ if neg_is_pos.any():
251
+ if not hasattr(self, "_inftensor"):
252
+ self._inftensor = float("-inf")
253
+ # logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor)
254
+ logits[1:][neg_is_pos] = self._inftensor
255
+ logits = logits.transpose(0, 2)
256
+ logits = logits.transpose(0, 1)
257
+ logits = logits.reshape(-1, logits.size(-1))
258
+ return logits
259
+
260
+ def CTL(self, x, y, negs, mask=None):
261
+ # Step1: compute cosine similarity, shape [B*T, n_negatives+1]
262
+ logits = self.compute_preds(x, y, negs)
263
+
264
+ # Step2: target shape [B*T]
265
+ target = x.new_zeros(x.size(0) * x.size(1), dtype=torch.long)
266
+
267
+ # Step3: compute CTL loss
268
+ if mask is not None:
269
+ normalize_length = mask.sum()
270
+ bz, sz = mask.size(0), mask.size(-1)
271
+ mask = mask.squeeze(1).reshape(bz * sz).eq(0)
272
+ ce = F.cross_entropy(logits, target, reduction='none')
273
+ loss = ce.masked_fill(mask, 0).sum() / normalize_length
274
+ else:
275
+ loss = F.cross_entropy(logits, target)
276
+
277
+ return loss
wenet/ctl_model/encoder.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song ([email protected])
3
+ # 2023 NetEase Inc
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ """Encoder definition."""
18
+ from typing import Optional, Tuple
19
+
20
+ import torch
21
+
22
+ from wenet.utils.mask import make_pad_mask
23
+ from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder
24
+
25
+
26
+ class DualTransformerEncoder(TransformerEncoder):
27
+ """Transformer encoder module."""
28
+
29
+ def __init__(
30
+ self,
31
+ input_size: int,
32
+ output_size: int = 256,
33
+ attention_heads: int = 4,
34
+ linear_units: int = 2048,
35
+ num_blocks: int = 6,
36
+ dropout_rate: float = 0.1,
37
+ positional_dropout_rate: float = 0.1,
38
+ attention_dropout_rate: float = 0.0,
39
+ input_layer: str = "conv2d",
40
+ pos_enc_layer_type: str = "abs_pos",
41
+ normalize_before: bool = True,
42
+ static_chunk_size: int = 0,
43
+ use_dynamic_chunk: bool = False,
44
+ global_cmvn: torch.nn.Module = None,
45
+ use_dynamic_left_chunk: bool = False,
46
+ query_bias: bool = True,
47
+ key_bias: bool = True,
48
+ value_bias: bool = True,
49
+ activation_type: str = "relu",
50
+ gradient_checkpointing: bool = False,
51
+ use_sdpa: bool = False,
52
+ layer_norm_type: str = 'layer_norm',
53
+ norm_eps: float = 1e-5,
54
+ n_kv_head: Optional[int] = None,
55
+ head_dim: Optional[int] = None,
56
+ selfattention_layer_type: str = "selfattn",
57
+ mlp_type: str = 'position_wise_feed_forward',
58
+ mlp_bias: bool = True,
59
+ n_expert: int = 8,
60
+ n_expert_activated: int = 2,
61
+ ):
62
+ """ Construct DualTransformerEncoder
63
+ Support both the full context mode and the streaming mode separately
64
+ """
65
+ super().__init__(input_size, output_size, attention_heads,
66
+ linear_units, num_blocks, dropout_rate,
67
+ positional_dropout_rate, attention_dropout_rate,
68
+ input_layer, pos_enc_layer_type, normalize_before,
69
+ static_chunk_size, use_dynamic_chunk, global_cmvn,
70
+ use_dynamic_left_chunk, query_bias, key_bias,
71
+ value_bias, activation_type, gradient_checkpointing,
72
+ use_sdpa, layer_norm_type, norm_eps, n_kv_head,
73
+ head_dim, selfattention_layer_type, mlp_type,
74
+ mlp_bias, n_expert, n_expert_activated)
75
+
76
+ def forward_full(
77
+ self,
78
+ xs: torch.Tensor,
79
+ xs_lens: torch.Tensor,
80
+ decoding_chunk_size: int = 0,
81
+ num_decoding_left_chunks: int = -1,
82
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
83
+ T = xs.size(1)
84
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
85
+ if self.global_cmvn is not None:
86
+ xs = self.global_cmvn(xs)
87
+ xs, pos_emb, masks = self.embed(xs, masks)
88
+ mask_pad = masks # (B, 1, T/subsample_rate)
89
+ for layer in self.encoders:
90
+ xs, masks, _, _ = layer(xs, masks, pos_emb, mask_pad)
91
+ if self.normalize_before:
92
+ xs = self.after_norm(xs)
93
+ return xs, masks
94
+
95
+
96
+ class DualConformerEncoder(ConformerEncoder):
97
+ """Conformer encoder module."""
98
+
99
+ def __init__(
100
+ self,
101
+ input_size: int,
102
+ output_size: int = 256,
103
+ attention_heads: int = 4,
104
+ linear_units: int = 2048,
105
+ num_blocks: int = 6,
106
+ dropout_rate: float = 0.1,
107
+ positional_dropout_rate: float = 0.1,
108
+ attention_dropout_rate: float = 0.0,
109
+ input_layer: str = "conv2d",
110
+ pos_enc_layer_type: str = "rel_pos",
111
+ normalize_before: bool = True,
112
+ static_chunk_size: int = 0,
113
+ use_dynamic_chunk: bool = False,
114
+ global_cmvn: torch.nn.Module = None,
115
+ use_dynamic_left_chunk: bool = False,
116
+ positionwise_conv_kernel_size: int = 1,
117
+ macaron_style: bool = True,
118
+ selfattention_layer_type: str = "rel_selfattn",
119
+ activation_type: str = "swish",
120
+ use_cnn_module: bool = True,
121
+ cnn_module_kernel: int = 15,
122
+ causal: bool = False,
123
+ cnn_module_norm: str = "batch_norm",
124
+ query_bias: bool = True,
125
+ key_bias: bool = True,
126
+ value_bias: bool = True,
127
+ conv_bias: bool = True,
128
+ gradient_checkpointing: bool = False,
129
+ use_sdpa: bool = False,
130
+ layer_norm_type: str = 'layer_norm',
131
+ norm_eps: float = 1e-5,
132
+ n_kv_head: Optional[int] = None,
133
+ head_dim: Optional[int] = None,
134
+ mlp_type: str = 'position_wise_feed_forward',
135
+ mlp_bias: bool = True,
136
+ n_expert: int = 8,
137
+ n_expert_activated: int = 2,
138
+ ):
139
+ """ Construct DualConformerEncoder
140
+ Support both the full context mode and the streaming mode separately
141
+ """
142
+ super().__init__(
143
+ input_size, output_size, attention_heads, linear_units, num_blocks,
144
+ dropout_rate, positional_dropout_rate, attention_dropout_rate,
145
+ input_layer, pos_enc_layer_type, normalize_before,
146
+ static_chunk_size, use_dynamic_chunk, global_cmvn,
147
+ use_dynamic_left_chunk, positionwise_conv_kernel_size,
148
+ macaron_style, selfattention_layer_type, activation_type,
149
+ use_cnn_module, cnn_module_kernel, causal, cnn_module_norm,
150
+ query_bias, key_bias, value_bias, conv_bias,
151
+ gradient_checkpointing, use_sdpa, layer_norm_type, norm_eps,
152
+ n_kv_head, head_dim, mlp_type, mlp_bias, n_expert,
153
+ n_expert_activated)
154
+
155
+ def forward_full(
156
+ self,
157
+ xs: torch.Tensor,
158
+ xs_lens: torch.Tensor,
159
+ decoding_chunk_size: int = 0,
160
+ num_decoding_left_chunks: int = -1,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ T = xs.size(1)
163
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
164
+ if self.global_cmvn is not None:
165
+ xs = self.global_cmvn(xs)
166
+ xs, pos_emb, masks = self.embed(xs, masks)
167
+ mask_pad = masks # (B, 1, T/subsample_rate)
168
+ for layer in self.encoders:
169
+ xs, masks, _, _ = layer(xs, masks, pos_emb, mask_pad)
170
+ if self.normalize_before:
171
+ xs = self.after_norm(xs)
172
+ return xs, masks
wenet/dataset/__init__.py ADDED
File without changes
wenet/dataset/datapipes.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Wenet Community. (authors: Dinghao Zhou)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import collections
16
+ from collections.abc import Callable
17
+ import copy
18
+ import sys
19
+ import tarfile
20
+ import logging
21
+ from typing import List, Optional
22
+ import numpy as np
23
+ import torch
24
+ from torch.utils.data import IterDataPipe, functional_datapipe
25
+ from torch.utils.data import datapipes
26
+ from torch.utils.data.datapipes.iter import Mapper
27
+ from torch.utils.data.datapipes.iter.sharding import (
28
+ SHARDING_PRIORITIES, ShardingFilterIterDataPipe)
29
+ from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
30
+
31
+ from wenet.dataset.processor import parse_url
32
+
33
+
34
+ @functional_datapipe("map_ignore_error")
35
+ class MapperIgnoreErrorDataPipe(Mapper):
36
+
37
+ def __init__(self,
38
+ dataset: IterDataPipe,
39
+ fn: Callable,
40
+ input_col=None,
41
+ output_col=None,
42
+ log_error: bool = True) -> None:
43
+ super().__init__(dataset, fn, input_col, output_col)
44
+ self._iter = None
45
+ self.log_error = log_error
46
+
47
+ def __iter__(self):
48
+ if self._iter is None:
49
+ self._iter = iter(self.datapipe)
50
+
51
+ while True:
52
+ try:
53
+ elem = next(self._iter)
54
+ yield self._apply_fn(elem)
55
+ except StopIteration:
56
+ self._iter = None
57
+ return
58
+ except Exception as ex:
59
+ if self.log_error:
60
+ logging.warning(str(ex))
61
+
62
+
63
+ @functional_datapipe('bucket_by_sequence_length')
64
+ class BucketBySequenceLengthDataPipe(IterDataPipe):
65
+
66
+ def __init__(
67
+ self,
68
+ dataset: IterDataPipe,
69
+ elem_length_func,
70
+ bucket_boundaries: List[int],
71
+ bucket_batch_sizes: List[int],
72
+ wrapper_class=None,
73
+ ) -> None:
74
+ super().__init__()
75
+ _check_unpickable_fn(elem_length_func)
76
+ assert len(bucket_batch_sizes) == len(bucket_boundaries) + 1
77
+ self.bucket_batch_sizes = bucket_batch_sizes
78
+ self.bucket_boundaries = bucket_boundaries + [sys.maxsize]
79
+ self.elem_length_func = elem_length_func
80
+
81
+ self._group_dp = GroupByWindowDataPipe(dataset,
82
+ self._element_to_bucket_id,
83
+ self._window_size_func,
84
+ wrapper_class=wrapper_class)
85
+
86
+ def __iter__(self):
87
+ yield from self._group_dp
88
+
89
+ def _element_to_bucket_id(self, elem):
90
+ seq_len = self.elem_length_func(elem)
91
+ bucket_id = 0
92
+ for (i, b) in enumerate(self.bucket_boundaries):
93
+ if seq_len < b:
94
+ bucket_id = i
95
+ break
96
+ return bucket_id
97
+
98
+ def _window_size_func(self, bucket_id):
99
+ return self.bucket_batch_sizes[bucket_id]
100
+
101
+
102
+ @functional_datapipe("group_by_window")
103
+ class GroupByWindowDataPipe(datapipes.iter.Grouper):
104
+
105
+ def __init__(
106
+ self,
107
+ dataset: IterDataPipe,
108
+ key_func,
109
+ window_size_func,
110
+ wrapper_class=None,
111
+ ):
112
+ super().__init__(dataset,
113
+ key_func,
114
+ keep_key=False,
115
+ group_size=None,
116
+ drop_remaining=False)
117
+ _check_unpickable_fn(window_size_func)
118
+ self.dp = dataset
119
+ self.window_size_func = window_size_func
120
+ if wrapper_class is not None:
121
+ _check_unpickable_fn(wrapper_class)
122
+ del self.wrapper_class
123
+ self.wrapper_class = wrapper_class
124
+
125
+ def __iter__(self):
126
+ for x in self.datapipe:
127
+ key = self.group_key_fn(x)
128
+
129
+ self.buffer_elements[key].append(x)
130
+ self.curr_buffer_size += 1
131
+
132
+ group_size = self.window_size_func(key)
133
+ if group_size == len(self.buffer_elements[key]):
134
+ result = self.wrapper_class(self.buffer_elements[key])
135
+ yield result
136
+ self.curr_buffer_size -= len(self.buffer_elements[key])
137
+ del self.buffer_elements[key]
138
+
139
+ if self.curr_buffer_size == self.max_buffer_size:
140
+ result_to_yield = self._remove_biggest_key()
141
+ if result_to_yield is not None:
142
+ result = self.wrapper_class(result_to_yield)
143
+ yield result
144
+
145
+ for key in tuple(self.buffer_elements.keys()):
146
+ result = self.wrapper_class(self.buffer_elements.pop(key))
147
+ self.curr_buffer_size -= len(result)
148
+ yield result
149
+
150
+
151
+ @functional_datapipe("sort")
152
+ class SortDataPipe(IterDataPipe):
153
+
154
+ def __init__(self,
155
+ dataset: IterDataPipe,
156
+ buffer_size: int = 500,
157
+ key_func=None,
158
+ reverse=False) -> None:
159
+ if key_func is not None:
160
+ _check_unpickable_fn(key_func)
161
+ self.buffer_size = buffer_size
162
+ super().__init__()
163
+ self.dp = dataset
164
+ self._buffer = []
165
+ self.key_func = key_func
166
+ self.reverse = reverse
167
+
168
+ def __iter__(self):
169
+ for elem in self.dp:
170
+ self._buffer.append(elem)
171
+ if len(self._buffer) >= self.buffer_size:
172
+ self._buffer.sort(key=self.key_func, reverse=self.reverse)
173
+ for x in self._buffer:
174
+ yield x
175
+ del self._buffer
176
+ self._buffer = []
177
+ # The sample left over
178
+ self._buffer.sort(key=self.key_func, reverse=self.reverse)
179
+ for x in self._buffer:
180
+ yield x
181
+ del self._buffer
182
+ self._buffer = []
183
+
184
+
185
+ @functional_datapipe("dynamic_batch")
186
+ class DynamicBatchDataPipe(IterDataPipe):
187
+
188
+ def __init__(self, dataset: IterDataPipe, window_class,
189
+ wrapper_class) -> None:
190
+ _check_unpickable_fn(window_class)
191
+ _check_unpickable_fn(wrapper_class)
192
+ super().__init__()
193
+ self.dp = dataset
194
+ assert window_class is not None
195
+ assert wrapper_class is not None
196
+ self.window_class = window_class
197
+ self._buffer = []
198
+ self._wrappr_class = wrapper_class
199
+
200
+ def __iter__(self):
201
+ for elem in self.dp:
202
+ if not self.window_class(elem, len(self._buffer)):
203
+ self._buffer.append(elem)
204
+ else:
205
+ if len(self._buffer) > 0:
206
+ yield self._wrappr_class(self._buffer)
207
+ del self._buffer
208
+ self._buffer = [elem]
209
+ if len(self._buffer) > 0:
210
+ yield self._wrappr_class(self._buffer)
211
+ del self._buffer
212
+ self._buffer = []
213
+
214
+
215
+ @functional_datapipe("prefetch")
216
+ class PrefetchDataPipe(IterDataPipe):
217
+ """Performs prefetching"""
218
+
219
+ def __init__(
220
+ self,
221
+ dataset: IterDataPipe,
222
+ buffer_size: int = 500,
223
+ ):
224
+ # TODO(Mddct): support multiprocessing pool with shared-memory to
225
+ # prefetch
226
+ super().__init__()
227
+ self.dp = dataset
228
+ self._iter = None
229
+ self._prefetch_buffer_size = buffer_size
230
+ self._buffer = None
231
+ if self._prefetch_buffer_size > 0:
232
+ self._buffer = collections.deque(maxlen=self._prefetch_buffer_size)
233
+
234
+ def __iter__(self):
235
+ if self._prefetch_buffer_size > 0:
236
+ if self._iter is None:
237
+ self._iter = iter(self.dp)
238
+ assert self._buffer is not None
239
+
240
+ while True:
241
+ if len(self._buffer) <= self._prefetch_buffer_size // 2:
242
+ while len(self._buffer) < self._prefetch_buffer_size:
243
+ try:
244
+ self._buffer.append(next(self._iter))
245
+ except StopIteration:
246
+ if len(self._buffer) != 0:
247
+ while len(self._buffer) > 0:
248
+ yield self._buffer.popleft()
249
+ self._iter = None
250
+ return
251
+ while len(self._buffer) > self._prefetch_buffer_size // 2:
252
+ elem = self._buffer.popleft()
253
+ yield elem
254
+
255
+ else:
256
+ yield from self.dp
257
+
258
+
259
+ @functional_datapipe("repeat")
260
+ class RepeatDatapipe(IterDataPipe):
261
+
262
+ def __init__(self, dataset: IterDataPipe, count: int = -1):
263
+ super().__init__()
264
+ self.dp = dataset
265
+ self.count = count
266
+
267
+ def __iter__(self):
268
+ if self.count == 1:
269
+ yield from self.dp
270
+ return
271
+ i = 0
272
+ while self.count < 0 or i < self.count:
273
+ for elem in self.dp:
274
+ new_elem = copy.copy(elem)
275
+ yield new_elem
276
+ i += 1
277
+
278
+
279
+ @functional_datapipe("shard")
280
+ class ShardDataPipe(ShardingFilterIterDataPipe):
281
+
282
+ def __init__(self, dataset: IterDataPipe, partition: bool = False):
283
+ super().__init__(dataset, None)
284
+ self.partition = partition
285
+ self.dp = dataset
286
+
287
+ def apply_sharding(self, num_of_instances: int, instance_id: int,
288
+ sharding_group: SHARDING_PRIORITIES):
289
+ if self.partition:
290
+ return super().apply_sharding(num_of_instances, instance_id,
291
+ sharding_group)
292
+ else:
293
+ # We can not handle uneven data for CV on DDP, so we don't
294
+ # sample data by rank, that means every GPU gets the same
295
+ # and all the CV data
296
+ info = torch.utils.data.get_worker_info()
297
+ if info is None:
298
+ self.num_of_instances = 1
299
+ self.instance_id = 0
300
+ else:
301
+ n_workers_per_device = info.num_workers
302
+ self.num_of_instances = n_workers_per_device
303
+ self.instance_id = info.id
304
+
305
+
306
+ @functional_datapipe("interleave")
307
+ class InterlaveDataPipe(IterDataPipe):
308
+
309
+ def __init__(
310
+ self,
311
+ source_datapipes: List[IterDataPipe],
312
+ weights: Optional[List[float]] = None,
313
+ seed=2027,
314
+ ):
315
+ super().__init__()
316
+ self.rng = np.random.default_rng(seed)
317
+ self.source_datapipes = source_datapipes
318
+ self.weights = weights
319
+ if weights is None:
320
+ self.weights = [1 / len(self.source_datapipes)] * len(
321
+ self.source_datapipes)
322
+ else:
323
+ self.weights = [weight / sum(weights) for weight in weights]
324
+ self.iters = None
325
+
326
+ def __iter__(self):
327
+ weights = copy.deepcopy(self.weights)
328
+ exhausted = len(self.source_datapipes) * [False]
329
+ if self.iters is None:
330
+ self.iters = [(i, iter(d))
331
+ for i, d in enumerate(self.source_datapipes)]
332
+ while True:
333
+ # TODO(Mddct): rng
334
+ index_iter = self.rng.choice(self.iters, p=weights)
335
+ i, ite = index_iter
336
+ try:
337
+ elem = next(ite)
338
+ yield elem
339
+ except StopIteration:
340
+ weights[i] = 0.
341
+ exhausted[i] = True
342
+ if all(exhausted):
343
+ return
344
+ weights = [weight / sum(weights) for weight in weights]
345
+
346
+
347
+ class TextLineDataPipe(IterDataPipe):
348
+ """ Streamming Text line
349
+ """
350
+
351
+ def __init__(self, filenames, mode='r'):
352
+ super().__init__()
353
+ _dp = datapipes.iter.FileLister(filenames)
354
+ _dp = datapipes.iter.FileOpener(_dp, mode=mode)
355
+ self.dp = _dp
356
+
357
+ def __iter__(self):
358
+ for fname, stream in self.dp:
359
+ for line in stream:
360
+ line = line.strip('\n')
361
+ yield {"file_name": fname, "line": line}
362
+ stream.close()
363
+
364
+
365
+ @functional_datapipe("tar_file_and_group")
366
+ class TarsDataPipe(IterDataPipe):
367
+ """ Decode wenet's tar , yield {'txt': "...", "raw": "..."}
368
+ """
369
+
370
+ def __init__(self, dataset: IterDataPipe) -> None:
371
+ super().__init__()
372
+ self.dp = dataset
373
+
374
+ def __iter__(self):
375
+ from wenet.dataset.processor import AUDIO_FORMAT_SETS
376
+ for sample in self.dp:
377
+ assert 'file_name' in sample
378
+ assert 'line' in sample
379
+ assert 'stream' in sample
380
+ try:
381
+ with tarfile.open(fileobj=sample['stream'],
382
+ mode="r:*") as stream:
383
+ prev_prefix = None
384
+ example = {
385
+ 'file_name': sample['file_name'],
386
+ 'tar_file_name': sample['line']
387
+ }
388
+ valid = True
389
+ for tarinfo in stream:
390
+ name = tarinfo.name
391
+ pos = name.rfind('.')
392
+ assert pos > 0
393
+ prefix, postfix = name[:pos], name[pos + 1:]
394
+ if prev_prefix is not None and prefix != prev_prefix:
395
+ example['key'] = prev_prefix
396
+ if valid:
397
+ yield example
398
+ example = {
399
+ 'file_name': sample['file_name'],
400
+ 'tar_file_name': sample['line']
401
+ }
402
+ valid = True
403
+ with stream.extractfile(tarinfo) as file_obj:
404
+ try:
405
+ if postfix == 'txt':
406
+ example['txt'] = file_obj.read().decode(
407
+ 'utf8').strip()
408
+ elif postfix in AUDIO_FORMAT_SETS:
409
+ example['wav'] = file_obj.read()
410
+ else:
411
+ example[postfix] = file_obj.read()
412
+ except Exception as ex:
413
+ valid = False
414
+ logging.warning(
415
+ 'error to parse {}'.format(name))
416
+ prev_prefix = prefix
417
+ if prev_prefix is not None:
418
+ example['key'] = prev_prefix
419
+ yield example
420
+ except Exception as ex:
421
+ msg = 'In tar_file_and_group: {} when processing {}'.format(
422
+ ex, sample['line'])
423
+ logging.warning(msg)
424
+ finally:
425
+ if 'process' in sample:
426
+ sample['process'].communicate()
427
+ sample['stream'].close()
428
+
429
+
430
+ class WenetRawDatasetSource(IterDataPipe):
431
+
432
+ def __init__(self,
433
+ filenames: str,
434
+ prefetch: int = 500,
435
+ partition: bool = True,
436
+ shuffle: bool = False,
437
+ shuffle_size: int = 10000,
438
+ cycle: int = 1) -> None:
439
+ super().__init__()
440
+ self.dp = TextLineDataPipe(filenames)
441
+ if shuffle:
442
+ self.dp = self.dp.shuffle(buffer_size=shuffle_size)
443
+ self.dp = self.dp.repeat(cycle).prefetch(prefetch)
444
+ self.dp = self.dp.shard(partition)
445
+
446
+ def __iter__(self):
447
+ for d in self.dp:
448
+ yield d
449
+
450
+
451
+ class WenetTarShardDatasetSource(IterDataPipe):
452
+
453
+ def __init__(self,
454
+ filenames: str,
455
+ prefetch: int = 500,
456
+ partition: bool = True,
457
+ shuffle: bool = False,
458
+ shuffle_size: int = 10000,
459
+ cycle: int = 1) -> None:
460
+ super().__init__()
461
+ self.dp = TextLineDataPipe(filenames)
462
+ if shuffle:
463
+ self.dp = self.dp.shuffle(buffer_size=shuffle_size)
464
+ self.dp = self.dp.repeat(cycle)
465
+ self.dp = self.dp.shard(partition).map_ignore_error(
466
+ parse_url).tar_file_and_group().prefetch(prefetch)
467
+
468
+ def __iter__(self):
469
+ for d in self.dp:
470
+ yield d
wenet/dataset/dataset.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+
17
+ import torch
18
+ import torch.distributed as dist
19
+ from torch.utils.data import IterableDataset
20
+
21
+ import wenet.dataset.deprecated.processor as processor
22
+ from wenet.text.base_tokenizer import BaseTokenizer
23
+ from wenet.utils.file_utils import read_lists
24
+
25
+
26
+ class Processor(IterableDataset):
27
+
28
+ def __init__(self, source, f, *args, **kw):
29
+ assert callable(f)
30
+ self.source = source
31
+ self.f = f
32
+ self.args = args
33
+ self.kw = kw
34
+
35
+ def set_epoch(self, epoch):
36
+ self.source.set_epoch(epoch)
37
+
38
+ def __iter__(self):
39
+ """ Return an iterator over the source dataset processed by the
40
+ given processor.
41
+ """
42
+ assert self.source is not None
43
+ assert callable(self.f)
44
+ return self.f(iter(self.source), *self.args, **self.kw)
45
+
46
+ def apply(self, f):
47
+ assert callable(f)
48
+ return Processor(self, f, *self.args, **self.kw)
49
+
50
+
51
+ class DistributedSampler:
52
+
53
+ def __init__(self, shuffle=True, partition=True, split_num=1):
54
+ self.epoch = -1
55
+ self.update()
56
+ self.shuffle = shuffle
57
+ self.partition = partition
58
+ self.split_num = split_num
59
+
60
+ def update(self):
61
+ assert dist.is_available()
62
+ if dist.is_initialized():
63
+ self.rank = dist.get_rank()
64
+ self.world_size = dist.get_world_size()
65
+ else:
66
+ self.rank = 0
67
+ self.world_size = 1
68
+ worker_info = torch.utils.data.get_worker_info()
69
+ if worker_info is None:
70
+ self.worker_id = 0
71
+ self.num_workers = 1
72
+ else:
73
+ self.worker_id = worker_info.id
74
+ self.num_workers = worker_info.num_workers
75
+ return dict(rank=self.rank,
76
+ world_size=self.world_size,
77
+ worker_id=self.worker_id,
78
+ num_workers=self.num_workers)
79
+
80
+ def set_epoch(self, epoch):
81
+ self.epoch = epoch
82
+
83
+ def split_data(self, total_num):
84
+ data = list(range(total_num))
85
+ sub_epoch = self.epoch + 1
86
+ full_epoch = sub_epoch // self.split_num
87
+ num_per_sub_epochs = total_num // self.split_num
88
+ random.Random(full_epoch).shuffle(data)
89
+
90
+ split_index = sub_epoch - full_epoch * self.split_num
91
+ begin = split_index * num_per_sub_epochs
92
+ end = (begin + num_per_sub_epochs
93
+ if (split_index + 1) < self.split_num else
94
+ total_num)
95
+
96
+ # print(f'begin: {begin}, end: {end}, world_size: {self.world_size}')
97
+ return data[begin:end]
98
+
99
+ def sample(self, data, split_num=1):
100
+ """ Sample data according to rank/world_size/num_workers
101
+
102
+ Args:
103
+ data(List): input data list
104
+
105
+ Returns:
106
+ List: data list after sample
107
+ """
108
+ if self.split_num == 1:
109
+ data = list(range(len(data)))
110
+ else:
111
+ data = self.split_data(len(data))
112
+ # TODO(Binbin Zhang): fix this
113
+ # We can not handle uneven data for CV on DDP, so we don't
114
+ # sample data by rank, that means every GPU gets the same
115
+ # and all the CV data
116
+ if self.partition:
117
+ if self.shuffle:
118
+ random.Random(self.epoch).shuffle(data)
119
+ data = data[self.rank::self.world_size]
120
+ # print(f'num dataset: {len(data)}')
121
+ data = data[self.worker_id::self.num_workers]
122
+ self.epoch += 1
123
+ return data
124
+
125
+
126
+ class DataList(IterableDataset):
127
+
128
+ def __init__(self, lists, shuffle=True, partition=True, split_num=1):
129
+ self.lists = lists
130
+ self.sampler = DistributedSampler(shuffle, partition, split_num)
131
+
132
+ def set_epoch(self, epoch):
133
+ self.sampler.set_epoch(epoch)
134
+
135
+ def __iter__(self):
136
+ sampler_info = self.sampler.update()
137
+ indexes = self.sampler.sample(self.lists)
138
+ for index in indexes:
139
+ # yield dict(src=src)
140
+ data = dict(src=self.lists[index])
141
+ data.update(sampler_info)
142
+ yield data
143
+
144
+
145
+ def Dataset(data_type,
146
+ data_list_file,
147
+ tokenizer: BaseTokenizer,
148
+ conf,
149
+ partition=True):
150
+ """ Construct dataset from arguments
151
+
152
+ We have two shuffle stage in the Dataset. The first is global
153
+ shuffle at shards tar/raw file level. The second is global shuffle
154
+ at training samples level.
155
+
156
+ Args:
157
+ data_type(str): raw/shard
158
+ bpe_model(str): model for english bpe part
159
+ partition(bool): whether to do data partition in terms of rank
160
+ """
161
+ assert data_type in ['raw', 'shard', 'shard_full_data']
162
+ lists = read_lists(data_list_file)
163
+ shuffle = conf.get('shuffle', True)
164
+ split_num = conf.get('split_num', 1)
165
+ dataset = DataList(lists, shuffle=shuffle, partition=partition, split_num=split_num)
166
+ if data_type == 'shard':
167
+ dataset = Processor(dataset, processor.url_opener)
168
+ dataset = Processor(dataset, processor.tar_file_and_group)
169
+ elif data_type == 'shard_full_data':
170
+ dataset = Processor(dataset, processor.url_opener)
171
+ dataset = Processor(dataset, processor.tar_file_and_group_full_data)
172
+ else:
173
+ dataset = Processor(dataset, processor.parse_raw)
174
+
175
+ speaker_conf = conf.get('speaker_conf', None)
176
+ if speaker_conf is not None:
177
+ dataset = Processor(dataset, processor.parse_speaker, **speaker_conf)
178
+
179
+ if conf.get('eod_id', None) is not None:
180
+ tokenizer.eod_id = conf['eod_id']
181
+ # prompt dict
182
+ from gxl_ai_utils.utils import utils_file
183
+ global_prompt_dict = utils_file.load_dict_from_yaml('conf/prompt_stage4.yaml')
184
+ dataset = Processor(dataset, processor.tokenize, tokenizer,
185
+ global_prompt_dict=global_prompt_dict)
186
+ filter_conf = conf.get('filter_conf', {})
187
+ dataset = Processor(dataset, processor.filter, **filter_conf)
188
+
189
+ resample_conf = conf.get('resample_conf', {})
190
+ dataset = Processor(dataset, processor.resample, **resample_conf)
191
+
192
+ speed_perturb = conf.get('speed_perturb', False)
193
+ if speed_perturb:
194
+ dataset = Processor(dataset, processor.speed_perturb)
195
+
196
+ feats_type = conf.get('feats_type', 'fbank')
197
+ assert feats_type in ['fbank', 'mfcc', 'log_mel_spectrogram']
198
+ if feats_type == 'fbank':
199
+ fbank_conf = conf.get('fbank_conf', {})
200
+ dataset = Processor(dataset, processor.compute_fbank, **fbank_conf)
201
+ elif feats_type == 'mfcc':
202
+ mfcc_conf = conf.get('mfcc_conf', {})
203
+ dataset = Processor(dataset, processor.compute_mfcc, **mfcc_conf)
204
+ elif feats_type == 'log_mel_spectrogram':
205
+ log_mel_spectrogram_conf = conf.get('log_mel_spectrogram_conf', {})
206
+ dataset = Processor(dataset, processor.compute_log_mel_spectrogram,
207
+ **log_mel_spectrogram_conf)
208
+
209
+ spec_aug = conf.get('spec_aug', True)
210
+ spec_sub = conf.get('spec_sub', False)
211
+ spec_trim = conf.get('spec_trim', False)
212
+ if spec_aug:
213
+ spec_aug_conf = conf.get('spec_aug_conf', {})
214
+ dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf)
215
+ if spec_sub:
216
+ spec_sub_conf = conf.get('spec_sub_conf', {})
217
+ dataset = Processor(dataset, processor.spec_sub, **spec_sub_conf)
218
+ if spec_trim:
219
+ spec_trim_conf = conf.get('spec_trim_conf', {})
220
+ dataset = Processor(dataset, processor.spec_trim, **spec_trim_conf)
221
+
222
+ if shuffle:
223
+ shuffle_conf = conf.get('shuffle_conf', {})
224
+ dataset = Processor(dataset, processor.shuffle, **shuffle_conf)
225
+
226
+ sort = conf.get('sort', True)
227
+ if sort:
228
+ sort_conf = conf.get('sort_conf', {})
229
+ dataset = Processor(dataset, processor.sort, **sort_conf)
230
+
231
+ batch_conf = conf.get('batch_conf', {})
232
+ dataset = Processor(dataset, processor.batch, **batch_conf)
233
+ dataset = Processor(dataset, processor.padding)
234
+ return dataset
wenet/dataset/deprecated/dataset.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+
17
+ import torch
18
+ import torch.distributed as dist
19
+ from torch.utils.data import IterableDataset
20
+
21
+ import wenet.dataset.deprecated.processor as processor
22
+ from wenet.text.base_tokenizer import BaseTokenizer
23
+ from wenet.utils.file_utils import read_lists
24
+
25
+
26
+ class Processor(IterableDataset):
27
+
28
+ def __init__(self, source, f, *args, **kw):
29
+ assert callable(f)
30
+ self.source = source
31
+ self.f = f
32
+ self.args = args
33
+ self.kw = kw
34
+
35
+ def set_epoch(self, epoch):
36
+ self.source.set_epoch(epoch)
37
+
38
+ def __iter__(self):
39
+ """ Return an iterator over the source dataset processed by the
40
+ given processor.
41
+ """
42
+ assert self.source is not None
43
+ assert callable(self.f)
44
+ return self.f(iter(self.source), *self.args, **self.kw)
45
+
46
+ def apply(self, f):
47
+ assert callable(f)
48
+ return Processor(self, f, *self.args, **self.kw)
49
+
50
+
51
+ class DistributedSampler:
52
+
53
+ def __init__(self, shuffle=True, partition=True):
54
+ self.epoch = -1
55
+ self.update()
56
+ self.shuffle = shuffle
57
+ self.partition = partition
58
+
59
+ def update(self):
60
+ assert dist.is_available()
61
+ if dist.is_initialized():
62
+ self.rank = dist.get_rank()
63
+ self.world_size = dist.get_world_size()
64
+ else:
65
+ self.rank = 0
66
+ self.world_size = 1
67
+ worker_info = torch.utils.data.get_worker_info()
68
+ if worker_info is None:
69
+ self.worker_id = 0
70
+ self.num_workers = 1
71
+ else:
72
+ self.worker_id = worker_info.id
73
+ self.num_workers = worker_info.num_workers
74
+ return dict(rank=self.rank,
75
+ world_size=self.world_size,
76
+ worker_id=self.worker_id,
77
+ num_workers=self.num_workers)
78
+
79
+ def set_epoch(self, epoch):
80
+ self.epoch = epoch
81
+
82
+ def sample(self, data):
83
+ """ Sample data according to rank/world_size/num_workers
84
+
85
+ Args:
86
+ data(List): input data list
87
+
88
+ Returns:
89
+ List: data list after sample
90
+ """
91
+ data = list(range(len(data)))
92
+ # TODO(Binbin Zhang): fix this
93
+ # We can not handle uneven data for CV on DDP, so we don't
94
+ # sample data by rank, that means every GPU gets the same
95
+ # and all the CV data
96
+ if self.partition:
97
+ if self.shuffle:
98
+ random.Random(self.epoch).shuffle(data)
99
+ data = data[self.rank::self.world_size]
100
+ data = data[self.worker_id::self.num_workers]
101
+ return data
102
+
103
+
104
+ class DataList(IterableDataset):
105
+
106
+ def __init__(self, lists, shuffle=True, partition=True):
107
+ self.lists = lists
108
+ self.sampler = DistributedSampler(shuffle, partition)
109
+
110
+ def set_epoch(self, epoch):
111
+ self.sampler.set_epoch(epoch)
112
+
113
+ def __iter__(self):
114
+ sampler_info = self.sampler.update()
115
+ indexes = self.sampler.sample(self.lists)
116
+ for index in indexes:
117
+ # yield dict(src=src)
118
+ data = dict(src=self.lists[index])
119
+ data.update(sampler_info)
120
+ yield data
121
+
122
+
123
+ def Dataset(data_type,
124
+ data_list_file,
125
+ tokenizer: BaseTokenizer,
126
+ conf,
127
+ partition=True):
128
+ """ Construct dataset from arguments
129
+
130
+ We have two shuffle stage in the Dataset. The first is global
131
+ shuffle at shards tar/raw file level. The second is global shuffle
132
+ at training samples level.
133
+
134
+ Args:
135
+ data_type(str): raw/shard
136
+ bpe_model(str): model for english bpe part
137
+ partition(bool): whether to do data partition in terms of rank
138
+ """
139
+ assert data_type in ['raw', 'shard']
140
+ lists = read_lists(data_list_file)
141
+ shuffle = conf.get('shuffle', True)
142
+ dataset = DataList(lists, shuffle=shuffle, partition=partition)
143
+ if data_type == 'shard':
144
+ dataset = Processor(dataset, processor.url_opener)
145
+ dataset = Processor(dataset, processor.tar_file_and_group)
146
+ else:
147
+ dataset = Processor(dataset, processor.parse_raw)
148
+
149
+ speaker_conf = conf.get('speaker_conf', None)
150
+ if speaker_conf is not None:
151
+ dataset = Processor(dataset, processor.parse_speaker, **speaker_conf)
152
+
153
+ dataset = Processor(dataset, processor.tokenize, tokenizer)
154
+ filter_conf = conf.get('filter_conf', {})
155
+ dataset = Processor(dataset, processor.filter, **filter_conf)
156
+
157
+ resample_conf = conf.get('resample_conf', {})
158
+ dataset = Processor(dataset, processor.resample, **resample_conf)
159
+
160
+ speed_perturb = conf.get('speed_perturb', False)
161
+ if speed_perturb:
162
+ dataset = Processor(dataset, processor.speed_perturb)
163
+
164
+ feats_type = conf.get('feats_type', 'fbank')
165
+ assert feats_type in ['fbank', 'mfcc', 'log_mel_spectrogram']
166
+ if feats_type == 'fbank':
167
+ fbank_conf = conf.get('fbank_conf', {})
168
+ dataset = Processor(dataset, processor.compute_fbank, **fbank_conf)
169
+ elif feats_type == 'mfcc':
170
+ mfcc_conf = conf.get('mfcc_conf', {})
171
+ dataset = Processor(dataset, processor.compute_mfcc, **mfcc_conf)
172
+ elif feats_type == 'log_mel_spectrogram':
173
+ log_mel_spectrogram_conf = conf.get('log_mel_spectrogram_conf', {})
174
+ dataset = Processor(dataset, processor.compute_log_mel_spectrogram,
175
+ **log_mel_spectrogram_conf)
176
+
177
+ spec_aug = conf.get('spec_aug', True)
178
+ spec_sub = conf.get('spec_sub', False)
179
+ spec_trim = conf.get('spec_trim', False)
180
+ if spec_aug:
181
+ spec_aug_conf = conf.get('spec_aug_conf', {})
182
+ dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf)
183
+ if spec_sub:
184
+ spec_sub_conf = conf.get('spec_sub_conf', {})
185
+ dataset = Processor(dataset, processor.spec_sub, **spec_sub_conf)
186
+ if spec_trim:
187
+ spec_trim_conf = conf.get('spec_trim_conf', {})
188
+ dataset = Processor(dataset, processor.spec_trim, **spec_trim_conf)
189
+
190
+ if shuffle:
191
+ shuffle_conf = conf.get('shuffle_conf', {})
192
+ dataset = Processor(dataset, processor.shuffle, **shuffle_conf)
193
+
194
+ sort = conf.get('sort', True)
195
+ if sort:
196
+ sort_conf = conf.get('sort_conf', {})
197
+ dataset = Processor(dataset, processor.sort, **sort_conf)
198
+
199
+ batch_conf = conf.get('batch_conf', {})
200
+ dataset = Processor(dataset, processor.batch, **batch_conf)
201
+ dataset = Processor(dataset, processor.padding)
202
+ return dataset
wenet/dataset/deprecated/processor.py ADDED
@@ -0,0 +1,1023 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import librosa
17
+ import logging
18
+ import json
19
+ import random
20
+ import tarfile
21
+ from subprocess import PIPE, Popen
22
+ from urllib.parse import urlparse
23
+
24
+ import torch
25
+ import torchaudio
26
+ import torchaudio.compliance.kaldi as kaldi
27
+ import torch.nn.functional as F
28
+ from gxl_ai_utils.utils import utils_file
29
+ from torch.nn.utils.rnn import pad_sequence
30
+ from wenet.text.base_tokenizer import BaseTokenizer
31
+
32
+ # torchaudio.utils.sox_utils.set_buffer_size(16500)
33
+ torchaudio.set_audio_backend("soundfile")
34
+
35
+ AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
36
+
37
+
38
+ def url_opener(data):
39
+ """ Give url or local file, return file descriptor
40
+ Inplace operation.
41
+
42
+ Args:
43
+ data(Iterable[str]): url or local file list
44
+
45
+ Returns:
46
+ Iterable[{src, stream}]
47
+ """
48
+ for sample in data:
49
+ assert 'src' in sample
50
+ # TODO(Binbin Zhang): support HTTP
51
+ url = sample['src']
52
+ try:
53
+ pr = urlparse(url)
54
+ # local file
55
+ if pr.scheme == '' or pr.scheme == 'file':
56
+ stream = open(url, 'rb')
57
+ # network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP
58
+ else:
59
+ cmd = f'wget -q -O - {url}'
60
+ process = Popen(cmd, shell=True, stdout=PIPE)
61
+ sample.update(process=process)
62
+ stream = process.stdout
63
+ sample.update(stream=stream)
64
+ yield sample
65
+ except Exception as ex:
66
+ logging.warning('Failed to open {}'.format(url))
67
+
68
+
69
+ def tar_file_and_group(data):
70
+ """ Expand a stream of open tar files into a stream of tar file contents.
71
+ And groups the file with same prefix
72
+
73
+ Args:
74
+ data: Iterable[{src, stream}]
75
+
76
+ Returns:
77
+ Iterable[{key, wav, txt, sample_rate}]
78
+ """
79
+ for sample in data:
80
+ assert 'stream' in sample
81
+ stream = None
82
+ try:
83
+ stream = tarfile.open(fileobj=sample['stream'], mode="r:*")
84
+ prev_prefix = None
85
+ example = {}
86
+ valid = True
87
+ for tarinfo in stream:
88
+ name = tarinfo.name
89
+ pos = name.rfind('.')
90
+ assert pos > 0
91
+ prefix, postfix = name[:pos], name[pos + 1:]
92
+ if prev_prefix is not None and prefix != prev_prefix:
93
+ example['key'] = prev_prefix
94
+ if valid:
95
+ yield example
96
+ example = {}
97
+ valid = True
98
+ with stream.extractfile(tarinfo) as file_obj:
99
+ try:
100
+ if postfix == 'txt':
101
+ example['txt'] = file_obj.read().decode(
102
+ 'utf8').strip()
103
+ elif postfix in AUDIO_FORMAT_SETS:
104
+ waveform, sample_rate = torchaudio.load(file_obj)
105
+ example['wav'] = waveform
106
+ example['sample_rate'] = sample_rate
107
+ else:
108
+ example[postfix] = file_obj.read()
109
+ except Exception as ex:
110
+ valid = False
111
+ logging.warning('error to parse {}'.format(name))
112
+ prev_prefix = prefix
113
+ if prev_prefix is not None:
114
+ example['key'] = prev_prefix
115
+ yield example
116
+ except Exception as ex:
117
+ logging.warning(
118
+ 'In tar_file_and_group: {} when processing {}'.format(
119
+ ex, sample['src']))
120
+ finally:
121
+ if stream is not None:
122
+ stream.close()
123
+ if 'process' in sample:
124
+ sample['process'].communicate()
125
+ sample['stream'].close()
126
+
127
+
128
+ def tar_file_and_group_full_data(data):
129
+ """ Expand a stream of open tar files into a stream of tar file contents.
130
+ And groups the file with same prefix
131
+
132
+ Args:
133
+ data: Iterable[{src, stream}]
134
+
135
+ Returns:
136
+ Iterable[{key, wav, txt, sample_rate}]
137
+ """
138
+ for sample in data:
139
+ assert 'stream' in sample
140
+ stream = None
141
+ try:
142
+ stream = tarfile.open(fileobj=sample['stream'], mode="r:*")
143
+ prev_prefix = None
144
+ example = {}
145
+ valid = True
146
+ for tarinfo in stream:
147
+ name = tarinfo.name
148
+ pos = name.rfind('.')
149
+ assert pos > 0
150
+ prefix, postfix = name[:pos], name[pos + 1:]
151
+ if prev_prefix is not None and prefix != prev_prefix:
152
+ example['key'] = prev_prefix
153
+ if valid:
154
+ # assert 'txt' in example
155
+ if 'txt' not in example:
156
+ example['txt'] = ''
157
+ yield example
158
+ example = {}
159
+ valid = True
160
+ with stream.extractfile(tarinfo) as file_obj:
161
+ try:
162
+ if postfix == 'txt':
163
+ example['txt'] = file_obj.read().decode(
164
+ 'utf8').strip()
165
+ elif postfix == 'lang':
166
+ example['lang'] = file_obj.read().decode(
167
+ 'utf8').strip()
168
+ elif postfix == 'speaker':
169
+ try:
170
+ example['speaker'] = file_obj.read().decode(
171
+ 'utf8').strip()
172
+ except Exception as ex:
173
+ example['speaker'] = "none"
174
+ elif postfix == 'emotion':
175
+ example['emotion'] = file_obj.read().decode(
176
+ 'utf8').strip()
177
+ elif postfix == 'gender':
178
+ example['gender'] = file_obj.read().decode(
179
+ 'utf8').strip()
180
+ elif postfix == 'task':
181
+ example['task'] = file_obj.read().decode(
182
+ 'utf8').strip()
183
+ elif postfix == 'speech_token':
184
+ example['speech_token'] = file_obj.read()
185
+ elif postfix == 'duration':
186
+ duration_str = file_obj.read().decode(
187
+ 'utf8').strip()
188
+ try:
189
+ duration_float = float(duration_str)
190
+ example['duration'] = duration_float
191
+ except Exception as ex:
192
+ logging.warning(f'error to parse duration {duration_str}')
193
+ example['duration'] = 0
194
+
195
+ elif postfix in AUDIO_FORMAT_SETS:
196
+ waveform, sample_rate = torchaudio.load(file_obj)
197
+ # 检查音频的维度
198
+ num_channels = waveform.shape[0]
199
+ # 如果音频是多通道的,则进行通道平均
200
+ if num_channels > 1:
201
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
202
+ example['wav'] = waveform
203
+ example['sample_rate'] = sample_rate
204
+ else:
205
+ example[postfix] = file_obj.read()
206
+ except Exception as ex:
207
+ valid = False
208
+ # logging.warning('error to parse {}'.format(name))
209
+ prev_prefix = prefix
210
+ if prev_prefix is not None:
211
+ example['key'] = prev_prefix
212
+ if 'txt' in example:
213
+ yield example
214
+
215
+ except Exception as ex:
216
+ logging.warning(
217
+ 'In tar_file_and_group: {} when processing {}'.format(
218
+ ex, sample['src']))
219
+ finally:
220
+ if stream is not None:
221
+ stream.close()
222
+ if 'process' in sample:
223
+ sample['process'].communicate()
224
+ sample['stream'].close()
225
+
226
+
227
+ def parse_raw(data):
228
+ """ Parse key/wav/txt from json line
229
+
230
+ Args:
231
+ data: Iterable[str], str is a json line has key/wav/txt
232
+
233
+ Returns:
234
+ Iterable[{key, wav, txt, sample_rate}]
235
+ """
236
+ for sample in data:
237
+ assert 'src' in sample
238
+ json_line = sample['src']
239
+ obj = json.loads(json_line)
240
+ assert 'key' in obj
241
+ assert 'wav' in obj
242
+ assert 'txt' in obj
243
+ key = obj['key']
244
+ wav_file = obj['wav']
245
+ txt = obj['txt']
246
+ try:
247
+ if 'start' in obj:
248
+ assert 'end' in obj
249
+ sample_rate = torchaudio.info(wav_file).sample_rate
250
+ start_frame = int(obj['start'] * sample_rate)
251
+ end_frame = int(obj['end'] * sample_rate)
252
+ waveform, _ = torchaudio.load(filepath=wav_file,
253
+ num_frames=end_frame -
254
+ start_frame,
255
+ frame_offset=start_frame)
256
+ else:
257
+ waveform, sample_rate = torchaudio.load(wav_file)
258
+ # 检查音频的维度
259
+ num_channels = waveform.shape[0]
260
+ # 如果音频是多通道的,则进行通道平均
261
+ if num_channels > 1:
262
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
263
+ example = copy.deepcopy(obj) # copy and keep all the fields
264
+ example['wav'] = waveform # overwrite wav
265
+ example['sample_rate'] = sample_rate
266
+ yield example
267
+ except Exception as ex:
268
+ logging.warning('Failed to read {}'.format(wav_file))
269
+
270
+
271
+ def parse_speaker(data, speaker_table_path):
272
+ speaker_dict = {}
273
+ with open(speaker_table_path, 'r', encoding='utf8') as fin:
274
+ for line in fin:
275
+ arr = line.strip().split()
276
+ speaker_dict[arr[0]] = int(arr[1])
277
+ for sample in data:
278
+ assert 'speaker' in sample
279
+ speaker = sample['speaker']
280
+ sample['speaker'] = speaker_dict.get(speaker, 0)
281
+ yield sample
282
+
283
+
284
+ def filter(data,
285
+ max_length=1200,
286
+ min_length=10,
287
+ token_max_length=250,
288
+ token_min_length=1,
289
+ min_output_input_ratio=0.00005,
290
+ max_output_input_ratio=1,
291
+ filter_no_extra_info: bool = False,
292
+ max_seq_len=1000):
293
+ """ Filter sample according to feature and label length
294
+ Inplace operation.
295
+
296
+ Args::
297
+ data: Iterable[{key, wav, label, sample_rate}]
298
+ max_length: drop utterance which is greater than max_length(10ms)
299
+ min_length: drop utterance which is less than min_length(10ms)
300
+ token_max_length: drop utterance which is greater than
301
+ token_max_length, especially when use char unit for
302
+ english modeling
303
+ token_min_length: drop utterance which is
304
+ less than token_max_length
305
+ min_output_input_ratio: minimal ration of
306
+ token_length / feats_length(10ms)
307
+ max_output_input_ratio: maximum ration of
308
+ token_length / feats_length(10ms)
309
+
310
+ Returns:
311
+ Iterable[{key, wav, label, sample_rate}]
312
+ """
313
+ for sample in data:
314
+ try:
315
+ assert 'sample_rate' in sample
316
+ assert 'wav' in sample
317
+ assert 'label' in sample
318
+ except:
319
+ continue
320
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
321
+ num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100
322
+
323
+ # filter for shard_in_common
324
+ if filter_no_extra_info:
325
+ if 'lang' not in sample:
326
+ continue
327
+ if 'task' not in sample:
328
+ continue
329
+
330
+ if num_frames < min_length:
331
+ continue
332
+
333
+ # if "output_type" in sample and sample["output_type"] == "speech2text_token":
334
+ # max_length = int(max_length / 2)
335
+ # if "output_type" in sample and sample["output_type"] == "text2token":
336
+ # max_length = int(max_length / 1.5)
337
+ if num_frames > max_length:
338
+ # continue
339
+ if 'task' in sample and sample['task'] == '<CAPTION>':
340
+ # utils_file.logging_limit_print('进行了随机剪裁')
341
+ # 随机选择一个起始点进行裁剪
342
+ start_frame = random.randint(0, int(num_frames - max_length))
343
+ end_frame = start_frame + max_length
344
+ sample['wav'] = sample['wav'][:, int(start_frame / 100 * sample['sample_rate']): int(
345
+ end_frame / 100 * sample['sample_rate'])]
346
+ # print('sample[', sample['wav'].shape)
347
+ else:
348
+ continue
349
+ if len(sample['label']) < token_min_length:
350
+ continue
351
+ if len(sample['label']) > token_max_length:
352
+ continue
353
+ # if num_frames != 0:
354
+ # if len(sample['label']) / num_frames < min_output_input_ratio:
355
+ # continue
356
+ # if len(sample['label']) / num_frames > max_output_input_ratio:
357
+ # continue
358
+
359
+ if sample["output_type"] == "speech2text_token":
360
+ seq_len = len(sample['prompt']) + num_frames / 8 + len(sample['label']) + len(sample['speech_token'])
361
+ elif sample["output_type"] == "text2token":
362
+ seq_len = len(sample['prompt']) + len(sample['label']) + len(sample['speech_token'])
363
+ else:
364
+ seq_len = len(sample['prompt']) + num_frames / 8 + len(sample['label'])
365
+ utils_file.logging_limit_print(f'seqlen: {seq_len}, output_type:{sample["output_type"]},len(sample["prompt"]):{len(sample["prompt"])},num_frames / 8:{num_frames / 8},len(sample["label"]):{len(sample["label"])},len(sample["speech_token"]):{len(sample["speech_token"])} ')
366
+ if max_seq_len > 0 and max_seq_len < seq_len:
367
+ utils_file.logging_limit_print(f"seqlen: {seq_len} 超过了最大长度:{max_seq_len},contiune")
368
+ continue
369
+ yield sample
370
+
371
+
372
+ def resample(data, resample_rate=16000):
373
+ """ Resample data.
374
+ Inplace operation.
375
+
376
+ Args:
377
+ data: Iterable[{key, wav, label, sample_rate}]
378
+ resample_rate: target resample rate
379
+
380
+ Returns:
381
+ Iterable[{key, wav, label, sample_rate}]
382
+ """
383
+ for sample in data:
384
+ assert 'sample_rate' in sample
385
+ assert 'wav' in sample
386
+ sample_rate = sample['sample_rate']
387
+ waveform = sample['wav']
388
+ if sample_rate != resample_rate:
389
+ sample['sample_rate'] = resample_rate
390
+ sample['wav'] = torchaudio.transforms.Resample(
391
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
392
+ yield sample
393
+
394
+
395
+ def speed_perturb(data, speeds=None):
396
+ """ Apply speed perturb to the data.
397
+ Inplace operation.
398
+
399
+ Args:
400
+ data: Iterable[{key, wav, label, sample_rate}]
401
+ speeds(List[float]): optional speed
402
+
403
+ Returns:
404
+ Iterable[{key, wav, label, sample_rate}]
405
+ """
406
+ if speeds is None:
407
+ speeds = [0.9, 1.0, 1.1]
408
+ for sample in data:
409
+ assert 'sample_rate' in sample
410
+ assert 'wav' in sample
411
+ sample_rate = sample['sample_rate']
412
+ waveform = sample['wav']
413
+ speed = random.choice(speeds)
414
+ if speed != 1.0:
415
+ wav, _ = torchaudio.sox_effects.apply_effects_tensor(
416
+ waveform, sample_rate,
417
+ [['speed', str(speed)], ['rate', str(sample_rate)]])
418
+ sample['wav'] = wav
419
+
420
+ yield sample
421
+
422
+
423
+ def compute_fbank(data,
424
+ num_mel_bins=23,
425
+ frame_length=25,
426
+ frame_shift=10,
427
+ dither=0.0):
428
+ """ Extract fbank
429
+
430
+ Args:
431
+ data: Iterable[{key, wav, label, sample_rate}]
432
+
433
+ Returns:
434
+ Iterable[{key, feat, label}]
435
+ """
436
+ for sample in data:
437
+ assert 'sample_rate' in sample
438
+ assert 'wav' in sample
439
+ assert 'key' in sample
440
+ assert 'label' in sample
441
+ sample_rate = sample['sample_rate']
442
+ waveform = sample['wav']
443
+ waveform = waveform * (1 << 15)
444
+ # Only keep key, feat, label
445
+ mat = kaldi.fbank(waveform,
446
+ num_mel_bins=num_mel_bins,
447
+ frame_length=frame_length,
448
+ frame_shift=frame_shift,
449
+ dither=dither,
450
+ energy_floor=0.0,
451
+ sample_frequency=sample_rate)
452
+ sample['feat'] = mat
453
+ yield sample
454
+
455
+
456
+ def compute_mfcc(data,
457
+ num_mel_bins=23,
458
+ frame_length=25,
459
+ frame_shift=10,
460
+ dither=0.0,
461
+ num_ceps=40,
462
+ high_freq=0.0,
463
+ low_freq=20.0):
464
+ """ Extract mfcc
465
+
466
+ Args:
467
+ data: Iterable[{key, wav, label, sample_rate}]
468
+
469
+ Returns:
470
+ Iterable[{key, feat, label}]
471
+ """
472
+ for sample in data:
473
+ assert 'sample_rate' in sample
474
+ assert 'wav' in sample
475
+ assert 'key' in sample
476
+ assert 'label' in sample
477
+ sample_rate = sample['sample_rate']
478
+ waveform = sample['wav']
479
+ waveform = waveform * (1 << 15)
480
+ # Only keep key, feat, label
481
+ mat = kaldi.mfcc(waveform,
482
+ num_mel_bins=num_mel_bins,
483
+ frame_length=frame_length,
484
+ frame_shift=frame_shift,
485
+ dither=dither,
486
+ num_ceps=num_ceps,
487
+ high_freq=high_freq,
488
+ low_freq=low_freq,
489
+ sample_frequency=sample_rate)
490
+ sample['feat'] = mat
491
+ yield sample
492
+
493
+
494
+ def compute_log_mel_spectrogram(data,
495
+ n_fft=400,
496
+ hop_length=160,
497
+ num_mel_bins=80,
498
+ padding=0):
499
+ """ Extract log mel spectrogram, modified from openai-whisper, see:
500
+ - https://github.com/openai/whisper/blob/main/whisper/audio.py
501
+ - https://github.com/wenet-e2e/wenet/pull/2141#issuecomment-1811765040
502
+
503
+ Args:
504
+ data: Iterable[{key, wav, label, sample_rate}]
505
+
506
+ Returns:
507
+ Iterable[{key, feat, label}]
508
+ """
509
+ for sample in data:
510
+ assert 'sample_rate' in sample
511
+ assert 'wav' in sample
512
+ assert 'key' in sample
513
+ assert 'label' in sample
514
+ sample_rate = sample['sample_rate']
515
+ waveform = sample['wav'].squeeze(0) # (channel=1, sample) -> (sample,)
516
+ # print(f'wavform shape: {waveform.shape}')
517
+ if padding > 0:
518
+ waveform = F.pad(waveform, (0, padding))
519
+ window = torch.hann_window(n_fft)
520
+ stft = torch.stft(waveform,
521
+ n_fft,
522
+ hop_length,
523
+ window=window,
524
+ return_complex=True)
525
+ magnitudes = stft[..., :-1].abs() ** 2
526
+
527
+ filters = torch.from_numpy(
528
+ librosa.filters.mel(sr=sample_rate,
529
+ n_fft=n_fft,
530
+ n_mels=num_mel_bins))
531
+ mel_spec = filters @ magnitudes
532
+
533
+ # NOTE(xcsong): https://github.com/openai/whisper/discussions/269
534
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
535
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
536
+ log_spec = (log_spec + 4.0) / 4.0
537
+ sample['feat'] = log_spec.transpose(0, 1)
538
+ yield sample
539
+
540
+
541
+ import re
542
+
543
+
544
+ def process_text(text):
545
+ # 1. 删除汉字左右两侧的空格
546
+ text = re.sub(r'\s*([\u4e00-\u9fff])\s*', r'\1', text)
547
+ # 2. 将英文转成小写
548
+ text = text.lower()
549
+ # 3. 删除 < 和 > 符号两侧的空格
550
+ text = re.sub(r'\s*<\s*', '<', text)
551
+ text = re.sub(r'\s*>\s*', '>', text)
552
+ return text
553
+
554
+
555
+ global_style_dict = {
556
+ "朗读": "新闻科普",
557
+ "科普百科": "新闻科普",
558
+ "悬疑恐怖": "恐怖故事",
559
+ "童话故事": "童话故事",
560
+ "客服": "客服",
561
+ "诗歌": "诗歌散文",
562
+ "散文": "诗歌散文",
563
+ "武侠评书": "有声书",
564
+ "小说": "有声书",
565
+ "历史": "有声书",
566
+ "科幻": "有声书",
567
+ "对话": "日常口语",
568
+ "口语": "日常口语",
569
+ "幽默": "其他",
570
+ "其他": "其他",
571
+ }
572
+
573
+
574
+ def replace_keys_in_brackets(input_str, key_value_dict):
575
+ for key, value in key_value_dict.items():
576
+ # 构造匹配 <key> 形式的正则表达式模式
577
+ pattern = re.compile(r'<{}>'.format(key))
578
+ input_str = pattern.sub(f"<{value}>", input_str)
579
+ return input_str
580
+
581
+
582
+ def tokenize(data, tokenizer: BaseTokenizer, global_prompt_dict=None):
583
+ """ Decode text to chars or BPE
584
+ Inplace operation
585
+
586
+ Args:
587
+ data: Iterable[{key, wav, txt, sample_rate}]
588
+
589
+ Returns:
590
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
591
+ """
592
+ for sample in data:
593
+ try:
594
+ assert 'txt' in sample
595
+ except:
596
+ print(f'tokenize: {sample}')
597
+ exit()
598
+ if 'task' in sample:
599
+ task_name = sample['task']
600
+ # if "<AGE>" in task_name:
601
+ # txt = sample['txt'].replace("<YOUTH>", "<ADULT>").replace("<MIDDLE_AGE>", "<ADULT>").replace("<MIDDLE>", "<ADULT>")
602
+ if "<STYLE>" in sample['task']:
603
+ txt = replace_keys_in_brackets(sample['txt'], global_style_dict)
604
+ else:
605
+ txt = sample['txt']
606
+ else:
607
+ txt = sample['txt']
608
+
609
+ tokens, label = tokenizer.tokenize(process_text(txt))
610
+ sample['tokens'] = tokens # token是字符, label是数字
611
+ sample['label'] = label + [tokenizer.eod_id]
612
+ if 'task' in sample:
613
+ task_name = sample['task']
614
+ try:
615
+ random_index = random.randint(0, len(global_prompt_dict[task_name]) - 1)
616
+ prompt = global_prompt_dict[task_name][random_index]
617
+ sample['prompt'] = tokenizer.tokenize(prompt)[1] # labels
618
+ except:
619
+ pass
620
+ else:
621
+ task_name = '<TRANSCRIBE>'
622
+ try:
623
+ random_index = random.randint(0, len(global_prompt_dict[task_name]) - 1)
624
+ prompt = global_prompt_dict[task_name][random_index]
625
+ sample['prompt'] = tokenizer.tokenize(prompt)[1] # labels
626
+ except:
627
+ pass
628
+
629
+ if 'speech_token' in sample:
630
+ old_task_name = sample['task']
631
+ if old_task_name == "<TRANSCRIBE>":
632
+ task_name = '<TEXT2SPEECH_TOKEN>'
633
+ sample['output_type'] = 'text2token'
634
+ elif old_task_name == "<S2TCHAT>":
635
+ task_name = '<SPEECH2TEXT_SPEECH_TOKEN>'
636
+ sample['output_type'] = 'speech2text_token'
637
+ else:
638
+ task_name = old_task_name
639
+ try:
640
+ random_index = random.randint(0, len(global_prompt_dict[task_name]) - 1)
641
+ prompt = global_prompt_dict[task_name][random_index]
642
+ sample['prompt'] = tokenizer.tokenize(prompt)[1] # labels
643
+ except:
644
+ pass
645
+ # 报错修改 from sywang ,只有推理的时候才会需要(raw格式),tar格式会自动转int list
646
+ # try:
647
+ # utils_file.logging_limit_print("type of sample['speech_token']: ", type(sample['speech_token']))
648
+ # speech_tokens = ast.literal_eval(sample['speech_token']) # 解析字符串为列表
649
+ # except (ValueError, SyntaxError) as e:
650
+ # print(f"解析错误: {e}在{speech_tokens}")
651
+ # speech_tokens = []
652
+ # speech_token = [int(x) for x in speech_tokens]
653
+ speech_token = [int(x) for x in sample['speech_token']]
654
+ sample['speech_token'] = speech_token + [4096]
655
+ else:
656
+ sample['output_type'] = 'text'
657
+ sample['speech_token'] = [4096]
658
+ yield sample
659
+
660
+
661
+ def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80):
662
+ """ Do spec augmentation
663
+ Inplace operation
664
+
665
+ Args:
666
+ data: Iterable[{key, feat, label}]
667
+ num_t_mask: number of time mask to apply
668
+ num_f_mask: number of freq mask to apply
669
+ max_t: max width of time mask
670
+ max_f: max width of freq mask
671
+ max_w: max width of time warp
672
+
673
+ Returns
674
+ Iterable[{key, feat, label}]
675
+ """
676
+ for sample in data:
677
+ assert 'feat' in sample
678
+ x = sample['feat']
679
+ assert isinstance(x, torch.Tensor)
680
+ y = x.clone().detach()
681
+ max_frames = y.size(0)
682
+ max_freq = y.size(1)
683
+ # time mask
684
+ for i in range(num_t_mask):
685
+ start = random.randint(0, max_frames - 1)
686
+ length = random.randint(1, max_t)
687
+ end = min(max_frames, start + length)
688
+ y[start:end, :] = 0
689
+ # freq mask
690
+ for i in range(num_f_mask):
691
+ start = random.randint(0, max_freq - 1)
692
+ length = random.randint(1, max_f)
693
+ end = min(max_freq, start + length)
694
+ y[:, start:end] = 0
695
+ sample['feat'] = y
696
+ yield sample
697
+
698
+
699
+ def spec_sub(data, max_t=20, num_t_sub=3):
700
+ """ Do spec substitute
701
+ Inplace operation
702
+ ref: U2++, section 3.2.3 [https://arxiv.org/abs/2106.05642]
703
+
704
+ Args:
705
+ data: Iterable[{key, feat, label}]
706
+ max_t: max width of time substitute
707
+ num_t_sub: number of time substitute to apply
708
+
709
+ Returns
710
+ Iterable[{key, feat, label}]
711
+ """
712
+ for sample in data:
713
+ assert 'feat' in sample
714
+ x = sample['feat']
715
+ assert isinstance(x, torch.Tensor)
716
+ y = x.clone().detach()
717
+ max_frames = y.size(0)
718
+ for i in range(num_t_sub):
719
+ start = random.randint(0, max_frames - 1)
720
+ length = random.randint(1, max_t)
721
+ end = min(max_frames, start + length)
722
+ # only substitute the earlier time chosen randomly for current time
723
+ pos = random.randint(0, start)
724
+ y[start:end, :] = x[start - pos:end - pos, :]
725
+ sample['feat'] = y
726
+ yield sample
727
+
728
+
729
+ def spec_trim(data, max_t=20):
730
+ """ Trim tailing frames. Inplace operation.
731
+ ref: TrimTail [https://arxiv.org/abs/2211.00522]
732
+
733
+ Args:
734
+ data: Iterable[{key, feat, label}]
735
+ max_t: max width of length trimming
736
+
737
+ Returns
738
+ Iterable[{key, feat, label}]
739
+ """
740
+ for sample in data:
741
+ assert 'feat' in sample
742
+ x = sample['feat']
743
+ assert isinstance(x, torch.Tensor)
744
+ max_frames = x.size(0)
745
+ length = random.randint(1, max_t)
746
+ if length < max_frames / 2:
747
+ y = x.clone().detach()[:max_frames - length]
748
+ sample['feat'] = y
749
+ yield sample
750
+
751
+
752
+ def shuffle(data, shuffle_size=10000):
753
+ """ Local shuffle the data
754
+
755
+ Args:
756
+ data: Iterable[{key, feat, label}]
757
+ shuffle_size: buffer size for shuffle
758
+
759
+ Returns:
760
+ Iterable[{key, feat, label}]
761
+ """
762
+ buf = []
763
+ for sample in data:
764
+ buf.append(sample)
765
+ if len(buf) >= shuffle_size:
766
+ random.shuffle(buf)
767
+ for x in buf:
768
+ yield x
769
+ buf = []
770
+ # The sample left over
771
+ random.shuffle(buf)
772
+ for x in buf:
773
+ yield x
774
+
775
+
776
+ def sort(data, sort_size=500):
777
+ """ Sort the data by feature length.
778
+ Sort is used after shuffle and before batch, so we can group
779
+ utts with similar lengths into a batch, and `sort_size` should
780
+ be less than `shuffle_size`
781
+
782
+ Args:
783
+ data: Iterable[{key, feat, label}]
784
+ sort_size: buffer size for sort
785
+
786
+ Returns:
787
+ Iterable[{key, feat, label}]
788
+ """
789
+
790
+ buf = []
791
+ for sample in data:
792
+ buf.append(sample)
793
+ if len(buf) >= sort_size:
794
+ buf.sort(key=lambda x: x['feat'].size(0))
795
+ for x in buf:
796
+ yield x
797
+ buf = []
798
+ # The sample left over
799
+ buf.sort(key=lambda x: x['feat'].size(0))
800
+ for x in buf:
801
+ yield x
802
+
803
+
804
+ def static_batch(data, batch_size=16):
805
+ """ Static batch the data by `batch_size`
806
+
807
+ Args:
808
+ data: Iterable[{key, feat, label}]
809
+ batch_size: batch size
810
+
811
+ Returns:
812
+ Iterable[List[{key, feat, label}]]
813
+ """
814
+ buf = []
815
+ for sample in data:
816
+ buf.append(sample)
817
+ if len(buf) >= batch_size:
818
+ yield buf
819
+ buf = []
820
+ if len(buf) > 0:
821
+ yield buf
822
+
823
+
824
+ def dynamic_batch(data, max_frames_in_batch=12000, max_seq_in_batch=10000000):
825
+ """ Dynamic batch the data until the total frames in batch
826
+ reach `max_frames_in_batch`
827
+
828
+ Args:
829
+ data: Iterable[{key, feat, label}]
830
+ max_frames_in_batch: max_frames in one batch
831
+
832
+ Returns:
833
+ Iterable[List[{key, feat, label}]]
834
+ """
835
+ buf = []
836
+ longest_frames = 0
837
+ longest_seq = 0
838
+ max_frames_in_batch = max_frames_in_batch
839
+
840
+ buf_speech_token = []
841
+ longest_frames_token = 0
842
+ longest_seq_token = 0
843
+ max_frames_in_batch_token = int(max_frames_in_batch)
844
+
845
+ buf_speech_token_with_text = []
846
+ longest_frames_token_with_text = 0
847
+ longest_seq_token_with_text = 0
848
+ max_frames_in_batch_token_with_text = int(max_frames_in_batch / 2.5)
849
+
850
+ for sample in data:
851
+ assert 'feat' in sample
852
+ assert isinstance(sample['feat'], torch.Tensor)
853
+ new_sample_frames = sample['feat'].size(0)
854
+ if "output_type" in sample and sample["output_type"] == "speech2text_token":
855
+ new_seq = sample['feat'].size(0) / 8 + len(sample['label']) + len(sample.get('prompt', [])) + len(
856
+ sample.get('speech_token', []))
857
+ longest_seq_token = max(longest_seq_token, new_seq)
858
+ utils_file.logging_limit_print(
859
+ f'batchf fuc,当前条目new_seq为: {new_seq},longest_seq_token为: {longest_seq_token}')
860
+ longest_frames_token = max(longest_frames_token, new_sample_frames)
861
+ frames_after_padding_token = longest_frames_token * (len(buf_speech_token)+1)
862
+ seq_after_padding_token = longest_seq_token * (len(buf_speech_token)+1)
863
+ utils_file.logging_limit_print(
864
+ f'batchf fuc,当前条目new_seq为: {new_seq},longest_seq_token为: {longest_seq_token},seq_after_padding_token: {seq_after_padding_token}')
865
+ utils_file.logging_limit_print(
866
+ f'batchf fuc,当前条目 new_sample_frames 为: {new_sample_frames},longest_frames_token: {longest_frames_token},frames_after_padding_token: {frames_after_padding_token}')
867
+ if frames_after_padding_token > max_frames_in_batch_token or seq_after_padding_token > max_seq_in_batch:
868
+ yield buf_speech_token
869
+ buf_speech_token = [sample]
870
+ longest_frames_token = new_sample_frames
871
+ longest_seq_token = new_seq
872
+ else:
873
+ buf_speech_token.append(sample)
874
+ elif "output_type" in sample and sample["output_type"] == "text2token":
875
+ new_seq = len(sample['label']) + len(sample.get('prompt', [])) + len(
876
+ sample.get('speech_token', []))
877
+ longest_seq_token_with_text = max(longest_seq_token_with_text, new_seq)
878
+ longest_frames_token_with_text = max(longest_frames_token_with_text, new_sample_frames)
879
+ frames_after_padding_token_with_text = longest_frames_token_with_text * (len(buf_speech_token_with_text)+1)
880
+ seq_after_padding_token_with_text = longest_seq_token_with_text * (len(buf_speech_token_with_text)+1)
881
+ if frames_after_padding_token_with_text > max_frames_in_batch_token_with_text or seq_after_padding_token_with_text > max_seq_in_batch:
882
+ yield buf_speech_token_with_text
883
+ buf_speech_token_with_text = [sample]
884
+ longest_frames_token_with_text = new_sample_frames
885
+ longest_seq_token_with_text = new_seq
886
+ else:
887
+ buf_speech_token_with_text.append(sample)
888
+ else:
889
+ new_seq = sample['feat'].size(0) / 8 + len(sample['label']) + len(sample.get('prompt', []))
890
+ longest_seq = max(longest_seq, new_seq)
891
+ longest_frames = max(longest_frames, new_sample_frames)
892
+ frames_after_padding = longest_frames * (len(buf)+1)
893
+ seq_after_padding = longest_seq * (len(buf)+1)
894
+ if frames_after_padding > max_frames_in_batch or seq_after_padding > max_seq_in_batch:
895
+ yield buf
896
+ buf = [sample]
897
+ longest_frames = new_sample_frames
898
+ longest_seq = new_seq
899
+ else:
900
+ buf.append(sample)
901
+ if len(buf) > 0:
902
+ yield buf
903
+
904
+
905
+ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, max_seq_in_batch=10000000):
906
+ """ Wrapper for static/dynamic batch
907
+ """
908
+ if batch_type == 'static':
909
+ return static_batch(data, batch_size)
910
+ elif batch_type == 'dynamic':
911
+ return dynamic_batch(data, max_frames_in_batch, max_seq_in_batch=max_seq_in_batch)
912
+ else:
913
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
914
+
915
+
916
+ def padding(data):
917
+ """ Padding the data into training data
918
+
919
+ Args:
920
+ data: Iterable[List[{key, feat, label}]]
921
+
922
+ Returns:
923
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
924
+ """
925
+ for sample in data:
926
+ assert isinstance(sample, list)
927
+ feats_length = torch.tensor([x['feat'].size(0) for x in sample],
928
+ dtype=torch.int32)
929
+ order = torch.argsort(feats_length, descending=True)
930
+ feats_lengths = torch.tensor(
931
+ [sample[i]['feat'].size(0) for i in order], dtype=torch.int32)
932
+ sorted_feats = [sample[i]['feat'] for i in order]
933
+ sorted_keys = [sample[i]['key'] for i in order]
934
+ sorted_labels = [
935
+ torch.tensor(sample[i]['label'], dtype=torch.int64) for i in order
936
+ ]
937
+ sorted_speech_tokens = [
938
+ torch.tensor(sample[i]['speech_token'], dtype=torch.int64) for i in order
939
+ ]
940
+
941
+ sorted_wavs = [sample[i]['wav'].squeeze(0) for i in order]
942
+ label_lengths = torch.tensor([x.size(0) for x in sorted_labels],
943
+ dtype=torch.int32)
944
+ speech_token_lengths = torch.tensor([x.size(0) for x in sorted_speech_tokens],
945
+ dtype=torch.int32)
946
+ wav_lengths = torch.tensor([x.size(0) for x in sorted_wavs],
947
+ dtype=torch.int32)
948
+ # print('------------------')
949
+ # for feat_item in sorted_feats:
950
+ # print(feat_item.shape)
951
+ # print('------------------')
952
+
953
+ padded_feats = pad_sequence(sorted_feats,
954
+ batch_first=True,
955
+ padding_value=0)
956
+ padding_labels = pad_sequence(sorted_labels,
957
+ batch_first=True,
958
+ padding_value=-100)
959
+
960
+ padding_speech_tokens = pad_sequence(sorted_speech_tokens,
961
+ batch_first=True,
962
+ padding_value=-100)
963
+ padded_wavs = pad_sequence(sorted_wavs,
964
+ batch_first=True,
965
+ padding_value=0)
966
+
967
+ sorted_lang = [
968
+ sample[i].get('lang', 'cn') for i in order
969
+ ]
970
+
971
+ sorted_speaker = [
972
+ sample[i].get('speaker', 'None') for i in order
973
+ ]
974
+
975
+ sorted_emotion = [
976
+ sample[i].get('emotion', 'None') for i in order
977
+ ]
978
+ sorted_gender = [
979
+ sample[i].get('gender', 'None') for i in order
980
+ ]
981
+ # sorted_duration = [
982
+ # sample[i]['duration'] for i in order
983
+ # ]
984
+ sorted_task = [
985
+ sample[i].get('task', '<TRANSCRIBE>') for i in order
986
+ ]
987
+
988
+ batch = {
989
+ "keys": sorted_keys,
990
+ "feats": padded_feats,
991
+ "target": padding_labels,
992
+ "feats_lengths": feats_lengths,
993
+ "target_lengths": label_lengths,
994
+ "pcm": padded_wavs,
995
+ "pcm_length": wav_lengths,
996
+ "speech_tokens": padding_speech_tokens,
997
+ "speech_tokens_length": speech_token_lengths,
998
+ "lang": sorted_lang,
999
+ "speaker": sorted_speaker,
1000
+ "emotion": sorted_emotion,
1001
+ "gender": sorted_gender,
1002
+ "task": sorted_task
1003
+ }
1004
+ if 'prompt' in sample[0]:
1005
+ sorted_prompts = [
1006
+ torch.tensor(sample[i]['prompt'], dtype=torch.int64
1007
+ ) for i in order
1008
+ ]
1009
+ prompt_lengths = torch.tensor([x.size(0) for x in
1010
+ sorted_prompts], dtype=torch.int32)
1011
+ padding_prompts = pad_sequence(sorted_prompts,
1012
+ batch_first=True,
1013
+ padding_value=-1)
1014
+ batch['prompt'] = padding_prompts
1015
+ batch['prompt_lengths'] = prompt_lengths
1016
+
1017
+ if 'output_type' in sample[0] and sample[0]['output_type'] == 'speech2text_token':
1018
+ batch['output_type'] = 'speech2text_token'
1019
+ elif 'output_type' in sample[0] and sample[0]['output_type'] == 'text2token':
1020
+ batch['output_type'] = 'text2token'
1021
+ else:
1022
+ batch['output_type'] = 'text'
1023
+ yield batch
wenet/dataset/kaldi_io.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2014-2016 Brno University of Technology (author: Karel Vesely)
5
+ # Licensed under the Apache License, Version 2.0 (the "License")
6
+
7
+ import numpy as np
8
+ import sys, os, re, gzip, struct
9
+
10
+ #################################################
11
+ # Adding kaldi tools to shell path,
12
+
13
+ # Select kaldi,
14
+ if not 'KALDI_ROOT' in os.environ:
15
+ # Default! To change run python with 'export KALDI_ROOT=/some_dir python'
16
+ os.environ['KALDI_ROOT'] = '/mnt/matylda5/iveselyk/Tools/kaldi-trunk'
17
+
18
+ # Add kaldi tools to path,
19
+ os.environ['PATH'] = os.popen(
20
+ 'echo $KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lm/:$KALDI_ROOT/src/sgmmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$KALDI_ROOT/src/nnetbin:$KALDI_ROOT/src/nnet2bin:$KALDI_ROOT/src/nnet3bin:$KALDI_ROOT/src/online2bin/:$KALDI_ROOT/src/ivectorbin/:$KALDI_ROOT/src/lmbin/'
21
+ ).readline().strip() + ':' + os.environ['PATH']
22
+
23
+
24
+ #################################################
25
+ # Define all custom exceptions,
26
+ class UnsupportedDataType(Exception):
27
+ pass
28
+
29
+
30
+ class UnknownVectorHeader(Exception):
31
+ pass
32
+
33
+
34
+ class UnknownMatrixHeader(Exception):
35
+ pass
36
+
37
+
38
+ class BadSampleSize(Exception):
39
+ pass
40
+
41
+
42
+ class BadInputFormat(Exception):
43
+ pass
44
+
45
+
46
+ class SubprocessFailed(Exception):
47
+ pass
48
+
49
+
50
+ #################################################
51
+ # Data-type independent helper functions,
52
+
53
+
54
+ def open_or_fd(file, mode='rb'):
55
+ """ fd = open_or_fd(file)
56
+ Open file, gzipped file, pipe, or forward the file-descriptor.
57
+ Eventually seeks in the 'file' argument contains ':offset' suffix.
58
+ """
59
+ offset = None
60
+ try:
61
+ # strip 'ark:' prefix from r{x,w}filename (optional),
62
+ if re.search('^(ark|scp)(,scp|,b|,t|,n?f|,n?p|,b?o|,n?s|,n?cs)*:',
63
+ file):
64
+ (prefix, file) = file.split(':', 1)
65
+ # separate offset from filename (optional),
66
+ if re.search(':[0-9]+$', file):
67
+ (file, offset) = file.rsplit(':', 1)
68
+ # input pipe?
69
+ if file[-1] == '|':
70
+ fd = popen(file[:-1], 'rb') # custom,
71
+ # output pipe?
72
+ elif file[0] == '|':
73
+ fd = popen(file[1:], 'wb') # custom,
74
+ # is it gzipped?
75
+ elif file.split('.')[-1] == 'gz':
76
+ fd = gzip.open(file, mode)
77
+ # a normal file...
78
+ else:
79
+ fd = open(file, mode)
80
+ except TypeError:
81
+ # 'file' is opened file descriptor,
82
+ fd = file
83
+ # Eventually seek to offset,
84
+ if offset != None: fd.seek(int(offset))
85
+ return fd
86
+
87
+
88
+ # based on '/usr/local/lib/python3.4/os.py'
89
+ def popen(cmd, mode="rb"):
90
+ if not isinstance(cmd, str):
91
+ raise TypeError("invalid cmd type (%s, expected string)" % type(cmd))
92
+
93
+ import subprocess, io, threading
94
+
95
+ # cleanup function for subprocesses,
96
+ def cleanup(proc, cmd):
97
+ ret = proc.wait()
98
+ if ret > 0:
99
+ raise SubprocessFailed('cmd %s returned %d !' % (cmd, ret))
100
+ return
101
+
102
+ # text-mode,
103
+ if mode == "r":
104
+ proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
105
+ threading.Thread(target=cleanup,
106
+ args=(proc, cmd)).start() # clean-up thread,
107
+ return io.TextIOWrapper(proc.stdout)
108
+ elif mode == "w":
109
+ proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE)
110
+ threading.Thread(target=cleanup,
111
+ args=(proc, cmd)).start() # clean-up thread,
112
+ return io.TextIOWrapper(proc.stdin)
113
+ # binary,
114
+ elif mode == "rb":
115
+ proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
116
+ threading.Thread(target=cleanup,
117
+ args=(proc, cmd)).start() # clean-up thread,
118
+ return proc.stdout
119
+ elif mode == "wb":
120
+ proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE)
121
+ threading.Thread(target=cleanup,
122
+ args=(proc, cmd)).start() # clean-up thread,
123
+ return proc.stdin
124
+ # sanity,
125
+ else:
126
+ raise ValueError("invalid mode %s" % mode)
127
+
128
+
129
+ def read_key(fd):
130
+ """ [key] = read_key(fd)
131
+ Read the utterance-key from the opened ark/stream descriptor 'fd'.
132
+ """
133
+ key = ''
134
+ while 1:
135
+ char = fd.read(1).decode("latin1")
136
+ if char == '': break
137
+ if char == ' ': break
138
+ key += char
139
+ key = key.strip()
140
+ if key == '': return None # end of file,
141
+ assert (re.match('^\S+$', key) != None) # check format (no whitespace!)
142
+ return key
143
+
144
+
145
+ #################################################
146
+ # Integer vectors (alignments, ...),
147
+
148
+
149
+ def read_ali_ark(file_or_fd):
150
+ """ Alias to 'read_vec_int_ark()' """
151
+ return read_vec_int_ark(file_or_fd)
152
+
153
+
154
+ def read_vec_int_ark(file_or_fd):
155
+ """ generator(key,vec) = read_vec_int_ark(file_or_fd)
156
+ Create generator of (key,vector<int>) tuples, which reads from the ark file/stream.
157
+ file_or_fd : ark, gzipped ark, pipe or opened file descriptor.
158
+
159
+ Read ark to a 'dictionary':
160
+ d = { u:d for u,d in kaldi_io.read_vec_int_ark(file) }
161
+ """
162
+ fd = open_or_fd(file_or_fd)
163
+ try:
164
+ key = read_key(fd)
165
+ while key:
166
+ ali = read_vec_int(fd)
167
+ yield key, ali
168
+ key = read_key(fd)
169
+ finally:
170
+ if fd is not file_or_fd: fd.close()
171
+
172
+
173
+ def read_vec_int_scp(file_or_fd):
174
+ """ generator(key,vec) = read_vec_int_scp(file_or_fd)
175
+ Returns generator of (key,vector<int>) tuples, read according to kaldi scp.
176
+ file_or_fd : scp, gzipped scp, pipe or opened file descriptor.
177
+
178
+ Iterate the scp:
179
+ for key,vec in kaldi_io.read_vec_int_scp(file):
180
+ ...
181
+
182
+ Read scp to a 'dictionary':
183
+ d = { key:vec for key,mat in kaldi_io.read_vec_int_scp(file) }
184
+ """
185
+ fd = open_or_fd(file_or_fd)
186
+ try:
187
+ for line in fd:
188
+ (key, rxfile) = line.decode().split(' ')
189
+ vec = read_vec_int(rxfile)
190
+ yield key, vec
191
+ finally:
192
+ if fd is not file_or_fd: fd.close()
193
+
194
+
195
+ def read_vec_int(file_or_fd):
196
+ """ [int-vec] = read_vec_int(file_or_fd)
197
+ Read kaldi integer vector, ascii or binary input,
198
+ """
199
+ fd = open_or_fd(file_or_fd)
200
+ binary = fd.read(2).decode()
201
+ if binary == '\0B': # binary flag
202
+ assert (fd.read(1).decode() == '\4')
203
+ # int-size
204
+ vec_size = np.frombuffer(fd.read(4), dtype='int32',
205
+ count=1)[0] # vector dim
206
+ # Elements from int32 vector are sored in tuples: (sizeof(int32), value),
207
+ vec = np.frombuffer(fd.read(vec_size * 5),
208
+ dtype=[('size', 'int8'), ('value', 'int32')],
209
+ count=vec_size)
210
+ assert (vec[0]['size'] == 4) # int32 size,
211
+ ans = vec[:]['value'] # values are in 2nd column,
212
+ else: # ascii,
213
+ arr = (binary + fd.readline().decode()).strip().split()
214
+ try:
215
+ arr.remove('[')
216
+ arr.remove(']') # optionally
217
+ except ValueError:
218
+ pass
219
+ ans = np.array(arr, dtype=int)
220
+ if fd is not file_or_fd: fd.close() # cleanup
221
+ return ans
222
+
223
+
224
+ # Writing,
225
+ def write_vec_int(file_or_fd, v, key=''):
226
+ """ write_vec_int(f, v, key='')
227
+ Write a binary kaldi integer vector to filename or stream.
228
+ Arguments:
229
+ file_or_fd : filename or opened file descriptor for writing,
230
+ v : the vector to be stored,
231
+ key (optional) : used for writing ark-file, the utterance-id gets written before the vector.
232
+
233
+ Example of writing single vector:
234
+ kaldi_io.write_vec_int(filename, vec)
235
+
236
+ Example of writing arkfile:
237
+ with open(ark_file,'w') as f:
238
+ for key,vec in dict.iteritems():
239
+ kaldi_io.write_vec_flt(f, vec, key=key)
240
+ """
241
+ fd = open_or_fd(file_or_fd, mode='wb')
242
+ if sys.version_info[0] == 3: assert (fd.mode == 'wb')
243
+ try:
244
+ if key != '':
245
+ fd.write(
246
+ (key +
247
+ ' ').encode("latin1")) # ark-files have keys (utterance-id),
248
+ fd.write('\0B'.encode()) # we write binary!
249
+ # dim,
250
+ fd.write('\4'.encode()) # int32 type,
251
+ fd.write(struct.pack(np.dtype('int32').char, v.shape[0]))
252
+ # data,
253
+ for i in range(len(v)):
254
+ fd.write('\4'.encode()) # int32 type,
255
+ fd.write(struct.pack(np.dtype('int32').char, v[i])) # binary,
256
+ finally:
257
+ if fd is not file_or_fd: fd.close()
258
+
259
+
260
+ #################################################
261
+ # Float vectors (confidences, ivectors, ...),
262
+
263
+
264
+ # Reading,
265
+ def read_vec_flt_scp(file_or_fd):
266
+ """ generator(key,mat) = read_vec_flt_scp(file_or_fd)
267
+ Returns generator of (key,vector) tuples, read according to kaldi scp.
268
+ file_or_fd : scp, gzipped scp, pipe or opened file descriptor.
269
+
270
+ Iterate the scp:
271
+ for key,vec in kaldi_io.read_vec_flt_scp(file):
272
+ ...
273
+
274
+ Read scp to a 'dictionary':
275
+ d = { key:mat for key,mat in kaldi_io.read_mat_scp(file) }
276
+ """
277
+ fd = open_or_fd(file_or_fd)
278
+ try:
279
+ for line in fd:
280
+ (key, rxfile) = line.decode().split(' ')
281
+ vec = read_vec_flt(rxfile)
282
+ yield key, vec
283
+ finally:
284
+ if fd is not file_or_fd: fd.close()
285
+
286
+
287
+ def read_vec_flt_ark(file_or_fd):
288
+ """ generator(key,vec) = read_vec_flt_ark(file_or_fd)
289
+ Create generator of (key,vector<float>) tuples, reading from an ark file/stream.
290
+ file_or_fd : ark, gzipped ark, pipe or opened file descriptor.
291
+
292
+ Read ark to a 'dictionary':
293
+ d = { u:d for u,d in kaldi_io.read_vec_flt_ark(file) }
294
+ """
295
+ fd = open_or_fd(file_or_fd)
296
+ try:
297
+ key = read_key(fd)
298
+ while key:
299
+ ali = read_vec_flt(fd)
300
+ yield key, ali
301
+ key = read_key(fd)
302
+ finally:
303
+ if fd is not file_or_fd: fd.close()
304
+
305
+
306
+ def read_vec_flt(file_or_fd):
307
+ """ [flt-vec] = read_vec_flt(file_or_fd)
308
+ Read kaldi float vector, ascii or binary input,
309
+ """
310
+ fd = open_or_fd(file_or_fd)
311
+ binary = fd.read(2).decode()
312
+ if binary == '\0B': # binary flag
313
+ # Data type,
314
+ header = fd.read(3).decode()
315
+ if header == 'FV ': sample_size = 4 # floats
316
+ elif header == 'DV ': sample_size = 8 # doubles
317
+ else: raise UnknownVectorHeader("The header contained '%s'" % header)
318
+ assert (sample_size > 0)
319
+ # Dimension,
320
+ assert (fd.read(1).decode() == '\4')
321
+ # int-size
322
+ vec_size = np.frombuffer(fd.read(4), dtype='int32',
323
+ count=1)[0] # vector dim
324
+ # Read whole vector,
325
+ buf = fd.read(vec_size * sample_size)
326
+ if sample_size == 4: ans = np.frombuffer(buf, dtype='float32')
327
+ elif sample_size == 8: ans = np.frombuffer(buf, dtype='float64')
328
+ else: raise BadSampleSize
329
+ return ans
330
+ else: # ascii,
331
+ arr = (binary + fd.readline().decode()).strip().split()
332
+ try:
333
+ arr.remove('[')
334
+ arr.remove(']') # optionally
335
+ except ValueError:
336
+ pass
337
+ ans = np.array(arr, dtype=float)
338
+ if fd is not file_or_fd: fd.close() # cleanup
339
+ return ans
340
+
341
+
342
+ # Writing,
343
+ def write_vec_flt(file_or_fd, v, key=''):
344
+ """ write_vec_flt(f, v, key='')
345
+ Write a binary kaldi vector to filename or stream. Supports 32bit and 64bit floats.
346
+ Arguments:
347
+ file_or_fd : filename or opened file descriptor for writing,
348
+ v : the vector to be stored,
349
+ key (optional) : used for writing ark-file, the utterance-id gets written before the vector.
350
+
351
+ Example of writing single vector:
352
+ kaldi_io.write_vec_flt(filename, vec)
353
+
354
+ Example of writing arkfile:
355
+ with open(ark_file,'w') as f:
356
+ for key,vec in dict.iteritems():
357
+ kaldi_io.write_vec_flt(f, vec, key=key)
358
+ """
359
+ fd = open_or_fd(file_or_fd, mode='wb')
360
+ if sys.version_info[0] == 3: assert (fd.mode == 'wb')
361
+ try:
362
+ if key != '':
363
+ fd.write(
364
+ (key +
365
+ ' ').encode("latin1")) # ark-files have keys (utterance-id),
366
+ fd.write('\0B'.encode()) # we write binary!
367
+ # Data-type,
368
+ if v.dtype == 'float32': fd.write('FV '.encode())
369
+ elif v.dtype == 'float64': fd.write('DV '.encode())
370
+ else:
371
+ raise UnsupportedDataType(
372
+ "'%s', please use 'float32' or 'float64'" % v.dtype)
373
+ # Dim,
374
+ fd.write('\04'.encode())
375
+ fd.write(struct.pack(np.dtype('uint32').char, v.shape[0])) # dim
376
+ # Data,
377
+ fd.write(v.tobytes())
378
+ finally:
379
+ if fd is not file_or_fd: fd.close()
380
+
381
+
382
+ #################################################
383
+ # Float matrices (features, transformations, ...),
384
+
385
+
386
+ # Reading,
387
+ def read_mat_scp(file_or_fd):
388
+ """ generator(key,mat) = read_mat_scp(file_or_fd)
389
+ Returns generator of (key,matrix) tuples, read according to kaldi scp.
390
+ file_or_fd : scp, gzipped scp, pipe or opened file descriptor.
391
+
392
+ Iterate the scp:
393
+ for key,mat in kaldi_io.read_mat_scp(file):
394
+ ...
395
+
396
+ Read scp to a 'dictionary':
397
+ d = { key:mat for key,mat in kaldi_io.read_mat_scp(file) }
398
+ """
399
+ fd = open_or_fd(file_or_fd)
400
+ try:
401
+ for line in fd:
402
+ (key, rxfile) = line.decode().split(' ')
403
+ mat = read_mat(rxfile)
404
+ yield key, mat
405
+ finally:
406
+ if fd is not file_or_fd: fd.close()
407
+
408
+
409
+ def read_mat_ark(file_or_fd):
410
+ """ generator(key,mat) = read_mat_ark(file_or_fd)
411
+ Returns generator of (key,matrix) tuples, read from ark file/stream.
412
+ file_or_fd : scp, gzipped scp, pipe or opened file descriptor.
413
+
414
+ Iterate the ark:
415
+ for key,mat in kaldi_io.read_mat_ark(file):
416
+ ...
417
+
418
+ Read ark to a 'dictionary':
419
+ d = { key:mat for key,mat in kaldi_io.read_mat_ark(file) }
420
+ """
421
+ fd = open_or_fd(file_or_fd)
422
+ try:
423
+ key = read_key(fd)
424
+ while key:
425
+ mat = read_mat(fd)
426
+ yield key, mat
427
+ key = read_key(fd)
428
+ finally:
429
+ if fd is not file_or_fd: fd.close()
430
+
431
+
432
+ def read_mat(file_or_fd):
433
+ """ [mat] = read_mat(file_or_fd)
434
+ Reads single kaldi matrix, supports ascii and binary.
435
+ file_or_fd : file, gzipped file, pipe or opened file descriptor.
436
+ """
437
+ fd = open_or_fd(file_or_fd)
438
+ try:
439
+ binary = fd.read(2).decode()
440
+ if binary == '\0B':
441
+ mat = _read_mat_binary(fd)
442
+ else:
443
+ assert (binary == ' [')
444
+ mat = _read_mat_ascii(fd)
445
+ finally:
446
+ if fd is not file_or_fd: fd.close()
447
+ return mat
448
+
449
+
450
+ def _read_mat_binary(fd):
451
+ # Data type
452
+ header = fd.read(3).decode()
453
+ # 'CM', 'CM2', 'CM3' are possible values,
454
+ if header.startswith('CM'): return _read_compressed_mat(fd, header)
455
+ elif header == 'FM ': sample_size = 4 # floats
456
+ elif header == 'DM ': sample_size = 8 # doubles
457
+ else: raise UnknownMatrixHeader("The header contained '%s'" % header)
458
+ assert (sample_size > 0)
459
+ # Dimensions
460
+ s1, rows, s2, cols = np.frombuffer(fd.read(10),
461
+ dtype='int8,int32,int8,int32',
462
+ count=1)[0]
463
+ # Read whole matrix
464
+ buf = fd.read(rows * cols * sample_size)
465
+ if sample_size == 4: vec = np.frombuffer(buf, dtype='float32')
466
+ elif sample_size == 8: vec = np.frombuffer(buf, dtype='float64')
467
+ else: raise BadSampleSize
468
+ mat = np.reshape(vec, (rows, cols))
469
+ return mat
470
+
471
+
472
+ def _read_mat_ascii(fd):
473
+ rows = []
474
+ while 1:
475
+ line = fd.readline().decode()
476
+ if (len(line) == 0): raise BadInputFormat # eof, should not happen!
477
+ if len(line.strip()) == 0: continue # skip empty line
478
+ arr = line.strip().split()
479
+ if arr[-1] != ']':
480
+ rows.append(np.array(arr, dtype='float32')) # not last line
481
+ else:
482
+ rows.append(np.array(arr[:-1], dtype='float32')) # last line
483
+ mat = np.vstack(rows)
484
+ return mat
485
+
486
+
487
+ def _read_compressed_mat(fd, format):
488
+ """ Read a compressed matrix,
489
+ see: https://github.com/kaldi-asr/kaldi/blob/master/src/matrix/compressed-matrix.h
490
+ methods: CompressedMatrix::Read(...), CompressedMatrix::CopyToMat(...),
491
+ """
492
+ assert (format == 'CM ') # The formats CM2, CM3 are not supported...
493
+
494
+ # Format of header 'struct',
495
+ global_header = np.dtype([('minvalue', 'float32'), ('range', 'float32'),
496
+ ('num_rows', 'int32'), ('num_cols', 'int32')
497
+ ]) # member '.format' is not written,
498
+ per_col_header = np.dtype([('percentile_0', 'uint16'),
499
+ ('percentile_25', 'uint16'),
500
+ ('percentile_75', 'uint16'),
501
+ ('percentile_100', 'uint16')])
502
+
503
+ # Mapping for percentiles in col-headers,
504
+ def uint16_to_float(value, min, range):
505
+ return np.float32(min + range * 1.52590218966964e-05 * value)
506
+
507
+ # Mapping for matrix elements,
508
+ def uint8_to_float_v2(vec, p0, p25, p75, p100):
509
+ # Split the vector by masks,
510
+ mask_0_64 = (vec <= 64)
511
+ mask_193_255 = (vec > 192)
512
+ mask_65_192 = (~(mask_0_64 | mask_193_255))
513
+ # Sanity check (useful but slow...),
514
+ # assert(len(vec) == np.sum(np.hstack([mask_0_64,mask_65_192,mask_193_255])))
515
+ # assert(len(vec) == np.sum(np.any([mask_0_64,mask_65_192,mask_193_255], axis=0)))
516
+ # Build the float vector,
517
+ ans = np.empty(len(vec), dtype='float32')
518
+ ans[mask_0_64] = p0 + (p25 - p0) / 64. * vec[mask_0_64]
519
+ ans[mask_65_192] = p25 + (p75 - p25) / 128. * (vec[mask_65_192] - 64)
520
+ ans[mask_193_255] = p75 + (p100 - p75) / 63. * (vec[mask_193_255] -
521
+ 192)
522
+ return ans
523
+
524
+ # Read global header,
525
+ globmin, globrange, rows, cols = np.frombuffer(fd.read(16),
526
+ dtype=global_header,
527
+ count=1)[0]
528
+
529
+ # The data is structed as [Colheader, ... , Colheader, Data, Data , .... ]
530
+ # { cols }{ size }
531
+ col_headers = np.frombuffer(fd.read(cols * 8),
532
+ dtype=per_col_header,
533
+ count=cols)
534
+ data = np.reshape(np.frombuffer(fd.read(cols * rows),
535
+ dtype='uint8',
536
+ count=cols * rows),
537
+ newshape=(cols, rows)) # stored as col-major,
538
+
539
+ mat = np.empty((cols, rows), dtype='float32')
540
+ for i, col_header in enumerate(col_headers):
541
+ col_header_flt = [
542
+ uint16_to_float(percentile, globmin, globrange)
543
+ for percentile in col_header
544
+ ]
545
+ mat[i] = uint8_to_float_v2(data[i], *col_header_flt)
546
+
547
+ return mat.T # transpose! col-major -> row-major,
548
+
549
+
550
+ def write_ark_scp(key, mat, ark_fout, scp_out):
551
+ mat_offset = write_mat(ark_fout, mat, key)
552
+ scp_line = '{}\t{}:{}'.format(key, ark_fout.name, mat_offset)
553
+ scp_out.write(scp_line)
554
+ scp_out.write('\n')
555
+
556
+
557
+ # Writing,
558
+ def write_mat(file_or_fd, m, key=''):
559
+ """ write_mat(f, m, key='')
560
+ Write a binary kaldi matrix to filename or stream. Supports 32bit and 64bit floats.
561
+ Arguments:
562
+ file_or_fd : filename of opened file descriptor for writing,
563
+ m : the matrix to be stored,
564
+ key (optional) : used for writing ark-file, the utterance-id gets written before the matrix.
565
+
566
+ Example of writing single matrix:
567
+ kaldi_io.write_mat(filename, mat)
568
+
569
+ Example of writing arkfile:
570
+ with open(ark_file,'w') as f:
571
+ for key,mat in dict.iteritems():
572
+ kaldi_io.write_mat(f, mat, key=key)
573
+ """
574
+ mat_offset = 0
575
+ fd = open_or_fd(file_or_fd, mode='wb')
576
+ if sys.version_info[0] == 3: assert (fd.mode == 'wb')
577
+ try:
578
+ if key != '':
579
+ fd.write(
580
+ (key +
581
+ ' ').encode("latin1")) # ark-files have keys (utterance-id),
582
+ mat_offset = fd.tell()
583
+ fd.write('\0B'.encode()) # we write binary!
584
+ # Data-type,
585
+ if m.dtype == 'float32': fd.write('FM '.encode())
586
+ elif m.dtype == 'float64': fd.write('DM '.encode())
587
+ else:
588
+ raise UnsupportedDataType(
589
+ "'%s', please use 'float32' or 'float64'" % m.dtype)
590
+ # Dims,
591
+ fd.write('\04'.encode())
592
+ fd.write(struct.pack(np.dtype('uint32').char, m.shape[0])) # rows
593
+ fd.write('\04'.encode())
594
+ fd.write(struct.pack(np.dtype('uint32').char, m.shape[1])) # cols
595
+ # Data,
596
+ fd.write(m.tobytes())
597
+ finally:
598
+ if fd is not file_or_fd: fd.close()
599
+ return mat_offset
600
+
601
+
602
+ #################################################
603
+ # 'Posterior' kaldi type (posteriors, confusion network, nnet1 training targets, ...)
604
+ # Corresponds to: vector<vector<tuple<int,float> > >
605
+ # - outer vector: time axis
606
+ # - inner vector: records at the time
607
+ # - tuple: int = index, float = value
608
+ #
609
+
610
+
611
+ def read_cnet_ark(file_or_fd):
612
+ """ Alias of function 'read_post_ark()', 'cnet' = confusion network """
613
+ return read_post_ark(file_or_fd)
614
+
615
+
616
+ def read_post_ark(file_or_fd):
617
+ """ generator(key,vec<vec<int,float>>) = read_post_ark(file)
618
+ Returns generator of (key,posterior) tuples, read from ark file.
619
+ file_or_fd : ark, gzipped ark, pipe or opened file descriptor.
620
+
621
+ Iterate the ark:
622
+ for key,post in kaldi_io.read_post_ark(file):
623
+ ...
624
+
625
+ Read ark to a 'dictionary':
626
+ d = { key:post for key,post in kaldi_io.read_post_ark(file) }
627
+ """
628
+ fd = open_or_fd(file_or_fd)
629
+ try:
630
+ key = read_key(fd)
631
+ while key:
632
+ post = read_post(fd)
633
+ yield key, post
634
+ key = read_key(fd)
635
+ finally:
636
+ if fd is not file_or_fd: fd.close()
637
+
638
+
639
+ def read_post(file_or_fd):
640
+ """ [post] = read_post(file_or_fd)
641
+ Reads single kaldi 'Posterior' in binary format.
642
+
643
+ The 'Posterior' is C++ type 'vector<vector<tuple<int,float> > >',
644
+ the outer-vector is usually time axis, inner-vector are the records
645
+ at given time, and the tuple is composed of an 'index' (integer)
646
+ and a 'float-value'. The 'float-value' can represent a probability
647
+ or any other numeric value.
648
+
649
+ Returns vector of vectors of tuples.
650
+ """
651
+ fd = open_or_fd(file_or_fd)
652
+ ans = []
653
+ binary = fd.read(2).decode()
654
+ assert (binary == '\0B')
655
+ # binary flag
656
+ assert (fd.read(1).decode() == '\4')
657
+ # int-size
658
+ outer_vec_size = np.frombuffer(fd.read(4), dtype='int32',
659
+ count=1)[0] # number of frames (or bins)
660
+
661
+ # Loop over 'outer-vector',
662
+ for i in range(outer_vec_size):
663
+ assert (fd.read(1).decode() == '\4')
664
+ # int-size
665
+ inner_vec_size = np.frombuffer(
666
+ fd.read(4), dtype='int32',
667
+ count=1)[0] # number of records for frame (or bin)
668
+ data = np.frombuffer(fd.read(inner_vec_size * 10),
669
+ dtype=[('size_idx', 'int8'), ('idx', 'int32'),
670
+ ('size_post', 'int8'),
671
+ ('post', 'float32')],
672
+ count=inner_vec_size)
673
+ assert (data[0]['size_idx'] == 4)
674
+ assert (data[0]['size_post'] == 4)
675
+ ans.append(data[['idx', 'post']].tolist())
676
+
677
+ if fd is not file_or_fd: fd.close()
678
+ return ans
679
+
680
+
681
+ #################################################
682
+ # Kaldi Confusion Network bin begin/end times,
683
+ # (kaldi stores CNs time info separately from the Posterior).
684
+ #
685
+
686
+
687
+ def read_cntime_ark(file_or_fd):
688
+ """ generator(key,vec<tuple<float,float>>) = read_cntime_ark(file_or_fd)
689
+ Returns generator of (key,cntime) tuples, read from ark file.
690
+ file_or_fd : file, gzipped file, pipe or opened file descriptor.
691
+
692
+ Iterate the ark:
693
+ for key,time in kaldi_io.read_cntime_ark(file):
694
+ ...
695
+
696
+ Read ark to a 'dictionary':
697
+ d = { key:time for key,time in kaldi_io.read_post_ark(file) }
698
+ """
699
+ fd = open_or_fd(file_or_fd)
700
+ try:
701
+ key = read_key(fd)
702
+ while key:
703
+ cntime = read_cntime(fd)
704
+ yield key, cntime
705
+ key = read_key(fd)
706
+ finally:
707
+ if fd is not file_or_fd: fd.close()
708
+
709
+
710
+ def read_cntime(file_or_fd):
711
+ """ [cntime] = read_cntime(file_or_fd)
712
+ Reads single kaldi 'Confusion Network time info', in binary format:
713
+ C++ type: vector<tuple<float,float> >.
714
+ (begin/end times of bins at the confusion network).
715
+
716
+ Binary layout is '<num-bins> <beg1> <end1> <beg2> <end2> ...'
717
+
718
+ file_or_fd : file, gzipped file, pipe or opened file descriptor.
719
+
720
+ Returns vector of tuples.
721
+ """
722
+ fd = open_or_fd(file_or_fd)
723
+ binary = fd.read(2).decode()
724
+ assert (binary == '\0B')
725
+ # assuming it's binary
726
+
727
+ assert (fd.read(1).decode() == '\4')
728
+ # int-size
729
+ vec_size = np.frombuffer(fd.read(4), dtype='int32',
730
+ count=1)[0] # number of frames (or bins)
731
+
732
+ data = np.frombuffer(fd.read(vec_size * 10),
733
+ dtype=[('size_beg', 'int8'), ('t_beg', 'float32'),
734
+ ('size_end', 'int8'), ('t_end', 'float32')],
735
+ count=vec_size)
736
+ assert (data[0]['size_beg'] == 4)
737
+ assert (data[0]['size_end'] == 4)
738
+ ans = data[['t_beg',
739
+ 't_end']].tolist() # Return vector of tuples (t_beg,t_end),
740
+
741
+ if fd is not file_or_fd: fd.close()
742
+ return ans
743
+
744
+
745
+ #################################################
746
+ # Segments related,
747
+ #
748
+
749
+
750
+ # Segments as 'Bool vectors' can be handy,
751
+ # - for 'superposing' the segmentations,
752
+ # - for frame-selection in Speaker-ID experiments,
753
+ def read_segments_as_bool_vec(segments_file):
754
+ """ [ bool_vec ] = read_segments_as_bool_vec(segments_file)
755
+ using kaldi 'segments' file for 1 wav, format : '<utt> <rec> <t-beg> <t-end>'
756
+ - t-beg, t-end is in seconds,
757
+ - assumed 100 frames/second,
758
+ """
759
+ segs = np.loadtxt(segments_file, dtype='object,object,f,f', ndmin=1)
760
+ # Sanity checks,
761
+ assert (len(segs) > 0) # empty segmentation is an error,
762
+ assert (len(np.unique([rec[1] for rec in segs])) == 1
763
+ ) # segments with only 1 wav-file,
764
+ # Convert time to frame-indexes,
765
+ start = np.rint([100 * rec[2] for rec in segs]).astype(int)
766
+ end = np.rint([100 * rec[3] for rec in segs]).astype(int)
767
+ # Taken from 'read_lab_to_bool_vec', htk.py,
768
+ frms = np.repeat(
769
+ np.r_[np.tile([False, True], len(end)), False],
770
+ np.r_[np.c_[start - np.r_[0, end[:-1]], end - start].flat, 0])
771
+ assert np.sum(end - start) == np.sum(frms)
772
+ return frms
wenet/dataset/processor.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import librosa
17
+ import logging
18
+ import json
19
+ import random
20
+ import tarfile
21
+ from subprocess import PIPE, Popen
22
+ from urllib.parse import urlparse
23
+
24
+ import torch
25
+ import torchaudio
26
+ import torchaudio.compliance.kaldi as kaldi
27
+ import torch.nn.functional as F
28
+ from torch.nn.utils.rnn import pad_sequence
29
+ from wenet.text.base_tokenizer import BaseTokenizer
30
+
31
+ torchaudio.utils.sox_utils.set_buffer_size(16500)
32
+
33
+ AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
34
+
35
+
36
+ def url_opener(data):
37
+ """ Give url or local file, return file descriptor
38
+ Inplace operation.
39
+
40
+ Args:
41
+ data(Iterable[str]): url or local file list
42
+
43
+ Returns:
44
+ Iterable[{src, stream}]
45
+ """
46
+ for sample in data:
47
+ assert 'src' in sample
48
+ # TODO(Binbin Zhang): support HTTP
49
+ url = sample['src']
50
+ try:
51
+ pr = urlparse(url)
52
+ # local file
53
+ if pr.scheme == '' or pr.scheme == 'file':
54
+ stream = open(url, 'rb')
55
+ # network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP
56
+ else:
57
+ cmd = f'wget -q -O - {url}'
58
+ process = Popen(cmd, shell=True, stdout=PIPE)
59
+ sample.update(process=process)
60
+ stream = process.stdout
61
+ sample.update(stream=stream)
62
+ yield sample
63
+ except Exception as ex:
64
+ logging.warning('Failed to open {}'.format(url))
65
+
66
+
67
+ def tar_file_and_group(data):
68
+ """ Expand a stream of open tar files into a stream of tar file contents.
69
+ And groups the file with same prefix
70
+
71
+ Args:
72
+ data: Iterable[{src, stream}]
73
+
74
+ Returns:
75
+ Iterable[{key, wav, txt, sample_rate}]
76
+ """
77
+ for sample in data:
78
+ assert 'stream' in sample
79
+ stream = None
80
+ try:
81
+ stream = tarfile.open(fileobj=sample['stream'], mode="r:*")
82
+ prev_prefix = None
83
+ example = {}
84
+ valid = True
85
+ for tarinfo in stream:
86
+ name = tarinfo.name
87
+ pos = name.rfind('.')
88
+ assert pos > 0
89
+ prefix, postfix = name[:pos], name[pos + 1:]
90
+ if prev_prefix is not None and prefix != prev_prefix:
91
+ example['key'] = prev_prefix
92
+ if valid:
93
+ yield example
94
+ example = {}
95
+ valid = True
96
+ with stream.extractfile(tarinfo) as file_obj:
97
+ try:
98
+ if postfix == 'txt':
99
+ example['txt'] = file_obj.read().decode(
100
+ 'utf8').strip()
101
+ elif postfix in AUDIO_FORMAT_SETS:
102
+ waveform, sample_rate = torchaudio.load(file_obj)
103
+ example['wav'] = waveform
104
+ example['sample_rate'] = sample_rate
105
+ else:
106
+ example[postfix] = file_obj.read()
107
+ except Exception as ex:
108
+ valid = False
109
+ logging.warning('error to parse {}'.format(name))
110
+ prev_prefix = prefix
111
+ if prev_prefix is not None:
112
+ example['key'] = prev_prefix
113
+ yield example
114
+ except Exception as ex:
115
+ logging.warning(
116
+ 'In tar_file_and_group: {} when processing {}'.format(
117
+ ex, sample['src']))
118
+ finally:
119
+ if stream is not None:
120
+ stream.close()
121
+ if 'process' in sample:
122
+ sample['process'].communicate()
123
+ sample['stream'].close()
124
+
125
+
126
+ def parse_raw(data):
127
+ """ Parse key/wav/txt from json line
128
+
129
+ Args:
130
+ data: Iterable[str], str is a json line has key/wav/txt
131
+
132
+ Returns:
133
+ Iterable[{key, wav, txt, sample_rate}]
134
+ """
135
+ for sample in data:
136
+ assert 'src' in sample
137
+ json_line = sample['src']
138
+ obj = json.loads(json_line)
139
+ assert 'key' in obj
140
+ assert 'wav' in obj
141
+ assert 'txt' in obj
142
+ key = obj['key']
143
+ wav_file = obj['wav']
144
+ txt = obj['txt']
145
+ try:
146
+ if 'start' in obj:
147
+ assert 'end' in obj
148
+ sample_rate = torchaudio.info(wav_file).sample_rate
149
+ start_frame = int(obj['start'] * sample_rate)
150
+ end_frame = int(obj['end'] * sample_rate)
151
+ waveform, _ = torchaudio.load(filepath=wav_file,
152
+ num_frames=end_frame -
153
+ start_frame,
154
+ frame_offset=start_frame)
155
+ else:
156
+ waveform, sample_rate = torchaudio.load(wav_file)
157
+ example = copy.deepcopy(obj) # copy and keep all the fields
158
+ example['wav'] = waveform # overwrite wav
159
+ example['sample_rate'] = sample_rate
160
+ yield example
161
+ except Exception as ex:
162
+ logging.warning('Failed to read {}'.format(wav_file))
163
+
164
+
165
+ def parse_speaker(data, speaker_table_path):
166
+ speaker_dict = {}
167
+ with open(speaker_table_path, 'r', encoding='utf8') as fin:
168
+ for line in fin:
169
+ arr = line.strip().split()
170
+ speaker_dict[arr[0]] = int(arr[1])
171
+ for sample in data:
172
+ assert 'speaker' in sample
173
+ speaker = sample['speaker']
174
+ sample['speaker'] = speaker_dict.get(speaker, 0)
175
+ yield sample
176
+
177
+
178
+ def filter(data,
179
+ max_length=10240,
180
+ min_length=10,
181
+ token_max_length=200,
182
+ token_min_length=1,
183
+ min_output_input_ratio=0.0005,
184
+ max_output_input_ratio=1):
185
+ """ Filter sample according to feature and label length
186
+ Inplace operation.
187
+
188
+ Args::
189
+ data: Iterable[{key, wav, label, sample_rate}]
190
+ max_length: drop utterance which is greater than max_length(10ms)
191
+ min_length: drop utterance which is less than min_length(10ms)
192
+ token_max_length: drop utterance which is greater than
193
+ token_max_length, especially when use char unit for
194
+ english modeling
195
+ token_min_length: drop utterance which is
196
+ less than token_max_length
197
+ min_output_input_ratio: minimal ration of
198
+ token_length / feats_length(10ms)
199
+ max_output_input_ratio: maximum ration of
200
+ token_length / feats_length(10ms)
201
+
202
+ Returns:
203
+ Iterable[{key, wav, label, sample_rate}]
204
+ """
205
+ for sample in data:
206
+ try:
207
+ assert 'sample_rate' in sample
208
+ assert 'wav' in sample
209
+ assert 'label' in sample
210
+ except:
211
+ continue
212
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
213
+ num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100
214
+ if num_frames < min_length:
215
+ continue
216
+ if num_frames > max_length:
217
+ continue
218
+ if len(sample['label']) < token_min_length:
219
+ continue
220
+ if len(sample['label']) > token_max_length:
221
+ continue
222
+ if num_frames != 0:
223
+ if len(sample['label']) / num_frames < min_output_input_ratio:
224
+ continue
225
+ if len(sample['label']) / num_frames > max_output_input_ratio:
226
+ continue
227
+ yield sample
228
+
229
+
230
+ def resample(data, resample_rate=16000):
231
+ """ Resample data.
232
+ Inplace operation.
233
+
234
+ Args:
235
+ data: Iterable[{key, wav, label, sample_rate}]
236
+ resample_rate: target resample rate
237
+
238
+ Returns:
239
+ Iterable[{key, wav, label, sample_rate}]
240
+ """
241
+ for sample in data:
242
+ assert 'sample_rate' in sample
243
+ assert 'wav' in sample
244
+ sample_rate = sample['sample_rate']
245
+ waveform = sample['wav']
246
+ if sample_rate != resample_rate:
247
+ sample['sample_rate'] = resample_rate
248
+ sample['wav'] = torchaudio.transforms.Resample(
249
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
250
+ yield sample
251
+
252
+
253
+ def speed_perturb(data, speeds=None):
254
+ """ Apply speed perturb to the data.
255
+ Inplace operation.
256
+
257
+ Args:
258
+ data: Iterable[{key, wav, label, sample_rate}]
259
+ speeds(List[float]): optional speed
260
+
261
+ Returns:
262
+ Iterable[{key, wav, label, sample_rate}]
263
+ """
264
+ if speeds is None:
265
+ speeds = [0.9, 1.0, 1.1]
266
+ for sample in data:
267
+ assert 'sample_rate' in sample
268
+ assert 'wav' in sample
269
+ sample_rate = sample['sample_rate']
270
+ waveform = sample['wav']
271
+ speed = random.choice(speeds)
272
+ if speed != 1.0:
273
+ wav, _ = torchaudio.sox_effects.apply_effects_tensor(
274
+ waveform, sample_rate,
275
+ [['speed', str(speed)], ['rate', str(sample_rate)]])
276
+ sample['wav'] = wav
277
+
278
+ yield sample
279
+
280
+
281
+ def compute_fbank(data,
282
+ num_mel_bins=23,
283
+ frame_length=25,
284
+ frame_shift=10,
285
+ dither=0.0):
286
+ """ Extract fbank
287
+
288
+ Args:
289
+ data: Iterable[{key, wav, label, sample_rate}]
290
+
291
+ Returns:
292
+ Iterable[{key, feat, label}]
293
+ """
294
+ for sample in data:
295
+ assert 'sample_rate' in sample
296
+ assert 'wav' in sample
297
+ assert 'key' in sample
298
+ assert 'label' in sample
299
+ sample_rate = sample['sample_rate']
300
+ waveform = sample['wav']
301
+ waveform = waveform * (1 << 15)
302
+ # Only keep key, feat, label
303
+ mat = kaldi.fbank(waveform,
304
+ num_mel_bins=num_mel_bins,
305
+ frame_length=frame_length,
306
+ frame_shift=frame_shift,
307
+ dither=dither,
308
+ energy_floor=0.0,
309
+ sample_frequency=sample_rate)
310
+ sample['feat'] = mat
311
+ yield sample
312
+
313
+
314
+ def compute_mfcc(data,
315
+ num_mel_bins=23,
316
+ frame_length=25,
317
+ frame_shift=10,
318
+ dither=0.0,
319
+ num_ceps=40,
320
+ high_freq=0.0,
321
+ low_freq=20.0):
322
+ """ Extract mfcc
323
+
324
+ Args:
325
+ data: Iterable[{key, wav, label, sample_rate}]
326
+
327
+ Returns:
328
+ Iterable[{key, feat, label}]
329
+ """
330
+ for sample in data:
331
+ assert 'sample_rate' in sample
332
+ assert 'wav' in sample
333
+ assert 'key' in sample
334
+ assert 'label' in sample
335
+ sample_rate = sample['sample_rate']
336
+ waveform = sample['wav']
337
+ waveform = waveform * (1 << 15)
338
+ # Only keep key, feat, label
339
+ mat = kaldi.mfcc(waveform,
340
+ num_mel_bins=num_mel_bins,
341
+ frame_length=frame_length,
342
+ frame_shift=frame_shift,
343
+ dither=dither,
344
+ num_ceps=num_ceps,
345
+ high_freq=high_freq,
346
+ low_freq=low_freq,
347
+ sample_frequency=sample_rate)
348
+ sample['feat'] = mat
349
+ yield sample
350
+
351
+
352
+ def compute_log_mel_spectrogram(data,
353
+ n_fft=400,
354
+ hop_length=160,
355
+ num_mel_bins=80,
356
+ padding=0):
357
+ """ Extract log mel spectrogram, modified from openai-whisper, see:
358
+ - https://github.com/openai/whisper/blob/main/whisper/audio.py
359
+ - https://github.com/wenet-e2e/wenet/pull/2141#issuecomment-1811765040
360
+
361
+ Args:
362
+ data: Iterable[{key, wav, label, sample_rate}]
363
+
364
+ Returns:
365
+ Iterable[{key, feat, label}]
366
+ """
367
+ for sample in data:
368
+ assert 'sample_rate' in sample
369
+ assert 'wav' in sample
370
+ assert 'key' in sample
371
+ assert 'label' in sample
372
+ sample_rate = sample['sample_rate']
373
+ waveform = sample['wav'].squeeze(0) # (channel=1, sample) -> (sample,)
374
+ if padding > 0:
375
+ waveform = F.pad(waveform, (0, padding))
376
+ window = torch.hann_window(n_fft)
377
+ stft = torch.stft(waveform,
378
+ n_fft,
379
+ hop_length,
380
+ window=window,
381
+ return_complex=True)
382
+ magnitudes = stft[..., :-1].abs()**2
383
+
384
+ filters = torch.from_numpy(
385
+ librosa.filters.mel(sr=sample_rate,
386
+ n_fft=n_fft,
387
+ n_mels=num_mel_bins))
388
+ mel_spec = filters @ magnitudes
389
+
390
+ # NOTE(xcsong): https://github.com/openai/whisper/discussions/269
391
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
392
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
393
+ log_spec = (log_spec + 4.0) / 4.0
394
+ sample['feat'] = log_spec.transpose(0, 1)
395
+ yield sample
396
+
397
+
398
+ def tokenize(data, tokenizer: BaseTokenizer, global_prompt_dict=None):
399
+ """ Decode text to chars or BPE
400
+ Inplace operation
401
+
402
+ Args:
403
+ data: Iterable[{key, wav, txt, sample_rate}]
404
+
405
+ Returns:
406
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
407
+ """
408
+ for sample in data:
409
+ assert 'txt' in sample
410
+ if 'task' in sample:
411
+ task_name = sample['task']
412
+ if "<AGE>" in task_name:
413
+ txt = sample['txt'].replace("<YOUTH>", "<ADULT>")
414
+ else:
415
+ txt = sample['txt']
416
+ else:
417
+ txt = sample['txt']
418
+ tokens, label = tokenizer.tokenize(txt)
419
+ sample['tokens'] = tokens
420
+ sample['label'] = label + [tokenizer.eod_id]
421
+ if 'task' in sample:
422
+ task_name = sample['task']
423
+ random_index = random.randint(0, len(global_prompt_dict[task_name])-1)
424
+ prompt = global_prompt_dict[task_name][random_index]
425
+ sample['prompt'] = tokenizer.tokenize(prompt)
426
+ yield sample
427
+
428
+
429
+ def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80):
430
+ """ Do spec augmentation
431
+ Inplace operation
432
+
433
+ Args:
434
+ data: Iterable[{key, feat, label}]
435
+ num_t_mask: number of time mask to apply
436
+ num_f_mask: number of freq mask to apply
437
+ max_t: max width of time mask
438
+ max_f: max width of freq mask
439
+ max_w: max width of time warp
440
+
441
+ Returns
442
+ Iterable[{key, feat, label}]
443
+ """
444
+ for sample in data:
445
+ assert 'feat' in sample
446
+ x = sample['feat']
447
+ assert isinstance(x, torch.Tensor)
448
+ y = x.clone().detach()
449
+ max_frames = y.size(0)
450
+ max_freq = y.size(1)
451
+ # time mask
452
+ for i in range(num_t_mask):
453
+ start = random.randint(0, max_frames - 1)
454
+ length = random.randint(1, max_t)
455
+ end = min(max_frames, start + length)
456
+ y[start:end, :] = 0
457
+ # freq mask
458
+ for i in range(num_f_mask):
459
+ start = random.randint(0, max_freq - 1)
460
+ length = random.randint(1, max_f)
461
+ end = min(max_freq, start + length)
462
+ y[:, start:end] = 0
463
+ sample['feat'] = y
464
+ yield sample
465
+
466
+
467
+ def spec_sub(data, max_t=20, num_t_sub=3):
468
+ """ Do spec substitute
469
+ Inplace operation
470
+ ref: U2++, section 3.2.3 [https://arxiv.org/abs/2106.05642]
471
+
472
+ Args:
473
+ data: Iterable[{key, feat, label}]
474
+ max_t: max width of time substitute
475
+ num_t_sub: number of time substitute to apply
476
+
477
+ Returns
478
+ Iterable[{key, feat, label}]
479
+ """
480
+ for sample in data:
481
+ assert 'feat' in sample
482
+ x = sample['feat']
483
+ assert isinstance(x, torch.Tensor)
484
+ y = x.clone().detach()
485
+ max_frames = y.size(0)
486
+ for i in range(num_t_sub):
487
+ start = random.randint(0, max_frames - 1)
488
+ length = random.randint(1, max_t)
489
+ end = min(max_frames, start + length)
490
+ # only substitute the earlier time chosen randomly for current time
491
+ pos = random.randint(0, start)
492
+ y[start:end, :] = x[start - pos:end - pos, :]
493
+ sample['feat'] = y
494
+ yield sample
495
+
496
+
497
+ def spec_trim(data, max_t=20):
498
+ """ Trim tailing frames. Inplace operation.
499
+ ref: TrimTail [https://arxiv.org/abs/2211.00522]
500
+
501
+ Args:
502
+ data: Iterable[{key, feat, label}]
503
+ max_t: max width of length trimming
504
+
505
+ Returns
506
+ Iterable[{key, feat, label}]
507
+ """
508
+ for sample in data:
509
+ assert 'feat' in sample
510
+ x = sample['feat']
511
+ assert isinstance(x, torch.Tensor)
512
+ max_frames = x.size(0)
513
+ length = random.randint(1, max_t)
514
+ if length < max_frames / 2:
515
+ y = x.clone().detach()[:max_frames - length]
516
+ sample['feat'] = y
517
+ yield sample
518
+
519
+
520
+ def shuffle(data, shuffle_size=10000):
521
+ """ Local shuffle the data
522
+
523
+ Args:
524
+ data: Iterable[{key, feat, label}]
525
+ shuffle_size: buffer size for shuffle
526
+
527
+ Returns:
528
+ Iterable[{key, feat, label}]
529
+ """
530
+ buf = []
531
+ for sample in data:
532
+ buf.append(sample)
533
+ if len(buf) >= shuffle_size:
534
+ random.shuffle(buf)
535
+ for x in buf:
536
+ yield x
537
+ buf = []
538
+ # The sample left over
539
+ random.shuffle(buf)
540
+ for x in buf:
541
+ yield x
542
+
543
+
544
+ def sort(data, sort_size=500):
545
+ """ Sort the data by feature length.
546
+ Sort is used after shuffle and before batch, so we can group
547
+ utts with similar lengths into a batch, and `sort_size` should
548
+ be less than `shuffle_size`
549
+
550
+ Args:
551
+ data: Iterable[{key, feat, label}]
552
+ sort_size: buffer size for sort
553
+
554
+ Returns:
555
+ Iterable[{key, feat, label}]
556
+ """
557
+
558
+ buf = []
559
+ for sample in data:
560
+ buf.append(sample)
561
+ if len(buf) >= sort_size:
562
+ buf.sort(key=lambda x: x['feat'].size(0))
563
+ for x in buf:
564
+ yield x
565
+ buf = []
566
+ # The sample left over
567
+ buf.sort(key=lambda x: x['feat'].size(0))
568
+ for x in buf:
569
+ yield x
570
+
571
+
572
+ def static_batch(data, batch_size=16):
573
+ """ Static batch the data by `batch_size`
574
+
575
+ Args:
576
+ data: Iterable[{key, feat, label}]
577
+ batch_size: batch size
578
+
579
+ Returns:
580
+ Iterable[List[{key, feat, label}]]
581
+ """
582
+ buf = []
583
+ for sample in data:
584
+ buf.append(sample)
585
+ if len(buf) >= batch_size:
586
+ yield buf
587
+ buf = []
588
+ if len(buf) > 0:
589
+ yield buf
590
+
591
+
592
+ def dynamic_batch(data, max_frames_in_batch=12000):
593
+ """ Dynamic batch the data until the total frames in batch
594
+ reach `max_frames_in_batch`
595
+
596
+ Args:
597
+ data: Iterable[{key, feat, label}]
598
+ max_frames_in_batch: max_frames in one batch
599
+
600
+ Returns:
601
+ Iterable[List[{key, feat, label}]]
602
+ """
603
+ buf = []
604
+ longest_frames = 0
605
+ for sample in data:
606
+ assert 'feat' in sample
607
+ assert isinstance(sample['feat'], torch.Tensor)
608
+ new_sample_frames = sample['feat'].size(0)
609
+ longest_frames = max(longest_frames, new_sample_frames)
610
+ frames_after_padding = longest_frames * (len(buf) + 1)
611
+ if frames_after_padding > max_frames_in_batch:
612
+ yield buf
613
+ buf = [sample]
614
+ longest_frames = new_sample_frames
615
+ else:
616
+ buf.append(sample)
617
+ if len(buf) > 0:
618
+ yield buf
619
+
620
+
621
+ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000):
622
+ """ Wrapper for static/dynamic batch
623
+ """
624
+ if batch_type == 'static':
625
+ return static_batch(data, batch_size)
626
+ elif batch_type == 'dynamic':
627
+ return dynamic_batch(data, max_frames_in_batch)
628
+ else:
629
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
630
+
631
+
632
+ def padding(data):
633
+ """ Padding the data into training data
634
+
635
+ Args:
636
+ data: Iterable[List[{key, feat, label}]]
637
+
638
+ Returns:
639
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
640
+ """
641
+ for sample in data:
642
+ assert isinstance(sample, list)
643
+ feats_length = torch.tensor([x['feat'].size(0) for x in sample],
644
+ dtype=torch.int32)
645
+ order = torch.argsort(feats_length, descending=True)
646
+ feats_lengths = torch.tensor(
647
+ [sample[i]['feat'].size(0) for i in order], dtype=torch.int32)
648
+ sorted_feats = [sample[i]['feat'] for i in order]
649
+ sorted_keys = [sample[i]['key'] for i in order]
650
+ sorted_labels = [
651
+ torch.tensor(sample[i]['label'], dtype=torch.int64) for i in order
652
+ ]
653
+ sorted_wavs = [sample[i]['wav'].squeeze(0) for i in order]
654
+ label_lengths = torch.tensor([x.size(0) for x in sorted_labels],
655
+ dtype=torch.int32)
656
+ wav_lengths = torch.tensor([x.size(0) for x in sorted_wavs],
657
+ dtype=torch.int32)
658
+
659
+ padded_feats = pad_sequence(sorted_feats,
660
+ batch_first=True,
661
+ padding_value=0)
662
+ padding_labels = pad_sequence(sorted_labels,
663
+ batch_first=True,
664
+ padding_value=-1)
665
+ padded_wavs = pad_sequence(sorted_wavs,
666
+ batch_first=True,
667
+ padding_value=0)
668
+ batch = {
669
+ "keys": sorted_keys,
670
+ "feats": padded_feats,
671
+ "target": padding_labels,
672
+ "feats_lengths": feats_lengths,
673
+ "target_lengths": label_lengths,
674
+ "pcm": padded_wavs,
675
+ "pcm_length": wav_lengths,
676
+ }
677
+ if 'speaker' in sample[0]:
678
+ speaker = torch.tensor([sample[i]['speaker'] for i in order],
679
+ dtype=torch.int32)
680
+ batch['speaker'] = speaker
681
+ if 'prompt' in sample[0]:
682
+ sorted_prompts = [
683
+ torch.tensor(sample[i]['prompt'], dtype=torch.int64
684
+ ) for i in order
685
+ ]
686
+ prompt_lengths = torch.tensor([x.size(0) for x in
687
+ sorted_prompts], dtype=torch.int32)
688
+ padding_prompts = pad_sequence(sorted_prompts,
689
+ batch_first=True,
690
+ padding_value=-1)
691
+ batch['prompt'] = padding_prompts
692
+ batch['prompt_lengths'] = prompt_lengths
693
+
694
+ yield batch
wenet/dataset/wav_distortion.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Chao Yang)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import sys
16
+ import random
17
+ import math
18
+
19
+ import torchaudio
20
+ import torch
21
+
22
+
23
+ def db2amp(db):
24
+ return pow(10, db / 20)
25
+
26
+
27
+ def amp2db(amp):
28
+ return 20 * math.log10(amp)
29
+
30
+
31
+ def make_poly_distortion(conf):
32
+ """Generate a db-domain ploynomial distortion function
33
+
34
+ f(x) = a * x^m * (1-x)^n + x
35
+
36
+ Args:
37
+ conf: a dict {'a': #int, 'm': #int, 'n': #int}
38
+
39
+ Returns:
40
+ The ploynomial function, which could be applied on
41
+ a float amplitude value
42
+ """
43
+ a = conf['a']
44
+ m = conf['m']
45
+ n = conf['n']
46
+
47
+ def poly_distortion(x):
48
+ abs_x = abs(x)
49
+ if abs_x < 0.000001:
50
+ x = x
51
+ else:
52
+ db_norm = amp2db(abs_x) / 100 + 1
53
+ if db_norm < 0:
54
+ db_norm = 0
55
+ db_norm = a * pow(db_norm, m) * pow((1 - db_norm), n) + db_norm
56
+ if db_norm > 1:
57
+ db_norm = 1
58
+ db = (db_norm - 1) * 100
59
+ amp = db2amp(db)
60
+ if amp >= 0.9997:
61
+ amp = 0.9997
62
+ if x > 0:
63
+ x = amp
64
+ else:
65
+ x = -amp
66
+ return x
67
+
68
+ return poly_distortion
69
+
70
+
71
+ def make_quad_distortion():
72
+ return make_poly_distortion({'a': 1, 'm': 1, 'n': 1})
73
+
74
+
75
+ # the amplitude are set to max for all non-zero point
76
+ def make_max_distortion(conf):
77
+ """Generate a max distortion function
78
+
79
+ Args:
80
+ conf: a dict {'max_db': float }
81
+ 'max_db': the maxium value.
82
+
83
+ Returns:
84
+ The max function, which could be applied on
85
+ a float amplitude value
86
+ """
87
+ max_db = conf['max_db']
88
+ if max_db:
89
+ max_amp = db2amp(max_db) # < 0.997
90
+ else:
91
+ max_amp = 0.997
92
+
93
+ def max_distortion(x):
94
+ if x > 0:
95
+ x = max_amp
96
+ elif x < 0:
97
+ x = -max_amp
98
+ else:
99
+ x = 0.0
100
+ return x
101
+
102
+ return max_distortion
103
+
104
+
105
+ def make_amp_mask(db_mask=None):
106
+ """Get a amplitude domain mask from db domain mask
107
+
108
+ Args:
109
+ db_mask: Optional. A list of tuple. if None, using default value.
110
+
111
+ Returns:
112
+ A list of tuple. The amplitude domain mask
113
+ """
114
+ if db_mask is None:
115
+ db_mask = [(-110, -95), (-90, -80), (-65, -60), (-50, -30), (-15, 0)]
116
+ amp_mask = [(db2amp(db[0]), db2amp(db[1])) for db in db_mask]
117
+ return amp_mask
118
+
119
+
120
+ default_mask = make_amp_mask()
121
+
122
+
123
+ def generate_amp_mask(mask_num):
124
+ """Generate amplitude domain mask randomly in [-100db, 0db]
125
+
126
+ Args:
127
+ mask_num: the slot number of the mask
128
+
129
+ Returns:
130
+ A list of tuple. each tuple defines a slot.
131
+ e.g. [(-100, -80), (-65, -60), (-50, -30), (-15, 0)]
132
+ for #mask_num = 4
133
+ """
134
+ a = [0] * 2 * mask_num
135
+ a[0] = 0
136
+ m = []
137
+ for i in range(1, 2 * mask_num):
138
+ a[i] = a[i - 1] + random.uniform(0.5, 1)
139
+ max_val = a[2 * mask_num - 1]
140
+ for i in range(0, mask_num):
141
+ l = ((a[2 * i] - max_val) / max_val) * 100
142
+ r = ((a[2 * i + 1] - max_val) / max_val) * 100
143
+ m.append((l, r))
144
+ return make_amp_mask(m)
145
+
146
+
147
+ def make_fence_distortion(conf):
148
+ """Generate a fence distortion function
149
+
150
+ In this fence-like shape function, the values in mask slots are
151
+ set to maxium, while the values not in mask slots are set to 0.
152
+ Use seperated masks for Positive and negetive amplitude.
153
+
154
+ Args:
155
+ conf: a dict {'mask_number': int,'max_db': float }
156
+ 'mask_number': the slot number in mask.
157
+ 'max_db': the maxium value.
158
+
159
+ Returns:
160
+ The fence function, which could be applied on
161
+ a float amplitude value
162
+ """
163
+ mask_number = conf['mask_number']
164
+ max_db = conf['max_db']
165
+ max_amp = db2amp(max_db) # 0.997
166
+ if mask_number <= 0:
167
+ positive_mask = default_mask
168
+ negative_mask = make_amp_mask([(-50, 0)])
169
+ else:
170
+ positive_mask = generate_amp_mask(mask_number)
171
+ negative_mask = generate_amp_mask(mask_number)
172
+
173
+ def fence_distortion(x):
174
+ is_in_mask = False
175
+ if x > 0:
176
+ for mask in positive_mask:
177
+ if x >= mask[0] and x <= mask[1]:
178
+ is_in_mask = True
179
+ return max_amp
180
+ if not is_in_mask:
181
+ return 0.0
182
+ elif x < 0:
183
+ abs_x = abs(x)
184
+ for mask in negative_mask:
185
+ if abs_x >= mask[0] and abs_x <= mask[1]:
186
+ is_in_mask = True
187
+ return max_amp
188
+ if not is_in_mask:
189
+ return 0.0
190
+ return x
191
+
192
+ return fence_distortion
193
+
194
+
195
+ #
196
+ def make_jag_distortion(conf):
197
+ """Generate a jag distortion function
198
+
199
+ In this jag-like shape function, the values in mask slots are
200
+ not changed, while the values not in mask slots are set to 0.
201
+ Use seperated masks for Positive and negetive amplitude.
202
+
203
+ Args:
204
+ conf: a dict {'mask_number': #int}
205
+ 'mask_number': the slot number in mask.
206
+
207
+ Returns:
208
+ The jag function,which could be applied on
209
+ a float amplitude value
210
+ """
211
+ mask_number = conf['mask_number']
212
+ if mask_number <= 0:
213
+ positive_mask = default_mask
214
+ negative_mask = make_amp_mask([(-50, 0)])
215
+ else:
216
+ positive_mask = generate_amp_mask(mask_number)
217
+ negative_mask = generate_amp_mask(mask_number)
218
+
219
+ def jag_distortion(x):
220
+ is_in_mask = False
221
+ if x > 0:
222
+ for mask in positive_mask:
223
+ if x >= mask[0] and x <= mask[1]:
224
+ is_in_mask = True
225
+ return x
226
+ if not is_in_mask:
227
+ return 0.0
228
+ elif x < 0:
229
+ abs_x = abs(x)
230
+ for mask in negative_mask:
231
+ if abs_x >= mask[0] and abs_x <= mask[1]:
232
+ is_in_mask = True
233
+ return x
234
+ if not is_in_mask:
235
+ return 0.0
236
+ return x
237
+
238
+ return jag_distortion
239
+
240
+
241
+ # gaining 20db means amp = amp * 10
242
+ # gaining -20db means amp = amp / 10
243
+ def make_gain_db(conf):
244
+ """Generate a db domain gain function
245
+
246
+ Args:
247
+ conf: a dict {'db': #float}
248
+ 'db': the gaining value
249
+
250
+ Returns:
251
+ The db gain function, which could be applied on
252
+ a float amplitude value
253
+ """
254
+ db = conf['db']
255
+
256
+ def gain_db(x):
257
+ return min(0.997, x * pow(10, db / 20))
258
+
259
+ return gain_db
260
+
261
+
262
+ def distort(x, func, rate=0.8):
263
+ """Distort a waveform in sample point level
264
+
265
+ Args:
266
+ x: the origin wavefrom
267
+ func: the distort function
268
+ rate: sample point-level distort probability
269
+
270
+ Returns:
271
+ the distorted waveform
272
+ """
273
+ for i in range(0, x.shape[1]):
274
+ a = random.uniform(0, 1)
275
+ if a < rate:
276
+ x[0][i] = func(float(x[0][i]))
277
+ return x
278
+
279
+
280
+ def distort_chain(x, funcs, rate=0.8):
281
+ for i in range(0, x.shape[1]):
282
+ a = random.uniform(0, 1)
283
+ if a < rate:
284
+ for func in funcs:
285
+ x[0][i] = func(float(x[0][i]))
286
+ return x
287
+
288
+
289
+ # x is numpy
290
+ def distort_wav_conf(x, distort_type, distort_conf, rate=0.1):
291
+ if distort_type == 'gain_db':
292
+ gain_db = make_gain_db(distort_conf)
293
+ x = distort(x, gain_db)
294
+ elif distort_type == 'max_distortion':
295
+ max_distortion = make_max_distortion(distort_conf)
296
+ x = distort(x, max_distortion, rate=rate)
297
+ elif distort_type == 'fence_distortion':
298
+ fence_distortion = make_fence_distortion(distort_conf)
299
+ x = distort(x, fence_distortion, rate=rate)
300
+ elif distort_type == 'jag_distortion':
301
+ jag_distortion = make_jag_distortion(distort_conf)
302
+ x = distort(x, jag_distortion, rate=rate)
303
+ elif distort_type == 'poly_distortion':
304
+ poly_distortion = make_poly_distortion(distort_conf)
305
+ x = distort(x, poly_distortion, rate=rate)
306
+ elif distort_type == 'quad_distortion':
307
+ quad_distortion = make_quad_distortion()
308
+ x = distort(x, quad_distortion, rate=rate)
309
+ elif distort_type == 'none_distortion':
310
+ pass
311
+ else:
312
+ print('unsupport type')
313
+ return x
314
+
315
+
316
+ def distort_wav_conf_and_save(distort_type, distort_conf, rate, wav_in,
317
+ wav_out):
318
+ x, sr = torchaudio.load(wav_in)
319
+ x = x.detach().numpy()
320
+ out = distort_wav_conf(x, distort_type, distort_conf, rate)
321
+ torchaudio.save(wav_out, torch.from_numpy(out), sr)
322
+
323
+
324
+ if __name__ == "__main__":
325
+ distort_type = sys.argv[1]
326
+ wav_in = sys.argv[2]
327
+ wav_out = sys.argv[3]
328
+ conf = None
329
+ rate = 0.1
330
+ if distort_type == 'new_jag_distortion':
331
+ conf = {'mask_number': 4}
332
+ elif distort_type == 'new_fence_distortion':
333
+ conf = {'mask_number': 1, 'max_db': -30}
334
+ elif distort_type == 'poly_distortion':
335
+ conf = {'a': 4, 'm': 2, "n": 2}
336
+ distort_wav_conf_and_save(distort_type, conf, rate, wav_in, wav_out)
wenet/e_branchformer/encoder.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Yifan Peng (Carnegie Mellon University)
2
+ # 2023 Voicecomm Inc (Kai Li)
3
+ # 2023 Lucky Wong
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ """Encoder definition."""
18
+
19
+ import torch
20
+ from typing import List, Optional, Union
21
+ from wenet.branchformer.encoder import LayerDropModuleList
22
+
23
+ from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer
24
+ from wenet.branchformer.cgmlp import ConvolutionalGatingMLP
25
+ from wenet.transformer.encoder import ConformerEncoder
26
+ from wenet.utils.class_utils import (
27
+ WENET_ACTIVATION_CLASSES,
28
+ WENET_ATTENTION_CLASSES,
29
+ WENET_MLP_CLASSES,
30
+ )
31
+
32
+
33
+ class EBranchformerEncoder(ConformerEncoder):
34
+ """E-Branchformer encoder module."""
35
+
36
+ def __init__(
37
+ self,
38
+ input_size: int,
39
+ output_size: int = 256,
40
+ attention_heads: int = 4,
41
+ linear_units: int = 2048,
42
+ selfattention_layer_type: str = "rel_selfattn",
43
+ pos_enc_layer_type: str = "rel_pos",
44
+ activation_type: str = "swish",
45
+ cgmlp_linear_units: int = 2048,
46
+ cgmlp_conv_kernel: int = 31,
47
+ use_linear_after_conv: bool = False,
48
+ gate_activation: str = "identity",
49
+ num_blocks: int = 12,
50
+ dropout_rate: float = 0.1,
51
+ positional_dropout_rate: float = 0.1,
52
+ attention_dropout_rate: float = 0.0,
53
+ input_layer: str = "conv2d",
54
+ stochastic_depth_rate: Union[float, List[float]] = 0.0,
55
+ static_chunk_size: int = 0,
56
+ use_dynamic_chunk: bool = False,
57
+ global_cmvn: torch.nn.Module = None,
58
+ use_dynamic_left_chunk: bool = False,
59
+ causal: bool = False,
60
+ merge_conv_kernel: int = 3,
61
+ use_ffn: bool = True,
62
+ macaron_style: bool = True,
63
+ query_bias: bool = True,
64
+ key_bias: bool = True,
65
+ value_bias: bool = True,
66
+ conv_bias: bool = True,
67
+ gradient_checkpointing: bool = False,
68
+ use_sdpa: bool = False,
69
+ layer_norm_type: str = 'layer_norm',
70
+ norm_eps: float = 1e-5,
71
+ n_kv_head: Optional[int] = None,
72
+ head_dim: Optional[int] = None,
73
+ mlp_type: str = 'position_wise_feed_forward',
74
+ mlp_bias: bool = True,
75
+ n_expert: int = 8,
76
+ n_expert_activated: int = 2,
77
+ ):
78
+ super().__init__(input_size,
79
+ output_size,
80
+ attention_heads,
81
+ linear_units,
82
+ num_blocks,
83
+ dropout_rate,
84
+ positional_dropout_rate,
85
+ attention_dropout_rate,
86
+ input_layer,
87
+ pos_enc_layer_type,
88
+ True,
89
+ static_chunk_size,
90
+ use_dynamic_chunk,
91
+ global_cmvn,
92
+ use_dynamic_left_chunk,
93
+ 1,
94
+ macaron_style,
95
+ selfattention_layer_type,
96
+ activation_type,
97
+ query_bias=query_bias,
98
+ key_bias=key_bias,
99
+ value_bias=value_bias,
100
+ conv_bias=conv_bias,
101
+ gradient_checkpointing=gradient_checkpointing,
102
+ use_sdpa=use_sdpa,
103
+ layer_norm_type=layer_norm_type,
104
+ norm_eps=norm_eps,
105
+ n_kv_head=n_kv_head,
106
+ head_dim=head_dim,
107
+ mlp_type=mlp_type,
108
+ mlp_bias=mlp_bias,
109
+ n_expert=n_expert,
110
+ n_expert_activated=n_expert_activated)
111
+
112
+ encoder_selfattn_layer_args = (
113
+ attention_heads,
114
+ output_size,
115
+ attention_dropout_rate,
116
+ query_bias,
117
+ key_bias,
118
+ value_bias,
119
+ use_sdpa,
120
+ n_kv_head,
121
+ head_dim,
122
+ )
123
+
124
+ cgmlp_layer = ConvolutionalGatingMLP
125
+ cgmlp_layer_args = (output_size, cgmlp_linear_units, cgmlp_conv_kernel,
126
+ dropout_rate, use_linear_after_conv,
127
+ gate_activation, causal)
128
+
129
+ # feed-forward module definition
130
+ mlp_class = WENET_MLP_CLASSES[mlp_type]
131
+ activation = WENET_ACTIVATION_CLASSES[activation_type]()
132
+ positionwise_layer_args = (
133
+ output_size,
134
+ linear_units,
135
+ dropout_rate,
136
+ activation,
137
+ mlp_bias,
138
+ n_expert,
139
+ n_expert_activated,
140
+ )
141
+
142
+ if isinstance(stochastic_depth_rate, float):
143
+ stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
144
+ if len(stochastic_depth_rate) != num_blocks:
145
+ raise ValueError(
146
+ f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
147
+ f"should be equal to num_blocks ({num_blocks})")
148
+
149
+ self.encoders = LayerDropModuleList(
150
+ p=stochastic_depth_rate,
151
+ modules=[
152
+ EBranchformerEncoderLayer(
153
+ output_size,
154
+ WENET_ATTENTION_CLASSES[selfattention_layer_type](
155
+ *encoder_selfattn_layer_args),
156
+ cgmlp_layer(*cgmlp_layer_args),
157
+ mlp_class(*positionwise_layer_args) if use_ffn else None,
158
+ mlp_class(*positionwise_layer_args)
159
+ if use_ffn and macaron_style else None,
160
+ dropout_rate,
161
+ merge_conv_kernel=merge_conv_kernel,
162
+ causal=causal,
163
+ stochastic_depth_rate=stochastic_depth_rate[lnum],
164
+ ) for lnum in range(num_blocks)
165
+ ])
wenet/e_branchformer/encoder_layer.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Yifan Peng (Carnegie Mellon University)
2
+ # 2023 Voicecomm Inc (Kai Li)
3
+ # 2023 Lucky Wong
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ """EBranchformerEncoderLayer definition."""
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from typing import Optional, Tuple
22
+
23
+ from wenet.transformer.attention import T_CACHE
24
+
25
+
26
+ class EBranchformerEncoderLayer(torch.nn.Module):
27
+ """E-Branchformer encoder layer module.
28
+
29
+ Args:
30
+ size (int): model dimension
31
+ attn: standard self-attention or efficient attention
32
+ cgmlp: ConvolutionalGatingMLP
33
+ feed_forward: feed-forward module, optional
34
+ feed_forward: macaron-style feed-forward module, optional
35
+ dropout_rate (float): dropout probability
36
+ merge_conv_kernel (int): kernel size of the depth-wise conv in merge module
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ size: int,
42
+ attn: torch.nn.Module,
43
+ cgmlp: torch.nn.Module,
44
+ feed_forward: Optional[torch.nn.Module],
45
+ feed_forward_macaron: Optional[torch.nn.Module],
46
+ dropout_rate: float,
47
+ merge_conv_kernel: int = 3,
48
+ causal: bool = True,
49
+ stochastic_depth_rate=0.0,
50
+ ):
51
+ super().__init__()
52
+
53
+ self.size = size
54
+ self.attn = attn
55
+ self.cgmlp = cgmlp
56
+
57
+ self.feed_forward = feed_forward
58
+ self.feed_forward_macaron = feed_forward_macaron
59
+ self.ff_scale = 1.0
60
+ if self.feed_forward is not None:
61
+ self.norm_ff = nn.LayerNorm(size)
62
+ if self.feed_forward_macaron is not None:
63
+ self.ff_scale = 0.5
64
+ self.norm_ff_macaron = nn.LayerNorm(size)
65
+
66
+ self.norm_mha = nn.LayerNorm(size) # for the MHA module
67
+ self.norm_mlp = nn.LayerNorm(size) # for the MLP module
68
+ # for the final output of the block
69
+ self.norm_final = nn.LayerNorm(size)
70
+
71
+ self.dropout = torch.nn.Dropout(dropout_rate)
72
+
73
+ if causal:
74
+ padding = 0
75
+ self.lorder = merge_conv_kernel - 1
76
+ else:
77
+ # kernel_size should be an odd number for none causal convolution
78
+ assert (merge_conv_kernel - 1) % 2 == 0
79
+ padding = (merge_conv_kernel - 1) // 2
80
+ self.lorder = 0
81
+ self.depthwise_conv_fusion = torch.nn.Conv1d(
82
+ size + size,
83
+ size + size,
84
+ kernel_size=merge_conv_kernel,
85
+ stride=1,
86
+ padding=padding,
87
+ groups=size + size,
88
+ bias=True,
89
+ )
90
+ self.merge_proj = torch.nn.Linear(size + size, size)
91
+ self.stochastic_depth_rate = stochastic_depth_rate
92
+
93
+ def _forward(
94
+ self,
95
+ x: torch.Tensor,
96
+ mask: torch.Tensor,
97
+ pos_emb: torch.Tensor,
98
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
99
+ att_cache: T_CACHE = (torch.zeros(
100
+ (0, 0, 0, 0)), torch.zeros(0, 0, 0, 0)),
101
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
102
+ stoch_layer_coeff: float = 1.0
103
+ ) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE, torch.Tensor]:
104
+
105
+ if self.feed_forward_macaron is not None:
106
+ residual = x
107
+ x = self.norm_ff_macaron(x)
108
+ x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
109
+ self.feed_forward_macaron(x))
110
+
111
+ # Two branches
112
+ x1 = x
113
+ x2 = x
114
+
115
+ # Branch 1: multi-headed attention module
116
+ x1 = self.norm_mha(x1)
117
+ x_att, new_att_cache = self.attn(x1, x1, x1, mask, pos_emb, att_cache)
118
+ x1 = self.dropout(x_att)
119
+
120
+ # Branch 2: convolutional gating mlp
121
+ # Fake new cnn cache here, and then change it in conv_module
122
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
123
+ x2 = self.norm_mlp(x2)
124
+ x2, new_cnn_cache = self.cgmlp(x2, mask_pad, cnn_cache)
125
+ x2 = self.dropout(x2)
126
+
127
+ # Merge two branches
128
+ x_concat = torch.cat([x1, x2], dim=-1)
129
+ x_tmp = x_concat.transpose(1, 2)
130
+ if self.lorder > 0:
131
+ x_tmp = nn.functional.pad(x_tmp, (self.lorder, 0), "constant", 0.0)
132
+ assert x_tmp.size(2) > self.lorder
133
+ x_tmp = self.depthwise_conv_fusion(x_tmp)
134
+ x_tmp = x_tmp.transpose(1, 2)
135
+ x = x + stoch_layer_coeff * self.dropout(
136
+ self.merge_proj(x_concat + x_tmp))
137
+
138
+ if self.feed_forward is not None:
139
+ # feed forward module
140
+ residual = x
141
+ x = self.norm_ff(x)
142
+ x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
143
+ self.feed_forward(x))
144
+
145
+ x = self.norm_final(x)
146
+
147
+ return x, mask, new_att_cache, new_cnn_cache
148
+
149
+ def forward(
150
+ self,
151
+ x: torch.Tensor,
152
+ mask: torch.Tensor,
153
+ pos_emb: torch.Tensor,
154
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
155
+ att_cache: T_CACHE = (torch.zeros(
156
+ (0, 0, 0, 0)), torch.zeros(0, 0, 0, 0)),
157
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
158
+ ) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE, torch.Tensor]:
159
+ """Compute encoded features.
160
+
161
+ Args:
162
+ x (Union[Tuple, torch.Tensor]): Input tensor (#batch, time, size).
163
+ mask (torch.Tensor): Mask tensor for the input (#batch, time, time).
164
+ pos_emb (torch.Tensor): positional encoding, must not be None
165
+ for BranchformerEncoderLayer.
166
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
167
+ (#batch, 1,time), (0, 0, 0) means fake mask.
168
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
169
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
170
+ cnn_cache (torch.Tensor): Convolution cache in cgmlp layer
171
+ (#batch=1, size, cache_t2)
172
+
173
+ Returns:
174
+ torch.Tensor: Output tensor (#batch, time, size).
175
+ torch.Tensor: Mask tensor (#batch, time, time.
176
+ torch.Tensor: att_cache tensor,
177
+ (#batch=1, head, cache_t1 + time, d_k * 2).
178
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
179
+ """
180
+
181
+ stoch_layer_coeff = 1.0
182
+ # with stochastic depth, residual connection `x + f(x)` becomes
183
+ # `x <- x + 1 / (1 - p) * f(x)` at training time.
184
+ if self.training:
185
+ stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
186
+ return self._forward(x, mask, pos_emb, mask_pad, att_cache, cnn_cache,
187
+ stoch_layer_coeff)
wenet/efficient_conformer/__init__.py ADDED
File without changes
wenet/efficient_conformer/attention.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2022 Xingchen Song ([email protected])
4
+ # 2022 58.com(Wuba) Inc AI Lab.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Multi-Head Attention layer definition."""
18
+
19
+ import math
20
+ from typing import Tuple, Optional
21
+
22
+ import torch
23
+ from torch import nn
24
+ import torch.nn.functional as F
25
+ from wenet.transformer.attention import MultiHeadedAttention
26
+
27
+
28
+ class GroupedRelPositionMultiHeadedAttention(MultiHeadedAttention):
29
+ """Multi-Head Attention layer with relative position encoding.
30
+ Paper:
31
+ https://arxiv.org/abs/1901.02860
32
+ https://arxiv.org/abs/2109.01163
33
+ Args:
34
+ n_head (int): The number of heads.
35
+ n_feat (int): The number of features.
36
+ dropout_rate (float): Dropout rate.
37
+ """
38
+
39
+ def __init__(self, n_head, n_feat, dropout_rate, group_size=3):
40
+ """Construct an RelPositionMultiHeadedAttention object."""
41
+ super().__init__(n_head, n_feat, dropout_rate)
42
+ # linear transformation for positional encoding
43
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
44
+ self.group_size = group_size
45
+ self.d_k = n_feat // n_head # for GroupedAttention
46
+ self.n_feat = n_feat
47
+ # these two learnable bias are used in matrix c and matrix d
48
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
49
+ self.pos_bias_u = nn.Parameter(
50
+ torch.Tensor(self.h, self.d_k * self.group_size))
51
+ self.pos_bias_v = nn.Parameter(
52
+ torch.Tensor(self.h, self.d_k * self.group_size))
53
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
54
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
55
+
56
+ def rel_shift(self, x, zero_triu: bool = False):
57
+ """Compute relative positinal encoding.
58
+ Args:
59
+ x (torch.Tensor): Input tensor (batch, time, size).
60
+ zero_triu (bool): If true, return the lower triangular part of
61
+ the matrix.
62
+ Returns:
63
+ torch.Tensor: Output tensor.
64
+ """
65
+
66
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
67
+ device=x.device,
68
+ dtype=x.dtype)
69
+ x_padded = torch.cat([zero_pad, x], dim=-1)
70
+
71
+ x_padded = x_padded.view(x.size()[0],
72
+ x.size()[1],
73
+ x.size(3) + 1, x.size(2))
74
+ x = x_padded[:, :, 1:].view_as(x)
75
+
76
+ if zero_triu:
77
+ ones = torch.ones((x.size(2), x.size(3)))
78
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
79
+
80
+ return x
81
+
82
+ def pad4group(self, Q, K, V, P, mask, group_size: int = 3):
83
+ """
84
+ q: (#batch, time1, size) -> (#batch, head, time1, size/head)
85
+ k,v: (#batch, time2, size) -> (#batch, head, time2, size/head)
86
+ p: (#batch, time2, size)
87
+ """
88
+ # Compute Overflows
89
+ overflow_Q = Q.size(2) % group_size
90
+ overflow_KV = K.size(2) % group_size
91
+
92
+ # if-else for ONNX export
93
+ # 0 // 0.00000000000000001 = 0
94
+ # 1 // 1.00000000000000001 = 1
95
+ padding_Q = (group_size - overflow_Q) * int(
96
+ overflow_Q // (overflow_Q + 0.00000000000000001))
97
+ padding_KV = (group_size - overflow_KV) * int(
98
+ overflow_KV // (overflow_KV + 0.00000000000000001))
99
+
100
+ batch_size, _, seq_len_KV, _ = K.size()
101
+
102
+ # Input Padding (B, T, D) -> (B, T + P, D)
103
+ Q = F.pad(Q, (0, 0, 0, padding_Q), value=0.0)
104
+ K = F.pad(K, (0, 0, 0, padding_KV), value=0.0)
105
+ V = F.pad(V, (0, 0, 0, padding_KV), value=0.0)
106
+
107
+ if mask is not None and mask.size(2) > 0: # time2 > 0:
108
+ mask = mask[:, ::group_size, ::group_size]
109
+
110
+ Q = Q.transpose(1, 2).contiguous().view(
111
+ batch_size, -1, self.h, self.d_k * group_size).transpose(1, 2)
112
+ K = K.transpose(1, 2).contiguous().view(
113
+ batch_size, -1, self.h, self.d_k * group_size).transpose(1, 2)
114
+ V = V.transpose(1, 2).contiguous().view(
115
+ batch_size, -1, self.h, self.d_k * group_size).transpose(1, 2)
116
+
117
+ # process pos_emb
118
+ P_batch_size = P.size(0)
119
+ overflow_P = P.size(1) % group_size
120
+ padding_P = group_size - overflow_P if overflow_P else 0
121
+ P = F.pad(P, (0, 0, 0, padding_P), value=0.0)
122
+ P = P.view(P_batch_size, -1, self.h,
123
+ self.d_k * group_size).transpose(1, 2)
124
+
125
+ return Q, K, V, P, mask, padding_Q
126
+
127
+ def forward_attention(self,
128
+ value: torch.Tensor,
129
+ scores: torch.Tensor,
130
+ mask: torch.Tensor = torch.ones((0, 0, 0),
131
+ dtype=torch.bool),
132
+ padding_q: Optional[int] = None) -> torch.Tensor:
133
+ """Compute attention context vector.
134
+
135
+ Args:
136
+ value (torch.Tensor): Transformed value, size
137
+ (#batch, n_head, time2, d_k).
138
+ scores (torch.Tensor): Attention score, size
139
+ (#batch, n_head, time1, time2).
140
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
141
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
142
+ padding_q : for GroupedAttention in efficent conformer
143
+
144
+ Returns:
145
+ torch.Tensor: Transformed value (#batch, time1, d_model)
146
+ weighted by the attention score (#batch, time1, time2).
147
+
148
+ """
149
+ n_batch = value.size(0)
150
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
151
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
152
+ # 1st chunk to ease the onnx export.]
153
+ # 2. pytorch training
154
+ if mask.size(2) > 0: # time2 > 0
155
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
156
+ # For last chunk, time2 might be larger than scores.size(-1)
157
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
158
+ scores = scores.masked_fill(mask, -float('inf'))
159
+ attn = torch.softmax(scores, dim=-1).masked_fill(
160
+ mask, 0.0) # (batch, head, time1, time2)
161
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
162
+ # 1. onnx(16/-1, -1/-1, 16/0)
163
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
164
+ else:
165
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
166
+
167
+ p_attn = self.dropout(attn)
168
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
169
+
170
+ # n_feat!=h*d_k may be happened in GroupAttention
171
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1, self.n_feat)
172
+ ) # (batch, time1, d_model)
173
+ if padding_q is not None:
174
+ # for GroupedAttention in efficent conformer
175
+ x = x[:, :x.size(1) - padding_q]
176
+
177
+ return self.linear_out(x) # (batch, time1, d_model)
178
+
179
+ def forward(
180
+ self,
181
+ query: torch.Tensor,
182
+ key: torch.Tensor,
183
+ value: torch.Tensor,
184
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
185
+ pos_emb: torch.Tensor = torch.empty(0),
186
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
187
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
188
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
189
+ Args:
190
+ query (torch.Tensor): Query tensor (#batch, time1, size).
191
+ key (torch.Tensor): Key tensor (#batch, time2, size).
192
+ value (torch.Tensor): Value tensor (#batch, time2, size).
193
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
194
+ (#batch, time1, time2).
195
+ pos_emb (torch.Tensor): Positional embedding tensor
196
+ (#batch, time2, size).
197
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
198
+ where `cache_t == chunk_size * num_decoding_left_chunks`
199
+ and `head * d_k == size`
200
+ Returns:
201
+ torch.Tensor: Output tensor (#batch, time1, d_model).
202
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
203
+ where `cache_t == chunk_size * num_decoding_left_chunks`
204
+ and `head * d_k == size`
205
+ """
206
+ q = self.linear_q(query)
207
+ k = self.linear_k(key) # (#batch, time2, size)
208
+ v = self.linear_v(value)
209
+ p = self.linear_pos(pos_emb) # (#batch, time2, size)
210
+
211
+ batch_size, seq_len_KV, _ = k.size() # seq_len_KV = time2
212
+
213
+ # (#batch, time2, size) -> (#batch, head, time2, size/head)
214
+ q = q.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
215
+ k = k.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
216
+ v = v.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
217
+ if cache.size(0) > 0:
218
+ # use attention cache
219
+ key_cache, value_cache = torch.split(cache,
220
+ cache.size(-1) // 2,
221
+ dim=-1)
222
+ k = torch.cat([key_cache, k], dim=2)
223
+ v = torch.cat([value_cache, v], dim=2)
224
+ new_cache = torch.cat((k, v), dim=-1)
225
+
226
+ # May be k and p does not match. eg. time2=18+18/2=27 > mask=36/2=18
227
+ if mask is not None and mask.size(2) > 0:
228
+ time2 = mask.size(2)
229
+ k = k[:, :, -time2:, :]
230
+ v = v[:, :, -time2:, :]
231
+
232
+ # q k v p: (batch, head, time1, d_k)
233
+ q, k, v, p, mask, padding_q = self.pad4group(q, k, v, p, mask,
234
+ self.group_size)
235
+
236
+ # q_with_bias_u & q_with_bias_v = (batch, head, time1, d_k)
237
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
238
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
239
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
240
+
241
+ # compute attention score
242
+ # first compute matrix a and matrix c
243
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
244
+ # (batch, head, time1, time2)
245
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
246
+
247
+ # compute matrix b and matrix d
248
+ # (batch, head, time1, time2)
249
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
250
+ # Remove rel_shift since it is useless in speech recognition,
251
+ # and it requires special attention for streaming.
252
+ # matrix_bd = self.rel_shift(matrix_bd)
253
+
254
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
255
+ self.d_k * self.group_size) # (batch, head, time1, time2)
256
+
257
+ return self.forward_attention(v, scores, mask, padding_q), new_cache
wenet/efficient_conformer/convolution.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2022 58.com(Wuba) Inc AI Lab.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """ConvolutionModule definition."""
17
+ from typing import Tuple
18
+
19
+ import torch
20
+ from torch import nn
21
+
22
+
23
+ class ConvolutionModule(nn.Module):
24
+ """ConvolutionModule in Conformer model."""
25
+
26
+ def __init__(self,
27
+ channels: int,
28
+ kernel_size: int = 15,
29
+ activation: nn.Module = nn.ReLU(),
30
+ norm: str = "batch_norm",
31
+ causal: bool = False,
32
+ bias: bool = True,
33
+ stride: int = 1):
34
+ """Construct an ConvolutionModule object.
35
+ Args:
36
+ channels (int): The number of channels of conv layers.
37
+ kernel_size (int): Kernel size of conv layers.
38
+ causal (int): Whether use causal convolution or not
39
+ stride (int): Stride Convolution, for efficient Conformer
40
+ """
41
+ super().__init__()
42
+
43
+ self.pointwise_conv1 = nn.Conv1d(
44
+ channels,
45
+ 2 * channels,
46
+ kernel_size=1,
47
+ stride=1,
48
+ padding=0,
49
+ bias=bias,
50
+ )
51
+ # self.lorder is used to distinguish if it's a causal convolution,
52
+ # if self.lorder > 0: it's a causal convolution, the input will be
53
+ # padded with self.lorder frames on the left in forward.
54
+ # else: it's a symmetrical convolution
55
+ if causal:
56
+ padding = 0
57
+ self.lorder = kernel_size - 1
58
+ else:
59
+ # kernel_size should be an odd number for none causal convolution
60
+ assert (kernel_size - 1) % 2 == 0
61
+ padding = (kernel_size - 1) // 2
62
+ self.lorder = 0
63
+
64
+ self.depthwise_conv = nn.Conv1d(
65
+ channels,
66
+ channels,
67
+ kernel_size,
68
+ stride=stride, # for depthwise_conv in StrideConv
69
+ padding=padding,
70
+ groups=channels,
71
+ bias=bias,
72
+ )
73
+
74
+ assert norm in ['batch_norm', 'layer_norm']
75
+ if norm == "batch_norm":
76
+ self.use_layer_norm = False
77
+ self.norm = nn.BatchNorm1d(channels)
78
+ else:
79
+ self.use_layer_norm = True
80
+ self.norm = nn.LayerNorm(channels)
81
+
82
+ self.pointwise_conv2 = nn.Conv1d(
83
+ channels,
84
+ channels,
85
+ kernel_size=1,
86
+ stride=1,
87
+ padding=0,
88
+ bias=bias,
89
+ )
90
+ self.activation = activation
91
+ self.stride = stride
92
+
93
+ def forward(
94
+ self,
95
+ x: torch.Tensor,
96
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
97
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
98
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
99
+ """Compute convolution module.
100
+ Args:
101
+ x (torch.Tensor): Input tensor (#batch, time, channels).
102
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
103
+ (0, 0, 0) means fake mask.
104
+ cache (torch.Tensor): left context cache, it is only
105
+ used in causal convolution (#batch, channels, cache_t),
106
+ (0, 0, 0) meas fake cache.
107
+ Returns:
108
+ torch.Tensor: Output tensor (#batch, time, channels).
109
+ """
110
+ # exchange the temporal dimension and the feature dimension
111
+ x = x.transpose(1, 2) # (#batch, channels, time)
112
+
113
+ # mask batch padding
114
+ if mask_pad.size(2) > 0: # time > 0
115
+ x.masked_fill_(~mask_pad, 0.0)
116
+
117
+ if self.lorder > 0:
118
+ if cache.size(2) == 0: # cache_t == 0
119
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
120
+ else:
121
+ # When export ONNX,the first cache is not None but all-zero,
122
+ # cause shape error in residual block,
123
+ # eg. cache14 + x9 = 23, 23-7+1=17 != 9
124
+ cache = cache[:, :, -self.lorder:]
125
+ assert cache.size(0) == x.size(0) # equal batch
126
+ assert cache.size(1) == x.size(1) # equal channel
127
+ x = torch.cat((cache, x), dim=2)
128
+ assert (x.size(2) > self.lorder)
129
+ new_cache = x[:, :, -self.lorder:]
130
+ else:
131
+ # It's better we just return None if no cache is requried,
132
+ # However, for JIT export, here we just fake one tensor instead of
133
+ # None.
134
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
135
+
136
+ # GLU mechanism
137
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
138
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
139
+
140
+ # 1D Depthwise Conv
141
+ x = self.depthwise_conv(x)
142
+ if self.use_layer_norm:
143
+ x = x.transpose(1, 2)
144
+ x = self.activation(self.norm(x))
145
+ if self.use_layer_norm:
146
+ x = x.transpose(1, 2)
147
+ x = self.pointwise_conv2(x)
148
+ # mask batch padding
149
+ if mask_pad.size(2) > 0: # time > 0
150
+ if mask_pad.size(2) != x.size(2):
151
+ mask_pad = mask_pad[:, :, ::self.stride]
152
+ x.masked_fill_(~mask_pad, 0.0)
153
+
154
+ return x.transpose(1, 2), new_cache
wenet/efficient_conformer/encoder.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song ([email protected])
3
+ # 2022 58.com(Wuba) Inc AI Lab.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from EfficientConformer(https://github.com/burchim/EfficientConformer)
17
+ # Paper(https://arxiv.org/abs/2109.01163)
18
+ """Encoder definition."""
19
+ from typing import Tuple, Optional, List, Union
20
+
21
+ import torch
22
+ import logging
23
+ import torch.nn.functional as F
24
+
25
+ from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
26
+ from wenet.transformer.encoder_layer import ConformerEncoderLayer
27
+
28
+ from wenet.efficient_conformer.convolution import ConvolutionModule
29
+ from wenet.efficient_conformer.encoder_layer import StrideConformerEncoderLayer
30
+
31
+ from wenet.utils.mask import make_pad_mask
32
+ from wenet.utils.mask import add_optional_chunk_mask
33
+ from wenet.utils.class_utils import (
34
+ WENET_ATTENTION_CLASSES,
35
+ WENET_EMB_CLASSES,
36
+ WENET_SUBSAMPLE_CLASSES,
37
+ WENET_ACTIVATION_CLASSES,
38
+ )
39
+
40
+
41
+ class EfficientConformerEncoder(torch.nn.Module):
42
+ """Conformer encoder module."""
43
+
44
+ def __init__(self,
45
+ input_size: int,
46
+ output_size: int = 256,
47
+ attention_heads: int = 4,
48
+ linear_units: int = 2048,
49
+ num_blocks: int = 6,
50
+ dropout_rate: float = 0.1,
51
+ positional_dropout_rate: float = 0.1,
52
+ attention_dropout_rate: float = 0.0,
53
+ input_layer: str = "conv2d",
54
+ pos_enc_layer_type: str = "rel_pos",
55
+ normalize_before: bool = True,
56
+ static_chunk_size: int = 0,
57
+ use_dynamic_chunk: bool = False,
58
+ global_cmvn: torch.nn.Module = None,
59
+ use_dynamic_left_chunk: bool = False,
60
+ macaron_style: bool = True,
61
+ activation_type: str = "swish",
62
+ use_cnn_module: bool = True,
63
+ cnn_module_kernel: int = 15,
64
+ causal: bool = False,
65
+ cnn_module_norm: str = "batch_norm",
66
+ stride_layer_idx: Optional[Union[int, List[int]]] = 3,
67
+ stride: Optional[Union[int, List[int]]] = 2,
68
+ group_layer_idx: Optional[Union[int, List[int],
69
+ tuple]] = (0, 1, 2, 3),
70
+ group_size: int = 3,
71
+ stride_kernel: bool = True,
72
+ **kwargs):
73
+ """Construct Efficient Conformer Encoder
74
+
75
+ Args:
76
+ input_size to use_dynamic_chunk, see in BaseEncoder
77
+ macaron_style (bool): Whether to use macaron style for
78
+ positionwise layer.
79
+ activation_type (str): Encoder activation function type.
80
+ use_cnn_module (bool): Whether to use convolution module.
81
+ cnn_module_kernel (int): Kernel size of convolution module.
82
+ causal (bool): whether to use causal convolution or not.
83
+ stride_layer_idx (list): layer id with StrideConv, start from 0
84
+ stride (list): stride size of each StrideConv in efficient conformer
85
+ group_layer_idx (list): layer id with GroupedAttention, start from 0
86
+ group_size (int): group size of every GroupedAttention layer
87
+ stride_kernel (bool): default True. True: recompute cnn kernels with stride.
88
+ """
89
+ super().__init__()
90
+ self._output_size = output_size
91
+
92
+ logging.info(
93
+ f"input_layer = {input_layer}, "
94
+ f"subsampling_class = {WENET_SUBSAMPLE_CLASSES[input_layer]}")
95
+
96
+ self.global_cmvn = global_cmvn
97
+ self.embed = WENET_SUBSAMPLE_CLASSES[input_layer](
98
+ input_size,
99
+ output_size,
100
+ dropout_rate,
101
+ WENET_EMB_CLASSES[pos_enc_layer_type](output_size,
102
+ positional_dropout_rate),
103
+ )
104
+ self.input_layer = input_layer
105
+ self.normalize_before = normalize_before
106
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
107
+ self.static_chunk_size = static_chunk_size
108
+ self.use_dynamic_chunk = use_dynamic_chunk
109
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
110
+
111
+ activation = WENET_ACTIVATION_CLASSES[activation_type]()
112
+ self.num_blocks = num_blocks
113
+ self.attention_heads = attention_heads
114
+ self.cnn_module_kernel = cnn_module_kernel
115
+ self.global_chunk_size = 0
116
+ self.chunk_feature_map = 0
117
+
118
+ # efficient conformer configs
119
+ self.stride_layer_idx = [stride_layer_idx] \
120
+ if type(stride_layer_idx) == int else stride_layer_idx
121
+ self.stride = [stride] \
122
+ if type(stride) == int else stride
123
+ self.group_layer_idx = [group_layer_idx] \
124
+ if type(group_layer_idx) == int else group_layer_idx
125
+ self.grouped_size = group_size # group size of every GroupedAttention layer
126
+
127
+ assert len(self.stride) == len(self.stride_layer_idx)
128
+ self.cnn_module_kernels = [cnn_module_kernel
129
+ ] # kernel size of each StridedConv
130
+ for i in self.stride:
131
+ if stride_kernel:
132
+ self.cnn_module_kernels.append(self.cnn_module_kernels[-1] //
133
+ i)
134
+ else:
135
+ self.cnn_module_kernels.append(self.cnn_module_kernels[-1])
136
+
137
+ logging.info(f"stride_layer_idx= {self.stride_layer_idx}, "
138
+ f"stride = {self.stride}, "
139
+ f"cnn_module_kernel = {self.cnn_module_kernels}, "
140
+ f"group_layer_idx = {self.group_layer_idx}, "
141
+ f"grouped_size = {self.grouped_size}")
142
+
143
+ # feed-forward module definition
144
+ positionwise_layer = PositionwiseFeedForward
145
+ positionwise_layer_args = (
146
+ output_size,
147
+ linear_units,
148
+ dropout_rate,
149
+ activation,
150
+ )
151
+ # convolution module definition
152
+ convolution_layer = ConvolutionModule
153
+
154
+ # encoder definition
155
+ index = 0
156
+ layers = []
157
+ for i in range(num_blocks):
158
+ # self-attention module definition
159
+ if i in self.group_layer_idx:
160
+ encoder_selfattn_layer = WENET_ATTENTION_CLASSES[
161
+ "grouped_rel_selfattn"]
162
+ encoder_selfattn_layer_args = (attention_heads, output_size,
163
+ attention_dropout_rate,
164
+ self.grouped_size)
165
+ else:
166
+ if pos_enc_layer_type == "no_pos":
167
+ encoder_selfattn_layer = WENET_ATTENTION_CLASSES[
168
+ "selfattn"]
169
+ else:
170
+ encoder_selfattn_layer = WENET_ATTENTION_CLASSES[
171
+ "rel_selfattn"]
172
+ encoder_selfattn_layer_args = (attention_heads, output_size,
173
+ attention_dropout_rate)
174
+
175
+ # conformer module definition
176
+ if i in self.stride_layer_idx:
177
+ # conformer block with downsampling
178
+ convolution_layer_args_stride = (
179
+ output_size, self.cnn_module_kernels[index], activation,
180
+ cnn_module_norm, causal, True, self.stride[index])
181
+ layers.append(
182
+ StrideConformerEncoderLayer(
183
+ output_size,
184
+ encoder_selfattn_layer(*encoder_selfattn_layer_args),
185
+ positionwise_layer(*positionwise_layer_args),
186
+ positionwise_layer(*positionwise_layer_args)
187
+ if macaron_style else None,
188
+ convolution_layer(*convolution_layer_args_stride)
189
+ if use_cnn_module else None,
190
+ torch.nn.AvgPool1d(
191
+ kernel_size=self.stride[index],
192
+ stride=self.stride[index],
193
+ padding=0,
194
+ ceil_mode=True,
195
+ count_include_pad=False), # pointwise_conv_layer
196
+ dropout_rate,
197
+ normalize_before,
198
+ ))
199
+ index = index + 1
200
+ else:
201
+ # conformer block
202
+ convolution_layer_args_normal = (
203
+ output_size, self.cnn_module_kernels[index], activation,
204
+ cnn_module_norm, causal)
205
+ layers.append(
206
+ ConformerEncoderLayer(
207
+ output_size,
208
+ encoder_selfattn_layer(*encoder_selfattn_layer_args),
209
+ positionwise_layer(*positionwise_layer_args),
210
+ positionwise_layer(*positionwise_layer_args)
211
+ if macaron_style else None,
212
+ convolution_layer(*convolution_layer_args_normal)
213
+ if use_cnn_module else None,
214
+ dropout_rate,
215
+ normalize_before,
216
+ ))
217
+
218
+ self.encoders = torch.nn.ModuleList(layers)
219
+
220
+ def set_global_chunk_size(self, chunk_size):
221
+ """Used in ONNX export.
222
+ """
223
+ logging.info(f"set global chunk size: {chunk_size}, default is 0.")
224
+ self.global_chunk_size = chunk_size
225
+ if self.embed.subsampling_rate == 2:
226
+ self.chunk_feature_map = 2 * self.global_chunk_size + 1
227
+ elif self.embed.subsampling_rate == 6:
228
+ self.chunk_feature_map = 6 * self.global_chunk_size + 5
229
+ elif self.embed.subsampling_rate == 8:
230
+ self.chunk_feature_map = 8 * self.global_chunk_size + 7
231
+ else:
232
+ self.chunk_feature_map = 4 * self.global_chunk_size + 3
233
+
234
+ def output_size(self) -> int:
235
+ return self._output_size
236
+
237
+ def calculate_downsampling_factor(self, i: int) -> int:
238
+ factor = 1
239
+ for idx, stride_idx in enumerate(self.stride_layer_idx):
240
+ if i > stride_idx:
241
+ factor *= self.stride[idx]
242
+ return factor
243
+
244
+ def forward(
245
+ self,
246
+ xs: torch.Tensor,
247
+ xs_lens: torch.Tensor,
248
+ decoding_chunk_size: int = 0,
249
+ num_decoding_left_chunks: int = -1,
250
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
251
+ """Embed positions in tensor.
252
+ Args:
253
+ xs: padded input tensor (B, T, D)
254
+ xs_lens: input length (B)
255
+ decoding_chunk_size: decoding chunk size for dynamic chunk
256
+ 0: default for training, use random dynamic chunk.
257
+ <0: for decoding, use full chunk.
258
+ >0: for decoding, use fixed chunk size as set.
259
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
260
+ the chunk size is decoding_chunk_size.
261
+ >=0: use num_decoding_left_chunks
262
+ <0: use all left chunks
263
+ Returns:
264
+ encoder output tensor xs, and subsampled masks
265
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
266
+ masks: torch.Tensor batch padding mask after subsample
267
+ (B, 1, T' ~= T/subsample_rate)
268
+ """
269
+ T = xs.size(1)
270
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
271
+ if self.global_cmvn is not None:
272
+ xs = self.global_cmvn(xs)
273
+ xs, pos_emb, masks = self.embed(xs, masks)
274
+ mask_pad = masks # (B, 1, T/subsample_rate)
275
+ chunk_masks = add_optional_chunk_mask(xs, masks,
276
+ self.use_dynamic_chunk,
277
+ self.use_dynamic_left_chunk,
278
+ decoding_chunk_size,
279
+ self.static_chunk_size,
280
+ num_decoding_left_chunks)
281
+ index = 0 # traverse stride
282
+ for i, layer in enumerate(self.encoders):
283
+ # layer return : x, mask, new_att_cache, new_cnn_cache
284
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
285
+ if i in self.stride_layer_idx:
286
+ masks = masks[:, :, ::self.stride[index]]
287
+ chunk_masks = chunk_masks[:, ::self.stride[index], ::self.
288
+ stride[index]]
289
+ mask_pad = masks
290
+ pos_emb = pos_emb[:, ::self.stride[index], :]
291
+ index = index + 1
292
+
293
+ if self.normalize_before:
294
+ xs = self.after_norm(xs)
295
+ # Here we assume the mask is not changed in encoder layers, so just
296
+ # return the masks before encoder layers, and the masks will be used
297
+ # for cross attention with decoder later
298
+ return xs, masks
299
+
300
+ def forward_chunk(
301
+ self,
302
+ xs: torch.Tensor,
303
+ offset: int,
304
+ required_cache_size: int,
305
+ att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
306
+ cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
307
+ att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
308
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
309
+ """ Forward just one chunk
310
+
311
+ Args:
312
+ xs (torch.Tensor): chunk input
313
+ offset (int): current offset in encoder output time stamp
314
+ required_cache_size (int): cache size required for next chunk
315
+ compuation
316
+ >=0: actual cache size
317
+ <0: means all history cache is required
318
+ att_cache (torch.Tensor): cache tensor for KEY & VALUE in
319
+ transformer/conformer attention, with shape
320
+ (elayers, head, cache_t1, d_k * 2), where
321
+ `head * d_k == hidden-dim` and
322
+ `cache_t1 == chunk_size * num_decoding_left_chunks`.
323
+ cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
324
+ (elayers, b=1, hidden-dim, cache_t2), where
325
+ `cache_t2 == cnn.lorder - 1`
326
+ att_mask : mask matrix of self attention
327
+
328
+ Returns:
329
+ torch.Tensor: output of current input xs
330
+ torch.Tensor: subsampling cache required for next chunk computation
331
+ List[torch.Tensor]: encoder layers output cache required for next
332
+ chunk computation
333
+ List[torch.Tensor]: conformer cnn cache
334
+
335
+ """
336
+ assert xs.size(0) == 1
337
+
338
+ # using downsampling factor to recover offset
339
+ offset *= self.calculate_downsampling_factor(self.num_blocks + 1)
340
+
341
+ chunk_masks = torch.ones(1,
342
+ xs.size(1),
343
+ device=xs.device,
344
+ dtype=torch.bool)
345
+ chunk_masks = chunk_masks.unsqueeze(1) # (1, 1, xs-time)
346
+
347
+ real_len = 0
348
+ if self.global_chunk_size > 0:
349
+ # for ONNX decode simulation, padding xs to chunk_size
350
+ real_len = xs.size(1)
351
+ pad_len = self.chunk_feature_map - real_len
352
+ xs = F.pad(xs, (0, 0, 0, pad_len), value=0.0)
353
+ chunk_masks = F.pad(chunk_masks, (0, pad_len), value=0.0)
354
+
355
+ if self.global_cmvn is not None:
356
+ xs = self.global_cmvn(xs)
357
+
358
+ # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
359
+ xs, pos_emb, chunk_masks = self.embed(xs, chunk_masks, offset)
360
+ elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
361
+ chunk_size = xs.size(1)
362
+ attention_key_size = cache_t1 + chunk_size
363
+ # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
364
+ # shape(pos_emb) = (b=1, chunk_size, emb_size=output_size=hidden-dim)
365
+
366
+ if required_cache_size < 0:
367
+ next_cache_start = 0
368
+ elif required_cache_size == 0:
369
+ next_cache_start = attention_key_size
370
+ else:
371
+ next_cache_start = max(attention_key_size - required_cache_size, 0)
372
+
373
+ r_att_cache = []
374
+ r_cnn_cache = []
375
+ mask_pad = torch.ones(1,
376
+ xs.size(1),
377
+ device=xs.device,
378
+ dtype=torch.bool)
379
+ mask_pad = mask_pad.unsqueeze(1) # batchPad (b=1, 1, time=chunk_size)
380
+
381
+ if self.global_chunk_size > 0:
382
+ # for ONNX decode simulation
383
+ pos_emb = self.embed.position_encoding(
384
+ offset=max(offset - cache_t1, 0),
385
+ size=cache_t1 + self.global_chunk_size)
386
+ att_mask[:, :, -self.global_chunk_size:] = chunk_masks
387
+ mask_pad = chunk_masks.to(torch.bool)
388
+ else:
389
+ pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
390
+ size=attention_key_size)
391
+
392
+ max_att_len, max_cnn_len = 0, 0 # for repeat_interleave of new_att_cache
393
+ for i, layer in enumerate(self.encoders):
394
+ factor = self.calculate_downsampling_factor(i)
395
+ # NOTE(xcsong): Before layer.forward
396
+ # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
397
+ # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
398
+ # shape(new_att_cache) = [ batch, head, time2, outdim//head * 2 ]
399
+ att_cache_trunc = 0
400
+ if xs.size(1) + att_cache.size(2) / factor > pos_emb.size(1):
401
+ # The time step is not divisible by the downsampling multiple
402
+ att_cache_trunc = xs.size(1) + \
403
+ att_cache.size(2) // factor - pos_emb.size(1) + 1
404
+ xs, _, new_att_cache, new_cnn_cache = layer(
405
+ xs,
406
+ att_mask,
407
+ pos_emb,
408
+ mask_pad=mask_pad,
409
+ att_cache=att_cache[i:i +
410
+ 1, :, ::factor, :][:, :,
411
+ att_cache_trunc:, :],
412
+ cnn_cache=cnn_cache[i, :, :, :]
413
+ if cnn_cache.size(0) > 0 else cnn_cache)
414
+
415
+ if i in self.stride_layer_idx:
416
+ # compute time dimension for next block
417
+ efficient_index = self.stride_layer_idx.index(i)
418
+ att_mask = att_mask[:, ::self.stride[efficient_index], ::self.
419
+ stride[efficient_index]]
420
+ mask_pad = mask_pad[:, ::self.stride[efficient_index], ::self.
421
+ stride[efficient_index]]
422
+ pos_emb = pos_emb[:, ::self.stride[efficient_index], :]
423
+
424
+ # shape(new_att_cache) = [batch, head, time2, outdim]
425
+ new_att_cache = new_att_cache[:, :, next_cache_start // factor:, :]
426
+ # shape(new_cnn_cache) = [1, batch, outdim, cache_t2]
427
+ new_cnn_cache = new_cnn_cache.unsqueeze(0)
428
+
429
+ # use repeat_interleave to new_att_cache
430
+ new_att_cache = new_att_cache.repeat_interleave(repeats=factor,
431
+ dim=2)
432
+ # padding new_cnn_cache to cnn.lorder for casual convolution
433
+ new_cnn_cache = F.pad(
434
+ new_cnn_cache,
435
+ (self.cnn_module_kernel - 1 - new_cnn_cache.size(3), 0))
436
+
437
+ if i == 0:
438
+ # record length for the first block as max length
439
+ max_att_len = new_att_cache.size(2)
440
+ max_cnn_len = new_cnn_cache.size(3)
441
+
442
+ # update real shape of att_cache and cnn_cache
443
+ r_att_cache.append(new_att_cache[:, :, -max_att_len:, :])
444
+ r_cnn_cache.append(new_cnn_cache[:, :, :, -max_cnn_len:])
445
+
446
+ if self.normalize_before:
447
+ xs = self.after_norm(xs)
448
+
449
+ # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
450
+ # ? may be larger than cache_t1, it depends on required_cache_size
451
+ r_att_cache = torch.cat(r_att_cache, dim=0)
452
+ # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
453
+ r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
454
+
455
+ if self.global_chunk_size > 0 and real_len:
456
+ chunk_real_len = real_len // self.embed.subsampling_rate // \
457
+ self.calculate_downsampling_factor(self.num_blocks + 1)
458
+ # Keeping 1 more timestep can mitigate information leakage
459
+ # from the encoder caused by the padding
460
+ xs = xs[:, :chunk_real_len + 1, :]
461
+
462
+ return xs, r_att_cache, r_cnn_cache
463
+
464
+ def forward_chunk_by_chunk(
465
+ self,
466
+ xs: torch.Tensor,
467
+ decoding_chunk_size: int,
468
+ num_decoding_left_chunks: int = -1,
469
+ use_onnx=False) -> Tuple[torch.Tensor, torch.Tensor]:
470
+ """ Forward input chunk by chunk with chunk_size like a streaming
471
+ fashion
472
+
473
+ Here we should pay special attention to computation cache in the
474
+ streaming style forward chunk by chunk. Three things should be taken
475
+ into account for computation in the current network:
476
+ 1. transformer/conformer encoder layers output cache
477
+ 2. convolution in conformer
478
+ 3. convolution in subsampling
479
+
480
+ However, we don't implement subsampling cache for:
481
+ 1. We can control subsampling module to output the right result by
482
+ overlapping input instead of cache left context, even though it
483
+ wastes some computation, but subsampling only takes a very
484
+ small fraction of computation in the whole model.
485
+ 2. Typically, there are several covolution layers with subsampling
486
+ in subsampling module, it is tricky and complicated to do cache
487
+ with different convolution layers with different subsampling
488
+ rate.
489
+ 3. Currently, nn.Sequential is used to stack all the convolution
490
+ layers in subsampling, we need to rewrite it to make it work
491
+ with cache, which is not prefered.
492
+ Args:
493
+ xs (torch.Tensor): (1, max_len, dim)
494
+ decoding_chunk_size (int): decoding chunk size
495
+ num_decoding_left_chunks (int):
496
+ use_onnx (bool): True for simulating ONNX model inference.
497
+ """
498
+ assert decoding_chunk_size > 0
499
+ # The model is trained by static or dynamic chunk
500
+ assert self.static_chunk_size > 0 or self.use_dynamic_chunk
501
+ subsampling = self.embed.subsampling_rate
502
+ context = self.embed.right_context + 1 # Add current frame
503
+ stride = subsampling * decoding_chunk_size
504
+ decoding_window = (decoding_chunk_size - 1) * subsampling + context
505
+ num_frames = xs.size(1)
506
+
507
+ outputs = []
508
+ offset = 0
509
+ required_cache_size = decoding_chunk_size * num_decoding_left_chunks
510
+ if use_onnx:
511
+ logging.info("Simulating for ONNX runtime ...")
512
+ att_cache: torch.Tensor = torch.zeros(
513
+ (self.num_blocks, self.attention_heads, required_cache_size,
514
+ self.output_size() // self.attention_heads * 2),
515
+ device=xs.device)
516
+ cnn_cache: torch.Tensor = torch.zeros(
517
+ (self.num_blocks, 1, self.output_size(),
518
+ self.cnn_module_kernel - 1),
519
+ device=xs.device)
520
+ self.set_global_chunk_size(chunk_size=decoding_chunk_size)
521
+ else:
522
+ logging.info("Simulating for JIT runtime ...")
523
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0),
524
+ device=xs.device)
525
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0),
526
+ device=xs.device)
527
+
528
+ # Feed forward overlap input step by step
529
+ for cur in range(0, num_frames - context + 1, stride):
530
+ end = min(cur + decoding_window, num_frames)
531
+ logging.info(f"-->> frame chunk msg: cur={cur}, "
532
+ f"end={end}, num_frames={end-cur}, "
533
+ f"decoding_window={decoding_window}")
534
+ if use_onnx:
535
+ att_mask: torch.Tensor = torch.ones(
536
+ (1, 1, required_cache_size + decoding_chunk_size),
537
+ dtype=torch.bool,
538
+ device=xs.device)
539
+ if cur == 0:
540
+ att_mask[:, :, :required_cache_size] = 0
541
+ else:
542
+ att_mask: torch.Tensor = torch.ones((0, 0, 0),
543
+ dtype=torch.bool,
544
+ device=xs.device)
545
+
546
+ chunk_xs = xs[:, cur:end, :]
547
+ (y, att_cache, cnn_cache) = \
548
+ self.forward_chunk(
549
+ chunk_xs, offset, required_cache_size,
550
+ att_cache, cnn_cache, att_mask)
551
+ outputs.append(y)
552
+ offset += y.size(1)
553
+
554
+ ys = torch.cat(outputs, 1)
555
+ masks = torch.ones(1,
556
+ 1,
557
+ ys.size(1),
558
+ device=ys.device,
559
+ dtype=torch.bool)
560
+ return ys, masks
wenet/efficient_conformer/encoder_layer.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song ([email protected])
3
+ # 2022 58.com(Wuba) Inc AI Lab.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ """Encoder self-attention layer definition."""
18
+
19
+ from typing import Optional, Tuple
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class StrideConformerEncoderLayer(nn.Module):
25
+ """Encoder layer module.
26
+ Args:
27
+ size (int): Input dimension.
28
+ self_attn (torch.nn.Module): Self-attention module instance.
29
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
30
+ instance can be used as the argument.
31
+ feed_forward (torch.nn.Module): Feed-forward module instance.
32
+ `PositionwiseFeedForward` instance can be used as the argument.
33
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
34
+ instance.
35
+ `PositionwiseFeedForward` instance can be used as the argument.
36
+ conv_module (torch.nn.Module): Convolution module instance.
37
+ `ConvlutionModule` instance can be used as the argument.
38
+ dropout_rate (float): Dropout rate.
39
+ normalize_before (bool):
40
+ True: use layer_norm before each sub-block.
41
+ False: use layer_norm after each sub-block.
42
+ """
43
+
44
+ def __init__(self,
45
+ size: int,
46
+ self_attn: torch.nn.Module,
47
+ feed_forward: Optional[nn.Module] = None,
48
+ feed_forward_macaron: Optional[nn.Module] = None,
49
+ conv_module: Optional[nn.Module] = None,
50
+ pointwise_conv_layer: Optional[nn.Module] = None,
51
+ dropout_rate: float = 0.1,
52
+ normalize_before: bool = True):
53
+ """Construct an EncoderLayer object."""
54
+ super().__init__()
55
+ self.self_attn = self_attn
56
+ self.feed_forward = feed_forward
57
+ self.feed_forward_macaron = feed_forward_macaron
58
+ self.conv_module = conv_module
59
+ self.pointwise_conv_layer = pointwise_conv_layer
60
+ self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
61
+ self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
62
+ if feed_forward_macaron is not None:
63
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
64
+ self.ff_scale = 0.5
65
+ else:
66
+ self.ff_scale = 1.0
67
+ if self.conv_module is not None:
68
+ self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
69
+ self.norm_final = nn.LayerNorm(
70
+ size, eps=1e-5) # for the final output of the block
71
+ self.dropout = nn.Dropout(dropout_rate)
72
+ self.size = size
73
+ self.normalize_before = normalize_before
74
+ self.concat_linear = nn.Linear(size + size, size)
75
+
76
+ def forward(
77
+ self,
78
+ x: torch.Tensor,
79
+ mask: torch.Tensor,
80
+ pos_emb: torch.Tensor,
81
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
82
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
83
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
84
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
85
+ """Compute encoded features.
86
+
87
+ Args:
88
+ x (torch.Tensor): (#batch, time, size)
89
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
90
+ (0, 0, 0) means fake mask.
91
+ pos_emb (torch.Tensor): positional encoding, must not be None
92
+ for ConformerEncoderLayer.
93
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
94
+ (#batch, 1,time), (0, 0, 0) means fake mask.
95
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
96
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
97
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
98
+ (#batch=1, size, cache_t2)
99
+ Returns:
100
+ torch.Tensor: Output tensor (#batch, time, size).
101
+ torch.Tensor: Mask tensor (#batch, time, time).
102
+ torch.Tensor: att_cache tensor,
103
+ (#batch=1, head, cache_t1 + time, d_k * 2).
104
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
105
+ """
106
+
107
+ # whether to use macaron style
108
+ if self.feed_forward_macaron is not None:
109
+ residual = x
110
+ if self.normalize_before:
111
+ x = self.norm_ff_macaron(x)
112
+ x = residual + self.ff_scale * self.dropout(
113
+ self.feed_forward_macaron(x))
114
+ if not self.normalize_before:
115
+ x = self.norm_ff_macaron(x)
116
+
117
+ # multi-headed self-attention module
118
+ residual = x
119
+ if self.normalize_before:
120
+ x = self.norm_mha(x)
121
+
122
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
123
+ att_cache)
124
+
125
+ x = residual + self.dropout(x_att)
126
+ if not self.normalize_before:
127
+ x = self.norm_mha(x)
128
+
129
+ # convolution module
130
+ # Fake new cnn cache here, and then change it in conv_module
131
+ new_cnn_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device)
132
+ if self.conv_module is not None:
133
+ residual = x
134
+ if self.normalize_before:
135
+ x = self.norm_conv(x)
136
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
137
+
138
+ # add pointwise_conv for efficient conformer
139
+ # pointwise_conv_layer does not change shape
140
+ if self.pointwise_conv_layer is not None:
141
+ residual = residual.transpose(1, 2)
142
+ residual = self.pointwise_conv_layer(residual)
143
+ residual = residual.transpose(1, 2)
144
+ assert residual.size(0) == x.size(0)
145
+ assert residual.size(1) == x.size(1)
146
+ assert residual.size(2) == x.size(2)
147
+
148
+ x = residual + self.dropout(x)
149
+
150
+ if not self.normalize_before:
151
+ x = self.norm_conv(x)
152
+
153
+ # feed forward module
154
+ residual = x
155
+ if self.normalize_before:
156
+ x = self.norm_ff(x)
157
+
158
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
159
+ if not self.normalize_before:
160
+ x = self.norm_ff(x)
161
+
162
+ if self.conv_module is not None:
163
+ x = self.norm_final(x)
164
+
165
+ return x, mask, new_att_cache, new_cnn_cache
wenet/efficient_conformer/subsampling.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 58.com(Wuba) Inc AI Lab.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Subsampling layer definition."""
17
+
18
+ from typing import Tuple, Union
19
+
20
+ import torch
21
+ from wenet.transformer.subsampling import BaseSubsampling
22
+
23
+
24
+ class Conv2dSubsampling2(BaseSubsampling):
25
+ """Convolutional 2D subsampling (to 1/4 length).
26
+
27
+ Args:
28
+ idim (int): Input dimension.
29
+ odim (int): Output dimension.
30
+ dropout_rate (float): Dropout rate.
31
+
32
+ """
33
+
34
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
35
+ pos_enc_class: torch.nn.Module):
36
+ """Construct an Conv2dSubsampling4 object."""
37
+ super().__init__()
38
+ self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, odim, 3, 2),
39
+ torch.nn.ReLU())
40
+ self.out = torch.nn.Sequential(
41
+ torch.nn.Linear(odim * ((idim - 1) // 2), odim))
42
+ self.pos_enc = pos_enc_class
43
+ # The right context for every conv layer is computed by:
44
+ # (kernel_size - 1) * frame_rate_of_this_layer
45
+ self.subsampling_rate = 2
46
+ # 2 = (3 - 1) * 1
47
+ self.right_context = 2
48
+
49
+ def forward(
50
+ self,
51
+ x: torch.Tensor,
52
+ x_mask: torch.Tensor,
53
+ offset: Union[int, torch.Tensor] = 0
54
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
55
+ """Subsample x.
56
+
57
+ Args:
58
+ x (torch.Tensor): Input tensor (#batch, time, idim).
59
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
60
+
61
+ Returns:
62
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
63
+ where time' = time // 2.
64
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
65
+ where time' = time // 2.
66
+ torch.Tensor: positional encoding
67
+
68
+ """
69
+ x = x.unsqueeze(1) # (b, c=1, t, f)
70
+ x = self.conv(x)
71
+ b, c, t, f = x.size()
72
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
73
+ x, pos_emb = self.pos_enc(x, offset)
74
+ return x, pos_emb, x_mask[:, :, :-2:2]
wenet/finetune/lora/__init__.py ADDED
File without changes
wenet/finetune/lora/config.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ init_batch_size: 2
2
+ init_iters: 8
3
+ init_config:
4
+ mode: "gradient" # option: "simple", "svd", "gradient"
5
+ lora_A: "unit" # option: "gaussian", "kaiming", "fan_out_kaiming", "xavier", "zeros", "unit", "orthogonal"
6
+ lora_A_std: 0.01 # only needed when lora_A is "gaussian"
7
+ lora_B: "unit" # option: "gaussian", "kaiming", "fan_out_kaiming", "xavier", "zeros", "unit", "orthogonal"
8
+ lora_B_std: 0.01 # only needed when lora_B is "gaussian"
9
+ scale: "stable" # option: "default", "stable", "unit", "normalized", "gd", "weightS"
10
+ stable_gamma: 2 # only needed when scale is "stable"
11
+ direction: "ArB2r" # option: "ArBr", "A2rBr", "ArB2r"(only needed when mode is "gradient")
12
+ dtype: "fp32" # option: "bf16", "fp32"
13
+ norm_clip: false # norm clipping
wenet/finetune/lora/layers.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 microsoft
2
+ # 2023 Alan ([email protected])
3
+ # -----------------------------------------------------------------------------
4
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for
5
+ # license information.
6
+ # -----------------------------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ import math
13
+ from typing import List
14
+
15
+
16
+ class LoRALayer():
17
+
18
+ def __init__(
19
+ self,
20
+ r: int,
21
+ lora_alpha: int,
22
+ lora_dropout: float,
23
+ merge_weights: bool,
24
+ ):
25
+ self.r = r
26
+ self.lora_alpha = lora_alpha
27
+ # Optional dropout
28
+ if lora_dropout > 0.:
29
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
30
+ else:
31
+ self.lora_dropout = self.identity
32
+ # Mark the weight as unmerged
33
+ self.merged = False
34
+ self.merge_weights = merge_weights
35
+
36
+ def identity(self, x):
37
+ return x
38
+
39
+
40
+ class Embedding(nn.Embedding, LoRALayer):
41
+ # LoRA implemented in a dense layer
42
+ def __init__(self,
43
+ num_embeddings: int,
44
+ embedding_dim: int,
45
+ r: int = 0,
46
+ lora_alpha: int = 1,
47
+ merge_weights: bool = True,
48
+ **kwargs):
49
+ nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
50
+ LoRALayer.__init__(self,
51
+ r=r,
52
+ lora_alpha=lora_alpha,
53
+ lora_dropout=0,
54
+ merge_weights=merge_weights)
55
+ # Actual trainable parameters
56
+ if r > 0:
57
+ self.lora_A = nn.Parameter(
58
+ self.weight.new_zeros((r, num_embeddings)))
59
+ self.lora_B = nn.Parameter(
60
+ self.weight.new_zeros((embedding_dim, r)))
61
+ self.scaling = self.lora_alpha / self.r
62
+ # Freezing the pre-trained weight matrix
63
+ self.weight.requires_grad = False
64
+ self.reset_parameters()
65
+
66
+ def reset_parameters(self):
67
+ nn.Embedding.reset_parameters(self)
68
+ if hasattr(self, 'lora_A'):
69
+ # initialize A the same way as the default for nn.Linear and B to zero
70
+ nn.init.zeros_(self.lora_A)
71
+ nn.init.normal_(self.lora_B)
72
+
73
+ def train(self, mode: bool = True):
74
+ nn.Embedding.train(self, mode)
75
+ if mode:
76
+ if self.merge_weights and self.merged:
77
+ # Make sure that the weights are not merged
78
+ if self.r > 0:
79
+ temp = (self.lora_B @ self.lora_A).transpose(0, 1)
80
+ self.weight.data -= temp * self.scaling
81
+ self.merged = False
82
+ else:
83
+ if self.merge_weights and not self.merged:
84
+ # Merge the weights and mark it
85
+ if self.r > 0:
86
+ temp = (self.lora_B @ self.lora_A).transpose(0, 1)
87
+ self.weight.data += temp * self.scaling
88
+ self.merged = True
89
+
90
+ def forward(self, x: torch.Tensor):
91
+ if self.r > 0 and not self.merged:
92
+ result = nn.Embedding.forward(self, x)
93
+ after_A = F.embedding(x, self.lora_A.transpose(0, 1),
94
+ self.padding_idx, self.max_norm,
95
+ self.norm_type, self.scale_grad_by_freq,
96
+ self.sparse)
97
+ result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
98
+ return result
99
+ else:
100
+ return nn.Embedding.forward(self, x)
101
+
102
+
103
+ class Linear(nn.Linear, LoRALayer):
104
+ # LoRA implemented in a dense layer
105
+ def __init__(
106
+ self,
107
+ in_features: int,
108
+ out_features: int,
109
+ r: int = 0,
110
+ lora_alpha: int = 1,
111
+ lora_dropout: float = 0.,
112
+ fan_in_fan_out: bool = False,
113
+ # Set this to True if the layer to replace stores weight like (fan_in,
114
+ # fan_out)
115
+ merge_weights: bool = True,
116
+ **kwargs):
117
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
118
+ LoRALayer.__init__(self,
119
+ r=r,
120
+ lora_alpha=lora_alpha,
121
+ lora_dropout=lora_dropout,
122
+ merge_weights=merge_weights)
123
+
124
+ self.fan_in_fan_out = fan_in_fan_out
125
+ # Actual trainable parameters
126
+ if r > 0:
127
+ self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
128
+ self.lora_B = nn.Parameter(self.weight.new_zeros(
129
+ (out_features, r)))
130
+ self.scaling = self.lora_alpha / self.r
131
+ # Freezing the pre-trained weight matrix
132
+ self.weight.requires_grad = False
133
+ self.reset_parameters()
134
+ if fan_in_fan_out:
135
+ self.weight.data = self.weight.data.transpose(0, 1)
136
+
137
+ def reset_parameters(self):
138
+ nn.Linear.reset_parameters(self)
139
+ if hasattr(self, 'lora_A'):
140
+ # initialize A the same way as the default for nn.Linear and B to zero
141
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
142
+ nn.init.zeros_(self.lora_B)
143
+
144
+ def T(self, w):
145
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
146
+
147
+ def train(self, mode: bool = True):
148
+ nn.Linear.train(self, mode)
149
+ if mode:
150
+ if self.merge_weights and self.merged:
151
+ # Make sure that the weights are not merged
152
+ if self.r > 0:
153
+ temp = self.T(self.lora_B @ self.lora_A)
154
+ self.weight.data -= temp * self.scaling
155
+ self.merged = False
156
+ else:
157
+ if self.merge_weights and not self.merged:
158
+ # Merge the weights and mark it
159
+ if self.r > 0:
160
+ temp = self.T(self.lora_B @ self.lora_A)
161
+ self.weight.data += temp * self.scaling
162
+ self.merged = True
163
+
164
+ def forward(self, x: torch.Tensor):
165
+ if self.r > 0 and not self.merged:
166
+ result = F.linear(x, self.T(self.weight), bias=self.bias)
167
+ result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1)
168
+ @ self.lora_B.transpose(0, 1)) * self.scaling
169
+ return result
170
+ else:
171
+ return F.linear(x, self.T(self.weight), bias=self.bias)
172
+
173
+
174
+ class MergedLinear(nn.Linear, LoRALayer):
175
+ # LoRA implemented in a dense layer
176
+ def __init__(self,
177
+ in_features: int,
178
+ out_features: int,
179
+ r: int = 0,
180
+ lora_alpha: int = 1,
181
+ lora_dropout: float = 0.,
182
+ enable_lora: List[bool] = None,
183
+ fan_in_fan_out: bool = False,
184
+ merge_weights: bool = True,
185
+ **kwargs):
186
+ if enable_lora is None:
187
+ enable_lora = [False]
188
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
189
+ LoRALayer.__init__(self,
190
+ r=r,
191
+ lora_alpha=lora_alpha,
192
+ lora_dropout=lora_dropout,
193
+ merge_weights=merge_weights)
194
+ assert out_features % len(enable_lora) == 0, \
195
+ 'The length of enable_lora must divide out_features'
196
+ self.enable_lora = enable_lora
197
+ self.fan_in_fan_out = fan_in_fan_out
198
+ # Actual trainable parameters
199
+ if r > 0 and any(enable_lora):
200
+ self.lora_A = nn.Parameter(
201
+ self.weight.new_zeros((r * sum(enable_lora), in_features)))
202
+ self.lora_B = nn.Parameter(
203
+ self.weight.new_zeros(
204
+ (out_features // len(enable_lora) * sum(enable_lora), r)))
205
+ # weights for Conv1D with groups=sum(enable_lora)
206
+ self.scaling = self.lora_alpha / self.r
207
+ # Freezing the pre-trained weight matrix
208
+ self.weight.requires_grad = False
209
+ # Compute the indices
210
+ self.lora_ind = self.weight.new_zeros(
211
+ (out_features, ), dtype=torch.bool).view(len(enable_lora), -1)
212
+ self.lora_ind[enable_lora, :] = True
213
+ self.lora_ind = self.lora_ind.view(-1)
214
+ self.reset_parameters()
215
+ if fan_in_fan_out:
216
+ self.weight.data = self.weight.data.transpose(0, 1)
217
+
218
+ def reset_parameters(self):
219
+ nn.Linear.reset_parameters(self)
220
+ if hasattr(self, 'lora_A'):
221
+ # initialize A the same way as the default for nn.Linear and B to zero
222
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
223
+ nn.init.zeros_(self.lora_B)
224
+
225
+ def zero_pad(self, x):
226
+ result = x.new_zeros((len(self.lora_ind), *x.size()[1:]))
227
+ result[self.lora_ind] = x
228
+ return result
229
+
230
+ def T(self, w):
231
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
232
+
233
+ def merge_AB(self):
234
+ delta_w = F.conv1d(self.lora_A.unsqueeze(0),
235
+ self.lora_B.unsqueeze(-1),
236
+ groups=sum(self.enable_lora)).squeeze(0)
237
+ return self.T(delta_w)
238
+
239
+ def train(self, mode: bool = True):
240
+ nn.Linear.train(self, mode)
241
+ if mode:
242
+ if self.merge_weights and self.merged:
243
+ # Make sure that the weights are not merged
244
+ if self.r > 0 and any(self.enable_lora):
245
+ self.weight.data -= self.merge_AB() * self.scaling
246
+ self.merged = False
247
+ else:
248
+ if self.merge_weights and not self.merged:
249
+ # Merge the weights and mark it
250
+ if self.r > 0 and any(self.enable_lora):
251
+ self.weight.data += self.merge_AB() * self.scaling
252
+ self.merged = True
253
+
254
+ def forward(self, x: torch.Tensor):
255
+ if self.merged:
256
+ return F.linear(x, self.T(self.weight), bias=self.bias)
257
+ else:
258
+ result = F.linear(x, self.T(self.weight), bias=self.bias)
259
+ if self.r > 0:
260
+ temp = self.T(self.merge_AB().T)
261
+ result += self.lora_dropout(x) @ temp * self.scaling
262
+ return result
263
+
264
+
265
+ class ConvLoRA(nn.Module, LoRALayer):
266
+
267
+ def __init__(self,
268
+ conv_module,
269
+ in_channels,
270
+ out_channels,
271
+ kernel_size,
272
+ r=0,
273
+ lora_alpha=1,
274
+ lora_dropout=0.,
275
+ merge_weights=True,
276
+ **kwargs):
277
+ super(ConvLoRA, self).__init__()
278
+ self.conv = conv_module(in_channels, out_channels, kernel_size,
279
+ **kwargs)
280
+ LoRALayer.__init__(self,
281
+ r=r,
282
+ lora_alpha=lora_alpha,
283
+ lora_dropout=lora_dropout,
284
+ merge_weights=merge_weights)
285
+ assert isinstance(kernel_size, int)
286
+ # Actual trainable parameters
287
+ if r > 0:
288
+ self.lora_A = nn.Parameter(
289
+ self.conv.weight.new_zeros(
290
+ (r * kernel_size, in_channels * kernel_size)))
291
+ self.lora_B = nn.Parameter(
292
+ self.conv.weight.new_zeros(
293
+ (out_channels // self.conv.groups * kernel_size,
294
+ r * kernel_size)))
295
+ self.scaling = self.lora_alpha / self.r
296
+ # Freezing the pre-trained weight matrix
297
+ self.conv.weight.requires_grad = False
298
+ self.reset_parameters()
299
+ self.merged = False
300
+
301
+ def reset_parameters(self):
302
+ self.conv.reset_parameters()
303
+ if hasattr(self, 'lora_A'):
304
+ # initialize A the same way as the default for nn.Linear and B to zero
305
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
306
+ nn.init.zeros_(self.lora_B)
307
+
308
+ def train(self, mode=True):
309
+ super(ConvLoRA, self).train(mode)
310
+ if mode:
311
+ if self.merge_weights and self.merged:
312
+ if self.r > 0:
313
+ # Make sure that the weights are not merged
314
+ self.conv.weight.data -= (self.lora_B @ self.lora_A).view(
315
+ self.conv.weight.shape) * self.scaling
316
+ self.merged = False
317
+ else:
318
+ if self.merge_weights and not self.merged:
319
+ if self.r > 0:
320
+ # Merge the weights and mark it
321
+ self.conv.weight.data += (self.lora_B @ self.lora_A).view(
322
+ self.conv.weight.shape) * self.scaling
323
+ self.merged = True
324
+
325
+ def forward(self, x):
326
+ if self.r > 0 and not self.merged:
327
+ return self.conv._conv_forward(
328
+ x, self.conv.weight +
329
+ (self.lora_B @ self.lora_A).view(self.conv.weight.shape) *
330
+ self.scaling, self.conv.bias)
331
+ return self.conv(x)
332
+
333
+
334
+ class Conv2d(ConvLoRA):
335
+
336
+ def __init__(self, *args, **kwargs):
337
+ super(Conv2d, self).__init__(nn.Conv2d, *args, **kwargs)
338
+
339
+
340
+ class Conv1d(ConvLoRA):
341
+
342
+ def __init__(self, *args, **kwargs):
343
+ super(Conv1d, self).__init__(nn.Conv1d, *args, **kwargs)
344
+
345
+
346
+ # Can Extend to other ones like this
347
+ class Conv3d(ConvLoRA):
348
+
349
+ def __init__(self, *args, **kwargs):
350
+ super(Conv3d, self).__init__(nn.Conv3d, *args, **kwargs)
wenet/finetune/lora/utils.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 microsoft
2
+ # 2023 Alan ([email protected])
3
+ # -----------------------------------------------------------------------------
4
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for
5
+ # license information.
6
+ # -----------------------------------------------------------------------------
7
+
8
+ import logging
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from typing import Dict, List
13
+
14
+ import wenet.finetune.lora.layers as lora
15
+
16
+
17
+ def get_nested_attr(module, attr_path):
18
+ attrs = attr_path.split('.')
19
+ for attr in attrs:
20
+ if hasattr(module, attr):
21
+ module = getattr(module, attr)
22
+ else:
23
+ return None
24
+ return module
25
+
26
+
27
+ def inject_lora(module, lora_config):
28
+ lora_rank = lora_config["lora_rank"]
29
+ lora_alpha = lora_config["lora_alpha"]
30
+ lora_dropout = lora_config["lora_dropout"]
31
+ for lora_attr in lora_config["lora_list"]:
32
+ if hasattr(module, lora_attr):
33
+ submodule = getattr(module, lora_attr)
34
+ n_feat = submodule.in_features
35
+ lora_linear = lora.Linear(n_feat, n_feat, r=lora_rank,
36
+ lora_alpha=lora_alpha,
37
+ lora_dropout=lora_dropout)
38
+ setattr(module, lora_attr, lora_linear)
39
+
40
+
41
+ def inject_lora_to_model(model, lora_config):
42
+ lora_modules = []
43
+ for module in lora_config["lora_modules"]:
44
+ submodule = get_nested_attr(model, module)
45
+ for layer in submodule:
46
+ lora_modules.append(layer)
47
+
48
+ updated_lora_modules = []
49
+ for i in range(len(lora_modules)):
50
+ for attn_attr in lora_config["lora_attn_attr"]:
51
+ if hasattr(lora_modules[i], attn_attr):
52
+ updated_lora_modules.append(getattr(lora_modules[i], attn_attr))
53
+
54
+ for lora_module in updated_lora_modules:
55
+ inject_lora(lora_module, lora_config)
56
+
57
+
58
+ def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
59
+ logging.info('freezing all params except lora module.')
60
+ for n, p in model.named_parameters():
61
+ if 'lora_' not in n:
62
+ p.requires_grad = False
63
+ if bias == 'none':
64
+ return
65
+ elif bias == 'all':
66
+ for n, p in model.named_parameters():
67
+ if 'bias' in n:
68
+ p.requires_grad = True
69
+ elif bias == 'lora_only':
70
+ for m in model.modules():
71
+ if isinstance(m, lora.LoRALayer) and \
72
+ hasattr(m, 'bias') and \
73
+ m.bias is not None:
74
+ m.bias.requires_grad = True
75
+ else:
76
+ raise NotImplementedError
77
+
78
+
79
+ def lora_state_dict(model: nn.Module,
80
+ bias: str = 'none') -> Dict[str, torch.Tensor]:
81
+ my_state_dict = model.state_dict()
82
+ if bias == 'none':
83
+ return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
84
+ elif bias == 'all':
85
+ return {
86
+ k: my_state_dict[k]
87
+ for k in my_state_dict if 'lora_' in k or 'bias' in k
88
+ }
89
+ elif bias == 'lora_only':
90
+ to_return = {}
91
+ for k in my_state_dict:
92
+ if 'lora_' in k:
93
+ to_return[k] = my_state_dict[k]
94
+ bias_name = k.split('lora_')[0] + 'bias'
95
+ if bias_name in my_state_dict:
96
+ to_return[bias_name] = my_state_dict[bias_name]
97
+ return to_return
98
+ else:
99
+ raise NotImplementedError
100
+
101
+
102
+ def get_record_gradient_hook(model, record_dict):
103
+ def record_gradient_hook(grad):
104
+ for n, p in model.named_parameters():
105
+ if p.requires_grad and p.grad is not None:
106
+ if n not in record_dict:
107
+ record_dict[n] = p.grad.cpu()
108
+ else:
109
+ record_dict[n] += p.grad.cpu()
110
+ p.grad = None
111
+ return grad
112
+
113
+ return record_gradient_hook
114
+
115
+
116
+ def estimate_gradient(
117
+ model, dataloader, max_iters: int = 8,
118
+ device: torch.device = torch.device("cpu")
119
+ ) -> Dict[str, List[torch.Tensor]]:
120
+ r"""
121
+ Estimate the gradient of the model on the given dataset
122
+ """
123
+ logging.info("Estimating gradient layer by layer, time needed")
124
+ model.train()
125
+ named_grads = {}
126
+ hooks = []
127
+ requires_grad_states = {}
128
+ for name, param in model.named_parameters():
129
+ requires_grad_states[name] = param.requires_grad
130
+ param.requires_grad = True
131
+ hook = param.register_hook(get_record_gradient_hook(model, named_grads))
132
+ hooks.append(hook)
133
+ num = 0
134
+ for _, batch_dict in enumerate(dataloader):
135
+ num += 1
136
+ if max_iters is not None and num >= max_iters:
137
+ break
138
+ outputs = model(batch_dict, device)
139
+ outputs['loss'].backward()
140
+ get_record_gradient_hook(model, named_grads)(None) # get gradient of last layer
141
+ # make sure the gradient is cleared
142
+ for n, p in model.named_parameters():
143
+ if p.grad is not None:
144
+ p.grad = None
145
+ for n, _ in named_grads.items():
146
+ named_grads[n] /= num
147
+ for hook in hooks:
148
+ hook.remove()
149
+ # recover original requires_grad states
150
+ for name, param in model.named_parameters():
151
+ param.requires_grad = requires_grad_states[name]
152
+ torch.cuda.empty_cache()
153
+ return named_grads
154
+
155
+
156
+ @torch.no_grad()
157
+ def reinit_lora_modules(name, module, init_config, **kwargs):
158
+ r"""Refer to https://github.com/Outsider565/LoRA-GA/blob/
159
+ c185846309ea9012d0bcd46ebd30347dda1c592c/run_exp.py#L67
160
+ Reinitialize the lora model with the given configuration.
161
+ """
162
+ import math
163
+ lora_r = min(module.lora_A.shape)
164
+ a_dim = max(module.lora_A.shape)
165
+ b_dim = max(module.lora_B.shape)
166
+ if init_config.mode == "simple":
167
+ match init_config.lora_A:
168
+ case "gaussian":
169
+ torch.nn.init.normal_(
170
+ module.lora_A, mean=0.0,
171
+ std=init_config.lora_A_std
172
+ )
173
+ case "kaiming":
174
+ # https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124
175
+ torch.nn.init.kaiming_uniform_(module.lora_A,
176
+ a=math.sqrt(5))
177
+ case "fan_out_kaiming":
178
+ torch.nn.init.kaiming_normal_(
179
+ module.lora_A, mode="fan_out"
180
+ )
181
+ case "xavier":
182
+ torch.nn.init.xavier_normal_(module.lora_A)
183
+ case "zeros":
184
+ torch.nn.init.zeros_(module.lora_A)
185
+ case "unit":
186
+ torch.nn.init.normal_(
187
+ module.lora_A, mean=0.0,
188
+ std=1.0 / (a_dim**0.5)
189
+ )
190
+ case "orthogonal":
191
+ torch.nn.init.orthogonal_(module.lora_A)
192
+ case _:
193
+ raise ValueError(
194
+ f"Unknown lora_A initialization: {init_config.lora_A}"
195
+ )
196
+ match init_config.lora_B:
197
+ case "gaussian":
198
+ torch.nn.init.normal_(
199
+ module.lora_B, mean=0.0,
200
+ std=init_config.lora_B_std
201
+ )
202
+ case "kaiming":
203
+ torch.nn.init.kaiming_normal_(module.lora_B)
204
+ case "fan_out_kaiming":
205
+ torch.nn.init.kaiming_normal_(
206
+ module.lora_B, mode="fan_out"
207
+ )
208
+ case "xavier":
209
+ torch.nn.init.xavier_normal_(module.lora_B)
210
+ case "zeros":
211
+ torch.nn.init.zeros_(module.lora_B)
212
+ case "unit":
213
+ torch.nn.init.normal_(
214
+ module.lora_B, mean=0.0,
215
+ std=1.0 / (b_dim**0.5)
216
+ )
217
+ case "orthogonal":
218
+ torch.nn.init.orthogonal_(module.lora_B)
219
+ case _:
220
+ raise ValueError(
221
+ f"Unknown lora_B initialization: {init_config.lora_B}"
222
+ )
223
+ if getattr(init_config, 'scale', '') == "stable":
224
+ gamma = init_config.stable_gamma
225
+ m, n = module.weight.shape
226
+ module.lora_B.data *= (m**0.25) / gamma**0.5
227
+ module.lora_A.data *= (n**0.25) / gamma**0.5
228
+ elif init_config.mode == "svd":
229
+ U, S, V = torch.svd_lowrank(module.weight.float(), q=4 * lora_r,
230
+ niter=4)
231
+ V = V.T
232
+ m, n = module.weight.shape
233
+ if init_config.scale == "default":
234
+ S = S / module.scaling
235
+ module.lora_B = torch.nn.Parameter(
236
+ (U[:, :lora_r] * torch.sqrt(S[:lora_r])).contiguous()
237
+ )
238
+ module.lora_A = torch.nn.Parameter(
239
+ (V[:lora_r, :].T * torch.sqrt(S[:lora_r])).T.contiguous()
240
+ )
241
+ elif init_config.scale == "stable":
242
+ gamma = init_config.stable_gamma
243
+ module.lora_B = torch.nn.Parameter(
244
+ (U[:, :lora_r] * (m**0.25) / gamma**0.5).contiguous()
245
+ )
246
+ module.lora_A = torch.nn.Parameter(
247
+ (V[:lora_r, :] * (n**0.25) / gamma**0.5).contiguous()
248
+ )
249
+ elif init_config.scale == "unit":
250
+ module.lora_B = torch.nn.Parameter((U[:, :lora_r]).contiguous())
251
+ module.lora_A = torch.nn.Parameter((V[:lora_r, :]).contiguous())
252
+ elif init_config.scale == "normalized":
253
+ S_sum = S[:lora_r].sum()
254
+ module.lora_B = torch.nn.Parameter(
255
+ (U[:, :lora_r] * torch.sqrt(S[:lora_r])
256
+ / torch.sqrt(S_sum) * lora_r**0.5).contiguous()
257
+ )
258
+ module.lora_A = torch.nn.Parameter(
259
+ (V[:lora_r, :].T * torch.sqrt(S[:lora_r])
260
+ / torch.sqrt(S_sum) * lora_r**0.5).T.contiguous()
261
+ )
262
+ elif init_config.mode == "gradient":
263
+ named_grad = kwargs["named_grads"]
264
+ grad_name = name + ".weight"
265
+ grads = named_grad[grad_name]
266
+ U, S, V = torch.svd_lowrank(grads.cuda().float(), q=4 * lora_r, niter=4)
267
+ V = V.T
268
+ # set direction
269
+ if init_config.direction == "ArBr":
270
+ B = U[:, 0 : 2 * lora_r : 2]
271
+ A = V[1 : 2 * lora_r : 2, :]
272
+ elif init_config.direction == "A2rBr":
273
+ B = U[:, :lora_r]
274
+ A = V[lora_r : 2 * lora_r, :]
275
+ elif init_config.direction == "ArB2r":
276
+ B = U[:, lora_r : 2 * lora_r]
277
+ A = V[:lora_r, :]
278
+ scaling_factor = module.scaling
279
+ if init_config.scale == "gd":
280
+ A = A / scaling_factor
281
+ B = B / scaling_factor
282
+ elif init_config.scale == "unit":
283
+ # Because A,B is orthogonal, do not need to scale
284
+ pass
285
+ elif init_config.scale == "stable":
286
+ m, n = grads.shape
287
+ # m: feature_out, n: feature_in
288
+ # the scale of output is only related to the feature_out
289
+ gamma = init_config.stable_gamma
290
+ B = B * m**0.25 / gamma**0.5
291
+ A = A * m**0.25 / gamma**0.5
292
+ elif init_config.scale == "weightS":
293
+ _, S, _ = torch.svd_lowrank(module.weight.float(), q=4 * lora_r,
294
+ niter=4)
295
+ S = S / module.scaling
296
+ avg_s = torch.sqrt(S[:lora_r]).mean().to(A.device)
297
+ B = B * avg_s
298
+ A = A * avg_s
299
+ module.lora_B = torch.nn.Parameter(B.contiguous().cuda())
300
+ module.lora_A = torch.nn.Parameter(A.contiguous().cuda())
301
+
302
+ with torch.no_grad():
303
+ # consider dtype not in init_config
304
+ if not hasattr(init_config, "dtype"):
305
+ pass
306
+ elif init_config.dtype == "bf16":
307
+ module.lora_A.data = module.lora_A.data.to(torch.bfloat16)
308
+ module.lora_B.data = module.lora_B.data.to(torch.bfloat16)
309
+ elif init_config.dtype == "fp32":
310
+ module.lora_A.data = module.lora_A.data.to(torch.float32)
311
+ module.lora_B.data = module.lora_B.data.to(torch.float32)
312
+ # If lora_A@lora_B is not zero,
313
+ # then we need to subtract lora_A@lora_B from the original weight matrix
314
+ offset = (
315
+ module.lora_B @ module.lora_A
316
+ ).to(module.weight.data.device)
317
+ scaling_factor = module.scaling
318
+ offset *= scaling_factor
319
+ if hasattr(init_config, "norm_clip") and init_config.norm_clip:
320
+ # for numerical stability,
321
+ # offset's largest value must be less then weight's largest value
322
+ ratio = torch.max(torch.abs(module.weight.data)) / torch.max(
323
+ torch.abs(offset)
324
+ )
325
+ if ratio < 1:
326
+ offset *= ratio
327
+ module.lora_A.data *= ratio**0.5
328
+ module.lora_B.data *= ratio**0.5
329
+ logging.warning(f"Clipping offset by {ratio}")
330
+ try:
331
+ module.weight.data -= offset
332
+ except Exception as e:
333
+ logging.warning(f"{e}")
334
+ breakpoint()
wenet/k2/__init__.py ADDED
File without changes
wenet/k2/model.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Binbin Zhang ([email protected])
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, Tuple
16
+
17
+ import torch
18
+ from torch.nn.utils.rnn import pad_sequence
19
+
20
+ from wenet.transformer.asr_model import ASRModel
21
+ from wenet.transformer.ctc import CTC
22
+ from wenet.transformer.decoder import TransformerDecoder
23
+ from wenet.transformer.encoder import TransformerEncoder
24
+ from wenet.utils.common import (IGNORE_ID, add_sos_eos, reverse_pad_list)
25
+
26
+
27
+ class K2Model(ASRModel):
28
+
29
+ def __init__(
30
+ self,
31
+ vocab_size: int,
32
+ encoder: TransformerEncoder,
33
+ decoder: TransformerDecoder,
34
+ ctc: CTC,
35
+ ctc_weight: float = 0.5,
36
+ ignore_id: int = IGNORE_ID,
37
+ reverse_weight: float = 0.0,
38
+ lsm_weight: float = 0.0,
39
+ length_normalized_loss: bool = False,
40
+ lfmmi_dir: str = '',
41
+ special_tokens: dict = None,
42
+ device: torch.device = torch.device("cuda"),
43
+ ):
44
+ super().__init__(vocab_size,
45
+ encoder,
46
+ decoder,
47
+ ctc,
48
+ ctc_weight,
49
+ ignore_id,
50
+ reverse_weight,
51
+ lsm_weight,
52
+ length_normalized_loss,
53
+ special_tokens=special_tokens)
54
+ self.lfmmi_dir = lfmmi_dir
55
+ self.device = device
56
+ if self.lfmmi_dir != '':
57
+ self.load_lfmmi_resource()
58
+
59
+ @torch.jit.unused
60
+ def _forward_ctc(
61
+ self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor,
62
+ text: torch.Tensor,
63
+ text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
64
+ loss_ctc, ctc_probs = self._calc_lfmmi_loss(encoder_out, encoder_mask,
65
+ text)
66
+ return loss_ctc, ctc_probs
67
+
68
+ @torch.jit.unused
69
+ def load_lfmmi_resource(self):
70
+ try:
71
+ import icefall
72
+ except ImportError:
73
+ print('Error: Failed to import icefall')
74
+ with open('{}/tokens.txt'.format(self.lfmmi_dir), 'r') as fin:
75
+ for line in fin:
76
+ arr = line.strip().split()
77
+ if arr[0] == '<sos/eos>':
78
+ self.sos_eos_id = int(arr[1])
79
+ device = torch.device(self.device)
80
+ self.graph_compiler = icefall.mmi_graph_compiler.MmiTrainingGraphCompiler(
81
+ self.lfmmi_dir,
82
+ device=device,
83
+ oov="<UNK>",
84
+ sos_id=self.sos_eos_id,
85
+ eos_id=self.sos_eos_id,
86
+ )
87
+ self.lfmmi = icefall.mmi.LFMMILoss(
88
+ graph_compiler=self.graph_compiler,
89
+ den_scale=1,
90
+ use_pruned_intersect=False,
91
+ )
92
+ self.word_table = {}
93
+ with open('{}/words.txt'.format(self.lfmmi_dir), 'r') as fin:
94
+ for line in fin:
95
+ arr = line.strip().split()
96
+ assert len(arr) == 2
97
+ self.word_table[int(arr[1])] = arr[0]
98
+
99
+ @torch.jit.unused
100
+ def _calc_lfmmi_loss(self, encoder_out, encoder_mask, text):
101
+ try:
102
+ import k2
103
+ except ImportError:
104
+ print('Error: Failed to import k2')
105
+ ctc_probs = self.ctc.log_softmax(encoder_out)
106
+ supervision_segments = torch.stack((
107
+ torch.arange(len(encoder_mask)),
108
+ torch.zeros(len(encoder_mask)),
109
+ encoder_mask.squeeze(dim=1).sum(dim=1).to('cpu'),
110
+ ), 1).to(torch.int32)
111
+ dense_fsa_vec = k2.DenseFsaVec(
112
+ ctc_probs,
113
+ supervision_segments,
114
+ allow_truncate=3,
115
+ )
116
+ text = [
117
+ ' '.join([self.word_table[j.item()] for j in i if j != -1])
118
+ for i in text
119
+ ]
120
+ loss = self.lfmmi(dense_fsa_vec=dense_fsa_vec, texts=text) / len(text)
121
+ return loss, ctc_probs
122
+
123
+ def load_hlg_resource_if_necessary(self, hlg, word):
124
+ try:
125
+ import k2
126
+ except ImportError:
127
+ print('Error: Failed to import k2')
128
+ if not hasattr(self, 'hlg'):
129
+ device = torch.device(self.device)
130
+ self.hlg = k2.Fsa.from_dict(torch.load(hlg, map_location=device))
131
+ if not hasattr(self.hlg, "lm_scores"):
132
+ self.hlg.lm_scores = self.hlg.scores.clone()
133
+ if not hasattr(self, 'word_table'):
134
+ self.word_table = {}
135
+ with open(word, 'r') as fin:
136
+ for line in fin:
137
+ arr = line.strip().split()
138
+ assert len(arr) == 2
139
+ self.word_table[int(arr[1])] = arr[0]
140
+
141
+ @torch.no_grad()
142
+ def hlg_onebest(
143
+ self,
144
+ speech: torch.Tensor,
145
+ speech_lengths: torch.Tensor,
146
+ decoding_chunk_size: int = -1,
147
+ num_decoding_left_chunks: int = -1,
148
+ simulate_streaming: bool = False,
149
+ hlg: str = '',
150
+ word: str = '',
151
+ symbol_table: Dict[str, int] = None,
152
+ ) -> List[int]:
153
+ try:
154
+ import icefall
155
+ except ImportError:
156
+ print('Error: Failed to import icefall')
157
+ self.load_hlg_resource_if_necessary(hlg, word)
158
+ encoder_out, encoder_mask = self._forward_encoder(
159
+ speech, speech_lengths, decoding_chunk_size,
160
+ num_decoding_left_chunks,
161
+ simulate_streaming) # (B, maxlen, encoder_dim)
162
+ ctc_probs = self.ctc.log_softmax(
163
+ encoder_out) # (1, maxlen, vocab_size)
164
+ supervision_segments = torch.stack(
165
+ (torch.arange(len(encoder_mask)), torch.zeros(len(encoder_mask)),
166
+ encoder_mask.squeeze(dim=1).sum(dim=1).cpu()),
167
+ 1,
168
+ ).to(torch.int32)
169
+ lattice = icefall.decode.get_lattice(
170
+ nnet_output=ctc_probs,
171
+ decoding_graph=self.hlg,
172
+ supervision_segments=supervision_segments,
173
+ search_beam=20,
174
+ output_beam=7,
175
+ min_active_states=30,
176
+ max_active_states=10000,
177
+ subsampling_factor=4)
178
+ best_path = icefall.decode.one_best_decoding(lattice=lattice,
179
+ use_double_scores=True)
180
+ hyps = icefall.utils.get_texts(best_path)
181
+ hyps = [[symbol_table[k] for j in i for k in self.word_table[j]]
182
+ for i in hyps]
183
+ return hyps
184
+
185
+ @torch.no_grad()
186
+ def hlg_rescore(
187
+ self,
188
+ speech: torch.Tensor,
189
+ speech_lengths: torch.Tensor,
190
+ decoding_chunk_size: int = -1,
191
+ num_decoding_left_chunks: int = -1,
192
+ simulate_streaming: bool = False,
193
+ lm_scale: float = 0,
194
+ decoder_scale: float = 0,
195
+ r_decoder_scale: float = 0,
196
+ hlg: str = '',
197
+ word: str = '',
198
+ symbol_table: Dict[str, int] = None,
199
+ ) -> List[int]:
200
+ try:
201
+ import k2
202
+ import icefall
203
+ except ImportError:
204
+ print('Error: Failed to import k2 & icefall')
205
+ self.load_hlg_resource_if_necessary(hlg, word)
206
+ device = speech.device
207
+ encoder_out, encoder_mask = self._forward_encoder(
208
+ speech, speech_lengths, decoding_chunk_size,
209
+ num_decoding_left_chunks,
210
+ simulate_streaming) # (B, maxlen, encoder_dim)
211
+ ctc_probs = self.ctc.log_softmax(
212
+ encoder_out) # (1, maxlen, vocab_size)
213
+ supervision_segments = torch.stack(
214
+ (torch.arange(len(encoder_mask)), torch.zeros(len(encoder_mask)),
215
+ encoder_mask.squeeze(dim=1).sum(dim=1).cpu()),
216
+ 1,
217
+ ).to(torch.int32)
218
+ lattice = icefall.decode.get_lattice(
219
+ nnet_output=ctc_probs,
220
+ decoding_graph=self.hlg,
221
+ supervision_segments=supervision_segments,
222
+ search_beam=20,
223
+ output_beam=7,
224
+ min_active_states=30,
225
+ max_active_states=10000,
226
+ subsampling_factor=4)
227
+ nbest = icefall.decode.Nbest.from_lattice(
228
+ lattice=lattice,
229
+ num_paths=100,
230
+ use_double_scores=True,
231
+ nbest_scale=0.5,
232
+ )
233
+ nbest = nbest.intersect(lattice)
234
+ assert hasattr(nbest.fsa, "lm_scores")
235
+ assert hasattr(nbest.fsa, "tokens")
236
+ assert isinstance(nbest.fsa.tokens, torch.Tensor)
237
+
238
+ tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
239
+ tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
240
+ tokens = tokens.remove_values_leq(0)
241
+ hyps = tokens.tolist()
242
+
243
+ # cal attention_score
244
+ hyps_pad = pad_sequence([
245
+ torch.tensor(hyp, device=device, dtype=torch.long) for hyp in hyps
246
+ ], True, self.ignore_id) # (beam_size, max_hyps_len)
247
+ ori_hyps_pad = hyps_pad
248
+ hyps_lens = torch.tensor([len(hyp) for hyp in hyps],
249
+ device=device,
250
+ dtype=torch.long) # (beam_size,)
251
+ hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
252
+ hyps_lens = hyps_lens + 1 # Add <sos> at begining
253
+ encoder_out_repeat = []
254
+ tot_scores = nbest.tot_scores()
255
+ repeats = [tot_scores[i].shape[0] for i in range(tot_scores.dim0)]
256
+ for i in range(len(encoder_out)):
257
+ encoder_out_repeat.append(encoder_out[i:i + 1].repeat(
258
+ repeats[i], 1, 1))
259
+ encoder_out = torch.concat(encoder_out_repeat, dim=0)
260
+ encoder_mask = torch.ones(encoder_out.size(0),
261
+ 1,
262
+ encoder_out.size(1),
263
+ dtype=torch.bool,
264
+ device=device)
265
+ # used for right to left decoder
266
+ r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id)
267
+ r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos,
268
+ self.ignore_id)
269
+ reverse_weight = 0.5
270
+ decoder_out, r_decoder_out, _ = self.decoder(
271
+ encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
272
+ reverse_weight) # (beam_size, max_hyps_len, vocab_size)
273
+ decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
274
+ decoder_out = decoder_out
275
+ # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
276
+ # conventional transformer decoder.
277
+ r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
278
+ r_decoder_out = r_decoder_out
279
+
280
+ decoder_scores = torch.tensor([
281
+ sum([decoder_out[i, j, hyps[i][j]] for j in range(len(hyps[i]))])
282
+ for i in range(len(hyps))
283
+ ],
284
+ device=device) # noqa
285
+ r_decoder_scores = []
286
+ for i in range(len(hyps)):
287
+ score = 0
288
+ for j in range(len(hyps[i])):
289
+ score += r_decoder_out[i, len(hyps[i]) - j - 1, hyps[i][j]]
290
+ score += r_decoder_out[i, len(hyps[i]), self.eos]
291
+ r_decoder_scores.append(score)
292
+ r_decoder_scores = torch.tensor(r_decoder_scores, device=device)
293
+
294
+ am_scores = nbest.compute_am_scores()
295
+ ngram_lm_scores = nbest.compute_lm_scores()
296
+ tot_scores = am_scores.values + lm_scale * ngram_lm_scores.values + \
297
+ decoder_scale * decoder_scores + r_decoder_scale * r_decoder_scores
298
+ ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
299
+ max_indexes = ragged_tot_scores.argmax()
300
+ best_path = k2.index_fsa(nbest.fsa, max_indexes)
301
+ hyps = icefall.utils.get_texts(best_path)
302
+ hyps = [[symbol_table[k] for j in i for k in self.word_table[j]]
303
+ for i in hyps]
304
+ return hyps
wenet/llm_asr/__init__.py ADDED
File without changes