Spaces:
Running
on
Zero
Running
on
Zero
tomxxie
commited on
Commit
·
568e264
1
Parent(s):
66817ed
适配zeroGPU
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +11 -3
- wenet/LLM/causallm_model.py +207 -0
- wenet/LLM/decoder.py +161 -0
- wenet/LLM/sampler.py +43 -0
- wenet/__init__.py +1 -0
- wenet/bin/alignment.py +268 -0
- wenet/bin/average_model.py +125 -0
- wenet/bin/export_ipex.py +95 -0
- wenet/bin/export_jit.py +71 -0
- wenet/bin/export_onnx_bpu.py +1065 -0
- wenet/bin/export_onnx_cpu.py +470 -0
- wenet/bin/export_onnx_gpu.py +1263 -0
- wenet/bin/recognize.py +336 -0
- wenet/bin/recognize4llmasr.py +340 -0
- wenet/bin/recognize_onnx_gpu.py +297 -0
- wenet/bin/train.py +232 -0
- wenet/branchformer/__init__.py +0 -0
- wenet/branchformer/cgmlp.py +194 -0
- wenet/branchformer/encoder.py +177 -0
- wenet/branchformer/encoder_layer.py +245 -0
- wenet/cli/__init__.py +0 -0
- wenet/cli/hub.py +116 -0
- wenet/cli/model.py +176 -0
- wenet/cli/paraformer_model.py +82 -0
- wenet/cli/transcribe.py +87 -0
- wenet/ctl_model/asr_model_ctl.py +277 -0
- wenet/ctl_model/encoder.py +172 -0
- wenet/dataset/__init__.py +0 -0
- wenet/dataset/datapipes.py +470 -0
- wenet/dataset/dataset.py +234 -0
- wenet/dataset/deprecated/dataset.py +202 -0
- wenet/dataset/deprecated/processor.py +1023 -0
- wenet/dataset/kaldi_io.py +772 -0
- wenet/dataset/processor.py +694 -0
- wenet/dataset/wav_distortion.py +336 -0
- wenet/e_branchformer/encoder.py +165 -0
- wenet/e_branchformer/encoder_layer.py +187 -0
- wenet/efficient_conformer/__init__.py +0 -0
- wenet/efficient_conformer/attention.py +257 -0
- wenet/efficient_conformer/convolution.py +154 -0
- wenet/efficient_conformer/encoder.py +560 -0
- wenet/efficient_conformer/encoder_layer.py +165 -0
- wenet/efficient_conformer/subsampling.py +74 -0
- wenet/finetune/lora/__init__.py +0 -0
- wenet/finetune/lora/config.yaml +13 -0
- wenet/finetune/lora/layers.py +350 -0
- wenet/finetune/lora/utils.py +334 -0
- wenet/k2/__init__.py +0 -0
- wenet/k2/model.py +304 -0
- wenet/llm_asr/__init__.py +0 -0
app.py
CHANGED
@@ -10,9 +10,9 @@ import os
|
|
10 |
|
11 |
import sys
|
12 |
|
|
|
13 |
|
14 |
sys.path.insert(0, './')
|
15 |
-
from gxl_ai_utils.utils import utils_file
|
16 |
from wenet.utils.init_tokenizer import init_tokenizer
|
17 |
from wenet.utils.init_model import init_model
|
18 |
import logging
|
@@ -20,6 +20,14 @@ import librosa
|
|
20 |
import torch
|
21 |
import torchaudio
|
22 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
# 将图片转换为 Base64
|
25 |
with open("lab.png", "rb") as image_file:
|
@@ -53,7 +61,7 @@ def init_model_my():
|
|
53 |
args = SimpleNamespace(**{
|
54 |
"checkpoint": checkpoint_path,
|
55 |
})
|
56 |
-
configs =
|
57 |
model, configs = init_model(args, configs)
|
58 |
model = model.cuda()
|
59 |
tokenizer = init_tokenizer(configs)
|
@@ -73,7 +81,7 @@ def do_resample(input_wav_path, output_wav_path):
|
|
73 |
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
74 |
waveform = torchaudio.transforms.Resample(
|
75 |
orig_freq=sample_rate, new_freq=16000)(waveform)
|
76 |
-
|
77 |
torchaudio.save(output_wav_path, waveform, 16000)
|
78 |
|
79 |
def true_decode_fuc(input_wav_path, input_prompt):
|
|
|
10 |
|
11 |
import sys
|
12 |
|
13 |
+
import yaml
|
14 |
|
15 |
sys.path.insert(0, './')
|
|
|
16 |
from wenet.utils.init_tokenizer import init_tokenizer
|
17 |
from wenet.utils.init_model import init_model
|
18 |
import logging
|
|
|
20 |
import torch
|
21 |
import torchaudio
|
22 |
import numpy as np
|
23 |
+
def makedir_for_file(filepath):
|
24 |
+
dirpath = os.path.dirname(filepath)
|
25 |
+
if not os.path.exists(dirpath):
|
26 |
+
os.makedirs(dirpath)
|
27 |
+
def load_dict_from_yaml(file_path: str):
|
28 |
+
with open(file_path, 'rt', encoding='utf-8') as f:
|
29 |
+
dict_1 = yaml.load(f, Loader=yaml.FullLoader)
|
30 |
+
return dict_1
|
31 |
|
32 |
# 将图片转换为 Base64
|
33 |
with open("lab.png", "rb") as image_file:
|
|
|
61 |
args = SimpleNamespace(**{
|
62 |
"checkpoint": checkpoint_path,
|
63 |
})
|
64 |
+
configs = load_dict_from_yaml(config_path)
|
65 |
model, configs = init_model(args, configs)
|
66 |
model = model.cuda()
|
67 |
tokenizer = init_tokenizer(configs)
|
|
|
81 |
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
82 |
waveform = torchaudio.transforms.Resample(
|
83 |
orig_freq=sample_rate, new_freq=16000)(waveform)
|
84 |
+
makedir_for_file(output_wav_path)
|
85 |
torchaudio.save(output_wav_path, waveform, 16000)
|
86 |
|
87 |
def true_decode_fuc(input_wav_path, input_prompt):
|
wenet/LLM/causallm_model.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Union
|
2 |
+
import torch
|
3 |
+
from wenet.LLM.decoder import DecoderOnly
|
4 |
+
from wenet.LLM.sampler import sampler
|
5 |
+
from wenet.utils.common import IGNORE_ID, th_accuracy
|
6 |
+
from wenet.utils.mask import make_pad_mask, subsequent_mask
|
7 |
+
|
8 |
+
|
9 |
+
class CausalLM(torch.nn.Module):
|
10 |
+
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
vocab_size: int,
|
14 |
+
decoder: DecoderOnly,
|
15 |
+
special_tokens: dict,
|
16 |
+
tie_word_embedding: bool = False,
|
17 |
+
linear_bias: bool = False,
|
18 |
+
ignore_id: int = IGNORE_ID,
|
19 |
+
lsm_weight: float = 0.0,
|
20 |
+
reduction: str = 'mean',
|
21 |
+
) -> None:
|
22 |
+
super().__init__()
|
23 |
+
del special_tokens
|
24 |
+
|
25 |
+
self.embed = torch.nn.Embedding(vocab_size, decoder.hidden_size)
|
26 |
+
self.out = torch.nn.Linear(decoder.hidden_size,
|
27 |
+
vocab_size,
|
28 |
+
bias=linear_bias)
|
29 |
+
|
30 |
+
self.decoder = decoder
|
31 |
+
self.vocab_size = vocab_size
|
32 |
+
self.criterion_att = torch.nn.CrossEntropyLoss(
|
33 |
+
ignore_index=ignore_id,
|
34 |
+
label_smoothing=lsm_weight,
|
35 |
+
reduction=reduction,
|
36 |
+
)
|
37 |
+
self.tie_word_embedding = tie_word_embedding
|
38 |
+
self.ignore_id = ignore_id
|
39 |
+
|
40 |
+
@torch.jit.unused
|
41 |
+
def forward(
|
42 |
+
self,
|
43 |
+
batch: dict,
|
44 |
+
device: torch.device,
|
45 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
46 |
+
""" Forward for training
|
47 |
+
"""
|
48 |
+
text = batch['feats'].to(device)
|
49 |
+
target = batch['target'].to(device)
|
50 |
+
text_length = batch['feats_lengths'].to(device)
|
51 |
+
|
52 |
+
mask = ~make_pad_mask(text_length, max_len=text.size(1)).unsqueeze(
|
53 |
+
1) # (B,1,L)
|
54 |
+
causal_mask = subsequent_mask(
|
55 |
+
mask.size(-1), device=mask.device).unsqueeze(0) # (1,L,L)
|
56 |
+
att_mask = causal_mask & mask # (B, L, L)
|
57 |
+
|
58 |
+
embeding = self.embed(text)
|
59 |
+
decoder_out = self.out(self.decoder(embeding,
|
60 |
+
att_mask)[0]) # (B, L, vocab_size)
|
61 |
+
loss = self.criterion_att(decoder_out.view(-1, self.vocab_size),
|
62 |
+
target.view(-1))
|
63 |
+
acc = th_accuracy(decoder_out.view(-1, self.vocab_size),
|
64 |
+
target,
|
65 |
+
ignore_label=self.ignore_id)
|
66 |
+
|
67 |
+
return {
|
68 |
+
"loss": loss,
|
69 |
+
"ppl": torch.exp(loss.detach()),
|
70 |
+
"th_accuracy": acc
|
71 |
+
}
|
72 |
+
|
73 |
+
def tie_or_clone_weights(self, jit_mode: bool):
|
74 |
+
if not self.tie_word_embedding:
|
75 |
+
return
|
76 |
+
if jit_mode:
|
77 |
+
self.out.weight = torch.nn.Parameter(self.embed.weight.clone())
|
78 |
+
else:
|
79 |
+
self.out.weight = self.embed.weight
|
80 |
+
# TODO(Mddct): whether to deal bias for other llm model
|
81 |
+
|
82 |
+
@torch.jit.unused
|
83 |
+
@torch.inference_mode()
|
84 |
+
def generate(
|
85 |
+
self,
|
86 |
+
prompts_tokens: List[List[int]],
|
87 |
+
device: torch.device,
|
88 |
+
stop_tokens: List[int],
|
89 |
+
dtype: torch.dtype = torch.float32,
|
90 |
+
output_len: int = 100,
|
91 |
+
temperature: Union[float, None] = 0.95,
|
92 |
+
top_p: float = 1.0,
|
93 |
+
top_k: int = 100,
|
94 |
+
) -> List[List[int]]:
|
95 |
+
"""Generates responses for given prompts using Gemma model."""
|
96 |
+
# If a single prompt is provided, treat it as a batch of 1.
|
97 |
+
batch_size = len(prompts_tokens)
|
98 |
+
min_prompt_len = min(len(p) for p in prompts_tokens)
|
99 |
+
max_prompt_len = max(len(p) for p in prompts_tokens)
|
100 |
+
max_seq_len = max_prompt_len + output_len
|
101 |
+
assert max_seq_len <= self.decoder.pos_enc.max_len
|
102 |
+
|
103 |
+
# build KV caches
|
104 |
+
kv_caches = []
|
105 |
+
for _ in range(len(self.decoder.decoders)):
|
106 |
+
size = (batch_size, 0, self.decoder.n_kv_head,
|
107 |
+
self.decoder.head_dim)
|
108 |
+
k_cache = torch.zeros(size=size, dtype=dtype, device=device)
|
109 |
+
v_cache = torch.zeros(size=size, dtype=dtype, device=device)
|
110 |
+
kv_caches.append((k_cache, v_cache))
|
111 |
+
|
112 |
+
# prepare inputs
|
113 |
+
token_ids_tensor = torch.full((batch_size, max_seq_len),
|
114 |
+
IGNORE_ID,
|
115 |
+
dtype=torch.int64,
|
116 |
+
device=device)
|
117 |
+
input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
|
118 |
+
IGNORE_ID,
|
119 |
+
dtype=torch.int64,
|
120 |
+
device=device)
|
121 |
+
# right padding
|
122 |
+
for i, p in enumerate(prompts_tokens):
|
123 |
+
token_ids_tensor[i, :len(p)] = torch.tensor(p)
|
124 |
+
input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
|
125 |
+
p[:min_prompt_len])
|
126 |
+
|
127 |
+
prompt_mask_tensor = token_ids_tensor != IGNORE_ID
|
128 |
+
input_positions_tensor = torch.arange(0,
|
129 |
+
min_prompt_len,
|
130 |
+
dtype=torch.int64).to(device)
|
131 |
+
mask_tensor = torch.ones((1, 1, max_seq_len, max_seq_len),
|
132 |
+
dtype=torch.bool)
|
133 |
+
mask_tensor = torch.tril(mask_tensor).to(device)
|
134 |
+
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
|
135 |
+
att_mask = curr_mask_tensor.squeeze(
|
136 |
+
1)[:, :min_prompt_len, :min_prompt_len]
|
137 |
+
output_positions_tensor = torch.LongTensor([min_prompt_len - 1
|
138 |
+
]).to(device)
|
139 |
+
temperatures_tensor = None if not temperature else torch.FloatTensor(
|
140 |
+
[temperature] * batch_size).to(device)
|
141 |
+
top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
|
142 |
+
top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
|
143 |
+
output_index = torch.tensor(min_prompt_len,
|
144 |
+
dtype=torch.int64).to(device)
|
145 |
+
|
146 |
+
input_token_embeding = self.embed(input_token_ids_tensor)
|
147 |
+
offset = torch.tensor([0] * len(prompts_tokens)).to(device)
|
148 |
+
input_offset = offset
|
149 |
+
|
150 |
+
stop_tokens_tensor = torch.tensor(stop_tokens, device=device)
|
151 |
+
# Prefill up to min_prompt_len tokens, then treat other prefill as
|
152 |
+
# decode and ignore output.
|
153 |
+
for i in range(max_seq_len - min_prompt_len):
|
154 |
+
decoder_out, kv_caches, = self.decoder(
|
155 |
+
input_token_embeding,
|
156 |
+
att_mask,
|
157 |
+
input_offset,
|
158 |
+
kv_caches,
|
159 |
+
)
|
160 |
+
decoder_out = self.out(decoder_out)
|
161 |
+
decoder_out = decoder_out.index_select(1, output_positions_tensor)
|
162 |
+
next_token_ids = sampler(
|
163 |
+
decoder_out,
|
164 |
+
temperatures_tensor,
|
165 |
+
top_ps_tensor,
|
166 |
+
top_ks_tensor,
|
167 |
+
)
|
168 |
+
curr_prompt_mask = prompt_mask_tensor.index_select(
|
169 |
+
1, output_index).squeeze(dim=1)
|
170 |
+
curr_token_ids = token_ids_tensor.index_select(
|
171 |
+
1, output_index).squeeze(dim=1)
|
172 |
+
output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
|
173 |
+
next_token_ids).unsqueeze(dim=1)
|
174 |
+
token_ids_tensor.index_copy_(1, output_index, output_token_ids)
|
175 |
+
|
176 |
+
input_token_ids_tensor = output_token_ids
|
177 |
+
input_token_embeding = self.embed(input_token_ids_tensor)
|
178 |
+
|
179 |
+
input_positions_tensor = output_index.unsqueeze(dim=-1)
|
180 |
+
curr_mask_tensor = mask_tensor.index_select(
|
181 |
+
2, input_positions_tensor)
|
182 |
+
att_mask = curr_mask_tensor.squeeze(1)[:, :output_index +
|
183 |
+
1, :output_index + 1]
|
184 |
+
|
185 |
+
output_positions_tensor = torch.tensor(
|
186 |
+
0, dtype=torch.int64).to(device)
|
187 |
+
input_offset = offset + output_index.unsqueeze(-1)
|
188 |
+
output_index = output_index + 1
|
189 |
+
|
190 |
+
if all(torch.isin(next_token_ids, stop_tokens_tensor)):
|
191 |
+
break
|
192 |
+
|
193 |
+
token_ids = token_ids_tensor.tolist()
|
194 |
+
results = []
|
195 |
+
for i, tokens in enumerate(token_ids):
|
196 |
+
trimmed_output = tokens[len(prompts_tokens[i]
|
197 |
+
):len(prompts_tokens[i]) + output_len]
|
198 |
+
for stop_token in stop_tokens:
|
199 |
+
try:
|
200 |
+
eos_index = trimmed_output.index(stop_token)
|
201 |
+
trimmed_output = trimmed_output[:eos_index]
|
202 |
+
break
|
203 |
+
except Exception:
|
204 |
+
continue
|
205 |
+
results.append(trimmed_output)
|
206 |
+
|
207 |
+
return results
|
wenet/LLM/decoder.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
import torch
|
4 |
+
import torch.utils.checkpoint as ckpt
|
5 |
+
from wenet.transformer.attention import T_CACHE
|
6 |
+
|
7 |
+
from wenet.transformer.encoder_layer import TransformerEncoderLayer
|
8 |
+
from wenet.utils.class_utils import (WENET_ACTIVATION_CLASSES,
|
9 |
+
WENET_ATTENTION_CLASSES,
|
10 |
+
WENET_EMB_CLASSES, WENET_MLP_CLASSES,
|
11 |
+
WENET_NORM_CLASSES)
|
12 |
+
from wenet.utils.common import mask_to_bias
|
13 |
+
|
14 |
+
|
15 |
+
class DecoderOnly(torch.nn.Module):
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
n_kv_head: int,
|
20 |
+
head_dim: int,
|
21 |
+
hidden_size: int,
|
22 |
+
attention_heads: int = 4,
|
23 |
+
linear_units: int = 2048,
|
24 |
+
num_blocks: int = 6,
|
25 |
+
dropout_rate: float = 0.1,
|
26 |
+
positional_dropout_rate: float = 0.1,
|
27 |
+
attention_dropout_rate: float = 0.0,
|
28 |
+
normalize_before: bool = True,
|
29 |
+
query_bias: bool = False,
|
30 |
+
key_bias: bool = False,
|
31 |
+
value_bias: bool = False,
|
32 |
+
mlp_bias: bool = False,
|
33 |
+
activation_type: str = "gelu",
|
34 |
+
gelu_approximate: Union[str, None] = None,
|
35 |
+
max_position_embeding: int = 8192,
|
36 |
+
mlp_type: str = 'gated',
|
37 |
+
layer_norm_type: str = 'rms_norm',
|
38 |
+
norm_eps: float = 1e-5,
|
39 |
+
rms_norm_offset: bool = True,
|
40 |
+
selfattention_layer_type: str = "rope_abs_selfattn",
|
41 |
+
use_sdpa: bool = False,
|
42 |
+
gradient_checkpointing: bool = False,
|
43 |
+
rope_theta: float = 10000.0,
|
44 |
+
rope_style: str = 'google',
|
45 |
+
scale_embed: bool = True,
|
46 |
+
) -> None:
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
assert selfattention_layer_type in ['rope_abs_selfattn']
|
50 |
+
self.pos_enc = WENET_EMB_CLASSES["rope_pos"](
|
51 |
+
hidden_size,
|
52 |
+
head_dim,
|
53 |
+
max_len=max_position_embeding,
|
54 |
+
dropout_rate=positional_dropout_rate,
|
55 |
+
rope_theta=rope_theta,
|
56 |
+
scale=scale_embed)
|
57 |
+
if activation_type == "gelu" and gelu_approximate is not None:
|
58 |
+
activation = WENET_ACTIVATION_CLASSES['gelu'](
|
59 |
+
approximate=gelu_approximate)
|
60 |
+
else:
|
61 |
+
activation = WENET_ACTIVATION_CLASSES[activation_type]()
|
62 |
+
|
63 |
+
mlp_class = WENET_MLP_CLASSES[mlp_type]
|
64 |
+
self.num_blocks = num_blocks
|
65 |
+
# TODO: support lora & refactor lora
|
66 |
+
self.decoders = torch.nn.ModuleList([
|
67 |
+
TransformerEncoderLayer(
|
68 |
+
hidden_size,
|
69 |
+
WENET_ATTENTION_CLASSES[selfattention_layer_type](
|
70 |
+
attention_heads,
|
71 |
+
hidden_size,
|
72 |
+
attention_dropout_rate,
|
73 |
+
query_bias,
|
74 |
+
key_bias,
|
75 |
+
value_bias,
|
76 |
+
use_sdpa,
|
77 |
+
n_kv_head,
|
78 |
+
head_dim,
|
79 |
+
style=rope_style),
|
80 |
+
mlp_class(hidden_size, linear_units, dropout_rate, activation,
|
81 |
+
mlp_bias),
|
82 |
+
dropout_rate,
|
83 |
+
normalize_before,
|
84 |
+
layer_norm_type=layer_norm_type,
|
85 |
+
norm_eps=norm_eps,
|
86 |
+
rms_norm_offset=rms_norm_offset,
|
87 |
+
) for _ in range(self.num_blocks)
|
88 |
+
])
|
89 |
+
self.pre_norm = normalize_before
|
90 |
+
self.final_norm: Optional[torch.nn.Module] = None
|
91 |
+
if self.pre_norm:
|
92 |
+
norm_class = WENET_NORM_CLASSES[layer_norm_type]
|
93 |
+
if layer_norm_type == "rms_norm":
|
94 |
+
norm_class = partial(
|
95 |
+
norm_class,
|
96 |
+
add_unit_offset=rms_norm_offset,
|
97 |
+
)
|
98 |
+
self.final_norm = norm_class(hidden_size, eps=norm_eps)
|
99 |
+
|
100 |
+
self.n_kv_head = n_kv_head
|
101 |
+
self.head_dim = head_dim
|
102 |
+
self._hidden_size = hidden_size
|
103 |
+
self.use_sdpa = use_sdpa
|
104 |
+
self.gradient_checkpointing = gradient_checkpointing
|
105 |
+
|
106 |
+
def forward(
|
107 |
+
self,
|
108 |
+
input: torch.Tensor,
|
109 |
+
att_mask: torch.Tensor,
|
110 |
+
input_position: Union[int, torch.Tensor] = 0,
|
111 |
+
kv_caches: Optional[List[T_CACHE]] = None,
|
112 |
+
) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]:
|
113 |
+
xs, pos_emb = self.pos_enc(input, offset=input_position)
|
114 |
+
if self.use_sdpa:
|
115 |
+
att_mask = mask_to_bias(att_mask, xs.dtype)
|
116 |
+
|
117 |
+
if self.gradient_checkpointing and self.training:
|
118 |
+
xs = self.forward_layers_checkpointed(xs, att_mask, pos_emb)
|
119 |
+
else:
|
120 |
+
xs, kv_caches = self.forward_layers(xs, att_mask, pos_emb,
|
121 |
+
kv_caches)
|
122 |
+
if self.pre_norm and self.final_norm is not None:
|
123 |
+
xs = self.final_norm(xs)
|
124 |
+
return xs, kv_caches
|
125 |
+
|
126 |
+
def forward_layers(
|
127 |
+
self,
|
128 |
+
xs: torch.Tensor,
|
129 |
+
att_mask: torch.Tensor,
|
130 |
+
pos_emb: torch.Tensor,
|
131 |
+
kv_caches: Optional[List[T_CACHE]] = None,
|
132 |
+
) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]:
|
133 |
+
if self.training:
|
134 |
+
for (i, layer) in enumerate(self.decoders):
|
135 |
+
xs, _, _, _ = layer(xs, att_mask, pos_emb)
|
136 |
+
new_kv_caches = kv_caches
|
137 |
+
else:
|
138 |
+
assert kv_caches is not None
|
139 |
+
new_kv_caches = []
|
140 |
+
for (i, layer) in enumerate(self.decoders):
|
141 |
+
xs, _, new_kv_cache, _ = layer(xs,
|
142 |
+
att_mask,
|
143 |
+
pos_emb,
|
144 |
+
att_cache=(kv_caches[i][0],
|
145 |
+
kv_caches[i][1]))
|
146 |
+
new_kv_caches.append(new_kv_cache)
|
147 |
+
|
148 |
+
return xs, new_kv_caches
|
149 |
+
|
150 |
+
@torch.jit.ignore(drop=True)
|
151 |
+
def forward_layers_checkpointed(self, xs: torch.Tensor,
|
152 |
+
att_mask: torch.Tensor,
|
153 |
+
pos_emb: torch.Tensor) -> torch.Tensor:
|
154 |
+
for layer in self.decoders:
|
155 |
+
xs, _, _, _ = ckpt.checkpoint(layer.__call__, xs, att_mask,
|
156 |
+
pos_emb)
|
157 |
+
return xs
|
158 |
+
|
159 |
+
@property
|
160 |
+
def hidden_size(self):
|
161 |
+
return self._hidden_size
|
wenet/LLM/sampler.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
# modified from https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L26
|
6 |
+
@torch.no_grad()
|
7 |
+
def sampler(
|
8 |
+
logits: torch.Tensor,
|
9 |
+
temperatures: Union[torch.Tensor, None],
|
10 |
+
top_ps: torch.Tensor,
|
11 |
+
top_ks: torch.Tensor,
|
12 |
+
) -> torch.Tensor:
|
13 |
+
assert logits.size(1) == 1
|
14 |
+
logits = logits.squeeze(1) # (batch_size, vocab_size)
|
15 |
+
if temperatures is None:
|
16 |
+
return torch.argmax(logits, dim=-1).squeeze(dim=-1)
|
17 |
+
|
18 |
+
# Apply temperature scaling.
|
19 |
+
logits.div_(temperatures.unsqueeze(dim=1))
|
20 |
+
|
21 |
+
# Calculate probabilities with softmax.
|
22 |
+
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
23 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
24 |
+
|
25 |
+
# Apply top-p, top-k.
|
26 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
27 |
+
top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1)
|
28 |
+
probs_sort = torch.where(top_ps_mask, 0, probs_sort)
|
29 |
+
|
30 |
+
top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device)
|
31 |
+
top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
|
32 |
+
top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1)
|
33 |
+
probs_sort = torch.where(top_ks_mask, 0, probs_sort)
|
34 |
+
|
35 |
+
# Re-normalization.
|
36 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
37 |
+
probs = torch.gather(probs_sort,
|
38 |
+
dim=-1,
|
39 |
+
index=torch.argsort(probs_idx, dim=-1))
|
40 |
+
|
41 |
+
next_token_ids = torch.multinomial(probs, num_samples=1,
|
42 |
+
replacement=True).squeeze(dim=-1)
|
43 |
+
return next_token_ids
|
wenet/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from wenet.cli.model import load_model # noqa
|
wenet/bin/alignment.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Di Wu)
|
2 |
+
# 2022 Tinnove Inc (authors: Wei Ren)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from __future__ import print_function
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import copy
|
20 |
+
import logging
|
21 |
+
import os
|
22 |
+
import sys
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import yaml
|
26 |
+
from torch.utils.data import DataLoader
|
27 |
+
from textgrid import TextGrid, IntervalTier
|
28 |
+
import math
|
29 |
+
|
30 |
+
from wenet.dataset.dataset import Dataset
|
31 |
+
from wenet.utils.ctc_utils import force_align
|
32 |
+
from wenet.utils.common import get_subsample
|
33 |
+
from wenet.utils.init_model import init_model
|
34 |
+
from wenet.utils.init_tokenizer import init_tokenizer
|
35 |
+
|
36 |
+
|
37 |
+
def generator_textgrid(maxtime, lines, output):
|
38 |
+
# Download Praat: https://www.fon.hum.uva.nl/praat/
|
39 |
+
interval = maxtime / (len(lines) + 1)
|
40 |
+
margin = 0.0001
|
41 |
+
|
42 |
+
tg = TextGrid(maxTime=maxtime)
|
43 |
+
linetier = IntervalTier(name="line", maxTime=maxtime)
|
44 |
+
|
45 |
+
i = 0
|
46 |
+
for l in lines:
|
47 |
+
s, e, w = l.split()
|
48 |
+
linetier.add(minTime=float(s) + margin, maxTime=float(e), mark=w)
|
49 |
+
|
50 |
+
tg.append(linetier)
|
51 |
+
print("successfully generator {}".format(output))
|
52 |
+
tg.write(output)
|
53 |
+
|
54 |
+
|
55 |
+
def get_frames_timestamp(alignment,
|
56 |
+
prob,
|
57 |
+
blank_thres=0.999,
|
58 |
+
thres=0.0000000001):
|
59 |
+
# convert alignment to a praat format, which is a doing phonetics
|
60 |
+
# by computer and helps analyzing alignment
|
61 |
+
timestamp = []
|
62 |
+
# get frames level duration for each token
|
63 |
+
start = 0
|
64 |
+
end = 0
|
65 |
+
local_start = 0
|
66 |
+
while end < len(alignment):
|
67 |
+
while end < len(alignment) and alignment[end] == 0:
|
68 |
+
end += 1
|
69 |
+
if end == len(alignment):
|
70 |
+
timestamp[-1] += alignment[start:]
|
71 |
+
break
|
72 |
+
end += 1
|
73 |
+
while end < len(alignment) and alignment[end - 1] == alignment[end]:
|
74 |
+
end += 1
|
75 |
+
local_start = end - 1
|
76 |
+
# find the possible front border for current token
|
77 |
+
while local_start >= start and (
|
78 |
+
prob[local_start][0] < math.log(blank_thres)
|
79 |
+
or prob[local_start][alignment[end - 1]] > math.log(thres)):
|
80 |
+
alignment[local_start] = alignment[end - 1]
|
81 |
+
local_start -= 1
|
82 |
+
cur_alignment = alignment[start:end]
|
83 |
+
timestamp.append(cur_alignment)
|
84 |
+
start = end
|
85 |
+
return timestamp
|
86 |
+
|
87 |
+
|
88 |
+
def get_labformat(timestamp, subsample):
|
89 |
+
begin = 0
|
90 |
+
begin_time = 0
|
91 |
+
duration = 0
|
92 |
+
labformat = []
|
93 |
+
for idx, t in enumerate(timestamp):
|
94 |
+
# 25ms frame_length,10ms hop_length, 1/subsample
|
95 |
+
subsample = get_subsample(configs)
|
96 |
+
# time duration
|
97 |
+
i = 0
|
98 |
+
while t[i] == 0:
|
99 |
+
i += 1
|
100 |
+
begin = i
|
101 |
+
dur = 0
|
102 |
+
while i < len(t) and t[i] != 0:
|
103 |
+
i += 1
|
104 |
+
dur += 1
|
105 |
+
begin = begin_time + begin * 0.01 * subsample
|
106 |
+
duration = dur * 0.01 * subsample
|
107 |
+
if idx < len(timestamp) - 1:
|
108 |
+
print("{:.2f} {:.2f} {}".format(begin, begin + duration,
|
109 |
+
char_dict[t[-1]]))
|
110 |
+
labformat.append("{:.2f} {:.2f} {}\n".format(
|
111 |
+
begin, begin + duration, char_dict[t[-1]]))
|
112 |
+
else: # last token
|
113 |
+
non_blank = 0
|
114 |
+
for i in t:
|
115 |
+
if i != 0:
|
116 |
+
token = i
|
117 |
+
break
|
118 |
+
print("{:.2f} {:.2f} {}".format(begin, begin + duration,
|
119 |
+
char_dict[token]))
|
120 |
+
labformat.append("{:.2f} {:.2f} {}\n".format(
|
121 |
+
begin, begin + duration, char_dict[token]))
|
122 |
+
begin_time += len(t) * 0.01 * subsample
|
123 |
+
return labformat
|
124 |
+
|
125 |
+
|
126 |
+
if __name__ == '__main__':
|
127 |
+
parser = argparse.ArgumentParser(
|
128 |
+
description='use ctc to generate alignment')
|
129 |
+
parser.add_argument('--config', required=True, help='config file')
|
130 |
+
parser.add_argument('--input_file', required=True, help='format data file')
|
131 |
+
parser.add_argument('--data_type',
|
132 |
+
default='raw',
|
133 |
+
choices=['raw', 'shard'],
|
134 |
+
help='train and cv data type')
|
135 |
+
parser.add_argument('--gpu',
|
136 |
+
type=int,
|
137 |
+
default=-1,
|
138 |
+
help='gpu id for this rank, -1 for cpu')
|
139 |
+
parser.add_argument('--device',
|
140 |
+
type=str,
|
141 |
+
default="cpu",
|
142 |
+
choices=["cpu", "npu", "cuda"],
|
143 |
+
help='accelerator to use')
|
144 |
+
parser.add_argument('--blank_thres',
|
145 |
+
default=0.999999,
|
146 |
+
type=float,
|
147 |
+
help='ctc blank thes')
|
148 |
+
parser.add_argument('--thres',
|
149 |
+
default=0.000001,
|
150 |
+
type=float,
|
151 |
+
help='ctc non blank thes')
|
152 |
+
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
153 |
+
parser.add_argument('--dict', required=True, help='dict file')
|
154 |
+
parser.add_argument(
|
155 |
+
'--non_lang_syms',
|
156 |
+
help="non-linguistic symbol file. One symbol per line.")
|
157 |
+
parser.add_argument('--result_file',
|
158 |
+
required=True,
|
159 |
+
help='alignment result file')
|
160 |
+
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
|
161 |
+
parser.add_argument('--gen_praat',
|
162 |
+
action='store_true',
|
163 |
+
help='convert alignment to a praat format')
|
164 |
+
parser.add_argument('--bpe_model',
|
165 |
+
default=None,
|
166 |
+
type=str,
|
167 |
+
help='bpe model for english part')
|
168 |
+
|
169 |
+
args = parser.parse_args()
|
170 |
+
print(args)
|
171 |
+
logging.basicConfig(level=logging.DEBUG,
|
172 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
173 |
+
if args.gpu != -1:
|
174 |
+
# remain the original usage of gpu
|
175 |
+
args.device = "cuda"
|
176 |
+
if "cuda" in args.device:
|
177 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
178 |
+
|
179 |
+
if args.batch_size > 1:
|
180 |
+
logging.fatal('alignment mode must be running with batch_size == 1')
|
181 |
+
sys.exit(1)
|
182 |
+
|
183 |
+
with open(args.config, 'r') as fin:
|
184 |
+
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
185 |
+
|
186 |
+
# Load dict
|
187 |
+
char_dict = {}
|
188 |
+
with open(args.dict, 'r') as fin:
|
189 |
+
for line in fin:
|
190 |
+
arr = line.strip().split()
|
191 |
+
assert len(arr) == 2
|
192 |
+
char_dict[int(arr[1])] = arr[0]
|
193 |
+
eos = len(char_dict) - 1
|
194 |
+
|
195 |
+
# Init dataset and data loader
|
196 |
+
ali_conf = copy.deepcopy(configs['dataset_conf'])
|
197 |
+
|
198 |
+
ali_conf['filter_conf']['max_length'] = 102400
|
199 |
+
ali_conf['filter_conf']['min_length'] = 0
|
200 |
+
ali_conf['filter_conf']['token_max_length'] = 102400
|
201 |
+
ali_conf['filter_conf']['token_min_length'] = 0
|
202 |
+
ali_conf['filter_conf']['max_output_input_ratio'] = 102400
|
203 |
+
ali_conf['filter_conf']['min_output_input_ratio'] = 0
|
204 |
+
ali_conf['speed_perturb'] = False
|
205 |
+
ali_conf['spec_aug'] = False
|
206 |
+
ali_conf['spec_trim'] = False
|
207 |
+
ali_conf['shuffle'] = False
|
208 |
+
ali_conf['sort'] = False
|
209 |
+
ali_conf['fbank_conf']['dither'] = 0.0
|
210 |
+
ali_conf['batch_conf']['batch_type'] = "static"
|
211 |
+
ali_conf['batch_conf']['batch_size'] = args.batch_size
|
212 |
+
|
213 |
+
tokenizer = init_tokenizer(configs)
|
214 |
+
ali_dataset = Dataset(args.data_type,
|
215 |
+
args.input_file,
|
216 |
+
tokenizer,
|
217 |
+
ali_conf,
|
218 |
+
partition=False)
|
219 |
+
|
220 |
+
ali_data_loader = DataLoader(ali_dataset, batch_size=None, num_workers=0)
|
221 |
+
|
222 |
+
# Init asr model from configs
|
223 |
+
model, configs = init_model(args, configs)
|
224 |
+
|
225 |
+
device = torch.device(args.device)
|
226 |
+
model = model.to(device)
|
227 |
+
|
228 |
+
model.eval()
|
229 |
+
with torch.no_grad(), open(args.result_file, 'w',
|
230 |
+
encoding='utf-8') as fout:
|
231 |
+
for batch_idx, batch in enumerate(ali_data_loader):
|
232 |
+
print("#" * 80)
|
233 |
+
key, feat, target, feats_length, target_length = batch
|
234 |
+
|
235 |
+
feat = feat.to(device)
|
236 |
+
target = target.to(device)
|
237 |
+
feats_length = feats_length.to(device)
|
238 |
+
target_length = target_length.to(device)
|
239 |
+
# Let's assume B = batch_size and N = beam_size
|
240 |
+
# 1. Encoder
|
241 |
+
encoder_out, encoder_mask = model._forward_encoder(
|
242 |
+
feat, feats_length) # (B, maxlen, encoder_dim)
|
243 |
+
maxlen = encoder_out.size(1)
|
244 |
+
ctc_probs = model.ctc.log_softmax(
|
245 |
+
encoder_out) # (1, maxlen, vocab_size)
|
246 |
+
# print(ctc_probs.size(1))
|
247 |
+
ctc_probs = ctc_probs.squeeze(0)
|
248 |
+
target = target.squeeze(0)
|
249 |
+
alignment = force_align(ctc_probs, target)
|
250 |
+
fout.write('{} {}\n'.format(key[0], alignment))
|
251 |
+
|
252 |
+
if args.gen_praat:
|
253 |
+
timestamp = get_frames_timestamp(alignment, ctc_probs,
|
254 |
+
args.blank_thres, args.thres)
|
255 |
+
subsample = get_subsample(configs)
|
256 |
+
labformat = get_labformat(timestamp, subsample)
|
257 |
+
|
258 |
+
lab_path = os.path.join(os.path.dirname(args.result_file),
|
259 |
+
key[0] + ".lab")
|
260 |
+
with open(lab_path, 'w', encoding='utf-8') as f:
|
261 |
+
f.writelines(labformat)
|
262 |
+
|
263 |
+
textgrid_path = os.path.join(os.path.dirname(args.result_file),
|
264 |
+
key[0] + ".TextGrid")
|
265 |
+
generator_textgrid(maxtime=(len(alignment) + 1) * 0.01 *
|
266 |
+
subsample,
|
267 |
+
lines=labformat,
|
268 |
+
output=textgrid_path)
|
wenet/bin/average_model.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import argparse
|
17 |
+
import glob
|
18 |
+
import sys
|
19 |
+
|
20 |
+
import yaml
|
21 |
+
import torch
|
22 |
+
|
23 |
+
|
24 |
+
def get_args():
|
25 |
+
parser = argparse.ArgumentParser(description='average model')
|
26 |
+
parser.add_argument('--dst_model', required=True, help='averaged model')
|
27 |
+
parser.add_argument('--src_path',
|
28 |
+
required=True,
|
29 |
+
help='src model path for average')
|
30 |
+
parser.add_argument('--val_best',
|
31 |
+
action="store_true",
|
32 |
+
help='averaged model')
|
33 |
+
parser.add_argument('--num',
|
34 |
+
default=5,
|
35 |
+
type=int,
|
36 |
+
help='nums for averaged model')
|
37 |
+
parser.add_argument('--min_epoch',
|
38 |
+
default=0,
|
39 |
+
type=int,
|
40 |
+
help='min epoch used for averaging model')
|
41 |
+
parser.add_argument('--max_epoch',
|
42 |
+
default=sys.maxsize,
|
43 |
+
type=int,
|
44 |
+
help='max epoch used for averaging model')
|
45 |
+
parser.add_argument('--min_step',
|
46 |
+
default=0,
|
47 |
+
type=int,
|
48 |
+
help='min step used for averaging model')
|
49 |
+
parser.add_argument('--max_step',
|
50 |
+
default=sys.maxsize,
|
51 |
+
type=int,
|
52 |
+
help='max step used for averaging model')
|
53 |
+
parser.add_argument('--mode',
|
54 |
+
default="hybrid",
|
55 |
+
choices=["hybrid", "epoch", "step"],
|
56 |
+
type=str,
|
57 |
+
help='average mode')
|
58 |
+
|
59 |
+
args = parser.parse_args()
|
60 |
+
print(args)
|
61 |
+
return args
|
62 |
+
|
63 |
+
|
64 |
+
def main():
|
65 |
+
args = get_args()
|
66 |
+
checkpoints = []
|
67 |
+
val_scores = []
|
68 |
+
if args.val_best:
|
69 |
+
if args.mode == "hybrid":
|
70 |
+
yamls = glob.glob('{}/*.yaml'.format(args.src_path))
|
71 |
+
yamls = [
|
72 |
+
f for f in yamls
|
73 |
+
if not (os.path.basename(f).startswith('train')
|
74 |
+
or os.path.basename(f).startswith('init'))
|
75 |
+
]
|
76 |
+
elif args.mode == "step":
|
77 |
+
yamls = glob.glob('{}/step_*.yaml'.format(args.src_path))
|
78 |
+
else:
|
79 |
+
yamls = glob.glob('{}/epoch_*.yaml'.format(args.src_path))
|
80 |
+
for y in yamls:
|
81 |
+
with open(y, 'r') as f:
|
82 |
+
dic_yaml = yaml.load(f, Loader=yaml.FullLoader)
|
83 |
+
loss = dic_yaml['loss_dict']['loss']
|
84 |
+
epoch = dic_yaml['epoch']
|
85 |
+
step = dic_yaml['step']
|
86 |
+
tag = dic_yaml['tag']
|
87 |
+
if epoch >= args.min_epoch and epoch <= args.max_epoch \
|
88 |
+
and step >= args.min_step and step <= args.max_step:
|
89 |
+
val_scores += [[epoch, step, loss, tag]]
|
90 |
+
sorted_val_scores = sorted(val_scores,
|
91 |
+
key=lambda x: x[2],
|
92 |
+
reverse=False)
|
93 |
+
print("best val (epoch, step, loss, tag) = " +
|
94 |
+
str(sorted_val_scores[:args.num]))
|
95 |
+
path_list = [
|
96 |
+
args.src_path + '/{}.pt'.format(score[-1])
|
97 |
+
for score in sorted_val_scores[:args.num]
|
98 |
+
]
|
99 |
+
else:
|
100 |
+
path_list = glob.glob('{}/[!init]*.pt'.format(args.src_path))
|
101 |
+
path_list = sorted(path_list, key=os.path.getmtime)
|
102 |
+
path_list = path_list[-args.num:]
|
103 |
+
print(path_list)
|
104 |
+
avg = {}
|
105 |
+
num = args.num
|
106 |
+
assert num == len(path_list)
|
107 |
+
for path in path_list:
|
108 |
+
print('Processing {}'.format(path))
|
109 |
+
states = torch.load(path, map_location=torch.device('cpu'))
|
110 |
+
for k in states.keys():
|
111 |
+
if k not in avg.keys():
|
112 |
+
avg[k] = states[k].clone()
|
113 |
+
else:
|
114 |
+
avg[k] += states[k]
|
115 |
+
# average
|
116 |
+
for k in avg.keys():
|
117 |
+
if avg[k] is not None:
|
118 |
+
# pytorch 1.6 use true_divide instead of /=
|
119 |
+
avg[k] = torch.true_divide(avg[k], num)
|
120 |
+
print('Saving to {}'.format(args.dst_model))
|
121 |
+
torch.save(avg, args.dst_model)
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == '__main__':
|
125 |
+
main()
|
wenet/bin/export_ipex.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021-2023 Intel Corporation
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from __future__ import print_function
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import yaml
|
12 |
+
|
13 |
+
from wenet.utils.init_model import init_model
|
14 |
+
import intel_extension_for_pytorch as ipex
|
15 |
+
from intel_extension_for_pytorch.quantization import prepare, convert
|
16 |
+
|
17 |
+
|
18 |
+
def get_args():
|
19 |
+
parser = argparse.ArgumentParser(description='export your script model')
|
20 |
+
parser.add_argument('--config', required=True, help='config file')
|
21 |
+
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
22 |
+
parser.add_argument('--output_file', default=None, help='output file')
|
23 |
+
parser.add_argument('--dtype',
|
24 |
+
default="fp32",
|
25 |
+
help='choose the dtype to run:[fp32,bf16]')
|
26 |
+
parser.add_argument('--output_quant_file',
|
27 |
+
default=None,
|
28 |
+
help='output quantized model file')
|
29 |
+
args = parser.parse_args()
|
30 |
+
return args
|
31 |
+
|
32 |
+
|
33 |
+
def scripting(model):
|
34 |
+
with torch.inference_mode():
|
35 |
+
script_model = torch.jit.script(model)
|
36 |
+
script_model = torch.jit.freeze(
|
37 |
+
script_model,
|
38 |
+
preserved_attrs=[
|
39 |
+
"forward_encoder_chunk", "ctc_activation",
|
40 |
+
"forward_attention_decoder", "subsampling_rate",
|
41 |
+
"right_context", "sos_symbol", "eos_symbol",
|
42 |
+
"is_bidirectional_decoder"
|
43 |
+
])
|
44 |
+
return script_model
|
45 |
+
|
46 |
+
|
47 |
+
def main():
|
48 |
+
args = get_args()
|
49 |
+
logging.basicConfig(level=logging.DEBUG,
|
50 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
51 |
+
# No need gpu for model export
|
52 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
53 |
+
|
54 |
+
with open(args.config, 'r') as fin:
|
55 |
+
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
56 |
+
model, configs = init_model(args, configs)
|
57 |
+
print(model)
|
58 |
+
|
59 |
+
# Apply IPEX optimization
|
60 |
+
model.eval()
|
61 |
+
torch._C._jit_set_texpr_fuser_enabled(False)
|
62 |
+
model.to(memory_format=torch.channels_last)
|
63 |
+
if args.dtype == "fp32":
|
64 |
+
ipex_model = ipex.optimize(model)
|
65 |
+
elif args.dtype == "bf16": # For Intel 4th generation Xeon (SPR)
|
66 |
+
ipex_model = ipex.optimize(model,
|
67 |
+
dtype=torch.bfloat16,
|
68 |
+
weights_prepack=False)
|
69 |
+
|
70 |
+
# Export jit torch script model
|
71 |
+
if args.output_file:
|
72 |
+
if args.dtype == "fp32":
|
73 |
+
script_model = scripting(ipex_model)
|
74 |
+
elif args.dtype == "bf16":
|
75 |
+
torch._C._jit_set_autocast_mode(True)
|
76 |
+
with torch.cpu.amp.autocast():
|
77 |
+
script_model = scripting(ipex_model)
|
78 |
+
script_model.save(args.output_file)
|
79 |
+
print('Export model successfully, see {}'.format(args.output_file))
|
80 |
+
|
81 |
+
# Export quantized jit torch script model
|
82 |
+
if args.output_quant_file:
|
83 |
+
dynamic_qconfig = ipex.quantization.default_dynamic_qconfig
|
84 |
+
dummy_data = (torch.zeros(1, 67, 80), 16, -16,
|
85 |
+
torch.zeros(12, 4, 32, 128), torch.zeros(12, 1, 256, 7))
|
86 |
+
model = prepare(model, dynamic_qconfig, dummy_data)
|
87 |
+
model = convert(model)
|
88 |
+
script_quant_model = scripting(model)
|
89 |
+
script_quant_model.save(args.output_quant_file)
|
90 |
+
print('Export quantized model successfully, '
|
91 |
+
'see {}'.format(args.output_quant_file))
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == '__main__':
|
95 |
+
main()
|
wenet/bin/export_jit.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import print_function
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import logging
|
19 |
+
import os
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import yaml
|
23 |
+
|
24 |
+
from wenet.utils.init_model import init_model
|
25 |
+
|
26 |
+
|
27 |
+
def get_args():
|
28 |
+
parser = argparse.ArgumentParser(description='export your script model')
|
29 |
+
parser.add_argument('--config', required=True, help='config file')
|
30 |
+
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
31 |
+
parser.add_argument('--output_file', default=None, help='output file')
|
32 |
+
parser.add_argument('--output_quant_file',
|
33 |
+
default=None,
|
34 |
+
help='output quantized model file')
|
35 |
+
args = parser.parse_args()
|
36 |
+
return args
|
37 |
+
|
38 |
+
|
39 |
+
def main():
|
40 |
+
args = get_args()
|
41 |
+
args.jit = True
|
42 |
+
logging.basicConfig(level=logging.DEBUG,
|
43 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
44 |
+
# No need gpu for model export
|
45 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
46 |
+
|
47 |
+
with open(args.config, 'r') as fin:
|
48 |
+
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
49 |
+
model, configs = init_model(args, configs)
|
50 |
+
model.eval()
|
51 |
+
print(model)
|
52 |
+
# Export jit torch script model
|
53 |
+
|
54 |
+
if args.output_file:
|
55 |
+
script_model = torch.jit.script(model)
|
56 |
+
script_model.save(args.output_file)
|
57 |
+
print('Export model successfully, see {}'.format(args.output_file))
|
58 |
+
|
59 |
+
# Export quantized jit torch script model
|
60 |
+
if args.output_quant_file:
|
61 |
+
quantized_model = torch.quantization.quantize_dynamic(
|
62 |
+
model, {torch.nn.Linear}, dtype=torch.qint8)
|
63 |
+
print(quantized_model)
|
64 |
+
script_quant_model = torch.jit.script(quantized_model)
|
65 |
+
script_quant_model.save(args.output_quant_file)
|
66 |
+
print('Export quantized model successfully, '
|
67 |
+
'see {}'.format(args.output_quant_file))
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == '__main__':
|
71 |
+
main()
|
wenet/bin/export_onnx_bpu.py
ADDED
@@ -0,0 +1,1065 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, Horizon Inc. Xingchen Song ([email protected])
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""NOTE(xcsong): Currently, we only support
|
15 |
+
1. specific conformer encoder architecture, see:
|
16 |
+
encoder: conformer
|
17 |
+
encoder_conf:
|
18 |
+
activation_type: **must be** relu
|
19 |
+
attention_heads: 2 or 4 or 8 or any number divisible by output_size
|
20 |
+
causal: **must be** true
|
21 |
+
cnn_module_kernel: 1 ~ 7
|
22 |
+
cnn_module_norm: **must be** batch_norm
|
23 |
+
input_layer: **must be** conv2d8
|
24 |
+
linear_units: 1 ~ 2048
|
25 |
+
normalize_before: **must be** true
|
26 |
+
num_blocks: 1 ~ 12
|
27 |
+
output_size: 1 ~ 512
|
28 |
+
pos_enc_layer_type: **must be** no_pos
|
29 |
+
selfattention_layer_type: **must be** selfattn
|
30 |
+
use_cnn_module: **must be** true
|
31 |
+
use_dynamic_chunk: **must be** true
|
32 |
+
use_dynamic_left_chunk: **must be** true
|
33 |
+
|
34 |
+
2. specific decoding method: ctc_greedy_search
|
35 |
+
"""
|
36 |
+
|
37 |
+
from __future__ import print_function
|
38 |
+
|
39 |
+
import os
|
40 |
+
import sys
|
41 |
+
import copy
|
42 |
+
import math
|
43 |
+
import yaml
|
44 |
+
import logging
|
45 |
+
from typing import Tuple
|
46 |
+
|
47 |
+
import torch
|
48 |
+
import numpy as np
|
49 |
+
|
50 |
+
from wenet.transformer.embedding import NoPositionalEncoding
|
51 |
+
from wenet.utils.init_model import init_model
|
52 |
+
from wenet.bin.export_onnx_cpu import (get_args, to_numpy,
|
53 |
+
print_input_output_info)
|
54 |
+
|
55 |
+
try:
|
56 |
+
import onnx
|
57 |
+
import onnxruntime
|
58 |
+
except ImportError:
|
59 |
+
print('Please install onnx and onnxruntime!')
|
60 |
+
sys.exit(1)
|
61 |
+
|
62 |
+
logger = logging.getLogger(__file__)
|
63 |
+
logger.setLevel(logging.INFO)
|
64 |
+
|
65 |
+
|
66 |
+
class BPULayerNorm(torch.nn.Module):
|
67 |
+
"""Refactor torch.nn.LayerNorm to meet 4-D dataflow."""
|
68 |
+
|
69 |
+
def __init__(self, module, chunk_size=8, run_on_bpu=False):
|
70 |
+
super().__init__()
|
71 |
+
original = copy.deepcopy(module)
|
72 |
+
self.hidden = module.weight.size(0)
|
73 |
+
self.chunk_size = chunk_size
|
74 |
+
self.run_on_bpu = run_on_bpu
|
75 |
+
|
76 |
+
if self.run_on_bpu:
|
77 |
+
self.weight = torch.nn.Parameter(
|
78 |
+
module.weight.reshape(1, self.hidden, 1,
|
79 |
+
1).repeat(1, 1, 1, chunk_size))
|
80 |
+
self.bias = torch.nn.Parameter(
|
81 |
+
module.bias.reshape(1, self.hidden, 1,
|
82 |
+
1).repeat(1, 1, 1, chunk_size))
|
83 |
+
self.negtive = torch.nn.Parameter(
|
84 |
+
torch.ones((1, self.hidden, 1, chunk_size)) * -1.0)
|
85 |
+
self.eps = torch.nn.Parameter(
|
86 |
+
torch.zeros((1, self.hidden, 1, chunk_size)) + module.eps)
|
87 |
+
self.mean_conv_1 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False)
|
88 |
+
self.mean_conv_1.weight = torch.nn.Parameter(
|
89 |
+
torch.ones(self.hidden, self.hidden, 1, 1) /
|
90 |
+
(1.0 * self.hidden))
|
91 |
+
self.mean_conv_2 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False)
|
92 |
+
self.mean_conv_2.weight = torch.nn.Parameter(
|
93 |
+
torch.ones(self.hidden, self.hidden, 1, 1) /
|
94 |
+
(1.0 * self.hidden))
|
95 |
+
else:
|
96 |
+
self.norm = module
|
97 |
+
|
98 |
+
self.check_equal(original)
|
99 |
+
|
100 |
+
def check_equal(self, module):
|
101 |
+
random_data = torch.randn(1, self.chunk_size, self.hidden)
|
102 |
+
orig_out = module(random_data)
|
103 |
+
new_out = self.forward(random_data.transpose(1, 2).unsqueeze(2))
|
104 |
+
np.testing.assert_allclose(to_numpy(orig_out),
|
105 |
+
to_numpy(
|
106 |
+
new_out.squeeze(2).transpose(1, 2)),
|
107 |
+
rtol=1e-02,
|
108 |
+
atol=1e-03)
|
109 |
+
|
110 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
111 |
+
if self.run_on_bpu:
|
112 |
+
u = self.mean_conv_1(x) # (1, h, 1, c)
|
113 |
+
numerator = x + u * self.negtive # (1, h, 1, c)
|
114 |
+
s = torch.pow(numerator, 2) # (1, h, 1, c)
|
115 |
+
s = self.mean_conv_2(s) # (1, h, 1, c)
|
116 |
+
denominator = torch.sqrt(s + self.eps) # (1, h, 1, c)
|
117 |
+
x = torch.div(numerator, denominator) # (1, h, 1, c)
|
118 |
+
x = x * self.weight + self.bias
|
119 |
+
else:
|
120 |
+
x = x.squeeze(2).transpose(1, 2).contiguous()
|
121 |
+
x = self.norm(x)
|
122 |
+
x = x.transpose(1, 2).contiguous().unsqueeze(2)
|
123 |
+
return x
|
124 |
+
|
125 |
+
|
126 |
+
class BPUIdentity(torch.nn.Module):
|
127 |
+
"""Refactor torch.nn.Identity().
|
128 |
+
For inserting BPU node whose input == output.
|
129 |
+
"""
|
130 |
+
|
131 |
+
def __init__(self, channels):
|
132 |
+
super().__init__()
|
133 |
+
self.channels = channels
|
134 |
+
self.identity_conv = torch.nn.Conv2d(channels,
|
135 |
+
channels,
|
136 |
+
1,
|
137 |
+
groups=channels,
|
138 |
+
bias=False)
|
139 |
+
torch.nn.init.dirac_(self.identity_conv.weight.data, groups=channels)
|
140 |
+
|
141 |
+
self.check_equal()
|
142 |
+
|
143 |
+
def check_equal(self):
|
144 |
+
random_data = torch.randn(1, self.channels, 1, 10)
|
145 |
+
result = self.forward(random_data)
|
146 |
+
np.testing.assert_allclose(to_numpy(random_data),
|
147 |
+
to_numpy(result),
|
148 |
+
rtol=1e-02,
|
149 |
+
atol=1e-03)
|
150 |
+
|
151 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
152 |
+
"""Identity with 4-D dataflow, input == output.
|
153 |
+
Args:
|
154 |
+
x (torch.Tensor): (batch, in_channel, 1, time)
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
(torch.Tensor): (batch, in_channel, 1, time).
|
158 |
+
"""
|
159 |
+
return self.identity_conv(x)
|
160 |
+
|
161 |
+
|
162 |
+
class BPULinear(torch.nn.Module):
|
163 |
+
"""Refactor torch.nn.Linear or pointwise_conv"""
|
164 |
+
|
165 |
+
def __init__(self, module, is_pointwise_conv=False):
|
166 |
+
super().__init__()
|
167 |
+
# Unchanged submodules and attributes
|
168 |
+
original = copy.deepcopy(module)
|
169 |
+
self.idim = module.weight.size(1)
|
170 |
+
self.odim = module.weight.size(0)
|
171 |
+
self.is_pointwise_conv = is_pointwise_conv
|
172 |
+
|
173 |
+
# Modify weight & bias
|
174 |
+
self.linear = torch.nn.Conv2d(self.idim, self.odim, 1, 1)
|
175 |
+
if is_pointwise_conv:
|
176 |
+
# (odim, idim, kernel=1) -> (odim, idim, 1, 1)
|
177 |
+
self.linear.weight = torch.nn.Parameter(
|
178 |
+
module.weight.unsqueeze(-1))
|
179 |
+
else:
|
180 |
+
# (odim, idim) -> (odim, idim, 1, 1)
|
181 |
+
self.linear.weight = torch.nn.Parameter(
|
182 |
+
module.weight.unsqueeze(2).unsqueeze(3))
|
183 |
+
self.linear.bias = module.bias
|
184 |
+
|
185 |
+
self.check_equal(original)
|
186 |
+
|
187 |
+
def check_equal(self, module):
|
188 |
+
random_data = torch.randn(1, 8, self.idim)
|
189 |
+
if self.is_pointwise_conv:
|
190 |
+
random_data = random_data.transpose(1, 2)
|
191 |
+
original_result = module(random_data)
|
192 |
+
if self.is_pointwise_conv:
|
193 |
+
random_data = random_data.transpose(1, 2)
|
194 |
+
original_result = original_result.transpose(1, 2)
|
195 |
+
random_data = random_data.transpose(1, 2).unsqueeze(2)
|
196 |
+
new_result = self.forward(random_data)
|
197 |
+
np.testing.assert_allclose(to_numpy(original_result),
|
198 |
+
to_numpy(
|
199 |
+
new_result.squeeze(2).transpose(1, 2)),
|
200 |
+
rtol=1e-02,
|
201 |
+
atol=1e-03)
|
202 |
+
|
203 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
204 |
+
"""Linear with 4-D dataflow.
|
205 |
+
Args:
|
206 |
+
x (torch.Tensor): (batch, in_channel, 1, time)
|
207 |
+
Returns:
|
208 |
+
(torch.Tensor): (batch, out_channel, 1, time).
|
209 |
+
"""
|
210 |
+
return self.linear(x)
|
211 |
+
|
212 |
+
|
213 |
+
class BPUGlobalCMVN(torch.nn.Module):
|
214 |
+
"""Refactor wenet/transformer/cmvn.py::GlobalCMVN"""
|
215 |
+
|
216 |
+
def __init__(self, module):
|
217 |
+
super().__init__()
|
218 |
+
# Unchanged submodules and attributes
|
219 |
+
self.norm_var = module.norm_var
|
220 |
+
|
221 |
+
# NOTE(xcsong): Expand to 4-D tensor, (mel_dim) -> (1, 1, mel_dim, 1)
|
222 |
+
self.mean = module.mean.unsqueeze(-1).unsqueeze(0).unsqueeze(0)
|
223 |
+
self.istd = module.istd.unsqueeze(-1).unsqueeze(0).unsqueeze(0)
|
224 |
+
|
225 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
226 |
+
"""CMVN with 4-D dataflow.
|
227 |
+
Args:
|
228 |
+
x (torch.Tensor): (batch, 1, mel_dim, time)
|
229 |
+
Returns:
|
230 |
+
(torch.Tensor): normalized feature with same shape.
|
231 |
+
"""
|
232 |
+
x = x - self.mean
|
233 |
+
if self.norm_var:
|
234 |
+
x = x * self.istd
|
235 |
+
return x
|
236 |
+
|
237 |
+
|
238 |
+
class BPUConv2dSubsampling8(torch.nn.Module):
|
239 |
+
"""Refactor wenet/transformer/subsampling.py::Conv2dSubsampling8
|
240 |
+
|
241 |
+
NOTE(xcsong): Only support pos_enc_class == NoPositionalEncoding
|
242 |
+
"""
|
243 |
+
|
244 |
+
def __init__(self, module):
|
245 |
+
super().__init__()
|
246 |
+
# Unchanged submodules and attributes
|
247 |
+
original = copy.deepcopy(module)
|
248 |
+
self.right_context = module.right_context
|
249 |
+
self.subsampling_rate = module.subsampling_rate
|
250 |
+
assert isinstance(module.pos_enc, NoPositionalEncoding)
|
251 |
+
|
252 |
+
# 1. Modify self.conv
|
253 |
+
# NOTE(xcsong): We change input shape from (1, 1, frames, mel_dim)
|
254 |
+
# to (1, 1, mel_dim, frames) for more efficient computation.
|
255 |
+
self.conv = module.conv
|
256 |
+
for idx in [0, 2, 4]:
|
257 |
+
self.conv[idx].weight = torch.nn.Parameter(
|
258 |
+
module.conv[idx].weight.transpose(2, 3))
|
259 |
+
|
260 |
+
# 2. Modify self.linear
|
261 |
+
# NOTE(xcsong): Split final projection to meet the requirment of
|
262 |
+
# maximum kernel_size (7 for XJ3)
|
263 |
+
self.linear = torch.nn.ModuleList()
|
264 |
+
odim = module.linear.weight.size(0) # 512, in this case
|
265 |
+
freq = module.linear.weight.size(1) // odim # 4608 // 512 == 9
|
266 |
+
self.odim, self.freq = odim, freq
|
267 |
+
weight = module.linear.weight.reshape(
|
268 |
+
odim, odim, freq,
|
269 |
+
1) # (odim, odim * freq) -> (odim, odim, freq, 1)
|
270 |
+
self.split_size = []
|
271 |
+
num_split = (freq - 1) // 7 + 1 # XJ3 requires kernel_size <= 7
|
272 |
+
slice_begin = 0
|
273 |
+
for idx in range(num_split):
|
274 |
+
kernel_size = min(freq, (idx + 1) * 7) - idx * 7
|
275 |
+
conv_ele = torch.nn.Conv2d(odim, odim, (kernel_size, 1),
|
276 |
+
(kernel_size, 1))
|
277 |
+
conv_ele.weight = torch.nn.Parameter(
|
278 |
+
weight[:, :, slice_begin:slice_begin + kernel_size, :])
|
279 |
+
conv_ele.bias = torch.nn.Parameter(torch.zeros_like(conv_ele.bias))
|
280 |
+
self.linear.append(conv_ele)
|
281 |
+
self.split_size.append(kernel_size)
|
282 |
+
slice_begin += kernel_size
|
283 |
+
self.linear[0].bias = torch.nn.Parameter(module.linear.bias)
|
284 |
+
|
285 |
+
self.check_equal(original)
|
286 |
+
|
287 |
+
def check_equal(self, module):
|
288 |
+
random_data = torch.randn(1, 67, 80)
|
289 |
+
mask = torch.zeros(1, 1, 67)
|
290 |
+
original_result, _, _ = module(random_data, mask) # (1, 8, 512)
|
291 |
+
random_data = random_data.transpose(1,
|
292 |
+
2).unsqueeze(0) # (1, 1, 80, 67)
|
293 |
+
new_result = self.forward(random_data) # (1, 512, 1, 8)
|
294 |
+
np.testing.assert_allclose(to_numpy(original_result),
|
295 |
+
to_numpy(
|
296 |
+
new_result.squeeze(2).transpose(1, 2)),
|
297 |
+
rtol=1e-02,
|
298 |
+
atol=1e-03)
|
299 |
+
|
300 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
301 |
+
"""Subsample x with 4-D dataflow.
|
302 |
+
Args:
|
303 |
+
x (torch.Tensor): Input tensor (#batch, 1, mel_dim, time).
|
304 |
+
|
305 |
+
Returns:
|
306 |
+
torch.Tensor: Subsampled tensor (#batch, odim, 1, time'),
|
307 |
+
where time' = time // 8.
|
308 |
+
"""
|
309 |
+
x = self.conv(x) # (1, odim, freq, time')
|
310 |
+
x_out = torch.zeros(x.size(0), self.odim, 1, x.size(3))
|
311 |
+
x = torch.split(x, self.split_size, dim=2)
|
312 |
+
for idx, (x_part, layer) in enumerate(zip(x, self.linear)):
|
313 |
+
x_out += layer(x_part)
|
314 |
+
return x_out
|
315 |
+
|
316 |
+
|
317 |
+
class BPUMultiHeadedAttention(torch.nn.Module):
|
318 |
+
"""Refactor wenet/transformer/attention.py::MultiHeadedAttention
|
319 |
+
|
320 |
+
NOTE(xcsong): Only support attention_class == MultiHeadedAttention,
|
321 |
+
we do not consider RelPositionMultiHeadedAttention currently.
|
322 |
+
"""
|
323 |
+
|
324 |
+
def __init__(self, module, chunk_size, left_chunks):
|
325 |
+
super().__init__()
|
326 |
+
# Unchanged submodules and attributes
|
327 |
+
original = copy.deepcopy(module)
|
328 |
+
self.d_k = module.d_k
|
329 |
+
self.h = module.h
|
330 |
+
n_feat = self.d_k * self.h
|
331 |
+
self.chunk_size = chunk_size
|
332 |
+
self.left_chunks = left_chunks
|
333 |
+
self.time = chunk_size * (left_chunks + 1)
|
334 |
+
self.activation = torch.nn.Softmax(dim=-1)
|
335 |
+
|
336 |
+
# 1. Modify self.linear_x
|
337 |
+
self.linear_q = BPULinear(module.linear_q)
|
338 |
+
self.linear_k = BPULinear(module.linear_k)
|
339 |
+
self.linear_v = BPULinear(module.linear_v)
|
340 |
+
self.linear_out = BPULinear(module.linear_out)
|
341 |
+
# 2. denom
|
342 |
+
self.register_buffer(
|
343 |
+
"denom", torch.full((1, self.h, 1, 1), 1.0 / math.sqrt(self.d_k)))
|
344 |
+
|
345 |
+
self.check_equal(original)
|
346 |
+
|
347 |
+
def check_equal(self, module):
|
348 |
+
random_data = torch.randn(1, self.chunk_size, self.d_k * self.h)
|
349 |
+
mask = torch.ones((1, self.h, self.chunk_size, self.time),
|
350 |
+
dtype=torch.bool)
|
351 |
+
cache = torch.zeros(1, self.h, self.chunk_size * self.left_chunks,
|
352 |
+
self.d_k * 2)
|
353 |
+
original_out, original_cache = module(random_data, random_data,
|
354 |
+
random_data, mask[:, 0, :, :],
|
355 |
+
torch.empty(0), cache)
|
356 |
+
random_data = random_data.transpose(1, 2).unsqueeze(2)
|
357 |
+
cache = cache.reshape(1, self.h, self.d_k * 2,
|
358 |
+
self.chunk_size * self.left_chunks)
|
359 |
+
new_out, new_cache = self.forward(random_data, random_data,
|
360 |
+
random_data, mask, cache)
|
361 |
+
np.testing.assert_allclose(to_numpy(original_out),
|
362 |
+
to_numpy(
|
363 |
+
new_out.squeeze(2).transpose(1, 2)),
|
364 |
+
rtol=1e-02,
|
365 |
+
atol=1e-03)
|
366 |
+
np.testing.assert_allclose(to_numpy(original_cache),
|
367 |
+
to_numpy(new_cache.transpose(2, 3)),
|
368 |
+
rtol=1e-02,
|
369 |
+
atol=1e-03)
|
370 |
+
|
371 |
+
def forward(
|
372 |
+
self,
|
373 |
+
q: torch.Tensor,
|
374 |
+
k: torch.Tensor,
|
375 |
+
v: torch.Tensor,
|
376 |
+
mask: torch.Tensor,
|
377 |
+
cache: torch.Tensor,
|
378 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
379 |
+
"""Compute scaled dot product attention.
|
380 |
+
|
381 |
+
Args:
|
382 |
+
q (torch.Tensor): Query tensor (#batch, size, 1, chunk_size).
|
383 |
+
k (torch.Tensor): Key tensor (#batch, size, 1, chunk_size).
|
384 |
+
v (torch.Tensor): Value tensor (#batch, size, 1, chunk_size).
|
385 |
+
mask (torch.Tensor): Mask tensor,
|
386 |
+
(#batch, head, chunk_size, cache_t + chunk_size).
|
387 |
+
cache (torch.Tensor): Cache tensor
|
388 |
+
(1, head, d_k * 2, cache_t),
|
389 |
+
where `cache_t == chunk_size * left_chunks`.
|
390 |
+
|
391 |
+
|
392 |
+
Returns:
|
393 |
+
torch.Tensor: Output tensor (#batch, size, 1, chunk_size).
|
394 |
+
torch.Tensor: Cache tensor
|
395 |
+
(1, head, d_k * 2, cache_t + chunk_size)
|
396 |
+
where `cache_t == chunk_size * left_chunks`
|
397 |
+
"""
|
398 |
+
# 1. Forward QKV
|
399 |
+
q = self.linear_q(q) # (1, d, 1, c) d == size, c == chunk_size
|
400 |
+
k = self.linear_k(k) # (1, d, 1, c)
|
401 |
+
v = self.linear_v(v) # (1, d, 1, c)
|
402 |
+
q = q.view(1, self.h, self.d_k, self.chunk_size)
|
403 |
+
k = k.view(1, self.h, self.d_k, self.chunk_size)
|
404 |
+
v = v.view(1, self.h, self.d_k, self.chunk_size)
|
405 |
+
q = q.transpose(2, 3) # (batch, head, time1, d_k)
|
406 |
+
k_cache, v_cache = torch.split(cache, cache.size(2) // 2, dim=2)
|
407 |
+
k = torch.cat((k_cache, k), dim=3)
|
408 |
+
v = torch.cat((v_cache, v), dim=3)
|
409 |
+
new_cache = torch.cat((k, v), dim=2)
|
410 |
+
# 2. (Q^T)K
|
411 |
+
scores = torch.matmul(q, k) * self.denom # (#b, n_head, time1, time2)
|
412 |
+
# 3. Forward attention
|
413 |
+
mask = mask.eq(0)
|
414 |
+
scores = scores.masked_fill(mask, -float('inf'))
|
415 |
+
attn = self.activation(scores).masked_fill(mask, 0.0)
|
416 |
+
attn = attn.transpose(2, 3)
|
417 |
+
x = torch.matmul(v, attn)
|
418 |
+
x = x.view(1, self.d_k * self.h, 1, self.chunk_size)
|
419 |
+
x_out = self.linear_out(x)
|
420 |
+
return x_out, new_cache
|
421 |
+
|
422 |
+
|
423 |
+
class BPUConvolution(torch.nn.Module):
|
424 |
+
"""Refactor wenet/transformer/convolution.py::ConvolutionModule
|
425 |
+
|
426 |
+
NOTE(xcsong): Only suport use_layer_norm == False
|
427 |
+
"""
|
428 |
+
|
429 |
+
def __init__(self, module):
|
430 |
+
super().__init__()
|
431 |
+
# Unchanged submodules and attributes
|
432 |
+
original = copy.deepcopy(module)
|
433 |
+
self.lorder = module.lorder
|
434 |
+
self.use_layer_norm = False
|
435 |
+
self.activation = module.activation
|
436 |
+
channels = module.pointwise_conv1.weight.size(1)
|
437 |
+
self.channels = channels
|
438 |
+
kernel_size = module.depthwise_conv.weight.size(2)
|
439 |
+
assert module.use_layer_norm is False
|
440 |
+
|
441 |
+
# 1. Modify self.pointwise_conv1
|
442 |
+
self.pointwise_conv1 = BPULinear(module.pointwise_conv1, True)
|
443 |
+
|
444 |
+
# 2. Modify self.depthwise_conv
|
445 |
+
self.depthwise_conv = torch.nn.Conv2d(channels,
|
446 |
+
channels, (1, kernel_size),
|
447 |
+
stride=1,
|
448 |
+
groups=channels)
|
449 |
+
self.depthwise_conv.weight = torch.nn.Parameter(
|
450 |
+
module.depthwise_conv.weight.unsqueeze(-2))
|
451 |
+
self.depthwise_conv.bias = torch.nn.Parameter(
|
452 |
+
module.depthwise_conv.bias)
|
453 |
+
|
454 |
+
# 3. Modify self.norm, Only support batchnorm2d
|
455 |
+
self.norm = torch.nn.BatchNorm2d(channels)
|
456 |
+
self.norm.training = False
|
457 |
+
self.norm.num_features = module.norm.num_features
|
458 |
+
self.norm.eps = module.norm.eps
|
459 |
+
self.norm.momentum = module.norm.momentum
|
460 |
+
self.norm.weight = torch.nn.Parameter(module.norm.weight)
|
461 |
+
self.norm.bias = torch.nn.Parameter(module.norm.bias)
|
462 |
+
self.norm.running_mean = module.norm.running_mean
|
463 |
+
self.norm.running_var = module.norm.running_var
|
464 |
+
|
465 |
+
# 4. Modify self.pointwise_conv2
|
466 |
+
self.pointwise_conv2 = BPULinear(module.pointwise_conv2, True)
|
467 |
+
|
468 |
+
# 5. Identity conv, for running `concat` on BPU
|
469 |
+
self.identity = BPUIdentity(channels)
|
470 |
+
|
471 |
+
self.check_equal(original)
|
472 |
+
|
473 |
+
def check_equal(self, module):
|
474 |
+
random_data = torch.randn(1, 8, self.channels)
|
475 |
+
cache = torch.zeros((1, self.channels, self.lorder))
|
476 |
+
original_out, original_cache = module(random_data, cache=cache)
|
477 |
+
random_data = random_data.transpose(1, 2).unsqueeze(2)
|
478 |
+
cache = cache.unsqueeze(2)
|
479 |
+
new_out, new_cache = self.forward(random_data, cache)
|
480 |
+
np.testing.assert_allclose(to_numpy(original_out),
|
481 |
+
to_numpy(
|
482 |
+
new_out.squeeze(2).transpose(1, 2)),
|
483 |
+
rtol=1e-02,
|
484 |
+
atol=1e-03)
|
485 |
+
np.testing.assert_allclose(to_numpy(original_cache),
|
486 |
+
to_numpy(new_cache.squeeze(2)),
|
487 |
+
rtol=1e-02,
|
488 |
+
atol=1e-03)
|
489 |
+
|
490 |
+
def forward(self, x: torch.Tensor,
|
491 |
+
cache: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
492 |
+
"""Compute convolution module.
|
493 |
+
Args:
|
494 |
+
x (torch.Tensor): Input tensor (#batch, channels, 1, chunk_size).
|
495 |
+
cache (torch.Tensor): left context cache, it is only
|
496 |
+
used in causal convolution (#batch, channels, 1, cache_t).
|
497 |
+
Returns:
|
498 |
+
torch.Tensor: Output tensor (#batch, channels, 1, chunk_size).
|
499 |
+
torch.Tensor: Cache tensor (#batch, channels, 1, cache_t).
|
500 |
+
"""
|
501 |
+
# Concat cache
|
502 |
+
x = torch.cat((self.identity(cache), self.identity(x)), dim=3)
|
503 |
+
new_cache = x[:, :, :, -self.lorder:]
|
504 |
+
|
505 |
+
# GLU mechanism
|
506 |
+
x = self.pointwise_conv1(x) # (batch, 2*channel, 1, dim)
|
507 |
+
x = torch.nn.functional.glu(x, dim=1) # (b, channel, 1, dim)
|
508 |
+
|
509 |
+
# Depthwise Conv
|
510 |
+
x = self.depthwise_conv(x)
|
511 |
+
x = self.activation(self.norm(x))
|
512 |
+
x = self.pointwise_conv2(x)
|
513 |
+
return x, new_cache
|
514 |
+
|
515 |
+
|
516 |
+
class BPUFFN(torch.nn.Module):
|
517 |
+
"""Refactor wenet/transformer/positionwise_feed_forward.py::PositionwiseFeedForward
|
518 |
+
"""
|
519 |
+
|
520 |
+
def __init__(self, module):
|
521 |
+
super().__init__()
|
522 |
+
# Unchanged submodules and attributes
|
523 |
+
original = copy.deepcopy(module)
|
524 |
+
self.activation = module.activation
|
525 |
+
|
526 |
+
# 1. Modify self.w_x
|
527 |
+
self.w_1 = BPULinear(module.w_1)
|
528 |
+
self.w_2 = BPULinear(module.w_2)
|
529 |
+
|
530 |
+
self.check_equal(original)
|
531 |
+
|
532 |
+
def check_equal(self, module):
|
533 |
+
random_data = torch.randn(1, 8, self.w_1.idim)
|
534 |
+
original_out = module(random_data)
|
535 |
+
random_data = random_data.transpose(1, 2).unsqueeze(2)
|
536 |
+
new_out = self.forward(random_data)
|
537 |
+
np.testing.assert_allclose(to_numpy(original_out),
|
538 |
+
to_numpy(
|
539 |
+
new_out.squeeze(2).transpose(1, 2)),
|
540 |
+
rtol=1e-02,
|
541 |
+
atol=1e-03)
|
542 |
+
|
543 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
544 |
+
"""Forward function.
|
545 |
+
|
546 |
+
Args:
|
547 |
+
xs: input tensor (B, D, 1, L)
|
548 |
+
Returns:
|
549 |
+
output tensor, (B, D, 1, L)
|
550 |
+
"""
|
551 |
+
return self.w_2(self.activation(self.w_1(x)))
|
552 |
+
|
553 |
+
|
554 |
+
class BPUConformerEncoderLayer(torch.nn.Module):
|
555 |
+
"""Refactor wenet/transformer/encoder_layer.py::ConformerEncoderLayer
|
556 |
+
"""
|
557 |
+
|
558 |
+
def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False):
|
559 |
+
super().__init__()
|
560 |
+
# Unchanged submodules and attributes
|
561 |
+
original = copy.deepcopy(module)
|
562 |
+
self.size = module.size
|
563 |
+
assert module.normalize_before is True
|
564 |
+
assert module.concat_after is False
|
565 |
+
|
566 |
+
# 1. Modify submodules
|
567 |
+
self.feed_forward_macaron = BPUFFN(module.feed_forward_macaron)
|
568 |
+
self.self_attn = BPUMultiHeadedAttention(module.self_attn, chunk_size,
|
569 |
+
left_chunks)
|
570 |
+
self.conv_module = BPUConvolution(module.conv_module)
|
571 |
+
self.feed_forward = BPUFFN(module.feed_forward)
|
572 |
+
|
573 |
+
# 2. Modify norms
|
574 |
+
self.norm_ff = BPULayerNorm(module.norm_ff, chunk_size, ln_run_on_bpu)
|
575 |
+
self.norm_mha = BPULayerNorm(module.norm_mha, chunk_size,
|
576 |
+
ln_run_on_bpu)
|
577 |
+
self.norm_ff_macron = BPULayerNorm(module.norm_ff_macaron, chunk_size,
|
578 |
+
ln_run_on_bpu)
|
579 |
+
self.norm_conv = BPULayerNorm(module.norm_conv, chunk_size,
|
580 |
+
ln_run_on_bpu)
|
581 |
+
self.norm_final = BPULayerNorm(module.norm_final, chunk_size,
|
582 |
+
ln_run_on_bpu)
|
583 |
+
|
584 |
+
# 3. 4-D ff_scale
|
585 |
+
self.register_buffer("ff_scale",
|
586 |
+
torch.full((1, self.size, 1, 1), module.ff_scale))
|
587 |
+
|
588 |
+
self.check_equal(original)
|
589 |
+
|
590 |
+
def check_equal(self, module):
|
591 |
+
time1 = self.self_attn.chunk_size
|
592 |
+
time2 = self.self_attn.time
|
593 |
+
h, d_k = self.self_attn.h, self.self_attn.d_k
|
594 |
+
random_x = torch.randn(1, time1, self.size)
|
595 |
+
att_mask = torch.ones(1, h, time1, time2)
|
596 |
+
att_cache = torch.zeros(1, h, time2 - time1, d_k * 2)
|
597 |
+
cnn_cache = torch.zeros(1, self.size, self.conv_module.lorder)
|
598 |
+
original_x, _, original_att_cache, original_cnn_cache = module(
|
599 |
+
random_x,
|
600 |
+
att_mask[:, 0, :, :],
|
601 |
+
torch.empty(0),
|
602 |
+
att_cache=att_cache,
|
603 |
+
cnn_cache=cnn_cache)
|
604 |
+
random_x = random_x.transpose(1, 2).unsqueeze(2)
|
605 |
+
att_cache = att_cache.reshape(1, h, d_k * 2, time2 - time1)
|
606 |
+
cnn_cache = cnn_cache.unsqueeze(2)
|
607 |
+
new_x, new_att_cache, new_cnn_cache = self.forward(
|
608 |
+
random_x, att_mask, att_cache, cnn_cache)
|
609 |
+
np.testing.assert_allclose(to_numpy(original_att_cache),
|
610 |
+
to_numpy(new_att_cache.transpose(2, 3)),
|
611 |
+
rtol=1e-02,
|
612 |
+
atol=1e-03)
|
613 |
+
np.testing.assert_allclose(to_numpy(original_x),
|
614 |
+
to_numpy(new_x.squeeze(2).transpose(1, 2)),
|
615 |
+
rtol=1e-02,
|
616 |
+
atol=1e-03)
|
617 |
+
np.testing.assert_allclose(to_numpy(original_cnn_cache),
|
618 |
+
to_numpy(new_cnn_cache.squeeze(2)),
|
619 |
+
rtol=1e-02,
|
620 |
+
atol=1e-03)
|
621 |
+
|
622 |
+
def forward(
|
623 |
+
self, x: torch.Tensor, att_mask: torch.Tensor, att_cache: torch.Tensor,
|
624 |
+
cnn_cache: torch.Tensor
|
625 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
626 |
+
"""Compute encoded features.
|
627 |
+
|
628 |
+
Args:
|
629 |
+
x (torch.Tensor): (#batch, size, 1, chunk_size)
|
630 |
+
att_mask (torch.Tensor): Mask tensor for the input
|
631 |
+
(#batch, head, chunk_size, cache_t1 + chunk_size),
|
632 |
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
633 |
+
(#batch=1, head, d_k * 2, cache_t1), head * d_k == size.
|
634 |
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
635 |
+
(#batch=1, size, 1, cache_t2)
|
636 |
+
Returns:
|
637 |
+
torch.Tensor: Output tensor (#batch, size, 1, chunk_size).
|
638 |
+
torch.Tensor: att_cache tensor,
|
639 |
+
(1, head, d_k * 2, cache_t1 + chunk_size).
|
640 |
+
torch.Tensor: cnn_cahce tensor (#batch, size, 1, cache_t2).
|
641 |
+
"""
|
642 |
+
# 1. ffn_macaron
|
643 |
+
residual = x
|
644 |
+
x = self.norm_ff_macron(x)
|
645 |
+
x = residual + self.ff_scale * self.feed_forward_macaron(x)
|
646 |
+
|
647 |
+
# 2. attention
|
648 |
+
residual = x
|
649 |
+
x = self.norm_mha(x)
|
650 |
+
x_att, new_att_cache = self.self_attn(x, x, x, att_mask, att_cache)
|
651 |
+
x = residual + x_att
|
652 |
+
|
653 |
+
# 3. convolution
|
654 |
+
residual = x
|
655 |
+
x = self.norm_conv(x)
|
656 |
+
x, new_cnn_cache = self.conv_module(x, cnn_cache)
|
657 |
+
x = residual + x
|
658 |
+
|
659 |
+
# 4. ffn
|
660 |
+
residual = x
|
661 |
+
x = self.norm_ff(x)
|
662 |
+
x = residual + self.ff_scale * self.feed_forward(x)
|
663 |
+
|
664 |
+
# 5. final post-norm
|
665 |
+
x = self.norm_final(x)
|
666 |
+
|
667 |
+
return x, new_att_cache, new_cnn_cache
|
668 |
+
|
669 |
+
|
670 |
+
class BPUConformerEncoder(torch.nn.Module):
|
671 |
+
"""Refactor wenet/transformer/encoder.py::ConformerEncoder
|
672 |
+
"""
|
673 |
+
|
674 |
+
def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False):
|
675 |
+
super().__init__()
|
676 |
+
# Unchanged submodules and attributes
|
677 |
+
original = copy.deepcopy(module)
|
678 |
+
output_size = module.output_size()
|
679 |
+
self._output_size = module.output_size()
|
680 |
+
self.after_norm = module.after_norm
|
681 |
+
self.chunk_size = chunk_size
|
682 |
+
self.left_chunks = left_chunks
|
683 |
+
self.head = module.encoders[0].self_attn.h
|
684 |
+
self.layers = len(module.encoders)
|
685 |
+
|
686 |
+
# 1. Modify submodules
|
687 |
+
self.global_cmvn = BPUGlobalCMVN(module.global_cmvn)
|
688 |
+
self.embed = BPUConv2dSubsampling8(module.embed)
|
689 |
+
self.encoders = torch.nn.ModuleList()
|
690 |
+
for layer in module.encoders:
|
691 |
+
self.encoders.append(
|
692 |
+
BPUConformerEncoderLayer(layer, chunk_size, left_chunks,
|
693 |
+
ln_run_on_bpu))
|
694 |
+
|
695 |
+
# 2. Auxiliary conv
|
696 |
+
self.identity_cnncache = BPUIdentity(output_size)
|
697 |
+
|
698 |
+
self.check_equal(original)
|
699 |
+
|
700 |
+
def check_equal(self, module):
|
701 |
+
time1 = self.encoders[0].self_attn.chunk_size
|
702 |
+
time2 = self.encoders[0].self_attn.time
|
703 |
+
layers = self.layers
|
704 |
+
h, d_k = self.head, self.encoders[0].self_attn.d_k
|
705 |
+
decoding_window = (self.chunk_size - 1) * \
|
706 |
+
module.embed.subsampling_rate + \
|
707 |
+
module.embed.right_context + 1
|
708 |
+
lorder = self.encoders[0].conv_module.lorder
|
709 |
+
random_x = torch.randn(1, decoding_window, 80)
|
710 |
+
att_mask = torch.ones(1, h, time1, time2)
|
711 |
+
att_cache = torch.zeros(layers, h, time2 - time1, d_k * 2)
|
712 |
+
cnn_cache = torch.zeros(layers, 1, self._output_size, lorder)
|
713 |
+
orig_x, orig_att_cache, orig_cnn_cache = module.forward_chunk(
|
714 |
+
random_x,
|
715 |
+
0,
|
716 |
+
time2 - time1,
|
717 |
+
att_mask=att_mask[:, 0, :, :],
|
718 |
+
att_cache=att_cache,
|
719 |
+
cnn_cache=cnn_cache)
|
720 |
+
random_x = random_x.unsqueeze(0)
|
721 |
+
att_cache = att_cache.reshape(1, h * layers, d_k * 2, time2 - time1)
|
722 |
+
cnn_cache = cnn_cache.reshape(1, self._output_size, layers, lorder)
|
723 |
+
new_x, new_att_cache, new_cnn_cache = self.forward(
|
724 |
+
random_x, att_cache, cnn_cache, att_mask)
|
725 |
+
caches = torch.split(new_att_cache, h, dim=1)
|
726 |
+
caches = [c.transpose(2, 3) for c in caches]
|
727 |
+
np.testing.assert_allclose(to_numpy(orig_att_cache),
|
728 |
+
to_numpy(torch.cat(caches, dim=0)),
|
729 |
+
rtol=1e-02,
|
730 |
+
atol=1e-03)
|
731 |
+
np.testing.assert_allclose(to_numpy(orig_x),
|
732 |
+
to_numpy(new_x.squeeze(2).transpose(1, 2)),
|
733 |
+
rtol=1e-02,
|
734 |
+
atol=1e-03)
|
735 |
+
np.testing.assert_allclose(
|
736 |
+
to_numpy(orig_cnn_cache),
|
737 |
+
to_numpy(new_cnn_cache.transpose(0, 2).transpose(1, 2)),
|
738 |
+
rtol=1e-02,
|
739 |
+
atol=1e-03)
|
740 |
+
|
741 |
+
def forward(
|
742 |
+
self, xs: torch.Tensor, att_cache: torch.Tensor,
|
743 |
+
cnn_cache: torch.Tensor, att_mask: torch.Tensor
|
744 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
745 |
+
""" Forward just one chunk
|
746 |
+
|
747 |
+
Args:
|
748 |
+
xs (torch.Tensor): chunk input, with shape (b=1, 1, time, mel-dim),
|
749 |
+
where `time == (chunk_size - 1) * subsample_rate + \
|
750 |
+
subsample.right_context + 1`
|
751 |
+
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
|
752 |
+
transformer/conformer attention, with shape
|
753 |
+
(1, head * elayers, d_k * 2, cache_t1), where
|
754 |
+
`head * d_k == hidden-dim` and
|
755 |
+
`cache_t1 == chunk_size * left_chunks`.
|
756 |
+
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
|
757 |
+
(1, hidden-dim, elayers, cache_t2), where
|
758 |
+
`cache_t2 == cnn.lorder - 1`
|
759 |
+
att_mask (torch.Tensor): Mask tensor for the input
|
760 |
+
(#batch, head, chunk_size, cache_t1 + chunk_size),
|
761 |
+
|
762 |
+
Returns:
|
763 |
+
torch.Tensor: output of current input xs,
|
764 |
+
with shape (b=1, hidden-dim, 1, chunk_size).
|
765 |
+
torch.Tensor: new attention cache required for next chunk, with
|
766 |
+
same shape as the original att_cache.
|
767 |
+
torch.Tensor: new conformer cnn cache required for next chunk, with
|
768 |
+
same shape as the original cnn_cache.
|
769 |
+
"""
|
770 |
+
# xs: (B, 1, time, mel_dim) -> (B, 1, mel_dim, time)
|
771 |
+
xs = xs.transpose(2, 3)
|
772 |
+
xs = self.global_cmvn(xs)
|
773 |
+
# xs: (B, 1, mel_dim, time) -> (B, hidden_dim, 1, chunk_size)
|
774 |
+
xs = self.embed(xs)
|
775 |
+
|
776 |
+
att_cache = torch.split(att_cache, self.head, dim=1)
|
777 |
+
cnn_cache = self.identity_cnncache(cnn_cache)
|
778 |
+
cnn_cache = torch.split(cnn_cache, 1, dim=2)
|
779 |
+
r_att_cache = []
|
780 |
+
r_cnn_cache = []
|
781 |
+
for i, layer in enumerate(self.encoders):
|
782 |
+
xs, new_att_cache, new_cnn_cache = layer(xs,
|
783 |
+
att_mask,
|
784 |
+
att_cache=att_cache[i],
|
785 |
+
cnn_cache=cnn_cache[i])
|
786 |
+
r_att_cache.append(new_att_cache[:, :, :, self.chunk_size:])
|
787 |
+
r_cnn_cache.append(new_cnn_cache)
|
788 |
+
r_att_cache = torch.cat(r_att_cache, dim=1)
|
789 |
+
r_cnn_cache = self.identity_cnncache(torch.cat(r_cnn_cache, dim=2))
|
790 |
+
|
791 |
+
xs = xs.squeeze(2).transpose(1, 2).contiguous()
|
792 |
+
xs = self.after_norm(xs)
|
793 |
+
# NOTE(xcsong): 4D in, 4D out to meet the requirment of CTC input.
|
794 |
+
xs = xs.transpose(1, 2).contiguous().unsqueeze(2) # (B, C, 1, T)
|
795 |
+
|
796 |
+
return (xs, r_att_cache, r_cnn_cache)
|
797 |
+
|
798 |
+
|
799 |
+
class BPUCTC(torch.nn.Module):
|
800 |
+
"""Refactor wenet/transformer/ctc.py::CTC
|
801 |
+
"""
|
802 |
+
|
803 |
+
def __init__(self, module):
|
804 |
+
super().__init__()
|
805 |
+
# Unchanged submodules and attributes
|
806 |
+
original = copy.deepcopy(module)
|
807 |
+
self.idim = module.ctc_lo.weight.size(1)
|
808 |
+
num_class = module.ctc_lo.weight.size(0)
|
809 |
+
|
810 |
+
# 1. Modify self.ctc_lo, Split final projection to meet the
|
811 |
+
# requirment of maximum in/out channels (2048 for XJ3)
|
812 |
+
self.ctc_lo = torch.nn.ModuleList()
|
813 |
+
self.split_size = []
|
814 |
+
num_split = (num_class - 1) // 2048 + 1
|
815 |
+
for idx in range(num_split):
|
816 |
+
out_channel = min(num_class, (idx + 1) * 2048) - idx * 2048
|
817 |
+
conv_ele = torch.nn.Conv2d(self.idim, out_channel, 1, 1)
|
818 |
+
self.ctc_lo.append(conv_ele)
|
819 |
+
self.split_size.append(out_channel)
|
820 |
+
orig_weight = torch.split(module.ctc_lo.weight, self.split_size, dim=0)
|
821 |
+
orig_bias = torch.split(module.ctc_lo.bias, self.split_size, dim=0)
|
822 |
+
for i, (w, b) in enumerate(zip(orig_weight, orig_bias)):
|
823 |
+
w = w.unsqueeze(2).unsqueeze(3)
|
824 |
+
self.ctc_lo[i].weight = torch.nn.Parameter(w)
|
825 |
+
self.ctc_lo[i].bias = torch.nn.Parameter(b)
|
826 |
+
|
827 |
+
self.check_equal(original)
|
828 |
+
|
829 |
+
def check_equal(self, module):
|
830 |
+
random_data = torch.randn(1, 100, self.idim)
|
831 |
+
original_result = module.ctc_lo(random_data)
|
832 |
+
random_data = random_data.transpose(1, 2).unsqueeze(2)
|
833 |
+
new_result = self.forward(random_data)
|
834 |
+
np.testing.assert_allclose(to_numpy(original_result),
|
835 |
+
to_numpy(
|
836 |
+
new_result.squeeze(2).transpose(1, 2)),
|
837 |
+
rtol=1e-02,
|
838 |
+
atol=1e-03)
|
839 |
+
|
840 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
841 |
+
"""frame activations, without softmax.
|
842 |
+
|
843 |
+
Args:
|
844 |
+
Tensor x: 4d tensor (B, hidden_dim, 1, chunk_size)
|
845 |
+
Returns:
|
846 |
+
torch.Tensor: (B, num_class, 1, chunk_size)
|
847 |
+
"""
|
848 |
+
out = []
|
849 |
+
for i, layer in enumerate(self.ctc_lo):
|
850 |
+
out.append(layer(x))
|
851 |
+
out = torch.cat(out, dim=1)
|
852 |
+
return out
|
853 |
+
|
854 |
+
|
855 |
+
def export_encoder(asr_model, args):
|
856 |
+
logger.info("Stage-1: export encoder")
|
857 |
+
decode_window, mel_dim = args.decoding_window, args.feature_size
|
858 |
+
encoder = BPUConformerEncoder(asr_model.encoder, args.chunk_size,
|
859 |
+
args.num_decoding_left_chunks,
|
860 |
+
args.ln_run_on_bpu)
|
861 |
+
encoder.eval()
|
862 |
+
encoder_outpath = os.path.join(args.output_dir, 'encoder.onnx')
|
863 |
+
|
864 |
+
logger.info("Stage-1.1: prepare inputs for encoder")
|
865 |
+
chunk = torch.randn((1, 1, decode_window, mel_dim))
|
866 |
+
required_cache_size = encoder.chunk_size * encoder.left_chunks
|
867 |
+
kv_time = required_cache_size + encoder.chunk_size
|
868 |
+
hidden, layers = encoder._output_size, len(encoder.encoders)
|
869 |
+
head = encoder.encoders[0].self_attn.h
|
870 |
+
d_k = hidden // head
|
871 |
+
lorder = encoder.encoders[0].conv_module.lorder
|
872 |
+
att_cache = torch.zeros(1, layers * head, d_k * 2, required_cache_size)
|
873 |
+
att_mask = torch.ones((1, head, encoder.chunk_size, kv_time))
|
874 |
+
att_mask[:, :, :, :required_cache_size] = 0
|
875 |
+
cnn_cache = torch.zeros((1, hidden, layers, lorder))
|
876 |
+
inputs = (chunk, att_cache, cnn_cache, att_mask)
|
877 |
+
logger.info("chunk.size(): {} att_cache.size(): {} "
|
878 |
+
"cnn_cache.size(): {} att_mask.size(): {}".format(
|
879 |
+
list(chunk.size()), list(att_cache.size()),
|
880 |
+
list(cnn_cache.size()), list(att_mask.size())))
|
881 |
+
|
882 |
+
logger.info("Stage-1.2: torch.onnx.export")
|
883 |
+
# NOTE(xcsong): Below attributes will be used in
|
884 |
+
# onnx2horizonbin.py::generate_config()
|
885 |
+
attributes = {}
|
886 |
+
attributes['input_name'] = "chunk;att_cache;cnn_cache;att_mask"
|
887 |
+
attributes['output_name'] = "output;r_att_cache;r_cnn_cache"
|
888 |
+
attributes['input_type'] = "featuremap;featuremap;featuremap;featuremap"
|
889 |
+
attributes['norm_type'] = \
|
890 |
+
"no_preprocess;no_preprocess;no_preprocess;no_preprocess"
|
891 |
+
attributes['input_layout_train'] = "NCHW;NCHW;NCHW;NCHW"
|
892 |
+
attributes['input_layout_rt'] = "NCHW;NCHW;NCHW;NCHW"
|
893 |
+
attributes['input_shape'] = \
|
894 |
+
"{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{}".format(
|
895 |
+
chunk.size(0), chunk.size(1), chunk.size(2), chunk.size(3),
|
896 |
+
att_cache.size(0), att_cache.size(1), att_cache.size(2),
|
897 |
+
att_cache.size(3), cnn_cache.size(0), cnn_cache.size(1),
|
898 |
+
cnn_cache.size(2), cnn_cache.size(3), att_mask.size(0),
|
899 |
+
att_mask.size(1), att_mask.size(2), att_mask.size(3)
|
900 |
+
)
|
901 |
+
torch.onnx.export( # NOTE(xcsong): only support opset==11
|
902 |
+
encoder,
|
903 |
+
inputs,
|
904 |
+
encoder_outpath,
|
905 |
+
opset_version=11,
|
906 |
+
export_params=True,
|
907 |
+
do_constant_folding=True,
|
908 |
+
input_names=attributes['input_name'].split(';'),
|
909 |
+
output_names=attributes['output_name'].split(';'),
|
910 |
+
dynamic_axes=None,
|
911 |
+
verbose=False)
|
912 |
+
onnx_encoder = onnx.load(encoder_outpath)
|
913 |
+
for k in vars(args):
|
914 |
+
meta = onnx_encoder.metadata_props.add()
|
915 |
+
meta.key, meta.value = str(k), str(getattr(args, k))
|
916 |
+
for k in attributes:
|
917 |
+
meta = onnx_encoder.metadata_props.add()
|
918 |
+
meta.key, meta.value = str(k), str(attributes[k])
|
919 |
+
onnx.checker.check_model(onnx_encoder)
|
920 |
+
onnx.helper.printable_graph(onnx_encoder.graph)
|
921 |
+
onnx.save(onnx_encoder, encoder_outpath)
|
922 |
+
print_input_output_info(onnx_encoder, "onnx_encoder")
|
923 |
+
logger.info('Export onnx_encoder, done! see {}'.format(encoder_outpath))
|
924 |
+
|
925 |
+
logger.info("Stage-1.3: check onnx_encoder and torch_encoder")
|
926 |
+
torch_output = []
|
927 |
+
torch_chunk, torch_att_mask = copy.deepcopy(chunk), copy.deepcopy(att_mask)
|
928 |
+
torch_att_cache = copy.deepcopy(att_cache)
|
929 |
+
torch_cnn_cache = copy.deepcopy(cnn_cache)
|
930 |
+
for i in range(10):
|
931 |
+
logger.info("torch chunk-{}: {}, att_cache: {}, cnn_cache: {}"
|
932 |
+
", att_mask: {}".format(i, list(torch_chunk.size()),
|
933 |
+
list(torch_att_cache.size()),
|
934 |
+
list(torch_cnn_cache.size()),
|
935 |
+
list(torch_att_mask.size())))
|
936 |
+
torch_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)):] = 1
|
937 |
+
out, torch_att_cache, torch_cnn_cache = encoder(
|
938 |
+
torch_chunk, torch_att_cache, torch_cnn_cache, torch_att_mask)
|
939 |
+
torch_output.append(out)
|
940 |
+
torch_output = torch.cat(torch_output, dim=-1)
|
941 |
+
|
942 |
+
onnx_output = []
|
943 |
+
onnx_chunk, onnx_att_mask = to_numpy(chunk), to_numpy(att_mask)
|
944 |
+
onnx_att_cache = to_numpy(att_cache)
|
945 |
+
onnx_cnn_cache = to_numpy(cnn_cache)
|
946 |
+
ort_session = onnxruntime.InferenceSession(encoder_outpath)
|
947 |
+
input_names = [node.name for node in onnx_encoder.graph.input]
|
948 |
+
for i in range(10):
|
949 |
+
logger.info("onnx chunk-{}: {}, att_cache: {}, cnn_cache: {},"
|
950 |
+
" att_mask: {}".format(i, onnx_chunk.shape,
|
951 |
+
onnx_att_cache.shape,
|
952 |
+
onnx_cnn_cache.shape,
|
953 |
+
onnx_att_mask.shape))
|
954 |
+
onnx_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)):] = 1
|
955 |
+
ort_inputs = {
|
956 |
+
'chunk': onnx_chunk,
|
957 |
+
'att_cache': onnx_att_cache,
|
958 |
+
'cnn_cache': onnx_cnn_cache,
|
959 |
+
'att_mask': onnx_att_mask,
|
960 |
+
}
|
961 |
+
ort_outs = ort_session.run(None, ort_inputs)
|
962 |
+
onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2]
|
963 |
+
onnx_output.append(ort_outs[0])
|
964 |
+
onnx_output = np.concatenate(onnx_output, axis=-1)
|
965 |
+
|
966 |
+
np.testing.assert_allclose(to_numpy(torch_output),
|
967 |
+
onnx_output,
|
968 |
+
rtol=1e-03,
|
969 |
+
atol=1e-04)
|
970 |
+
meta = ort_session.get_modelmeta()
|
971 |
+
logger.info("custom_metadata_map={}".format(meta.custom_metadata_map))
|
972 |
+
logger.info("Check onnx_encoder, pass!")
|
973 |
+
return encoder, ort_session
|
974 |
+
|
975 |
+
|
976 |
+
def export_ctc(asr_model, args):
|
977 |
+
logger.info("Stage-2: export ctc")
|
978 |
+
ctc = BPUCTC(asr_model.ctc).eval()
|
979 |
+
ctc_outpath = os.path.join(args.output_dir, 'ctc.onnx')
|
980 |
+
|
981 |
+
logger.info("Stage-2.1: prepare inputs for ctc")
|
982 |
+
hidden = torch.randn((1, args.output_size, 1, args.chunk_size))
|
983 |
+
|
984 |
+
logger.info("Stage-2.2: torch.onnx.export")
|
985 |
+
# NOTE(xcsong): Below attributes will be used in
|
986 |
+
# onnx2horizonbin.py::generate_config()
|
987 |
+
attributes = {}
|
988 |
+
attributes['input_name'], attributes['input_type'] = "hidden", "featuremap"
|
989 |
+
attributes['norm_type'] = "no_preprocess"
|
990 |
+
attributes['input_layout_train'] = "NCHW"
|
991 |
+
attributes['input_layout_rt'] = "NCHW"
|
992 |
+
attributes['input_shape'] = "{}x{}x{}x{}".format(
|
993 |
+
hidden.size(0),
|
994 |
+
hidden.size(1),
|
995 |
+
hidden.size(2),
|
996 |
+
hidden.size(3),
|
997 |
+
)
|
998 |
+
torch.onnx.export(ctc,
|
999 |
+
hidden,
|
1000 |
+
ctc_outpath,
|
1001 |
+
opset_version=11,
|
1002 |
+
export_params=True,
|
1003 |
+
do_constant_folding=True,
|
1004 |
+
input_names=['hidden'],
|
1005 |
+
output_names=['probs'],
|
1006 |
+
dynamic_axes=None,
|
1007 |
+
verbose=False)
|
1008 |
+
onnx_ctc = onnx.load(ctc_outpath)
|
1009 |
+
for k in vars(args):
|
1010 |
+
meta = onnx_ctc.metadata_props.add()
|
1011 |
+
meta.key, meta.value = str(k), str(getattr(args, k))
|
1012 |
+
for k in attributes:
|
1013 |
+
meta = onnx_ctc.metadata_props.add()
|
1014 |
+
meta.key, meta.value = str(k), str(attributes[k])
|
1015 |
+
onnx.checker.check_model(onnx_ctc)
|
1016 |
+
onnx.helper.printable_graph(onnx_ctc.graph)
|
1017 |
+
onnx.save(onnx_ctc, ctc_outpath)
|
1018 |
+
print_input_output_info(onnx_ctc, "onnx_ctc")
|
1019 |
+
logger.info('Export onnx_ctc, done! see {}'.format(ctc_outpath))
|
1020 |
+
|
1021 |
+
logger.info("Stage-2.3: check onnx_ctc and torch_ctc")
|
1022 |
+
torch_output = ctc(hidden)
|
1023 |
+
ort_session = onnxruntime.InferenceSession(ctc_outpath)
|
1024 |
+
onnx_output = ort_session.run(None, {'hidden': to_numpy(hidden)})
|
1025 |
+
|
1026 |
+
np.testing.assert_allclose(to_numpy(torch_output),
|
1027 |
+
onnx_output[0],
|
1028 |
+
rtol=1e-03,
|
1029 |
+
atol=1e-04)
|
1030 |
+
meta = ort_session.get_modelmeta()
|
1031 |
+
logger.info("custom_metadata_map={}".format(meta.custom_metadata_map))
|
1032 |
+
logger.info("Check onnx_ctc, pass!")
|
1033 |
+
return ctc, ort_session
|
1034 |
+
|
1035 |
+
|
1036 |
+
def export_decoder(asr_model, args):
|
1037 |
+
logger.info("Currently, Decoder is not supported.")
|
1038 |
+
|
1039 |
+
|
1040 |
+
if __name__ == '__main__':
|
1041 |
+
torch.manual_seed(777)
|
1042 |
+
args = get_args()
|
1043 |
+
args.ln_run_on_bpu = False
|
1044 |
+
# NOTE(xcsong): XJ3 BPU only support static shapes
|
1045 |
+
assert args.chunk_size > 0
|
1046 |
+
assert args.num_decoding_left_chunks > 0
|
1047 |
+
os.system("mkdir -p " + args.output_dir)
|
1048 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
1049 |
+
|
1050 |
+
with open(args.config, 'r') as fin:
|
1051 |
+
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
1052 |
+
|
1053 |
+
model, configs = init_model(args, configs)
|
1054 |
+
model.eval()
|
1055 |
+
print(model)
|
1056 |
+
|
1057 |
+
args.feature_size = configs['input_dim']
|
1058 |
+
args.output_size = model.encoder.output_size()
|
1059 |
+
args.decoding_window = (args.chunk_size - 1) * \
|
1060 |
+
model.encoder.embed.subsampling_rate + \
|
1061 |
+
model.encoder.embed.right_context + 1
|
1062 |
+
|
1063 |
+
export_encoder(model, args)
|
1064 |
+
export_ctc(model, args)
|
1065 |
+
export_decoder(model, args)
|
wenet/bin/export_onnx_cpu.py
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, Xingchen Song ([email protected])
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import print_function
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import logging
|
19 |
+
import os
|
20 |
+
import copy
|
21 |
+
import sys
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import yaml
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
from wenet.utils.init_model import init_model
|
28 |
+
|
29 |
+
try:
|
30 |
+
import onnx
|
31 |
+
import onnxruntime
|
32 |
+
from onnxruntime.quantization import quantize_dynamic, QuantType
|
33 |
+
except ImportError:
|
34 |
+
print('Please install onnx and onnxruntime!')
|
35 |
+
sys.exit(1)
|
36 |
+
|
37 |
+
|
38 |
+
def get_args():
|
39 |
+
parser = argparse.ArgumentParser(description='export your script model')
|
40 |
+
parser.add_argument('--config', required=True, help='config file')
|
41 |
+
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
42 |
+
parser.add_argument('--output_dir', required=True, help='output directory')
|
43 |
+
parser.add_argument('--chunk_size',
|
44 |
+
required=True,
|
45 |
+
type=int,
|
46 |
+
help='decoding chunk size')
|
47 |
+
parser.add_argument('--num_decoding_left_chunks',
|
48 |
+
required=True,
|
49 |
+
type=int,
|
50 |
+
help='cache chunks')
|
51 |
+
parser.add_argument('--reverse_weight',
|
52 |
+
default=0.5,
|
53 |
+
type=float,
|
54 |
+
help='reverse_weight in attention_rescoing')
|
55 |
+
args = parser.parse_args()
|
56 |
+
return args
|
57 |
+
|
58 |
+
|
59 |
+
def to_numpy(tensor):
|
60 |
+
if tensor.requires_grad:
|
61 |
+
return tensor.detach().cpu().numpy()
|
62 |
+
else:
|
63 |
+
return tensor.cpu().numpy()
|
64 |
+
|
65 |
+
|
66 |
+
def print_input_output_info(onnx_model, name, prefix="\t\t"):
|
67 |
+
input_names = [node.name for node in onnx_model.graph.input]
|
68 |
+
input_shapes = [[d.dim_value for d in node.type.tensor_type.shape.dim]
|
69 |
+
for node in onnx_model.graph.input]
|
70 |
+
output_names = [node.name for node in onnx_model.graph.output]
|
71 |
+
output_shapes = [[d.dim_value for d in node.type.tensor_type.shape.dim]
|
72 |
+
for node in onnx_model.graph.output]
|
73 |
+
print("{}{} inputs : {}".format(prefix, name, input_names))
|
74 |
+
print("{}{} input shapes : {}".format(prefix, name, input_shapes))
|
75 |
+
print("{}{} outputs: {}".format(prefix, name, output_names))
|
76 |
+
print("{}{} output shapes : {}".format(prefix, name, output_shapes))
|
77 |
+
|
78 |
+
|
79 |
+
def export_encoder(asr_model, args):
|
80 |
+
print("Stage-1: export encoder")
|
81 |
+
encoder = asr_model.encoder
|
82 |
+
encoder.forward = encoder.forward_chunk
|
83 |
+
encoder_outpath = os.path.join(args['output_dir'], 'encoder.onnx')
|
84 |
+
|
85 |
+
print("\tStage-1.1: prepare inputs for encoder")
|
86 |
+
chunk = torch.randn(
|
87 |
+
(args['batch'], args['decoding_window'], args['feature_size']))
|
88 |
+
offset = 0
|
89 |
+
# NOTE(xcsong): The uncertainty of `next_cache_start` only appears
|
90 |
+
# in the first few chunks, this is caused by dynamic att_cache shape, i,e
|
91 |
+
# (0, 0, 0, 0) for 1st chunk and (elayers, head, ?, d_k*2) for subsequent
|
92 |
+
# chunks. One way to ease the ONNX export is to keep `next_cache_start`
|
93 |
+
# as a fixed value. To do this, for the **first** chunk, if
|
94 |
+
# left_chunks > 0, we feed real cache & real mask to the model, otherwise
|
95 |
+
# fake cache & fake mask. In this way, we get:
|
96 |
+
# 1. 16/-1 mode: next_cache_start == 0 for all chunks
|
97 |
+
# 2. 16/4 mode: next_cache_start == chunk_size for all chunks
|
98 |
+
# 3. 16/0 mode: next_cache_start == chunk_size for all chunks
|
99 |
+
# 4. -1/-1 mode: next_cache_start == 0 for all chunks
|
100 |
+
# NO MORE DYNAMIC CHANGES!!
|
101 |
+
#
|
102 |
+
# NOTE(Mddct): We retain the current design for the convenience of supporting some
|
103 |
+
# inference frameworks without dynamic shapes. If you're interested in all-in-one
|
104 |
+
# model that supports different chunks please see:
|
105 |
+
# https://github.com/wenet-e2e/wenet/pull/1174
|
106 |
+
|
107 |
+
if args['left_chunks'] > 0: # 16/4
|
108 |
+
required_cache_size = args['chunk_size'] * args['left_chunks']
|
109 |
+
offset = required_cache_size
|
110 |
+
# Real cache
|
111 |
+
att_cache = torch.zeros(
|
112 |
+
(args['num_blocks'], args['head'], required_cache_size,
|
113 |
+
args['output_size'] // args['head'] * 2))
|
114 |
+
# Real mask
|
115 |
+
att_mask = torch.ones(
|
116 |
+
(args['batch'], 1, required_cache_size + args['chunk_size']),
|
117 |
+
dtype=torch.bool)
|
118 |
+
att_mask[:, :, :required_cache_size] = 0
|
119 |
+
elif args['left_chunks'] <= 0: # 16/-1, -1/-1, 16/0
|
120 |
+
required_cache_size = -1 if args['left_chunks'] < 0 else 0
|
121 |
+
# Fake cache
|
122 |
+
att_cache = torch.zeros((args['num_blocks'], args['head'], 0,
|
123 |
+
args['output_size'] // args['head'] * 2))
|
124 |
+
# Fake mask
|
125 |
+
att_mask = torch.ones((0, 0, 0), dtype=torch.bool)
|
126 |
+
cnn_cache = torch.zeros(
|
127 |
+
(args['num_blocks'], args['batch'], args['output_size'],
|
128 |
+
args['cnn_module_kernel'] - 1))
|
129 |
+
inputs = (chunk, offset, required_cache_size, att_cache, cnn_cache,
|
130 |
+
att_mask)
|
131 |
+
print("\t\tchunk.size(): {}\n".format(chunk.size()),
|
132 |
+
"\t\toffset: {}\n".format(offset),
|
133 |
+
"\t\trequired_cache: {}\n".format(required_cache_size),
|
134 |
+
"\t\tatt_cache.size(): {}\n".format(att_cache.size()),
|
135 |
+
"\t\tcnn_cache.size(): {}\n".format(cnn_cache.size()),
|
136 |
+
"\t\tatt_mask.size(): {}\n".format(att_mask.size()))
|
137 |
+
|
138 |
+
print("\tStage-1.2: torch.onnx.export")
|
139 |
+
dynamic_axes = {
|
140 |
+
'chunk': {
|
141 |
+
1: 'T'
|
142 |
+
},
|
143 |
+
'att_cache': {
|
144 |
+
2: 'T_CACHE'
|
145 |
+
},
|
146 |
+
'att_mask': {
|
147 |
+
2: 'T_ADD_T_CACHE'
|
148 |
+
},
|
149 |
+
'output': {
|
150 |
+
1: 'T'
|
151 |
+
},
|
152 |
+
'r_att_cache': {
|
153 |
+
2: 'T_CACHE'
|
154 |
+
},
|
155 |
+
}
|
156 |
+
# NOTE(xcsong): We keep dynamic axes even if in 16/4 mode, this is
|
157 |
+
# to avoid padding the last chunk (which usually contains less
|
158 |
+
# frames than required). For users who want static axes, just pop
|
159 |
+
# out specific axis.
|
160 |
+
# if args['chunk_size'] > 0: # 16/4, 16/-1, 16/0
|
161 |
+
# dynamic_axes.pop('chunk')
|
162 |
+
# dynamic_axes.pop('output')
|
163 |
+
# if args['left_chunks'] >= 0: # 16/4, 16/0
|
164 |
+
# # NOTE(xsong): since we feed real cache & real mask into the
|
165 |
+
# # model when left_chunks > 0, the shape of cache will never
|
166 |
+
# # be changed.
|
167 |
+
# dynamic_axes.pop('att_cache')
|
168 |
+
# dynamic_axes.pop('r_att_cache')
|
169 |
+
torch.onnx.export(encoder,
|
170 |
+
inputs,
|
171 |
+
encoder_outpath,
|
172 |
+
opset_version=13,
|
173 |
+
export_params=True,
|
174 |
+
do_constant_folding=True,
|
175 |
+
input_names=[
|
176 |
+
'chunk', 'offset', 'required_cache_size',
|
177 |
+
'att_cache', 'cnn_cache', 'att_mask'
|
178 |
+
],
|
179 |
+
output_names=['output', 'r_att_cache', 'r_cnn_cache'],
|
180 |
+
dynamic_axes=dynamic_axes,
|
181 |
+
verbose=False)
|
182 |
+
onnx_encoder = onnx.load(encoder_outpath)
|
183 |
+
for (k, v) in args.items():
|
184 |
+
meta = onnx_encoder.metadata_props.add()
|
185 |
+
meta.key, meta.value = str(k), str(v)
|
186 |
+
onnx.checker.check_model(onnx_encoder)
|
187 |
+
onnx.helper.printable_graph(onnx_encoder.graph)
|
188 |
+
# NOTE(xcsong): to add those metadatas we need to reopen
|
189 |
+
# the file and resave it.
|
190 |
+
onnx.save(onnx_encoder, encoder_outpath)
|
191 |
+
print_input_output_info(onnx_encoder, "onnx_encoder")
|
192 |
+
# Dynamic quantization
|
193 |
+
model_fp32 = encoder_outpath
|
194 |
+
model_quant = os.path.join(args['output_dir'], 'encoder.quant.onnx')
|
195 |
+
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)
|
196 |
+
print('\t\tExport onnx_encoder, done! see {}'.format(encoder_outpath))
|
197 |
+
|
198 |
+
print("\tStage-1.3: check onnx_encoder and torch_encoder")
|
199 |
+
torch_output = []
|
200 |
+
torch_chunk = copy.deepcopy(chunk)
|
201 |
+
torch_offset = copy.deepcopy(offset)
|
202 |
+
torch_required_cache_size = copy.deepcopy(required_cache_size)
|
203 |
+
torch_att_cache = copy.deepcopy(att_cache)
|
204 |
+
torch_cnn_cache = copy.deepcopy(cnn_cache)
|
205 |
+
torch_att_mask = copy.deepcopy(att_mask)
|
206 |
+
for i in range(10):
|
207 |
+
print("\t\ttorch chunk-{}: {}, offset: {}, att_cache: {},"
|
208 |
+
" cnn_cache: {}, att_mask: {}".format(
|
209 |
+
i, list(torch_chunk.size()), torch_offset,
|
210 |
+
list(torch_att_cache.size()), list(torch_cnn_cache.size()),
|
211 |
+
list(torch_att_mask.size())))
|
212 |
+
# NOTE(xsong): att_mask of the first few batches need changes if
|
213 |
+
# we use 16/4 mode.
|
214 |
+
if args['left_chunks'] > 0: # 16/4
|
215 |
+
torch_att_mask[:, :, -(args['chunk_size'] * (i + 1)):] = 1
|
216 |
+
out, torch_att_cache, torch_cnn_cache = encoder(
|
217 |
+
torch_chunk, torch_offset, torch_required_cache_size,
|
218 |
+
torch_att_cache, torch_cnn_cache, torch_att_mask)
|
219 |
+
torch_output.append(out)
|
220 |
+
torch_offset += out.size(1)
|
221 |
+
torch_output = torch.cat(torch_output, dim=1)
|
222 |
+
|
223 |
+
onnx_output = []
|
224 |
+
onnx_chunk = to_numpy(chunk)
|
225 |
+
onnx_offset = np.array((offset)).astype(np.int64)
|
226 |
+
onnx_required_cache_size = np.array((required_cache_size)).astype(np.int64)
|
227 |
+
onnx_att_cache = to_numpy(att_cache)
|
228 |
+
onnx_cnn_cache = to_numpy(cnn_cache)
|
229 |
+
onnx_att_mask = to_numpy(att_mask)
|
230 |
+
ort_session = onnxruntime.InferenceSession(
|
231 |
+
encoder_outpath, providers=['CPUExecutionProvider'])
|
232 |
+
input_names = [node.name for node in onnx_encoder.graph.input]
|
233 |
+
for i in range(10):
|
234 |
+
print("\t\tonnx chunk-{}: {}, offset: {}, att_cache: {},"
|
235 |
+
" cnn_cache: {}, att_mask: {}".format(i, onnx_chunk.shape,
|
236 |
+
onnx_offset,
|
237 |
+
onnx_att_cache.shape,
|
238 |
+
onnx_cnn_cache.shape,
|
239 |
+
onnx_att_mask.shape))
|
240 |
+
# NOTE(xsong): att_mask of the first few batches need changes if
|
241 |
+
# we use 16/4 mode.
|
242 |
+
if args['left_chunks'] > 0: # 16/4
|
243 |
+
onnx_att_mask[:, :, -(args['chunk_size'] * (i + 1)):] = 1
|
244 |
+
ort_inputs = {
|
245 |
+
'chunk': onnx_chunk,
|
246 |
+
'offset': onnx_offset,
|
247 |
+
'required_cache_size': onnx_required_cache_size,
|
248 |
+
'att_cache': onnx_att_cache,
|
249 |
+
'cnn_cache': onnx_cnn_cache,
|
250 |
+
'att_mask': onnx_att_mask
|
251 |
+
}
|
252 |
+
# NOTE(xcsong): If we use 16/-1, -1/-1 or 16/0 mode, `next_cache_start`
|
253 |
+
# will be hardcoded to 0 or chunk_size by ONNX, thus
|
254 |
+
# required_cache_size and att_mask are no more needed and they will
|
255 |
+
# be removed by ONNX automatically.
|
256 |
+
for k in list(ort_inputs):
|
257 |
+
if k not in input_names:
|
258 |
+
ort_inputs.pop(k)
|
259 |
+
ort_outs = ort_session.run(None, ort_inputs)
|
260 |
+
onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2]
|
261 |
+
onnx_output.append(ort_outs[0])
|
262 |
+
onnx_offset += ort_outs[0].shape[1]
|
263 |
+
onnx_output = np.concatenate(onnx_output, axis=1)
|
264 |
+
|
265 |
+
np.testing.assert_allclose(to_numpy(torch_output),
|
266 |
+
onnx_output,
|
267 |
+
rtol=1e-03,
|
268 |
+
atol=1e-05)
|
269 |
+
meta = ort_session.get_modelmeta()
|
270 |
+
print("\t\tcustom_metadata_map={}".format(meta.custom_metadata_map))
|
271 |
+
print("\t\tCheck onnx_encoder, pass!")
|
272 |
+
|
273 |
+
|
274 |
+
def export_ctc(asr_model, args):
|
275 |
+
print("Stage-2: export ctc")
|
276 |
+
ctc = asr_model.ctc
|
277 |
+
ctc.forward = ctc.log_softmax
|
278 |
+
ctc_outpath = os.path.join(args['output_dir'], 'ctc.onnx')
|
279 |
+
|
280 |
+
print("\tStage-2.1: prepare inputs for ctc")
|
281 |
+
hidden = torch.randn(
|
282 |
+
(args['batch'], args['chunk_size'] if args['chunk_size'] > 0 else 16,
|
283 |
+
args['output_size']))
|
284 |
+
|
285 |
+
print("\tStage-2.2: torch.onnx.export")
|
286 |
+
dynamic_axes = {'hidden': {1: 'T'}, 'probs': {1: 'T'}}
|
287 |
+
torch.onnx.export(ctc,
|
288 |
+
hidden,
|
289 |
+
ctc_outpath,
|
290 |
+
opset_version=13,
|
291 |
+
export_params=True,
|
292 |
+
do_constant_folding=True,
|
293 |
+
input_names=['hidden'],
|
294 |
+
output_names=['probs'],
|
295 |
+
dynamic_axes=dynamic_axes,
|
296 |
+
verbose=False)
|
297 |
+
onnx_ctc = onnx.load(ctc_outpath)
|
298 |
+
for (k, v) in args.items():
|
299 |
+
meta = onnx_ctc.metadata_props.add()
|
300 |
+
meta.key, meta.value = str(k), str(v)
|
301 |
+
onnx.checker.check_model(onnx_ctc)
|
302 |
+
onnx.helper.printable_graph(onnx_ctc.graph)
|
303 |
+
onnx.save(onnx_ctc, ctc_outpath)
|
304 |
+
print_input_output_info(onnx_ctc, "onnx_ctc")
|
305 |
+
# Dynamic quantization
|
306 |
+
model_fp32 = ctc_outpath
|
307 |
+
model_quant = os.path.join(args['output_dir'], 'ctc.quant.onnx')
|
308 |
+
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)
|
309 |
+
print('\t\tExport onnx_ctc, done! see {}'.format(ctc_outpath))
|
310 |
+
|
311 |
+
print("\tStage-2.3: check onnx_ctc and torch_ctc")
|
312 |
+
torch_output = ctc(hidden)
|
313 |
+
ort_session = onnxruntime.InferenceSession(
|
314 |
+
ctc_outpath, providers=['CPUExecutionProvider'])
|
315 |
+
onnx_output = ort_session.run(None, {'hidden': to_numpy(hidden)})
|
316 |
+
|
317 |
+
np.testing.assert_allclose(to_numpy(torch_output),
|
318 |
+
onnx_output[0],
|
319 |
+
rtol=1e-03,
|
320 |
+
atol=1e-05)
|
321 |
+
print("\t\tCheck onnx_ctc, pass!")
|
322 |
+
|
323 |
+
|
324 |
+
def export_decoder(asr_model, args):
|
325 |
+
print("Stage-3: export decoder")
|
326 |
+
decoder = asr_model
|
327 |
+
# NOTE(lzhin): parameters of encoder will be automatically removed
|
328 |
+
# since they are not used during rescoring.
|
329 |
+
decoder.forward = decoder.forward_attention_decoder
|
330 |
+
decoder_outpath = os.path.join(args['output_dir'], 'decoder.onnx')
|
331 |
+
|
332 |
+
print("\tStage-3.1: prepare inputs for decoder")
|
333 |
+
# hardcode time->200 nbest->10 len->20, they are dynamic axes.
|
334 |
+
encoder_out = torch.randn((1, 200, args['output_size']))
|
335 |
+
hyps = torch.randint(low=0, high=args['vocab_size'], size=[10, 20])
|
336 |
+
hyps[:, 0] = args['vocab_size'] - 1 # <sos>
|
337 |
+
hyps_lens = torch.randint(low=15, high=21, size=[10])
|
338 |
+
|
339 |
+
print("\tStage-3.2: torch.onnx.export")
|
340 |
+
dynamic_axes = {
|
341 |
+
'hyps': {
|
342 |
+
0: 'NBEST',
|
343 |
+
1: 'L'
|
344 |
+
},
|
345 |
+
'hyps_lens': {
|
346 |
+
0: 'NBEST'
|
347 |
+
},
|
348 |
+
'encoder_out': {
|
349 |
+
1: 'T'
|
350 |
+
},
|
351 |
+
'score': {
|
352 |
+
0: 'NBEST',
|
353 |
+
1: 'L'
|
354 |
+
},
|
355 |
+
'r_score': {
|
356 |
+
0: 'NBEST',
|
357 |
+
1: 'L'
|
358 |
+
}
|
359 |
+
}
|
360 |
+
inputs = (hyps, hyps_lens, encoder_out, args['reverse_weight'])
|
361 |
+
torch.onnx.export(
|
362 |
+
decoder,
|
363 |
+
inputs,
|
364 |
+
decoder_outpath,
|
365 |
+
opset_version=13,
|
366 |
+
export_params=True,
|
367 |
+
do_constant_folding=True,
|
368 |
+
input_names=['hyps', 'hyps_lens', 'encoder_out', 'reverse_weight'],
|
369 |
+
output_names=['score', 'r_score'],
|
370 |
+
dynamic_axes=dynamic_axes,
|
371 |
+
verbose=False)
|
372 |
+
onnx_decoder = onnx.load(decoder_outpath)
|
373 |
+
for (k, v) in args.items():
|
374 |
+
meta = onnx_decoder.metadata_props.add()
|
375 |
+
meta.key, meta.value = str(k), str(v)
|
376 |
+
onnx.checker.check_model(onnx_decoder)
|
377 |
+
onnx.helper.printable_graph(onnx_decoder.graph)
|
378 |
+
onnx.save(onnx_decoder, decoder_outpath)
|
379 |
+
print_input_output_info(onnx_decoder, "onnx_decoder")
|
380 |
+
model_fp32 = decoder_outpath
|
381 |
+
model_quant = os.path.join(args['output_dir'], 'decoder.quant.onnx')
|
382 |
+
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)
|
383 |
+
print('\t\tExport onnx_decoder, done! see {}'.format(decoder_outpath))
|
384 |
+
|
385 |
+
print("\tStage-3.3: check onnx_decoder and torch_decoder")
|
386 |
+
torch_score, torch_r_score = decoder(hyps, hyps_lens, encoder_out,
|
387 |
+
args['reverse_weight'])
|
388 |
+
ort_session = onnxruntime.InferenceSession(
|
389 |
+
decoder_outpath, providers=['CPUExecutionProvider'])
|
390 |
+
input_names = [node.name for node in onnx_decoder.graph.input]
|
391 |
+
ort_inputs = {
|
392 |
+
'hyps': to_numpy(hyps),
|
393 |
+
'hyps_lens': to_numpy(hyps_lens),
|
394 |
+
'encoder_out': to_numpy(encoder_out),
|
395 |
+
'reverse_weight': np.array((args['reverse_weight'])),
|
396 |
+
}
|
397 |
+
for k in list(ort_inputs):
|
398 |
+
if k not in input_names:
|
399 |
+
ort_inputs.pop(k)
|
400 |
+
onnx_output = ort_session.run(None, ort_inputs)
|
401 |
+
|
402 |
+
np.testing.assert_allclose(to_numpy(torch_score),
|
403 |
+
onnx_output[0],
|
404 |
+
rtol=1e-03,
|
405 |
+
atol=1e-05)
|
406 |
+
if args['is_bidirectional_decoder'] and args['reverse_weight'] > 0.0:
|
407 |
+
np.testing.assert_allclose(to_numpy(torch_r_score),
|
408 |
+
onnx_output[1],
|
409 |
+
rtol=1e-03,
|
410 |
+
atol=1e-05)
|
411 |
+
print("\t\tCheck onnx_decoder, pass!")
|
412 |
+
|
413 |
+
|
414 |
+
def main():
|
415 |
+
torch.manual_seed(777)
|
416 |
+
args = get_args()
|
417 |
+
logging.basicConfig(level=logging.DEBUG,
|
418 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
419 |
+
output_dir = args.output_dir
|
420 |
+
os.system("mkdir -p " + output_dir)
|
421 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
422 |
+
|
423 |
+
with open(args.config, 'r') as fin:
|
424 |
+
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
425 |
+
|
426 |
+
model, configs = init_model(args, configs)
|
427 |
+
model.eval()
|
428 |
+
print(model)
|
429 |
+
|
430 |
+
arguments = {}
|
431 |
+
arguments['output_dir'] = output_dir
|
432 |
+
arguments['batch'] = 1
|
433 |
+
arguments['chunk_size'] = args.chunk_size
|
434 |
+
arguments['left_chunks'] = args.num_decoding_left_chunks
|
435 |
+
arguments['reverse_weight'] = args.reverse_weight
|
436 |
+
arguments['output_size'] = configs['encoder_conf']['output_size']
|
437 |
+
arguments['num_blocks'] = configs['encoder_conf']['num_blocks']
|
438 |
+
arguments['cnn_module_kernel'] = configs['encoder_conf'].get(
|
439 |
+
'cnn_module_kernel', 1)
|
440 |
+
arguments['head'] = configs['encoder_conf']['attention_heads']
|
441 |
+
arguments['feature_size'] = configs['input_dim']
|
442 |
+
arguments['vocab_size'] = configs['output_dim']
|
443 |
+
# NOTE(xcsong): if chunk_size == -1, hardcode to 67
|
444 |
+
arguments['decoding_window'] = (args.chunk_size - 1) * \
|
445 |
+
model.encoder.embed.subsampling_rate + \
|
446 |
+
model.encoder.embed.right_context + 1 if args.chunk_size > 0 else 67
|
447 |
+
arguments['encoder'] = configs['encoder']
|
448 |
+
arguments['decoder'] = configs['decoder']
|
449 |
+
arguments['subsampling_rate'] = model.subsampling_rate()
|
450 |
+
arguments['right_context'] = model.right_context()
|
451 |
+
arguments['sos_symbol'] = model.sos_symbol()
|
452 |
+
arguments['eos_symbol'] = model.eos_symbol()
|
453 |
+
arguments['is_bidirectional_decoder'] = 1 \
|
454 |
+
if model.is_bidirectional_decoder() else 0
|
455 |
+
|
456 |
+
# NOTE(xcsong): Please note that -1/-1 means non-streaming model! It is
|
457 |
+
# not a [16/4 16/-1 16/0] all-in-one model and it should not be used in
|
458 |
+
# streaming mode (i.e., setting chunk_size=16 in `decoder_main`). If you
|
459 |
+
# want to use 16/-1 or any other streaming mode in `decoder_main`,
|
460 |
+
# please export onnx in the same config.
|
461 |
+
if arguments['left_chunks'] > 0:
|
462 |
+
assert arguments['chunk_size'] > 0 # -1/4 not supported
|
463 |
+
|
464 |
+
export_encoder(model, arguments)
|
465 |
+
export_ctc(model, arguments)
|
466 |
+
export_decoder(model, arguments)
|
467 |
+
|
468 |
+
|
469 |
+
if __name__ == '__main__':
|
470 |
+
main()
|
wenet/bin/export_onnx_gpu.py
ADDED
@@ -0,0 +1,1263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import print_function
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import os
|
19 |
+
import sys
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import yaml
|
23 |
+
import logging
|
24 |
+
|
25 |
+
import torch.nn.functional as F
|
26 |
+
from wenet.transformer.ctc import CTC
|
27 |
+
from wenet.transformer.decoder import TransformerDecoder
|
28 |
+
from wenet.transformer.encoder import BaseEncoder
|
29 |
+
from wenet.utils.init_model import init_model
|
30 |
+
from wenet.utils.mask import make_pad_mask
|
31 |
+
|
32 |
+
try:
|
33 |
+
import onnxruntime
|
34 |
+
except ImportError:
|
35 |
+
print("Please install onnxruntime-gpu!")
|
36 |
+
sys.exit(1)
|
37 |
+
|
38 |
+
logger = logging.getLogger(__file__)
|
39 |
+
logger.setLevel(logging.INFO)
|
40 |
+
|
41 |
+
|
42 |
+
class Encoder(torch.nn.Module):
|
43 |
+
|
44 |
+
def __init__(self, encoder: BaseEncoder, ctc: CTC, beam_size: int = 10):
|
45 |
+
super().__init__()
|
46 |
+
self.encoder = encoder
|
47 |
+
self.ctc = ctc
|
48 |
+
self.beam_size = beam_size
|
49 |
+
|
50 |
+
def forward(
|
51 |
+
self,
|
52 |
+
speech: torch.Tensor,
|
53 |
+
speech_lengths: torch.Tensor,
|
54 |
+
):
|
55 |
+
"""Encoder
|
56 |
+
Args:
|
57 |
+
speech: (Batch, Length, ...)
|
58 |
+
speech_lengths: (Batch, )
|
59 |
+
Returns:
|
60 |
+
encoder_out: B x T x F
|
61 |
+
encoder_out_lens: B
|
62 |
+
ctc_log_probs: B x T x V
|
63 |
+
beam_log_probs: B x T x beam_size
|
64 |
+
beam_log_probs_idx: B x T x beam_size
|
65 |
+
"""
|
66 |
+
encoder_out, encoder_mask = self.encoder(speech, speech_lengths, -1,
|
67 |
+
-1)
|
68 |
+
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
|
69 |
+
ctc_log_probs = self.ctc.log_softmax(encoder_out)
|
70 |
+
encoder_out_lens = encoder_out_lens.int()
|
71 |
+
beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs,
|
72 |
+
self.beam_size,
|
73 |
+
dim=2)
|
74 |
+
return (
|
75 |
+
encoder_out,
|
76 |
+
encoder_out_lens,
|
77 |
+
ctc_log_probs,
|
78 |
+
beam_log_probs,
|
79 |
+
beam_log_probs_idx,
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
class StreamingEncoder(torch.nn.Module):
|
84 |
+
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
model,
|
88 |
+
required_cache_size,
|
89 |
+
beam_size,
|
90 |
+
transformer=False,
|
91 |
+
return_ctc_logprobs=False,
|
92 |
+
):
|
93 |
+
super().__init__()
|
94 |
+
self.ctc = model.ctc
|
95 |
+
self.subsampling_rate = model.encoder.embed.subsampling_rate
|
96 |
+
self.embed = model.encoder.embed
|
97 |
+
self.global_cmvn = model.encoder.global_cmvn
|
98 |
+
self.required_cache_size = required_cache_size
|
99 |
+
self.beam_size = beam_size
|
100 |
+
self.encoder = model.encoder
|
101 |
+
self.transformer = transformer
|
102 |
+
self.return_ctc_logprobs = return_ctc_logprobs
|
103 |
+
|
104 |
+
def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache,
|
105 |
+
cache_mask):
|
106 |
+
"""Streaming Encoder
|
107 |
+
Args:
|
108 |
+
xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
|
109 |
+
where `time == (chunk_size - 1) * subsample_rate + \
|
110 |
+
subsample.right_context + 1`
|
111 |
+
offset (torch.Tensor): offset with shape (b, 1)
|
112 |
+
1 is retained for triton deployment
|
113 |
+
required_cache_size (int): cache size required for next chunk
|
114 |
+
compuation
|
115 |
+
> 0: actual cache size
|
116 |
+
<= 0: not allowed in streaming gpu encoder `
|
117 |
+
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
|
118 |
+
transformer/conformer attention, with shape
|
119 |
+
(b, elayers, head, cache_t1, d_k * 2), where
|
120 |
+
`head * d_k == hidden-dim` and
|
121 |
+
`cache_t1 == chunk_size * num_decoding_left_chunks`.
|
122 |
+
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
|
123 |
+
(b, elayers, b, hidden-dim, cache_t2), where
|
124 |
+
`cache_t2 == cnn.lorder - 1`
|
125 |
+
cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
|
126 |
+
in a batch of request, each request may have different
|
127 |
+
history cache. Cache mask is used to indidate the effective
|
128 |
+
cache for each request
|
129 |
+
Returns:
|
130 |
+
torch.Tensor: log probabilities of ctc output and cutoff by beam size
|
131 |
+
with shape (b, chunk_size, beam)
|
132 |
+
torch.Tensor: index of top beam size probabilities for each timestep
|
133 |
+
with shape (b, chunk_size, beam)
|
134 |
+
torch.Tensor: output of current input xs,
|
135 |
+
with shape (b, chunk_size, hidden-dim).
|
136 |
+
torch.Tensor: new attention cache required for next chunk, with
|
137 |
+
same shape (b, elayers, head, cache_t1, d_k * 2)
|
138 |
+
as the original att_cache
|
139 |
+
torch.Tensor: new conformer cnn cache required for next chunk, with
|
140 |
+
same shape as the original cnn_cache.
|
141 |
+
torch.Tensor: new cache mask, with same shape as the original
|
142 |
+
cache mask
|
143 |
+
"""
|
144 |
+
offset = offset.squeeze(1)
|
145 |
+
T = chunk_xs.size(1)
|
146 |
+
chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1)
|
147 |
+
# B X 1 X T
|
148 |
+
chunk_mask = chunk_mask.to(chunk_xs.dtype)
|
149 |
+
# transpose batch & num_layers dim
|
150 |
+
att_cache = torch.transpose(att_cache, 0, 1)
|
151 |
+
cnn_cache = torch.transpose(cnn_cache, 0, 1)
|
152 |
+
|
153 |
+
# rewrite encoder.forward_chunk
|
154 |
+
# <---------forward_chunk START--------->
|
155 |
+
xs = self.global_cmvn(chunk_xs)
|
156 |
+
# chunk mask is important for batch inferencing since
|
157 |
+
# different sequence in a batch has different length
|
158 |
+
xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset)
|
159 |
+
cache_size = att_cache.size(3) # required cache size
|
160 |
+
masks = torch.cat((cache_mask, chunk_mask), dim=2)
|
161 |
+
index = offset - cache_size
|
162 |
+
|
163 |
+
pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1))
|
164 |
+
pos_emb = pos_emb.to(dtype=xs.dtype)
|
165 |
+
|
166 |
+
next_cache_start = -self.required_cache_size
|
167 |
+
r_cache_mask = masks[:, :, next_cache_start:]
|
168 |
+
|
169 |
+
r_att_cache = []
|
170 |
+
r_cnn_cache = []
|
171 |
+
for i, layer in enumerate(self.encoder.encoders):
|
172 |
+
xs, _, new_att_cache, new_cnn_cache = layer(
|
173 |
+
xs,
|
174 |
+
masks,
|
175 |
+
pos_emb,
|
176 |
+
att_cache=att_cache[i],
|
177 |
+
cnn_cache=cnn_cache[i],
|
178 |
+
)
|
179 |
+
# shape(new_att_cache) is (B, head, attention_key_size, d_k * 2),
|
180 |
+
# shape(new_cnn_cache) is (B, hidden-dim, cache_t2)
|
181 |
+
r_att_cache.append(
|
182 |
+
new_att_cache[:, :, next_cache_start:, :].unsqueeze(1))
|
183 |
+
if not self.transformer:
|
184 |
+
r_cnn_cache.append(new_cnn_cache.unsqueeze(1))
|
185 |
+
if self.encoder.normalize_before:
|
186 |
+
chunk_out = self.encoder.after_norm(xs)
|
187 |
+
else:
|
188 |
+
chunk_out = xs
|
189 |
+
|
190 |
+
r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx
|
191 |
+
if not self.transformer:
|
192 |
+
r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers
|
193 |
+
|
194 |
+
# <---------forward_chunk END--------->
|
195 |
+
|
196 |
+
log_ctc_probs = self.ctc.log_softmax(chunk_out)
|
197 |
+
log_probs, log_probs_idx = torch.topk(log_ctc_probs,
|
198 |
+
self.beam_size,
|
199 |
+
dim=2)
|
200 |
+
log_probs = log_probs.to(chunk_xs.dtype)
|
201 |
+
|
202 |
+
r_offset = offset + chunk_out.shape[1]
|
203 |
+
# the below ops not supported in Tensorrt
|
204 |
+
# chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
|
205 |
+
# rounding_mode='floor')
|
206 |
+
chunk_out_lens = chunk_lens // self.subsampling_rate
|
207 |
+
r_offset = r_offset.unsqueeze(1)
|
208 |
+
if self.return_ctc_logprobs:
|
209 |
+
return (
|
210 |
+
log_ctc_probs,
|
211 |
+
chunk_out,
|
212 |
+
chunk_out_lens,
|
213 |
+
r_offset,
|
214 |
+
r_att_cache,
|
215 |
+
r_cnn_cache,
|
216 |
+
r_cache_mask,
|
217 |
+
)
|
218 |
+
else:
|
219 |
+
return (
|
220 |
+
log_probs,
|
221 |
+
log_probs_idx,
|
222 |
+
chunk_out,
|
223 |
+
chunk_out_lens,
|
224 |
+
r_offset,
|
225 |
+
r_att_cache,
|
226 |
+
r_cnn_cache,
|
227 |
+
r_cache_mask,
|
228 |
+
)
|
229 |
+
|
230 |
+
|
231 |
+
class StreamingSqueezeformerEncoder(torch.nn.Module):
|
232 |
+
|
233 |
+
def __init__(self, model, required_cache_size, beam_size):
|
234 |
+
super().__init__()
|
235 |
+
self.ctc = model.ctc
|
236 |
+
self.subsampling_rate = model.encoder.embed.subsampling_rate
|
237 |
+
self.embed = model.encoder.embed
|
238 |
+
self.global_cmvn = model.encoder.global_cmvn
|
239 |
+
self.required_cache_size = required_cache_size
|
240 |
+
self.beam_size = beam_size
|
241 |
+
self.encoder = model.encoder
|
242 |
+
self.reduce_idx = model.encoder.reduce_idx
|
243 |
+
self.recover_idx = model.encoder.recover_idx
|
244 |
+
if self.reduce_idx is None:
|
245 |
+
self.time_reduce = None
|
246 |
+
else:
|
247 |
+
if self.recover_idx is None:
|
248 |
+
self.time_reduce = "normal" # no recovery at the end
|
249 |
+
else:
|
250 |
+
self.time_reduce = "recover" # recovery at the end
|
251 |
+
assert len(self.reduce_idx) == len(self.recover_idx)
|
252 |
+
|
253 |
+
def calculate_downsampling_factor(self, i: int) -> int:
|
254 |
+
if self.reduce_idx is None:
|
255 |
+
return 1
|
256 |
+
else:
|
257 |
+
reduce_exp, recover_exp = 0, 0
|
258 |
+
for exp, rd_idx in enumerate(self.reduce_idx):
|
259 |
+
if i >= rd_idx:
|
260 |
+
reduce_exp = exp + 1
|
261 |
+
if self.recover_idx is not None:
|
262 |
+
for exp, rc_idx in enumerate(self.recover_idx):
|
263 |
+
if i >= rc_idx:
|
264 |
+
recover_exp = exp + 1
|
265 |
+
return int(2**(reduce_exp - recover_exp))
|
266 |
+
|
267 |
+
def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache,
|
268 |
+
cache_mask):
|
269 |
+
"""Streaming Encoder
|
270 |
+
Args:
|
271 |
+
xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
|
272 |
+
where `time == (chunk_size - 1) * subsample_rate + \
|
273 |
+
subsample.right_context + 1`
|
274 |
+
offset (torch.Tensor): offset with shape (b, 1)
|
275 |
+
1 is retained for triton deployment
|
276 |
+
required_cache_size (int): cache size required for next chunk
|
277 |
+
compuation
|
278 |
+
> 0: actual cache size
|
279 |
+
<= 0: not allowed in streaming gpu encoder `
|
280 |
+
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
|
281 |
+
transformer/conformer attention, with shape
|
282 |
+
(b, elayers, head, cache_t1, d_k * 2), where
|
283 |
+
`head * d_k == hidden-dim` and
|
284 |
+
`cache_t1 == chunk_size * num_decoding_left_chunks`.
|
285 |
+
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
|
286 |
+
(b, elayers, b, hidden-dim, cache_t2), where
|
287 |
+
`cache_t2 == cnn.lorder - 1`
|
288 |
+
cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
|
289 |
+
in a batch of request, each request may have different
|
290 |
+
history cache. Cache mask is used to indidate the effective
|
291 |
+
cache for each request
|
292 |
+
Returns:
|
293 |
+
torch.Tensor: log probabilities of ctc output and cutoff by beam size
|
294 |
+
with shape (b, chunk_size, beam)
|
295 |
+
torch.Tensor: index of top beam size probabilities for each timestep
|
296 |
+
with shape (b, chunk_size, beam)
|
297 |
+
torch.Tensor: output of current input xs,
|
298 |
+
with shape (b, chunk_size, hidden-dim).
|
299 |
+
torch.Tensor: new attention cache required for next chunk, with
|
300 |
+
same shape (b, elayers, head, cache_t1, d_k * 2)
|
301 |
+
as the original att_cache
|
302 |
+
torch.Tensor: new conformer cnn cache required for next chunk, with
|
303 |
+
same shape as the original cnn_cache.
|
304 |
+
torch.Tensor: new cache mask, with same shape as the original
|
305 |
+
cache mask
|
306 |
+
"""
|
307 |
+
offset = offset.squeeze(1)
|
308 |
+
T = chunk_xs.size(1)
|
309 |
+
chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1)
|
310 |
+
# B X 1 X T
|
311 |
+
chunk_mask = chunk_mask.to(chunk_xs.dtype)
|
312 |
+
# transpose batch & num_layers dim
|
313 |
+
att_cache = torch.transpose(att_cache, 0, 1)
|
314 |
+
cnn_cache = torch.transpose(cnn_cache, 0, 1)
|
315 |
+
|
316 |
+
# rewrite encoder.forward_chunk
|
317 |
+
# <---------forward_chunk START--------->
|
318 |
+
xs = self.global_cmvn(chunk_xs)
|
319 |
+
# chunk mask is important for batch inferencing since
|
320 |
+
# different sequence in a batch has different length
|
321 |
+
xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset)
|
322 |
+
elayers, cache_size = att_cache.size(0), att_cache.size(3)
|
323 |
+
att_mask = torch.cat((cache_mask, chunk_mask), dim=2)
|
324 |
+
index = offset - cache_size
|
325 |
+
|
326 |
+
pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1))
|
327 |
+
pos_emb = pos_emb.to(dtype=xs.dtype)
|
328 |
+
|
329 |
+
next_cache_start = -self.required_cache_size
|
330 |
+
r_cache_mask = att_mask[:, :, next_cache_start:]
|
331 |
+
|
332 |
+
r_att_cache = []
|
333 |
+
r_cnn_cache = []
|
334 |
+
mask_pad = torch.ones(1,
|
335 |
+
xs.size(1),
|
336 |
+
device=xs.device,
|
337 |
+
dtype=torch.bool)
|
338 |
+
mask_pad = mask_pad.unsqueeze(1)
|
339 |
+
max_att_len: int = 0
|
340 |
+
recover_activations: List[Tuple[torch.Tensor, torch.Tensor,
|
341 |
+
torch.Tensor, torch.Tensor]] = []
|
342 |
+
index = 0
|
343 |
+
xs_lens = torch.tensor([xs.size(1)], device=xs.device, dtype=torch.int)
|
344 |
+
xs = self.encoder.preln(xs)
|
345 |
+
for i, layer in enumerate(self.encoder.encoders):
|
346 |
+
if self.reduce_idx is not None:
|
347 |
+
if self.time_reduce is not None and i in self.reduce_idx:
|
348 |
+
recover_activations.append(
|
349 |
+
(xs, att_mask, pos_emb, mask_pad))
|
350 |
+
(
|
351 |
+
xs,
|
352 |
+
xs_lens,
|
353 |
+
att_mask,
|
354 |
+
mask_pad,
|
355 |
+
) = self.encoder.time_reduction_layer(
|
356 |
+
xs, xs_lens, att_mask, mask_pad)
|
357 |
+
pos_emb = pos_emb[:, ::2, :]
|
358 |
+
if self.encoder.pos_enc_layer_type == "rel_pos_repaired":
|
359 |
+
pos_emb = pos_emb[:, :xs.size(1) * 2 - 1, :]
|
360 |
+
index += 1
|
361 |
+
|
362 |
+
if self.recover_idx is not None:
|
363 |
+
if self.time_reduce == "recover" and i in self.recover_idx:
|
364 |
+
index -= 1
|
365 |
+
(
|
366 |
+
recover_tensor,
|
367 |
+
recover_att_mask,
|
368 |
+
recover_pos_emb,
|
369 |
+
recover_mask_pad,
|
370 |
+
) = recover_activations[index]
|
371 |
+
# recover output length for ctc decode
|
372 |
+
xs = xs.unsqueeze(2).repeat(1, 1, 2, 1).flatten(1, 2)
|
373 |
+
xs = self.encoder.time_recover_layer(xs)
|
374 |
+
recoverd_t = recover_tensor.size(1)
|
375 |
+
xs = recover_tensor + xs[:, :recoverd_t, :].contiguous()
|
376 |
+
att_mask = recover_att_mask
|
377 |
+
pos_emb = recover_pos_emb
|
378 |
+
mask_pad = recover_mask_pad
|
379 |
+
|
380 |
+
factor = self.calculate_downsampling_factor(i)
|
381 |
+
|
382 |
+
xs, _, new_att_cache, new_cnn_cache = layer(
|
383 |
+
xs,
|
384 |
+
att_mask,
|
385 |
+
pos_emb,
|
386 |
+
att_cache=att_cache[i][:, :, ::factor, :]
|
387 |
+
[:, :, :pos_emb.size(1) - xs.size(1), :]
|
388 |
+
if elayers > 0 else att_cache[:, :, ::factor, :],
|
389 |
+
cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache,
|
390 |
+
)
|
391 |
+
cached_att = new_att_cache[:, :, next_cache_start // factor:, :]
|
392 |
+
cached_cnn = new_cnn_cache.unsqueeze(1)
|
393 |
+
cached_att = (cached_att.unsqueeze(3).repeat(1, 1, 1, factor,
|
394 |
+
1).flatten(2, 3))
|
395 |
+
if i == 0:
|
396 |
+
# record length for the first block as max length
|
397 |
+
max_att_len = cached_att.size(2)
|
398 |
+
r_att_cache.append(cached_att[:, :, :max_att_len, :].unsqueeze(1))
|
399 |
+
r_cnn_cache.append(cached_cnn)
|
400 |
+
|
401 |
+
chunk_out = xs
|
402 |
+
r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx
|
403 |
+
r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers
|
404 |
+
|
405 |
+
# <---------forward_chunk END--------->
|
406 |
+
|
407 |
+
log_ctc_probs = self.ctc.log_softmax(chunk_out)
|
408 |
+
log_probs, log_probs_idx = torch.topk(log_ctc_probs,
|
409 |
+
self.beam_size,
|
410 |
+
dim=2)
|
411 |
+
log_probs = log_probs.to(chunk_xs.dtype)
|
412 |
+
|
413 |
+
r_offset = offset + chunk_out.shape[1]
|
414 |
+
# the below ops not supported in Tensorrt
|
415 |
+
# chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
|
416 |
+
# rounding_mode='floor')
|
417 |
+
chunk_out_lens = chunk_lens // self.subsampling_rate
|
418 |
+
r_offset = r_offset.unsqueeze(1)
|
419 |
+
|
420 |
+
return (
|
421 |
+
log_probs,
|
422 |
+
log_probs_idx,
|
423 |
+
chunk_out,
|
424 |
+
chunk_out_lens,
|
425 |
+
r_offset,
|
426 |
+
r_att_cache,
|
427 |
+
r_cnn_cache,
|
428 |
+
r_cache_mask,
|
429 |
+
)
|
430 |
+
|
431 |
+
|
432 |
+
class StreamingEfficientConformerEncoder(torch.nn.Module):
|
433 |
+
|
434 |
+
def __init__(self, model, required_cache_size, beam_size):
|
435 |
+
super().__init__()
|
436 |
+
self.ctc = model.ctc
|
437 |
+
self.subsampling_rate = model.encoder.embed.subsampling_rate
|
438 |
+
self.embed = model.encoder.embed
|
439 |
+
self.global_cmvn = model.encoder.global_cmvn
|
440 |
+
self.required_cache_size = required_cache_size
|
441 |
+
self.beam_size = beam_size
|
442 |
+
self.encoder = model.encoder
|
443 |
+
|
444 |
+
# Efficient Conformer
|
445 |
+
self.stride_layer_idx = model.encoder.stride_layer_idx
|
446 |
+
self.stride = model.encoder.stride
|
447 |
+
self.num_blocks = model.encoder.num_blocks
|
448 |
+
self.cnn_module_kernel = model.encoder.cnn_module_kernel
|
449 |
+
|
450 |
+
def calculate_downsampling_factor(self, i: int) -> int:
|
451 |
+
factor = 1
|
452 |
+
for idx, stride_idx in enumerate(self.stride_layer_idx):
|
453 |
+
if i > stride_idx:
|
454 |
+
factor *= self.stride[idx]
|
455 |
+
return factor
|
456 |
+
|
457 |
+
def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache,
|
458 |
+
cache_mask):
|
459 |
+
"""Streaming Encoder
|
460 |
+
Args:
|
461 |
+
chunk_xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
|
462 |
+
where `time == (chunk_size - 1) * subsample_rate + \
|
463 |
+
subsample.right_context + 1`
|
464 |
+
chunk_lens (torch.Tensor):
|
465 |
+
offset (torch.Tensor): offset with shape (b, 1)
|
466 |
+
1 is retained for triton deployment
|
467 |
+
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
|
468 |
+
transformer/conformer attention, with shape
|
469 |
+
(b, elayers, head, cache_t1, d_k * 2), where
|
470 |
+
`head * d_k == hidden-dim` and
|
471 |
+
`cache_t1 == chunk_size * num_decoding_left_chunks`.
|
472 |
+
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
|
473 |
+
(b, elayers, hidden-dim, cache_t2), where
|
474 |
+
`cache_t2 == cnn.lorder - 1`
|
475 |
+
cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
|
476 |
+
in a batch of request, each request may have different
|
477 |
+
history cache. Cache mask is used to indidate the effective
|
478 |
+
cache for each request
|
479 |
+
Returns:
|
480 |
+
torch.Tensor: log probabilities of ctc output and cutoff by beam size
|
481 |
+
with shape (b, chunk_size, beam)
|
482 |
+
torch.Tensor: index of top beam size probabilities for each timestep
|
483 |
+
with shape (b, chunk_size, beam)
|
484 |
+
torch.Tensor: output of current input xs,
|
485 |
+
with shape (b, chunk_size, hidden-dim).
|
486 |
+
torch.Tensor: new attention cache required for next chunk, with
|
487 |
+
same shape (b, elayers, head, cache_t1, d_k * 2)
|
488 |
+
as the original att_cache
|
489 |
+
torch.Tensor: new conformer cnn cache required for next chunk, with
|
490 |
+
same shape as the original cnn_cache.
|
491 |
+
torch.Tensor: new cache mask, with same shape as the original
|
492 |
+
cache mask
|
493 |
+
"""
|
494 |
+
offset = offset.squeeze(1) # (b, )
|
495 |
+
offset *= self.calculate_downsampling_factor(self.num_blocks + 1)
|
496 |
+
|
497 |
+
T = chunk_xs.size(1)
|
498 |
+
chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1) # (b, 1, T)
|
499 |
+
# B X 1 X T
|
500 |
+
chunk_mask = chunk_mask.to(chunk_xs.dtype)
|
501 |
+
# transpose batch & num_layers dim
|
502 |
+
# Shape(att_cache): (elayers, b, head, cache_t1, d_k * 2)
|
503 |
+
# Shape(cnn_cache): (elayers, b, outsize, cnn_kernel)
|
504 |
+
att_cache = torch.transpose(att_cache, 0, 1)
|
505 |
+
cnn_cache = torch.transpose(cnn_cache, 0, 1)
|
506 |
+
|
507 |
+
# rewrite encoder.forward_chunk
|
508 |
+
# <---------forward_chunk START--------->
|
509 |
+
xs = self.global_cmvn(chunk_xs)
|
510 |
+
# chunk mask is important for batch inferencing since
|
511 |
+
# different sequence in a batch has different length
|
512 |
+
xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset)
|
513 |
+
cache_size = att_cache.size(3) # required cache size
|
514 |
+
masks = torch.cat((cache_mask, chunk_mask), dim=2)
|
515 |
+
att_mask = torch.cat((cache_mask, chunk_mask), dim=2)
|
516 |
+
index = offset - cache_size
|
517 |
+
|
518 |
+
pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1))
|
519 |
+
pos_emb = pos_emb.to(dtype=xs.dtype)
|
520 |
+
|
521 |
+
next_cache_start = -self.required_cache_size
|
522 |
+
r_cache_mask = masks[:, :, next_cache_start:]
|
523 |
+
|
524 |
+
r_att_cache = []
|
525 |
+
r_cnn_cache = []
|
526 |
+
mask_pad = chunk_mask.to(torch.bool)
|
527 |
+
max_att_len, max_cnn_len = (
|
528 |
+
0,
|
529 |
+
0,
|
530 |
+
) # for repeat_interleave of new_att_cache
|
531 |
+
for i, layer in enumerate(self.encoder.encoders):
|
532 |
+
factor = self.calculate_downsampling_factor(i)
|
533 |
+
# NOTE(xcsong): Before layer.forward
|
534 |
+
# shape(att_cache[i:i + 1]) is (b, head, cache_t1, d_k * 2),
|
535 |
+
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
|
536 |
+
# shape(new_att_cache) = [ batch, head, time2, outdim//head * 2 ]
|
537 |
+
att_cache_trunc = 0
|
538 |
+
if xs.size(1) + att_cache.size(3) / factor > pos_emb.size(1):
|
539 |
+
# The time step is not divisible by the downsampling multiple
|
540 |
+
# We propose to double the chunk_size.
|
541 |
+
att_cache_trunc = (xs.size(1) + att_cache.size(3) // factor -
|
542 |
+
pos_emb.size(1) + 1)
|
543 |
+
xs, _, new_att_cache, new_cnn_cache = layer(
|
544 |
+
xs,
|
545 |
+
att_mask,
|
546 |
+
pos_emb,
|
547 |
+
mask_pad=mask_pad,
|
548 |
+
att_cache=att_cache[i][:, :, ::factor, :][:, :,
|
549 |
+
att_cache_trunc:, :],
|
550 |
+
cnn_cache=cnn_cache[i, :, :, :]
|
551 |
+
if cnn_cache.size(0) > 0 else cnn_cache,
|
552 |
+
)
|
553 |
+
|
554 |
+
if i in self.stride_layer_idx:
|
555 |
+
# compute time dimension for next block
|
556 |
+
efficient_index = self.stride_layer_idx.index(i)
|
557 |
+
att_mask = att_mask[:, ::self.stride[efficient_index], ::self.
|
558 |
+
stride[efficient_index], ]
|
559 |
+
mask_pad = mask_pad[:, ::self.stride[efficient_index], ::self.
|
560 |
+
stride[efficient_index], ]
|
561 |
+
pos_emb = pos_emb[:, ::self.stride[efficient_index], :]
|
562 |
+
|
563 |
+
# shape(new_att_cache) = [batch, head, time2, outdim]
|
564 |
+
new_att_cache = new_att_cache[:, :, next_cache_start // factor:, :]
|
565 |
+
# shape(new_cnn_cache) = [batch, 1, outdim, cache_t2]
|
566 |
+
new_cnn_cache = new_cnn_cache.unsqueeze(1) # shape(1):layerID
|
567 |
+
|
568 |
+
# use repeat_interleave to new_att_cache
|
569 |
+
# new_att_cache = new_att_cache.repeat_interleave(repeats=factor, dim=2)
|
570 |
+
new_att_cache = (new_att_cache.unsqueeze(3).repeat(
|
571 |
+
1, 1, 1, factor, 1).flatten(2, 3))
|
572 |
+
# padding new_cnn_cache to cnn.lorder for casual convolution
|
573 |
+
new_cnn_cache = F.pad(
|
574 |
+
new_cnn_cache,
|
575 |
+
(self.cnn_module_kernel - 1 - new_cnn_cache.size(3), 0),
|
576 |
+
)
|
577 |
+
|
578 |
+
if i == 0:
|
579 |
+
# record length for the first block as max length
|
580 |
+
max_att_len = new_att_cache.size(2)
|
581 |
+
max_cnn_len = new_cnn_cache.size(3)
|
582 |
+
|
583 |
+
# update real shape of att_cache and cnn_cache
|
584 |
+
r_att_cache.append(new_att_cache[:, :,
|
585 |
+
-max_att_len:, :].unsqueeze(1))
|
586 |
+
r_cnn_cache.append(new_cnn_cache[:, :, :, -max_cnn_len:])
|
587 |
+
|
588 |
+
if self.encoder.normalize_before:
|
589 |
+
chunk_out = self.encoder.after_norm(xs)
|
590 |
+
else:
|
591 |
+
chunk_out = xs
|
592 |
+
|
593 |
+
# shape of r_att_cache: (b, elayers, head, time2, outdim)
|
594 |
+
r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx
|
595 |
+
# shape of r_cnn_cache: (b, elayers, outdim, cache_t2)
|
596 |
+
r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers
|
597 |
+
|
598 |
+
# <---------forward_chunk END--------->
|
599 |
+
|
600 |
+
log_ctc_probs = self.ctc.log_softmax(chunk_out)
|
601 |
+
log_probs, log_probs_idx = torch.topk(log_ctc_probs,
|
602 |
+
self.beam_size,
|
603 |
+
dim=2)
|
604 |
+
log_probs = log_probs.to(chunk_xs.dtype)
|
605 |
+
|
606 |
+
r_offset = offset + chunk_out.shape[1]
|
607 |
+
# the below ops not supported in Tensorrt
|
608 |
+
# chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
|
609 |
+
# rounding_mode='floor')
|
610 |
+
chunk_out_lens = (
|
611 |
+
chunk_lens // self.subsampling_rate //
|
612 |
+
self.calculate_downsampling_factor(self.num_blocks + 1))
|
613 |
+
chunk_out_lens += 1
|
614 |
+
r_offset = r_offset.unsqueeze(1)
|
615 |
+
|
616 |
+
return (
|
617 |
+
log_probs,
|
618 |
+
log_probs_idx,
|
619 |
+
chunk_out,
|
620 |
+
chunk_out_lens,
|
621 |
+
r_offset,
|
622 |
+
r_att_cache,
|
623 |
+
r_cnn_cache,
|
624 |
+
r_cache_mask,
|
625 |
+
)
|
626 |
+
|
627 |
+
|
628 |
+
class Decoder(torch.nn.Module):
|
629 |
+
|
630 |
+
def __init__(
|
631 |
+
self,
|
632 |
+
decoder: TransformerDecoder,
|
633 |
+
ctc_weight: float = 0.5,
|
634 |
+
reverse_weight: float = 0.0,
|
635 |
+
beam_size: int = 10,
|
636 |
+
decoder_fastertransformer: bool = False,
|
637 |
+
):
|
638 |
+
super().__init__()
|
639 |
+
self.decoder = decoder
|
640 |
+
self.ctc_weight = ctc_weight
|
641 |
+
self.reverse_weight = reverse_weight
|
642 |
+
self.beam_size = beam_size
|
643 |
+
self.decoder_fastertransformer = decoder_fastertransformer
|
644 |
+
|
645 |
+
def forward(
|
646 |
+
self,
|
647 |
+
encoder_out: torch.Tensor,
|
648 |
+
encoder_lens: torch.Tensor,
|
649 |
+
hyps_pad_sos_eos: torch.Tensor,
|
650 |
+
hyps_lens_sos: torch.Tensor,
|
651 |
+
r_hyps_pad_sos_eos: torch.Tensor,
|
652 |
+
ctc_score: torch.Tensor,
|
653 |
+
):
|
654 |
+
"""Encoder
|
655 |
+
Args:
|
656 |
+
encoder_out: B x T x F
|
657 |
+
encoder_lens: B
|
658 |
+
hyps_pad_sos_eos: B x beam x (T2+1),
|
659 |
+
hyps with sos & eos and padded by ignore id
|
660 |
+
hyps_lens_sos: B x beam, length for each hyp with sos
|
661 |
+
r_hyps_pad_sos_eos: B x beam x (T2+1),
|
662 |
+
reversed hyps with sos & eos and padded by ignore id
|
663 |
+
ctc_score: B x beam, ctc score for each hyp
|
664 |
+
Returns:
|
665 |
+
decoder_out: B x beam x T2 x V
|
666 |
+
r_decoder_out: B x beam x T2 x V
|
667 |
+
best_index: B
|
668 |
+
"""
|
669 |
+
B, T, F = encoder_out.shape
|
670 |
+
bz = self.beam_size
|
671 |
+
B2 = B * bz
|
672 |
+
encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F)
|
673 |
+
encoder_mask = ~make_pad_mask(encoder_lens, T).unsqueeze(1)
|
674 |
+
encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T)
|
675 |
+
T2 = hyps_pad_sos_eos.shape[2] - 1
|
676 |
+
hyps_pad = hyps_pad_sos_eos.view(B2, T2 + 1)
|
677 |
+
hyps_lens = hyps_lens_sos.view(B2, )
|
678 |
+
hyps_pad_sos = hyps_pad[:, :-1].contiguous()
|
679 |
+
hyps_pad_eos = hyps_pad[:, 1:].contiguous()
|
680 |
+
|
681 |
+
r_hyps_pad = r_hyps_pad_sos_eos.view(B2, T2 + 1)
|
682 |
+
r_hyps_pad_sos = r_hyps_pad[:, :-1].contiguous()
|
683 |
+
r_hyps_pad_eos = r_hyps_pad[:, 1:].contiguous()
|
684 |
+
|
685 |
+
decoder_out, r_decoder_out, _ = self.decoder(
|
686 |
+
encoder_out,
|
687 |
+
encoder_mask,
|
688 |
+
hyps_pad_sos,
|
689 |
+
hyps_lens,
|
690 |
+
r_hyps_pad_sos,
|
691 |
+
self.reverse_weight,
|
692 |
+
)
|
693 |
+
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
|
694 |
+
V = decoder_out.shape[-1]
|
695 |
+
decoder_out = decoder_out.view(B2, T2, V)
|
696 |
+
mask = ~make_pad_mask(hyps_lens, T2) # B2 x T2
|
697 |
+
# mask index, remove ignore id
|
698 |
+
index = torch.unsqueeze(hyps_pad_eos * mask, 2)
|
699 |
+
score = decoder_out.gather(2, index).squeeze(2) # B2 X T2
|
700 |
+
# mask padded part
|
701 |
+
score = score * mask
|
702 |
+
decoder_out = decoder_out.view(B, bz, T2, V)
|
703 |
+
if self.reverse_weight > 0:
|
704 |
+
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out,
|
705 |
+
dim=-1)
|
706 |
+
r_decoder_out = r_decoder_out.view(B2, T2, V)
|
707 |
+
index = torch.unsqueeze(r_hyps_pad_eos * mask, 2)
|
708 |
+
r_score = r_decoder_out.gather(2, index).squeeze(2)
|
709 |
+
r_score = r_score * mask
|
710 |
+
score = (score * (1 - self.reverse_weight) +
|
711 |
+
self.reverse_weight * r_score)
|
712 |
+
r_decoder_out = r_decoder_out.view(B, bz, T2, V)
|
713 |
+
score = torch.sum(score, axis=1) # B2
|
714 |
+
score = torch.reshape(score, (B, bz)) + self.ctc_weight * ctc_score
|
715 |
+
best_index = torch.argmax(score, dim=1)
|
716 |
+
if self.decoder_fastertransformer:
|
717 |
+
return decoder_out, best_index
|
718 |
+
else:
|
719 |
+
return best_index
|
720 |
+
|
721 |
+
|
722 |
+
def to_numpy(tensors):
|
723 |
+
out = []
|
724 |
+
if type(tensors) == torch.tensor:
|
725 |
+
tensors = [tensors]
|
726 |
+
for tensor in tensors:
|
727 |
+
if tensor.requires_grad:
|
728 |
+
tensor = tensor.detach().cpu().numpy()
|
729 |
+
else:
|
730 |
+
tensor = tensor.cpu().numpy()
|
731 |
+
out.append(tensor)
|
732 |
+
return out
|
733 |
+
|
734 |
+
|
735 |
+
def test(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True):
|
736 |
+
for a, b in zip(xlist, blist):
|
737 |
+
try:
|
738 |
+
torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
|
739 |
+
except AssertionError as error:
|
740 |
+
if tolerate_small_mismatch:
|
741 |
+
print(error)
|
742 |
+
else:
|
743 |
+
raise
|
744 |
+
|
745 |
+
|
746 |
+
def export_offline_encoder(model, configs, args, logger, encoder_onnx_path):
|
747 |
+
bz = 32
|
748 |
+
seq_len = 100
|
749 |
+
beam_size = args.beam_size
|
750 |
+
feature_size = configs["input_dim"]
|
751 |
+
|
752 |
+
speech = torch.randn(bz, seq_len, feature_size, dtype=torch.float32)
|
753 |
+
speech_lens = torch.randint(low=10,
|
754 |
+
high=seq_len,
|
755 |
+
size=(bz, ),
|
756 |
+
dtype=torch.int32)
|
757 |
+
encoder = Encoder(model.encoder, model.ctc, beam_size)
|
758 |
+
encoder.eval()
|
759 |
+
|
760 |
+
torch.onnx.export(
|
761 |
+
encoder,
|
762 |
+
(speech, speech_lens),
|
763 |
+
encoder_onnx_path,
|
764 |
+
export_params=True,
|
765 |
+
opset_version=13,
|
766 |
+
do_constant_folding=True,
|
767 |
+
input_names=["speech", "speech_lengths"],
|
768 |
+
output_names=[
|
769 |
+
"encoder_out",
|
770 |
+
"encoder_out_lens",
|
771 |
+
"ctc_log_probs",
|
772 |
+
"beam_log_probs",
|
773 |
+
"beam_log_probs_idx",
|
774 |
+
],
|
775 |
+
dynamic_axes={
|
776 |
+
"speech": {
|
777 |
+
0: "B",
|
778 |
+
1: "T"
|
779 |
+
},
|
780 |
+
"speech_lengths": {
|
781 |
+
0: "B"
|
782 |
+
},
|
783 |
+
"encoder_out": {
|
784 |
+
0: "B",
|
785 |
+
1: "T_OUT"
|
786 |
+
},
|
787 |
+
"encoder_out_lens": {
|
788 |
+
0: "B"
|
789 |
+
},
|
790 |
+
"ctc_log_probs": {
|
791 |
+
0: "B",
|
792 |
+
1: "T_OUT"
|
793 |
+
},
|
794 |
+
"beam_log_probs": {
|
795 |
+
0: "B",
|
796 |
+
1: "T_OUT"
|
797 |
+
},
|
798 |
+
"beam_log_probs_idx": {
|
799 |
+
0: "B",
|
800 |
+
1: "T_OUT"
|
801 |
+
},
|
802 |
+
},
|
803 |
+
verbose=False,
|
804 |
+
)
|
805 |
+
|
806 |
+
with torch.no_grad():
|
807 |
+
o0, o1, o2, o3, o4 = encoder(speech, speech_lens)
|
808 |
+
|
809 |
+
providers = ["CUDAExecutionProvider"]
|
810 |
+
ort_session = onnxruntime.InferenceSession(encoder_onnx_path,
|
811 |
+
providers=providers)
|
812 |
+
ort_inputs = {
|
813 |
+
"speech": to_numpy(speech),
|
814 |
+
"speech_lengths": to_numpy(speech_lens),
|
815 |
+
}
|
816 |
+
ort_outs = ort_session.run(None, ort_inputs)
|
817 |
+
|
818 |
+
# check encoder output
|
819 |
+
test(to_numpy([o0, o1, o2, o3, o4]), ort_outs)
|
820 |
+
logger.info("export offline onnx encoder succeed!")
|
821 |
+
onnx_config = {
|
822 |
+
"beam_size": args.beam_size,
|
823 |
+
"reverse_weight": args.reverse_weight,
|
824 |
+
"ctc_weight": args.ctc_weight,
|
825 |
+
"fp16": args.fp16,
|
826 |
+
}
|
827 |
+
return onnx_config
|
828 |
+
|
829 |
+
|
830 |
+
def export_online_encoder(model, configs, args, logger, encoder_onnx_path):
|
831 |
+
decoding_chunk_size = args.decoding_chunk_size
|
832 |
+
subsampling = model.encoder.embed.subsampling_rate
|
833 |
+
context = model.encoder.embed.right_context + 1
|
834 |
+
decoding_window = (decoding_chunk_size - 1) * subsampling + context
|
835 |
+
batch_size = 32
|
836 |
+
audio_len = decoding_window
|
837 |
+
feature_size = configs["input_dim"]
|
838 |
+
output_size = configs["encoder_conf"]["output_size"]
|
839 |
+
num_layers = configs["encoder_conf"]["num_blocks"]
|
840 |
+
# in transformer the cnn module will not be available
|
841 |
+
transformer = False
|
842 |
+
cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1) - 1
|
843 |
+
if not cnn_module_kernel:
|
844 |
+
transformer = True
|
845 |
+
num_decoding_left_chunks = args.num_decoding_left_chunks
|
846 |
+
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
|
847 |
+
if configs["encoder"] == "squeezeformer":
|
848 |
+
encoder = StreamingSqueezeformerEncoder(model, required_cache_size,
|
849 |
+
args.beam_size)
|
850 |
+
elif configs["encoder"] == "efficientConformer":
|
851 |
+
encoder = StreamingEfficientConformerEncoder(model,
|
852 |
+
required_cache_size,
|
853 |
+
args.beam_size)
|
854 |
+
else:
|
855 |
+
encoder = StreamingEncoder(
|
856 |
+
model,
|
857 |
+
required_cache_size,
|
858 |
+
args.beam_size,
|
859 |
+
transformer,
|
860 |
+
args.return_ctc_logprobs,
|
861 |
+
)
|
862 |
+
encoder.eval()
|
863 |
+
|
864 |
+
# begin to export encoder
|
865 |
+
chunk_xs = torch.randn(batch_size,
|
866 |
+
audio_len,
|
867 |
+
feature_size,
|
868 |
+
dtype=torch.float32)
|
869 |
+
chunk_lens = torch.ones(batch_size, dtype=torch.int32) * audio_len
|
870 |
+
|
871 |
+
offset = torch.arange(0, batch_size).unsqueeze(1)
|
872 |
+
# (elayers, b, head, cache_t1, d_k * 2)
|
873 |
+
head = configs["encoder_conf"]["attention_heads"]
|
874 |
+
d_k = configs["encoder_conf"]["output_size"] // head
|
875 |
+
att_cache = torch.randn(
|
876 |
+
batch_size,
|
877 |
+
num_layers,
|
878 |
+
head,
|
879 |
+
required_cache_size,
|
880 |
+
d_k * 2,
|
881 |
+
dtype=torch.float32,
|
882 |
+
)
|
883 |
+
cnn_cache = torch.randn(
|
884 |
+
batch_size,
|
885 |
+
num_layers,
|
886 |
+
output_size,
|
887 |
+
cnn_module_kernel,
|
888 |
+
dtype=torch.float32,
|
889 |
+
)
|
890 |
+
|
891 |
+
cache_mask = torch.ones(batch_size,
|
892 |
+
1,
|
893 |
+
required_cache_size,
|
894 |
+
dtype=torch.float32)
|
895 |
+
input_names = [
|
896 |
+
"chunk_xs",
|
897 |
+
"chunk_lens",
|
898 |
+
"offset",
|
899 |
+
"att_cache",
|
900 |
+
"cnn_cache",
|
901 |
+
"cache_mask",
|
902 |
+
]
|
903 |
+
output_names = [
|
904 |
+
"log_probs",
|
905 |
+
"log_probs_idx",
|
906 |
+
"chunk_out",
|
907 |
+
"chunk_out_lens",
|
908 |
+
"r_offset",
|
909 |
+
"r_att_cache",
|
910 |
+
"r_cnn_cache",
|
911 |
+
"r_cache_mask",
|
912 |
+
]
|
913 |
+
if args.return_ctc_logprobs:
|
914 |
+
output_names = [
|
915 |
+
"ctc_log_probs",
|
916 |
+
"chunk_out",
|
917 |
+
"chunk_out_lens",
|
918 |
+
"r_offset",
|
919 |
+
"r_att_cache",
|
920 |
+
"r_cnn_cache",
|
921 |
+
"r_cache_mask",
|
922 |
+
]
|
923 |
+
input_tensors = (
|
924 |
+
chunk_xs,
|
925 |
+
chunk_lens,
|
926 |
+
offset,
|
927 |
+
att_cache,
|
928 |
+
cnn_cache,
|
929 |
+
cache_mask,
|
930 |
+
)
|
931 |
+
if transformer:
|
932 |
+
assert (args.return_ctc_logprobs is
|
933 |
+
False), "return_ctc_logprobs is not supported in transformer"
|
934 |
+
output_names.pop(6)
|
935 |
+
|
936 |
+
all_names = input_names + output_names
|
937 |
+
dynamic_axes = {}
|
938 |
+
for name in all_names:
|
939 |
+
# only the first dimension is dynamic
|
940 |
+
# all other dimension is fixed
|
941 |
+
dynamic_axes[name] = {0: "B"}
|
942 |
+
|
943 |
+
torch.onnx.export(
|
944 |
+
encoder,
|
945 |
+
input_tensors,
|
946 |
+
encoder_onnx_path,
|
947 |
+
export_params=True,
|
948 |
+
opset_version=14,
|
949 |
+
do_constant_folding=True,
|
950 |
+
input_names=input_names,
|
951 |
+
output_names=output_names,
|
952 |
+
dynamic_axes=dynamic_axes,
|
953 |
+
verbose=False,
|
954 |
+
)
|
955 |
+
|
956 |
+
with torch.no_grad():
|
957 |
+
torch_outs = encoder(chunk_xs, chunk_lens, offset, att_cache,
|
958 |
+
cnn_cache, cache_mask)
|
959 |
+
if transformer:
|
960 |
+
torch_outs = list(torch_outs).pop(6)
|
961 |
+
ort_session = onnxruntime.InferenceSession(
|
962 |
+
encoder_onnx_path, providers=["CUDAExecutionProvider"])
|
963 |
+
ort_inputs = {}
|
964 |
+
|
965 |
+
input_tensors = to_numpy(input_tensors)
|
966 |
+
for idx, name in enumerate(input_names):
|
967 |
+
ort_inputs[name] = input_tensors[idx]
|
968 |
+
if transformer:
|
969 |
+
del ort_inputs["cnn_cache"]
|
970 |
+
ort_outs = ort_session.run(None, ort_inputs)
|
971 |
+
test(to_numpy(torch_outs), ort_outs, rtol=1e-03, atol=1e-05)
|
972 |
+
logger.info("export to onnx streaming encoder succeed!")
|
973 |
+
onnx_config = {
|
974 |
+
"subsampling_rate": subsampling,
|
975 |
+
"context": context,
|
976 |
+
"decoding_chunk_size": decoding_chunk_size,
|
977 |
+
"num_decoding_left_chunks": num_decoding_left_chunks,
|
978 |
+
"beam_size": args.beam_size,
|
979 |
+
"fp16": args.fp16,
|
980 |
+
"feat_size": feature_size,
|
981 |
+
"decoding_window": decoding_window,
|
982 |
+
"cnn_module_kernel_cache": cnn_module_kernel,
|
983 |
+
"return_ctc_logprobs": args.return_ctc_logprobs,
|
984 |
+
}
|
985 |
+
return onnx_config
|
986 |
+
|
987 |
+
|
988 |
+
def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path,
|
989 |
+
decoder_fastertransformer):
|
990 |
+
bz, seq_len = 32, 100
|
991 |
+
beam_size = args.beam_size
|
992 |
+
decoder = Decoder(
|
993 |
+
model.decoder,
|
994 |
+
model.ctc_weight,
|
995 |
+
model.reverse_weight,
|
996 |
+
beam_size,
|
997 |
+
decoder_fastertransformer,
|
998 |
+
)
|
999 |
+
decoder.eval()
|
1000 |
+
|
1001 |
+
hyps_pad_sos_eos = torch.randint(low=3,
|
1002 |
+
high=1000,
|
1003 |
+
size=(bz, beam_size, seq_len))
|
1004 |
+
hyps_lens_sos = torch.randint(low=3,
|
1005 |
+
high=seq_len,
|
1006 |
+
size=(bz, beam_size),
|
1007 |
+
dtype=torch.int32)
|
1008 |
+
r_hyps_pad_sos_eos = torch.randint(low=3,
|
1009 |
+
high=1000,
|
1010 |
+
size=(bz, beam_size, seq_len))
|
1011 |
+
|
1012 |
+
output_size = configs["encoder_conf"]["output_size"]
|
1013 |
+
encoder_out = torch.randn(bz, seq_len, output_size, dtype=torch.float32)
|
1014 |
+
encoder_out_lens = torch.randint(low=3,
|
1015 |
+
high=seq_len,
|
1016 |
+
size=(bz, ),
|
1017 |
+
dtype=torch.int32)
|
1018 |
+
ctc_score = torch.randn(bz, beam_size, dtype=torch.float32)
|
1019 |
+
|
1020 |
+
input_names = [
|
1021 |
+
"encoder_out",
|
1022 |
+
"encoder_out_lens",
|
1023 |
+
"hyps_pad_sos_eos",
|
1024 |
+
"hyps_lens_sos",
|
1025 |
+
"r_hyps_pad_sos_eos",
|
1026 |
+
"ctc_score",
|
1027 |
+
]
|
1028 |
+
output_names = ["best_index"]
|
1029 |
+
if decoder_fastertransformer:
|
1030 |
+
output_names.insert(0, "decoder_out")
|
1031 |
+
|
1032 |
+
torch.onnx.export(
|
1033 |
+
decoder,
|
1034 |
+
(
|
1035 |
+
encoder_out,
|
1036 |
+
encoder_out_lens,
|
1037 |
+
hyps_pad_sos_eos,
|
1038 |
+
hyps_lens_sos,
|
1039 |
+
r_hyps_pad_sos_eos,
|
1040 |
+
ctc_score,
|
1041 |
+
),
|
1042 |
+
decoder_onnx_path,
|
1043 |
+
export_params=True,
|
1044 |
+
opset_version=13,
|
1045 |
+
do_constant_folding=True,
|
1046 |
+
input_names=input_names,
|
1047 |
+
output_names=output_names,
|
1048 |
+
dynamic_axes={
|
1049 |
+
"encoder_out": {
|
1050 |
+
0: "B",
|
1051 |
+
1: "T"
|
1052 |
+
},
|
1053 |
+
"encoder_out_lens": {
|
1054 |
+
0: "B"
|
1055 |
+
},
|
1056 |
+
"hyps_pad_sos_eos": {
|
1057 |
+
0: "B",
|
1058 |
+
2: "T2"
|
1059 |
+
},
|
1060 |
+
"hyps_lens_sos": {
|
1061 |
+
0: "B"
|
1062 |
+
},
|
1063 |
+
"r_hyps_pad_sos_eos": {
|
1064 |
+
0: "B",
|
1065 |
+
2: "T2"
|
1066 |
+
},
|
1067 |
+
"ctc_score": {
|
1068 |
+
0: "B"
|
1069 |
+
},
|
1070 |
+
"best_index": {
|
1071 |
+
0: "B"
|
1072 |
+
},
|
1073 |
+
},
|
1074 |
+
verbose=False,
|
1075 |
+
)
|
1076 |
+
with torch.no_grad():
|
1077 |
+
o0 = decoder(
|
1078 |
+
encoder_out,
|
1079 |
+
encoder_out_lens,
|
1080 |
+
hyps_pad_sos_eos,
|
1081 |
+
hyps_lens_sos,
|
1082 |
+
r_hyps_pad_sos_eos,
|
1083 |
+
ctc_score,
|
1084 |
+
)
|
1085 |
+
providers = ["CUDAExecutionProvider"]
|
1086 |
+
ort_session = onnxruntime.InferenceSession(decoder_onnx_path,
|
1087 |
+
providers=providers)
|
1088 |
+
|
1089 |
+
input_tensors = [
|
1090 |
+
encoder_out,
|
1091 |
+
encoder_out_lens,
|
1092 |
+
hyps_pad_sos_eos,
|
1093 |
+
hyps_lens_sos,
|
1094 |
+
r_hyps_pad_sos_eos,
|
1095 |
+
ctc_score,
|
1096 |
+
]
|
1097 |
+
ort_inputs = {}
|
1098 |
+
input_tensors = to_numpy(input_tensors)
|
1099 |
+
for idx, name in enumerate(input_names):
|
1100 |
+
ort_inputs[name] = input_tensors[idx]
|
1101 |
+
|
1102 |
+
# if model.reverse weight == 0,
|
1103 |
+
# the r_hyps_pad will be removed
|
1104 |
+
# from the onnx decoder since it doen't play any role
|
1105 |
+
if model.reverse_weight == 0:
|
1106 |
+
del ort_inputs["r_hyps_pad_sos_eos"]
|
1107 |
+
ort_outs = ort_session.run(None, ort_inputs)
|
1108 |
+
|
1109 |
+
# check decoder output
|
1110 |
+
if decoder_fastertransformer:
|
1111 |
+
test(to_numpy(o0), ort_outs, rtol=1e-03, atol=1e-05)
|
1112 |
+
else:
|
1113 |
+
test(to_numpy([o0]), ort_outs, rtol=1e-03, atol=1e-05)
|
1114 |
+
logger.info("export to onnx decoder succeed!")
|
1115 |
+
|
1116 |
+
|
1117 |
+
if __name__ == "__main__":
|
1118 |
+
parser = argparse.ArgumentParser(description="export x86_gpu model")
|
1119 |
+
parser.add_argument("--config", required=True, help="config file")
|
1120 |
+
parser.add_argument("--checkpoint", required=True, help="checkpoint model")
|
1121 |
+
parser.add_argument(
|
1122 |
+
"--cmvn_file",
|
1123 |
+
required=False,
|
1124 |
+
default="",
|
1125 |
+
type=str,
|
1126 |
+
help="global_cmvn file, default path is in config file",
|
1127 |
+
)
|
1128 |
+
parser.add_argument(
|
1129 |
+
"--reverse_weight",
|
1130 |
+
default=-1.0,
|
1131 |
+
type=float,
|
1132 |
+
required=False,
|
1133 |
+
help="reverse weight for bitransformer," +
|
1134 |
+
"default value is in config file",
|
1135 |
+
)
|
1136 |
+
parser.add_argument(
|
1137 |
+
"--ctc_weight",
|
1138 |
+
default=-1.0,
|
1139 |
+
type=float,
|
1140 |
+
required=False,
|
1141 |
+
help="ctc weight, default value is in config file",
|
1142 |
+
)
|
1143 |
+
parser.add_argument(
|
1144 |
+
"--beam_size",
|
1145 |
+
default=10,
|
1146 |
+
type=int,
|
1147 |
+
required=False,
|
1148 |
+
help="beam size would be ctc output size",
|
1149 |
+
)
|
1150 |
+
parser.add_argument(
|
1151 |
+
"--output_onnx_dir",
|
1152 |
+
default="onnx_model",
|
1153 |
+
help="output onnx encoder and decoder directory",
|
1154 |
+
)
|
1155 |
+
parser.add_argument(
|
1156 |
+
"--fp16",
|
1157 |
+
action="store_true",
|
1158 |
+
help="whether to export fp16 model, default false",
|
1159 |
+
)
|
1160 |
+
# arguments for streaming encoder
|
1161 |
+
parser.add_argument(
|
1162 |
+
"--streaming",
|
1163 |
+
action="store_true",
|
1164 |
+
help="whether to export streaming encoder, default false",
|
1165 |
+
)
|
1166 |
+
parser.add_argument(
|
1167 |
+
"--decoding_chunk_size",
|
1168 |
+
default=16,
|
1169 |
+
type=int,
|
1170 |
+
required=False,
|
1171 |
+
help="the decoding chunk size, <=0 is not supported",
|
1172 |
+
)
|
1173 |
+
parser.add_argument(
|
1174 |
+
"--num_decoding_left_chunks",
|
1175 |
+
default=5,
|
1176 |
+
type=int,
|
1177 |
+
required=False,
|
1178 |
+
help="number of left chunks, <= 0 is not supported",
|
1179 |
+
)
|
1180 |
+
parser.add_argument(
|
1181 |
+
"--decoder_fastertransformer",
|
1182 |
+
action="store_true",
|
1183 |
+
help="return decoder_out and best_index for ft",
|
1184 |
+
)
|
1185 |
+
parser.add_argument(
|
1186 |
+
"--return_ctc_logprobs",
|
1187 |
+
action="store_true",
|
1188 |
+
help="return full ctc_log_probs for TLG streaming encoder",
|
1189 |
+
)
|
1190 |
+
args = parser.parse_args()
|
1191 |
+
|
1192 |
+
torch.manual_seed(0)
|
1193 |
+
torch.set_printoptions(precision=10)
|
1194 |
+
|
1195 |
+
with open(args.config, "r") as fin:
|
1196 |
+
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
1197 |
+
if args.cmvn_file and os.path.exists(args.cmvn_file):
|
1198 |
+
if 'cmvn' not in configs:
|
1199 |
+
configs['cmvn'] = "global_cmvn"
|
1200 |
+
configs['cmvn_conf'] = {}
|
1201 |
+
else:
|
1202 |
+
assert configs['cmvn'] == "global_cmvn"
|
1203 |
+
assert configs['cmvn_conf'] is not None
|
1204 |
+
configs['cmvn_conf']["cmvn_file"] = args.cmvn_file
|
1205 |
+
if (args.reverse_weight != -1.0
|
1206 |
+
and "reverse_weight" in configs["model_conf"]):
|
1207 |
+
configs["model_conf"]["reverse_weight"] = args.reverse_weight
|
1208 |
+
print("Update reverse weight to", args.reverse_weight)
|
1209 |
+
if args.ctc_weight != -1:
|
1210 |
+
print("Update ctc weight to ", args.ctc_weight)
|
1211 |
+
configs["model_conf"]["ctc_weight"] = args.ctc_weight
|
1212 |
+
configs["encoder_conf"]["use_dynamic_chunk"] = False
|
1213 |
+
|
1214 |
+
model, configs = init_model(args, configs)
|
1215 |
+
model.eval()
|
1216 |
+
|
1217 |
+
if not os.path.exists(args.output_onnx_dir):
|
1218 |
+
os.mkdir(args.output_onnx_dir)
|
1219 |
+
encoder_onnx_path = os.path.join(args.output_onnx_dir, "encoder.onnx")
|
1220 |
+
export_enc_func = None
|
1221 |
+
if args.streaming:
|
1222 |
+
assert args.decoding_chunk_size > 0
|
1223 |
+
assert args.num_decoding_left_chunks > 0
|
1224 |
+
export_enc_func = export_online_encoder
|
1225 |
+
else:
|
1226 |
+
export_enc_func = export_offline_encoder
|
1227 |
+
|
1228 |
+
onnx_config = export_enc_func(model, configs, args, logger,
|
1229 |
+
encoder_onnx_path)
|
1230 |
+
|
1231 |
+
decoder_onnx_path = os.path.join(args.output_onnx_dir, "decoder.onnx")
|
1232 |
+
export_rescoring_decoder(
|
1233 |
+
model,
|
1234 |
+
configs,
|
1235 |
+
args,
|
1236 |
+
logger,
|
1237 |
+
decoder_onnx_path,
|
1238 |
+
args.decoder_fastertransformer,
|
1239 |
+
)
|
1240 |
+
|
1241 |
+
if args.fp16:
|
1242 |
+
try:
|
1243 |
+
import onnxmltools
|
1244 |
+
from onnxmltools.utils.float16_converter import (
|
1245 |
+
convert_float_to_float16, )
|
1246 |
+
except ImportError:
|
1247 |
+
print("Please install onnxmltools!")
|
1248 |
+
sys.exit(1)
|
1249 |
+
encoder_onnx_model = onnxmltools.utils.load_model(encoder_onnx_path)
|
1250 |
+
encoder_onnx_model = convert_float_to_float16(encoder_onnx_model)
|
1251 |
+
encoder_onnx_path = os.path.join(args.output_onnx_dir,
|
1252 |
+
"encoder_fp16.onnx")
|
1253 |
+
onnxmltools.utils.save_model(encoder_onnx_model, encoder_onnx_path)
|
1254 |
+
decoder_onnx_model = onnxmltools.utils.load_model(decoder_onnx_path)
|
1255 |
+
decoder_onnx_model = convert_float_to_float16(decoder_onnx_model)
|
1256 |
+
decoder_onnx_path = os.path.join(args.output_onnx_dir,
|
1257 |
+
"decoder_fp16.onnx")
|
1258 |
+
onnxmltools.utils.save_model(decoder_onnx_model, decoder_onnx_path)
|
1259 |
+
# dump configurations
|
1260 |
+
|
1261 |
+
config_dir = os.path.join(args.output_onnx_dir, "config.yaml")
|
1262 |
+
with open(config_dir, "w") as out:
|
1263 |
+
yaml.dump(onnx_config, out)
|
wenet/bin/recognize.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import print_function
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import copy
|
19 |
+
import logging
|
20 |
+
import os
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import yaml
|
24 |
+
from torch.utils.data import DataLoader
|
25 |
+
|
26 |
+
from wenet.dataset.dataset import Dataset
|
27 |
+
from wenet.utils.config import override_config
|
28 |
+
from wenet.utils.init_model import init_model
|
29 |
+
from wenet.utils.init_tokenizer import init_tokenizer
|
30 |
+
from wenet.utils.context_graph import ContextGraph
|
31 |
+
from wenet.utils.ctc_utils import get_blank_id
|
32 |
+
from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu
|
33 |
+
|
34 |
+
|
35 |
+
def get_args():
|
36 |
+
parser = argparse.ArgumentParser(description='recognize with your model')
|
37 |
+
parser.add_argument('--config', required=True, help='config file')
|
38 |
+
parser.add_argument('--test_data', required=True, help='test data file')
|
39 |
+
parser.add_argument('--data_type',
|
40 |
+
default='raw',
|
41 |
+
# choices=['raw', 'shard'],
|
42 |
+
help='train and cv data type')
|
43 |
+
parser.add_argument('--gpu',
|
44 |
+
type=int,
|
45 |
+
default=-1,
|
46 |
+
help='gpu id for this rank, -1 for cpu')
|
47 |
+
parser.add_argument('--device',
|
48 |
+
type=str,
|
49 |
+
default="cpu",
|
50 |
+
choices=["cpu", "npu", "cuda"],
|
51 |
+
help='accelerator to use')
|
52 |
+
parser.add_argument('--dtype',
|
53 |
+
type=str,
|
54 |
+
default='fp32',
|
55 |
+
choices=['fp16', 'fp32', 'bf16'],
|
56 |
+
help='model\'s dtype')
|
57 |
+
parser.add_argument('--num_workers',
|
58 |
+
default=0,
|
59 |
+
type=int,
|
60 |
+
help='num of subprocess workers for reading')
|
61 |
+
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
62 |
+
parser.add_argument('--beam_size',
|
63 |
+
type=int,
|
64 |
+
default=10,
|
65 |
+
help='beam size for search')
|
66 |
+
parser.add_argument('--length_penalty',
|
67 |
+
type=float,
|
68 |
+
default=0.0,
|
69 |
+
help='length penalty')
|
70 |
+
parser.add_argument('--blank_penalty',
|
71 |
+
type=float,
|
72 |
+
default=0.0,
|
73 |
+
help='blank penalty')
|
74 |
+
parser.add_argument('--result_dir', required=True, help='asr result file')
|
75 |
+
parser.add_argument('--batch_size',
|
76 |
+
type=int,
|
77 |
+
default=16,
|
78 |
+
help='asr result file')
|
79 |
+
parser.add_argument('--modes',
|
80 |
+
nargs='+',
|
81 |
+
help="""decoding mode, support the following:
|
82 |
+
attention
|
83 |
+
ctc_greedy_search
|
84 |
+
ctc_prefix_beam_search
|
85 |
+
attention_rescoring
|
86 |
+
rnnt_greedy_search
|
87 |
+
rnnt_beam_search
|
88 |
+
rnnt_beam_attn_rescoring
|
89 |
+
ctc_beam_td_attn_rescoring
|
90 |
+
hlg_onebest
|
91 |
+
hlg_rescore
|
92 |
+
paraformer_greedy_search
|
93 |
+
paraformer_beam_search""")
|
94 |
+
parser.add_argument('--search_ctc_weight',
|
95 |
+
type=float,
|
96 |
+
default=1.0,
|
97 |
+
help='ctc weight for nbest generation')
|
98 |
+
parser.add_argument('--search_transducer_weight',
|
99 |
+
type=float,
|
100 |
+
default=0.0,
|
101 |
+
help='transducer weight for nbest generation')
|
102 |
+
parser.add_argument('--ctc_weight',
|
103 |
+
type=float,
|
104 |
+
default=0.0,
|
105 |
+
help='ctc weight for rescoring weight in \
|
106 |
+
attention rescoring decode mode \
|
107 |
+
ctc weight for rescoring weight in \
|
108 |
+
transducer attention rescore decode mode')
|
109 |
+
|
110 |
+
parser.add_argument('--transducer_weight',
|
111 |
+
type=float,
|
112 |
+
default=0.0,
|
113 |
+
help='transducer weight for rescoring weight in '
|
114 |
+
'transducer attention rescore mode')
|
115 |
+
parser.add_argument('--attn_weight',
|
116 |
+
type=float,
|
117 |
+
default=0.0,
|
118 |
+
help='attention weight for rescoring weight in '
|
119 |
+
'transducer attention rescore mode')
|
120 |
+
parser.add_argument('--decoding_chunk_size',
|
121 |
+
type=int,
|
122 |
+
default=-1,
|
123 |
+
help='''decoding chunk size,
|
124 |
+
<0: for decoding, use full chunk.
|
125 |
+
>0: for decoding, use fixed chunk size as set.
|
126 |
+
0: used for training, it's prohibited here''')
|
127 |
+
parser.add_argument('--num_decoding_left_chunks',
|
128 |
+
type=int,
|
129 |
+
default=-1,
|
130 |
+
help='number of left chunks for decoding')
|
131 |
+
parser.add_argument('--simulate_streaming',
|
132 |
+
action='store_true',
|
133 |
+
help='simulate streaming inference')
|
134 |
+
parser.add_argument('--reverse_weight',
|
135 |
+
type=float,
|
136 |
+
default=0.0,
|
137 |
+
help='''right to left weight for attention rescoring
|
138 |
+
decode mode''')
|
139 |
+
parser.add_argument('--override_config',
|
140 |
+
action='append',
|
141 |
+
default=[],
|
142 |
+
help="override yaml config")
|
143 |
+
|
144 |
+
parser.add_argument('--word',
|
145 |
+
default='',
|
146 |
+
type=str,
|
147 |
+
help='word file, only used for hlg decode')
|
148 |
+
parser.add_argument('--hlg',
|
149 |
+
default='',
|
150 |
+
type=str,
|
151 |
+
help='hlg file, only used for hlg decode')
|
152 |
+
parser.add_argument('--lm_scale',
|
153 |
+
type=float,
|
154 |
+
default=0.0,
|
155 |
+
help='lm scale for hlg attention rescore decode')
|
156 |
+
parser.add_argument('--decoder_scale',
|
157 |
+
type=float,
|
158 |
+
default=0.0,
|
159 |
+
help='lm scale for hlg attention rescore decode')
|
160 |
+
parser.add_argument('--r_decoder_scale',
|
161 |
+
type=float,
|
162 |
+
default=0.0,
|
163 |
+
help='lm scale for hlg attention rescore decode')
|
164 |
+
|
165 |
+
parser.add_argument(
|
166 |
+
'--context_bias_mode',
|
167 |
+
type=str,
|
168 |
+
default='',
|
169 |
+
help='''Context bias mode, selectable from the following
|
170 |
+
option: decoding-graph, deep-biasing''')
|
171 |
+
parser.add_argument('--context_list_path',
|
172 |
+
type=str,
|
173 |
+
default='',
|
174 |
+
help='Context list path')
|
175 |
+
parser.add_argument('--context_graph_score',
|
176 |
+
type=float,
|
177 |
+
default=0.0,
|
178 |
+
help='''The higher the score, the greater the degree of
|
179 |
+
bias using decoding-graph for biasing''')
|
180 |
+
|
181 |
+
parser.add_argument('--use_lora',
|
182 |
+
type=bool,
|
183 |
+
default=False,
|
184 |
+
help='''Whether to use lora for biasing''')
|
185 |
+
parser.add_argument("--lora_ckpt_path",
|
186 |
+
default=None,
|
187 |
+
type=str,
|
188 |
+
help="lora checkpoint path.")
|
189 |
+
|
190 |
+
parser.add_argument('--task',
|
191 |
+
type=str,
|
192 |
+
default='asr',
|
193 |
+
help='Context list path')
|
194 |
+
parser.add_argument('--lang',
|
195 |
+
type=str,
|
196 |
+
default='zh',
|
197 |
+
help='Context list path')
|
198 |
+
args = parser.parse_args()
|
199 |
+
print(args)
|
200 |
+
return args
|
201 |
+
|
202 |
+
|
203 |
+
def main():
|
204 |
+
args = get_args()
|
205 |
+
logging.basicConfig(level=logging.DEBUG,
|
206 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
207 |
+
if args.gpu != -1:
|
208 |
+
# remain the original usage of gpu
|
209 |
+
args.device = "cuda"
|
210 |
+
if "cuda" in args.device:
|
211 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
212 |
+
|
213 |
+
with open(args.config, 'r') as fin:
|
214 |
+
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
215 |
+
if len(args.override_config) > 0:
|
216 |
+
configs = override_config(configs, args.override_config)
|
217 |
+
|
218 |
+
test_conf = copy.deepcopy(configs['dataset_conf'])
|
219 |
+
|
220 |
+
test_conf['filter_conf']['max_length'] = 102400
|
221 |
+
test_conf['filter_conf']['min_length'] = 0
|
222 |
+
test_conf['filter_conf']['token_max_length'] = 102400
|
223 |
+
test_conf['filter_conf']['token_min_length'] = 0
|
224 |
+
test_conf['filter_conf']['max_output_input_ratio'] = 102400
|
225 |
+
test_conf['filter_conf']['min_output_input_ratio'] = 0
|
226 |
+
test_conf['speed_perturb'] = False
|
227 |
+
test_conf['spec_aug'] = False
|
228 |
+
test_conf['spec_sub'] = False
|
229 |
+
test_conf['spec_trim'] = False
|
230 |
+
test_conf['shuffle'] = False
|
231 |
+
test_conf['sort'] = False
|
232 |
+
test_conf['cycle'] = 1
|
233 |
+
test_conf['list_shuffle'] = False
|
234 |
+
if 'fbank_conf' in test_conf:
|
235 |
+
test_conf['fbank_conf']['dither'] = 0.0
|
236 |
+
elif 'mfcc_conf' in test_conf:
|
237 |
+
test_conf['mfcc_conf']['dither'] = 0.0
|
238 |
+
test_conf['batch_conf']['batch_type'] = "static"
|
239 |
+
test_conf['batch_conf']['batch_size'] = args.batch_size
|
240 |
+
|
241 |
+
tokenizer = init_tokenizer(configs)
|
242 |
+
test_dataset = Dataset(args.data_type,
|
243 |
+
args.test_data,
|
244 |
+
tokenizer,
|
245 |
+
test_conf,
|
246 |
+
partition=False)
|
247 |
+
|
248 |
+
test_data_loader = DataLoader(test_dataset,
|
249 |
+
batch_size=None,
|
250 |
+
num_workers=args.num_workers)
|
251 |
+
|
252 |
+
# Init asr model from configs
|
253 |
+
args.jit = False
|
254 |
+
model, configs = init_model(args, configs)
|
255 |
+
|
256 |
+
device = torch.device(args.device)
|
257 |
+
model = model.to(device)
|
258 |
+
model.eval()
|
259 |
+
dtype = torch.float32
|
260 |
+
if args.dtype == 'fp16':
|
261 |
+
dtype = torch.float16
|
262 |
+
elif args.dtype == 'bf16':
|
263 |
+
dtype = torch.bfloat16
|
264 |
+
logging.info("compute dtype is {}".format(dtype))
|
265 |
+
|
266 |
+
context_graph = None
|
267 |
+
if 'decoding-graph' in args.context_bias_mode:
|
268 |
+
context_graph = ContextGraph(args.context_list_path,
|
269 |
+
tokenizer.symbol_table,
|
270 |
+
configs['tokenizer_conf']['bpe_path'],
|
271 |
+
args.context_graph_score)
|
272 |
+
|
273 |
+
_, blank_id = get_blank_id(configs, tokenizer.symbol_table)
|
274 |
+
logging.info("blank_id is {}".format(blank_id))
|
275 |
+
|
276 |
+
# TODO(Dinghao Zhou): Support RNN-T related decoding
|
277 |
+
# TODO(Lv Xiang): Support k2 related decoding
|
278 |
+
# TODO(Kaixun Huang): Support context graph
|
279 |
+
files = {}
|
280 |
+
for mode in args.modes:
|
281 |
+
dir_name = os.path.join(args.result_dir, mode)
|
282 |
+
os.makedirs(dir_name, exist_ok=True)
|
283 |
+
file_name = os.path.join(dir_name, 'text')
|
284 |
+
files[mode] = open(file_name, 'w', encoding='utf-8')
|
285 |
+
max_format_len = max([len(mode) for mode in args.modes])
|
286 |
+
|
287 |
+
with torch.cuda.amp.autocast(enabled=True,
|
288 |
+
dtype=dtype,
|
289 |
+
cache_enabled=False):
|
290 |
+
with torch.no_grad():
|
291 |
+
utt_num=0
|
292 |
+
# logging.info(f'utt_num: {utt_num}')
|
293 |
+
for batch_idx, batch in enumerate(test_data_loader):
|
294 |
+
keys = batch["keys"]
|
295 |
+
feats = batch["feats"].to(device)
|
296 |
+
target = batch["target"].to(device)
|
297 |
+
feats_lengths = batch["feats_lengths"].to(device)
|
298 |
+
target_lengths = batch["target_lengths"].to(device)
|
299 |
+
batch_size = feats.size(0)
|
300 |
+
# task_list = ["transcribe" for i in range(batch_size)]
|
301 |
+
task_list = [args.task for i in range(batch_size)]
|
302 |
+
lang_list = [args.lang for i in range(batch_size)]
|
303 |
+
infos = {"tasks": task_list, "langs":lang_list}
|
304 |
+
results = model.decode(
|
305 |
+
args.modes,
|
306 |
+
feats,
|
307 |
+
feats_lengths,
|
308 |
+
args.beam_size,
|
309 |
+
decoding_chunk_size=args.decoding_chunk_size,
|
310 |
+
num_decoding_left_chunks=args.num_decoding_left_chunks,
|
311 |
+
ctc_weight=args.ctc_weight,
|
312 |
+
simulate_streaming=args.simulate_streaming,
|
313 |
+
reverse_weight=args.reverse_weight,
|
314 |
+
context_graph=context_graph,
|
315 |
+
blank_id=blank_id,
|
316 |
+
blank_penalty=args.blank_penalty,
|
317 |
+
length_penalty=args.length_penalty,
|
318 |
+
infos=infos)
|
319 |
+
for i, key in enumerate(keys):
|
320 |
+
utt_num += 1
|
321 |
+
for mode, hyps in results.items():
|
322 |
+
tokens = hyps[i].tokens
|
323 |
+
line = '{} {}'.format(key,
|
324 |
+
tokenizer.detokenize(tokens)[0])
|
325 |
+
logging.info('{} {}'.format(mode.ljust(max_format_len),
|
326 |
+
line))
|
327 |
+
files[mode].write(line + '\n')
|
328 |
+
# if utt_num % 500 == 0:
|
329 |
+
# files[mode].flush()
|
330 |
+
for mode, f in files.items():
|
331 |
+
f.flush() # 强制将缓冲区内容刷新到文件
|
332 |
+
f.close()
|
333 |
+
|
334 |
+
|
335 |
+
if __name__ == '__main__':
|
336 |
+
main()
|
wenet/bin/recognize4llmasr.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import print_function
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import copy
|
19 |
+
import logging
|
20 |
+
import os
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import yaml
|
24 |
+
from gxl_ai_utils.utils.utils_model import set_random_seed
|
25 |
+
from torch.utils.data import DataLoader
|
26 |
+
|
27 |
+
from wenet.dataset.dataset import Dataset
|
28 |
+
from wenet.llm_asr.llmasr_model import LLMASR_Model
|
29 |
+
from wenet.utils.config import override_config
|
30 |
+
from wenet.utils.init_model import init_model
|
31 |
+
from wenet.utils.init_tokenizer import init_tokenizer
|
32 |
+
from wenet.utils.context_graph import ContextGraph
|
33 |
+
from wenet.utils.ctc_utils import get_blank_id
|
34 |
+
from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu
|
35 |
+
|
36 |
+
|
37 |
+
def get_args():
|
38 |
+
parser = argparse.ArgumentParser(description='recognize with your model')
|
39 |
+
parser.add_argument('--config', required=True, help='config file')
|
40 |
+
parser.add_argument('--test_data', required=True, help='test data file')
|
41 |
+
parser.add_argument('--data_type',
|
42 |
+
default='raw',
|
43 |
+
# choices=['raw', 'shard'],
|
44 |
+
help='train and cv data type')
|
45 |
+
parser.add_argument('--gpu',
|
46 |
+
type=int,
|
47 |
+
default=-1,
|
48 |
+
help='gpu id for this rank, -1 for cpu')
|
49 |
+
parser.add_argument('--device',
|
50 |
+
type=str,
|
51 |
+
default="cpu",
|
52 |
+
choices=["cpu", "npu", "cuda"],
|
53 |
+
help='accelerator to use')
|
54 |
+
parser.add_argument('--dtype',
|
55 |
+
type=str,
|
56 |
+
default='fp32',
|
57 |
+
choices=['fp16', 'fp32', 'bf16'],
|
58 |
+
help='model\'s dtype')
|
59 |
+
parser.add_argument('--num_workers',
|
60 |
+
default=0,
|
61 |
+
type=int,
|
62 |
+
help='num of subprocess workers for reading')
|
63 |
+
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
64 |
+
parser.add_argument('--beam_size',
|
65 |
+
type=int,
|
66 |
+
default=10,
|
67 |
+
help='beam size for search')
|
68 |
+
parser.add_argument('--length_penalty',
|
69 |
+
type=float,
|
70 |
+
default=0.0,
|
71 |
+
help='length penalty')
|
72 |
+
parser.add_argument('--blank_penalty',
|
73 |
+
type=float,
|
74 |
+
default=0.0,
|
75 |
+
help='blank penalty')
|
76 |
+
parser.add_argument('--result_dir', required=True, help='asr result file')
|
77 |
+
parser.add_argument('--batch_size',
|
78 |
+
type=int,
|
79 |
+
default=16,
|
80 |
+
help='asr result file')
|
81 |
+
parser.add_argument('--modes',
|
82 |
+
nargs='+',
|
83 |
+
help="""decoding mode, support the following:
|
84 |
+
attention
|
85 |
+
ctc_greedy_search
|
86 |
+
ctc_prefix_beam_search
|
87 |
+
attention_rescoring
|
88 |
+
rnnt_greedy_search
|
89 |
+
rnnt_beam_search
|
90 |
+
rnnt_beam_attn_rescoring
|
91 |
+
ctc_beam_td_attn_rescoring
|
92 |
+
hlg_onebest
|
93 |
+
hlg_rescore
|
94 |
+
paraformer_greedy_search
|
95 |
+
paraformer_beam_search""")
|
96 |
+
parser.add_argument('--search_ctc_weight',
|
97 |
+
type=float,
|
98 |
+
default=1.0,
|
99 |
+
help='ctc weight for nbest generation')
|
100 |
+
parser.add_argument('--search_transducer_weight',
|
101 |
+
type=float,
|
102 |
+
default=0.0,
|
103 |
+
help='transducer weight for nbest generation')
|
104 |
+
parser.add_argument('--ctc_weight',
|
105 |
+
type=float,
|
106 |
+
default=0.0,
|
107 |
+
help='ctc weight for rescoring weight in \
|
108 |
+
attention rescoring decode mode \
|
109 |
+
ctc weight for rescoring weight in \
|
110 |
+
transducer attention rescore decode mode')
|
111 |
+
|
112 |
+
parser.add_argument('--transducer_weight',
|
113 |
+
type=float,
|
114 |
+
default=0.0,
|
115 |
+
help='transducer weight for rescoring weight in '
|
116 |
+
'transducer attention rescore mode')
|
117 |
+
parser.add_argument('--attn_weight',
|
118 |
+
type=float,
|
119 |
+
default=0.0,
|
120 |
+
help='attention weight for rescoring weight in '
|
121 |
+
'transducer attention rescore mode')
|
122 |
+
parser.add_argument('--decoding_chunk_size',
|
123 |
+
type=int,
|
124 |
+
default=-1,
|
125 |
+
help='''decoding chunk size,
|
126 |
+
<0: for decoding, use full chunk.
|
127 |
+
>0: for decoding, use fixed chunk size as set.
|
128 |
+
0: used for training, it's prohibited here''')
|
129 |
+
parser.add_argument('--num_decoding_left_chunks',
|
130 |
+
type=int,
|
131 |
+
default=-1,
|
132 |
+
help='number of left chunks for decoding')
|
133 |
+
parser.add_argument('--simulate_streaming',
|
134 |
+
action='store_true',
|
135 |
+
help='simulate streaming inference')
|
136 |
+
parser.add_argument('--reverse_weight',
|
137 |
+
type=float,
|
138 |
+
default=0.0,
|
139 |
+
help='''right to left weight for attention rescoring
|
140 |
+
decode mode''')
|
141 |
+
parser.add_argument('--override_config',
|
142 |
+
action='append',
|
143 |
+
default=[],
|
144 |
+
help="override yaml config")
|
145 |
+
|
146 |
+
parser.add_argument('--word',
|
147 |
+
default='',
|
148 |
+
type=str,
|
149 |
+
help='word file, only used for hlg decode')
|
150 |
+
parser.add_argument('--hlg',
|
151 |
+
default='',
|
152 |
+
type=str,
|
153 |
+
help='hlg file, only used for hlg decode')
|
154 |
+
parser.add_argument('--lm_scale',
|
155 |
+
type=float,
|
156 |
+
default=0.0,
|
157 |
+
help='lm scale for hlg attention rescore decode')
|
158 |
+
parser.add_argument('--decoder_scale',
|
159 |
+
type=float,
|
160 |
+
default=0.0,
|
161 |
+
help='lm scale for hlg attention rescore decode')
|
162 |
+
parser.add_argument('--r_decoder_scale',
|
163 |
+
type=float,
|
164 |
+
default=0.0,
|
165 |
+
help='lm scale for hlg attention rescore decode')
|
166 |
+
|
167 |
+
parser.add_argument(
|
168 |
+
'--context_bias_mode',
|
169 |
+
type=str,
|
170 |
+
default='',
|
171 |
+
help='''Context bias mode, selectable from the following
|
172 |
+
option: decoding-graph, deep-biasing''')
|
173 |
+
parser.add_argument('--context_list_path',
|
174 |
+
type=str,
|
175 |
+
default='',
|
176 |
+
help='Context list path')
|
177 |
+
parser.add_argument('--context_graph_score',
|
178 |
+
type=float,
|
179 |
+
default=0.0,
|
180 |
+
help='''The higher the score, the greater the degree of
|
181 |
+
bias using decoding-graph for biasing''')
|
182 |
+
|
183 |
+
parser.add_argument('--use_lora',
|
184 |
+
type=bool,
|
185 |
+
default=False,
|
186 |
+
help='''Whether to use lora for biasing''')
|
187 |
+
parser.add_argument("--lora_ckpt_path",
|
188 |
+
default=None,
|
189 |
+
type=str,
|
190 |
+
help="lora checkpoint path.")
|
191 |
+
|
192 |
+
parser.add_argument('--task',
|
193 |
+
type=str,
|
194 |
+
default='asr',
|
195 |
+
help='Context list path')
|
196 |
+
parser.add_argument('--lang',
|
197 |
+
type=str,
|
198 |
+
default='zh',
|
199 |
+
help='Context list path')
|
200 |
+
args = parser.parse_args()
|
201 |
+
print(args)
|
202 |
+
return args
|
203 |
+
|
204 |
+
|
205 |
+
def main():
|
206 |
+
args = get_args()
|
207 |
+
logging.basicConfig(level=logging.DEBUG,
|
208 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
209 |
+
|
210 |
+
set_random_seed(777)
|
211 |
+
|
212 |
+
if args.gpu != -1:
|
213 |
+
# remain the original usage of gpu
|
214 |
+
args.device = "cuda"
|
215 |
+
if "cuda" in args.device:
|
216 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
217 |
+
|
218 |
+
with open(args.config, 'r') as fin:
|
219 |
+
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
220 |
+
if len(args.override_config) > 0:
|
221 |
+
configs = override_config(configs, args.override_config)
|
222 |
+
configs['dataset_conf']['filter_conf']['filter_no_extra_info'] = False
|
223 |
+
test_conf = copy.deepcopy(configs['dataset_conf'])
|
224 |
+
|
225 |
+
test_conf['filter_conf']['max_length'] = 3000 # whisper最长处理30s 102400
|
226 |
+
test_conf['filter_conf']['min_length'] = 0
|
227 |
+
test_conf['filter_conf']['token_max_length'] = 102400
|
228 |
+
test_conf['filter_conf']['token_min_length'] = 0
|
229 |
+
test_conf['filter_conf']['max_output_input_ratio'] = 102400
|
230 |
+
test_conf['filter_conf']['min_output_input_ratio'] = 0
|
231 |
+
test_conf['speed_perturb'] = False
|
232 |
+
test_conf['spec_aug'] = False
|
233 |
+
test_conf['spec_sub'] = False
|
234 |
+
test_conf['spec_trim'] = False
|
235 |
+
test_conf['shuffle'] = True
|
236 |
+
test_conf['sort'] = False
|
237 |
+
test_conf['cycle'] = 1
|
238 |
+
test_conf['list_shuffle'] = True
|
239 |
+
if 'fbank_conf' in test_conf:
|
240 |
+
test_conf['fbank_conf']['dither'] = 0.0
|
241 |
+
elif 'mfcc_conf' in test_conf:
|
242 |
+
test_conf['mfcc_conf']['dither'] = 0.0
|
243 |
+
test_conf['batch_conf']['batch_type'] = "static"
|
244 |
+
test_conf['batch_conf']['batch_size'] = 1
|
245 |
+
test_conf['split_num'] = 1
|
246 |
+
|
247 |
+
|
248 |
+
tokenizer = init_tokenizer(configs)
|
249 |
+
test_dataset = Dataset(args.data_type,
|
250 |
+
args.test_data,
|
251 |
+
tokenizer,
|
252 |
+
test_conf,
|
253 |
+
partition=False)
|
254 |
+
|
255 |
+
test_data_loader = DataLoader(test_dataset,
|
256 |
+
batch_size=None,
|
257 |
+
num_workers=args.num_workers)
|
258 |
+
|
259 |
+
# Init asr model from configs
|
260 |
+
args.jit = False
|
261 |
+
model, configs = init_model(args, configs)
|
262 |
+
|
263 |
+
device = torch.device(args.device)
|
264 |
+
model:LLMASR_Model = model.to(device)
|
265 |
+
model.eval()
|
266 |
+
dtype = torch.float32
|
267 |
+
if args.dtype == 'fp16':
|
268 |
+
dtype = torch.float16
|
269 |
+
elif args.dtype == 'bf16':
|
270 |
+
dtype = torch.bfloat16
|
271 |
+
logging.info("compute dtype is {}".format(dtype))
|
272 |
+
|
273 |
+
context_graph = None
|
274 |
+
if 'decoding-graph' in args.context_bias_mode:
|
275 |
+
context_graph = ContextGraph(args.context_list_path,
|
276 |
+
tokenizer.symbol_table,
|
277 |
+
configs['tokenizer_conf']['bpe_path'],
|
278 |
+
args.context_graph_score)
|
279 |
+
|
280 |
+
_, blank_id = get_blank_id(configs, tokenizer.symbol_table)
|
281 |
+
logging.info("blank_id is {}".format(blank_id))
|
282 |
+
|
283 |
+
# TODO(Dinghao Zhou): Support RNN-T related decoding
|
284 |
+
# TODO(Lv Xiang): Support k2 related decoding
|
285 |
+
# TODO(Kaixun Huang): Support context graph
|
286 |
+
files = {}
|
287 |
+
modes = ['llmasr_decode']
|
288 |
+
for mode in modes:
|
289 |
+
dir_name = os.path.join(args.result_dir, mode)
|
290 |
+
os.makedirs(dir_name, exist_ok=True)
|
291 |
+
file_name = os.path.join(dir_name, 'text')
|
292 |
+
files[mode] = open(file_name, 'w', encoding='utf-8')
|
293 |
+
max_format_len = max([len(mode) for mode in args.modes])
|
294 |
+
|
295 |
+
# Get prompt config
|
296 |
+
from gxl_ai_utils.utils import utils_file
|
297 |
+
global_prompt_dict = utils_file.load_dict_from_yaml('conf/prompt_stage4.yaml')
|
298 |
+
|
299 |
+
with torch.cuda.amp.autocast(enabled=True,
|
300 |
+
dtype=dtype,
|
301 |
+
cache_enabled=False):
|
302 |
+
with torch.no_grad():
|
303 |
+
# logging.info(f'utt_num: {utt_num}')
|
304 |
+
for batch_idx, batch in enumerate(test_data_loader):
|
305 |
+
keys = batch["keys"]
|
306 |
+
feats = batch["feats"].to(device)
|
307 |
+
target = batch["target"].to(device)
|
308 |
+
feats_lengths = batch["feats_lengths"].to(device)
|
309 |
+
target_lengths = batch["target_lengths"].to(device)
|
310 |
+
batch_size = feats.size(0)
|
311 |
+
|
312 |
+
import random
|
313 |
+
if '><' in args.task:
|
314 |
+
args.task = args.task.replace('><', '> <')
|
315 |
+
if args.task == "<TRANSCRIBE>" or args.task == "<transcribe>":
|
316 |
+
is_truncation = False
|
317 |
+
else:
|
318 |
+
is_truncation = True
|
319 |
+
random_index = random.randint(0, len(global_prompt_dict[args.task])-1)
|
320 |
+
prompt = global_prompt_dict[args.task][random_index]
|
321 |
+
# print(args.task, prompt)
|
322 |
+
|
323 |
+
res_text = model.generate(wavs=feats, wavs_len=feats_lengths, prompt=prompt)
|
324 |
+
for mode in modes:
|
325 |
+
line = "{}\t{}".format(keys[0], res_text[0])
|
326 |
+
files[mode].write(line+'\n')
|
327 |
+
utils_file.logging_print( '{} {} {}'.format(batch_idx, keys[0], res_text[0]))
|
328 |
+
if batch_idx % 100 == 0:
|
329 |
+
for mode, f in files.items():
|
330 |
+
f.flush() # 强制将缓冲区内容刷新到文件
|
331 |
+
# if batch_idx >= 1000 and is_truncation:
|
332 |
+
# utils_file.logging_info('采用截断至3000的策略')
|
333 |
+
# break
|
334 |
+
for mode, f in files.items():
|
335 |
+
f.flush() # 强制将缓冲区内容刷新到文件
|
336 |
+
f.close()
|
337 |
+
|
338 |
+
|
339 |
+
if __name__ == '__main__':
|
340 |
+
main()
|
wenet/bin/recognize_onnx_gpu.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
16 |
+
#
|
17 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
18 |
+
# you may not use this file except in compliance with the License.
|
19 |
+
# You may obtain a copy of the License at
|
20 |
+
#
|
21 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
22 |
+
#
|
23 |
+
# Unless required by applicable law or agreed to in writing, software
|
24 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
25 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
26 |
+
# See the License for the specific language governing permissions and
|
27 |
+
# limitations under the License.
|
28 |
+
"""
|
29 |
+
This script is for testing exported onnx encoder and decoder from
|
30 |
+
export_onnx_gpu.py. The exported onnx models only support batch offline ASR inference.
|
31 |
+
It requires a python wrapped c++ ctc decoder.
|
32 |
+
Please install it by following:
|
33 |
+
https://github.com/Slyne/ctc_decoder.git
|
34 |
+
"""
|
35 |
+
from __future__ import print_function
|
36 |
+
|
37 |
+
import argparse
|
38 |
+
import copy
|
39 |
+
import logging
|
40 |
+
import os
|
41 |
+
import sys
|
42 |
+
|
43 |
+
import torch
|
44 |
+
import yaml
|
45 |
+
from torch.utils.data import DataLoader
|
46 |
+
|
47 |
+
from wenet.dataset.dataset import Dataset
|
48 |
+
from wenet.utils.common import IGNORE_ID
|
49 |
+
from wenet.utils.config import override_config
|
50 |
+
from wenet.utils.init_tokenizer import init_tokenizer
|
51 |
+
|
52 |
+
import onnxruntime as rt
|
53 |
+
import multiprocessing
|
54 |
+
import numpy as np
|
55 |
+
|
56 |
+
try:
|
57 |
+
from swig_decoders import map_batch, \
|
58 |
+
ctc_beam_search_decoder_batch, \
|
59 |
+
TrieVector, PathTrie
|
60 |
+
except ImportError:
|
61 |
+
print('Please install ctc decoders first by refering to\n' +
|
62 |
+
'https://github.com/Slyne/ctc_decoder.git')
|
63 |
+
sys.exit(1)
|
64 |
+
|
65 |
+
def get_args():
|
66 |
+
parser = argparse.ArgumentParser(description='recognize with your model')
|
67 |
+
parser.add_argument('--config', required=True, help='config file')
|
68 |
+
parser.add_argument('--test_data', required=True, help='test data file')
|
69 |
+
parser.add_argument('--data_type',
|
70 |
+
default='raw',
|
71 |
+
choices=['raw', 'shard'],
|
72 |
+
help='train and cv data type')
|
73 |
+
parser.add_argument('--gpu',
|
74 |
+
type=int,
|
75 |
+
default=-1,
|
76 |
+
help='gpu id for this rank, -1 for cpu')
|
77 |
+
parser.add_argument('--dict', required=True, help='dict file')
|
78 |
+
parser.add_argument('--encoder_onnx',
|
79 |
+
required=True,
|
80 |
+
help='encoder onnx file')
|
81 |
+
parser.add_argument('--decoder_onnx',
|
82 |
+
required=True,
|
83 |
+
help='decoder onnx file')
|
84 |
+
parser.add_argument('--result_file', required=True, help='asr result file')
|
85 |
+
parser.add_argument('--batch_size',
|
86 |
+
type=int,
|
87 |
+
default=32,
|
88 |
+
help='asr result file')
|
89 |
+
parser.add_argument('--mode',
|
90 |
+
choices=[
|
91 |
+
'ctc_greedy_search', 'ctc_prefix_beam_search',
|
92 |
+
'attention_rescoring'
|
93 |
+
],
|
94 |
+
default='attention_rescoring',
|
95 |
+
help='decoding mode')
|
96 |
+
parser.add_argument('--bpe_model',
|
97 |
+
default=None,
|
98 |
+
type=str,
|
99 |
+
help='bpe model for english part')
|
100 |
+
parser.add_argument('--override_config',
|
101 |
+
action='append',
|
102 |
+
default=[],
|
103 |
+
help="override yaml config")
|
104 |
+
parser.add_argument('--fp16',
|
105 |
+
action='store_true',
|
106 |
+
help='whether to export fp16 model, default false')
|
107 |
+
args = parser.parse_args()
|
108 |
+
return args
|
109 |
+
|
110 |
+
def main():
|
111 |
+
args = get_args()
|
112 |
+
logging.basicConfig(level=logging.DEBUG,
|
113 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
114 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
115 |
+
|
116 |
+
with open(args.config, 'r') as fin:
|
117 |
+
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
118 |
+
if len(args.override_config) > 0:
|
119 |
+
configs = override_config(configs, args.override_config)
|
120 |
+
|
121 |
+
reverse_weight = configs["model_conf"].get("reverse_weight", 0.0)
|
122 |
+
special_tokens = configs.get('tokenizer_conf', {}).get('special_tokens', None)
|
123 |
+
test_conf = copy.deepcopy(configs['dataset_conf'])
|
124 |
+
test_conf['filter_conf']['max_length'] = 102400
|
125 |
+
test_conf['filter_conf']['min_length'] = 0
|
126 |
+
test_conf['filter_conf']['token_max_length'] = 102400
|
127 |
+
test_conf['filter_conf']['token_min_length'] = 0
|
128 |
+
test_conf['filter_conf']['max_output_input_ratio'] = 102400
|
129 |
+
test_conf['filter_conf']['min_output_input_ratio'] = 0
|
130 |
+
test_conf['speed_perturb'] = False
|
131 |
+
test_conf['spec_aug'] = False
|
132 |
+
test_conf['spec_sub'] = False
|
133 |
+
test_conf['spec_trim'] = False
|
134 |
+
test_conf['shuffle'] = False
|
135 |
+
test_conf['sort'] = False
|
136 |
+
test_conf['fbank_conf']['dither'] = 0.0
|
137 |
+
test_conf['batch_conf']['batch_type'] = "static"
|
138 |
+
test_conf['batch_conf']['batch_size'] = args.batch_size
|
139 |
+
|
140 |
+
tokenizer = init_tokenizer(configs)
|
141 |
+
test_dataset = Dataset(args.data_type,
|
142 |
+
args.test_data,
|
143 |
+
tokenizer,
|
144 |
+
test_conf,
|
145 |
+
partition=False)
|
146 |
+
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
147 |
+
|
148 |
+
# Init asr model from configs
|
149 |
+
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
150 |
+
if use_cuda:
|
151 |
+
EP_list = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
152 |
+
else:
|
153 |
+
EP_list = ['CPUExecutionProvider']
|
154 |
+
|
155 |
+
encoder_ort_session = rt.InferenceSession(args.encoder_onnx,
|
156 |
+
providers=EP_list)
|
157 |
+
decoder_ort_session = None
|
158 |
+
if args.mode == "attention_rescoring":
|
159 |
+
decoder_ort_session = rt.InferenceSession(args.decoder_onnx,
|
160 |
+
providers=EP_list)
|
161 |
+
|
162 |
+
# Load dict
|
163 |
+
vocabulary = []
|
164 |
+
char_dict = {}
|
165 |
+
with open(args.dict, 'r') as fin:
|
166 |
+
for line in fin:
|
167 |
+
arr = line.strip().split()
|
168 |
+
assert len(arr) == 2
|
169 |
+
char_dict[int(arr[1])] = arr[0]
|
170 |
+
vocabulary.append(arr[0])
|
171 |
+
|
172 |
+
vocab_size = len(char_dict)
|
173 |
+
sos = (vocab_size - 1 if special_tokens is None else
|
174 |
+
special_tokens.get("<sos>", vocab_size - 1))
|
175 |
+
eos = (vocab_size - 1 if special_tokens is None else
|
176 |
+
special_tokens.get("<eos>", vocab_size - 1))
|
177 |
+
|
178 |
+
with torch.no_grad(), open(args.result_file, 'w') as fout:
|
179 |
+
for _, batch in enumerate(test_data_loader):
|
180 |
+
keys = batch['keys']
|
181 |
+
feats = batch['feats']
|
182 |
+
feats_lengths = batch['feats_lengths']
|
183 |
+
feats, feats_lengths = feats.numpy(), feats_lengths.numpy()
|
184 |
+
if args.fp16:
|
185 |
+
feats = feats.astype(np.float16)
|
186 |
+
ort_inputs = {
|
187 |
+
encoder_ort_session.get_inputs()[0].name: feats,
|
188 |
+
encoder_ort_session.get_inputs()[1].name: feats_lengths
|
189 |
+
}
|
190 |
+
ort_outs = encoder_ort_session.run(None, ort_inputs)
|
191 |
+
encoder_out, encoder_out_lens, ctc_log_probs, \
|
192 |
+
beam_log_probs, beam_log_probs_idx = ort_outs
|
193 |
+
beam_size = beam_log_probs.shape[-1]
|
194 |
+
batch_size = beam_log_probs.shape[0]
|
195 |
+
num_processes = min(multiprocessing.cpu_count(), batch_size)
|
196 |
+
if args.mode == 'ctc_greedy_search':
|
197 |
+
if beam_size != 1:
|
198 |
+
log_probs_idx = beam_log_probs_idx[:, :, 0]
|
199 |
+
batch_sents = []
|
200 |
+
for idx, seq in enumerate(log_probs_idx):
|
201 |
+
batch_sents.append(seq[0:encoder_out_lens[idx]].tolist())
|
202 |
+
hyps = map_batch(batch_sents, vocabulary, num_processes, True,
|
203 |
+
0)
|
204 |
+
elif args.mode in ('ctc_prefix_beam_search',
|
205 |
+
"attention_rescoring"):
|
206 |
+
batch_log_probs_seq_list = beam_log_probs.tolist()
|
207 |
+
batch_log_probs_idx_list = beam_log_probs_idx.tolist()
|
208 |
+
batch_len_list = encoder_out_lens.tolist()
|
209 |
+
batch_log_probs_seq = []
|
210 |
+
batch_log_probs_ids = []
|
211 |
+
batch_start = [] # only effective in streaming deployment
|
212 |
+
batch_root = TrieVector()
|
213 |
+
root_dict = {}
|
214 |
+
for i in range(len(batch_len_list)):
|
215 |
+
num_sent = batch_len_list[i]
|
216 |
+
batch_log_probs_seq.append(
|
217 |
+
batch_log_probs_seq_list[i][0:num_sent])
|
218 |
+
batch_log_probs_ids.append(
|
219 |
+
batch_log_probs_idx_list[i][0:num_sent])
|
220 |
+
root_dict[i] = PathTrie()
|
221 |
+
batch_root.append(root_dict[i])
|
222 |
+
batch_start.append(True)
|
223 |
+
score_hyps = ctc_beam_search_decoder_batch(
|
224 |
+
batch_log_probs_seq, batch_log_probs_ids, batch_root,
|
225 |
+
batch_start, beam_size, num_processes, 0, -2, 0.99999)
|
226 |
+
if args.mode == 'ctc_prefix_beam_search':
|
227 |
+
hyps = []
|
228 |
+
for cand_hyps in score_hyps:
|
229 |
+
hyps.append(cand_hyps[0][1])
|
230 |
+
hyps = map_batch(hyps, vocabulary, num_processes, False, 0)
|
231 |
+
if args.mode == 'attention_rescoring':
|
232 |
+
ctc_score, all_hyps = [], []
|
233 |
+
max_len = 0
|
234 |
+
for hyps in score_hyps:
|
235 |
+
cur_len = len(hyps)
|
236 |
+
if len(hyps) < beam_size:
|
237 |
+
hyps += (beam_size - cur_len) * [(-float("INF"),
|
238 |
+
(0, ))]
|
239 |
+
cur_ctc_score = []
|
240 |
+
for hyp in hyps:
|
241 |
+
cur_ctc_score.append(hyp[0])
|
242 |
+
all_hyps.append(list(hyp[1]))
|
243 |
+
if len(hyp[1]) > max_len:
|
244 |
+
max_len = len(hyp[1])
|
245 |
+
ctc_score.append(cur_ctc_score)
|
246 |
+
if args.fp16:
|
247 |
+
ctc_score = np.array(ctc_score, dtype=np.float16)
|
248 |
+
else:
|
249 |
+
ctc_score = np.array(ctc_score, dtype=np.float32)
|
250 |
+
hyps_pad_sos_eos = np.ones(
|
251 |
+
(batch_size, beam_size, max_len + 2),
|
252 |
+
dtype=np.int64) * IGNORE_ID
|
253 |
+
r_hyps_pad_sos_eos = np.ones(
|
254 |
+
(batch_size, beam_size, max_len + 2),
|
255 |
+
dtype=np.int64) * IGNORE_ID
|
256 |
+
hyps_lens_sos = np.ones((batch_size, beam_size),
|
257 |
+
dtype=np.int32)
|
258 |
+
k = 0
|
259 |
+
for i in range(batch_size):
|
260 |
+
for j in range(beam_size):
|
261 |
+
cand = all_hyps[k]
|
262 |
+
l = len(cand) + 2
|
263 |
+
hyps_pad_sos_eos[i][j][0:l] = [sos] + cand + [eos]
|
264 |
+
r_hyps_pad_sos_eos[i][j][0:l] = [sos] + cand[::-1] + [
|
265 |
+
eos
|
266 |
+
]
|
267 |
+
hyps_lens_sos[i][j] = len(cand) + 1
|
268 |
+
k += 1
|
269 |
+
decoder_ort_inputs = {
|
270 |
+
decoder_ort_session.get_inputs()[0].name: encoder_out,
|
271 |
+
decoder_ort_session.get_inputs()[1].name: encoder_out_lens,
|
272 |
+
decoder_ort_session.get_inputs()[2].name: hyps_pad_sos_eos,
|
273 |
+
decoder_ort_session.get_inputs()[3].name: hyps_lens_sos,
|
274 |
+
decoder_ort_session.get_inputs()[-1].name: ctc_score
|
275 |
+
}
|
276 |
+
if reverse_weight > 0:
|
277 |
+
r_hyps_pad_sos_eos_name = decoder_ort_session.get_inputs(
|
278 |
+
)[4].name
|
279 |
+
decoder_ort_inputs[
|
280 |
+
r_hyps_pad_sos_eos_name] = r_hyps_pad_sos_eos
|
281 |
+
best_index = decoder_ort_session.run(None,
|
282 |
+
decoder_ort_inputs)[0]
|
283 |
+
best_sents = []
|
284 |
+
k = 0
|
285 |
+
for idx in best_index:
|
286 |
+
cur_best_sent = all_hyps[k:k + beam_size][idx]
|
287 |
+
best_sents.append(cur_best_sent)
|
288 |
+
k += beam_size
|
289 |
+
hyps = map_batch(best_sents, vocabulary, num_processes)
|
290 |
+
|
291 |
+
for i, key in enumerate(keys):
|
292 |
+
content = hyps[i]
|
293 |
+
logging.info('{} {}'.format(key, content))
|
294 |
+
fout.write('{} {}\n'.format(key, content))
|
295 |
+
|
296 |
+
if __name__ == '__main__':
|
297 |
+
main()
|
wenet/bin/train.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import print_function
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import datetime
|
19 |
+
import logging
|
20 |
+
import os
|
21 |
+
import random
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import yaml
|
25 |
+
import torch
|
26 |
+
|
27 |
+
import torch.distributed as dist
|
28 |
+
|
29 |
+
from torch.distributed.elastic.multiprocessing.errors import record
|
30 |
+
from wenet.utils.common import lrs_to_str, TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu
|
31 |
+
|
32 |
+
from wenet.utils.executor import Executor
|
33 |
+
from wenet.utils.config import override_config
|
34 |
+
from wenet.utils.init_model import init_model
|
35 |
+
from wenet.utils.init_tokenizer import init_tokenizer
|
36 |
+
from wenet.utils.train_utils import (
|
37 |
+
add_fsdp_args, add_model_args, add_dataset_args, add_ddp_args,
|
38 |
+
add_deepspeed_args, add_trace_args, init_distributed,
|
39 |
+
init_dataset_and_dataloader, check_modify_and_save_config,
|
40 |
+
init_optimizer_and_scheduler, init_scaler, trace_and_print_model,
|
41 |
+
wrap_cuda_model, init_summarywriter, save_model, log_per_epoch,
|
42 |
+
add_lora_args, reinit_lora)
|
43 |
+
from gxl_ai_utils.utils import utils_file
|
44 |
+
|
45 |
+
try:
|
46 |
+
import torch_npu
|
47 |
+
|
48 |
+
torch_npu.npu.conv.allow_hf32 = False
|
49 |
+
# import deepspeed_npu
|
50 |
+
from torch_npu.npu import amp
|
51 |
+
from torch_npu.contrib import transfer_to_npu
|
52 |
+
except ImportError:
|
53 |
+
utils_file.logging_warning(
|
54 |
+
"torch_npu is not installed, please install torch_npu first if you want to use torch_npu")
|
55 |
+
torch.backends.cudnn.allow_tf32 = False
|
56 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
57 |
+
|
58 |
+
from msprobe.pytorch import seed_all
|
59 |
+
import gc
|
60 |
+
|
61 |
+
gc.set_threshold(700, 10, 10000) # python gc阈值设置
|
62 |
+
|
63 |
+
|
64 |
+
# import deepspeed_npu
|
65 |
+
def get_args():
|
66 |
+
parser = argparse.ArgumentParser(description='training your network')
|
67 |
+
parser.add_argument('--train_engine',
|
68 |
+
default='torch_ddp',
|
69 |
+
choices=['torch_ddp', 'torch_fsdp', 'deepspeed'],
|
70 |
+
help='Engine for paralleled training')
|
71 |
+
# set default value of device to "cuda", avoiding the modify of original scripts
|
72 |
+
parser.add_argument('--device',
|
73 |
+
type=str,
|
74 |
+
default='cuda',
|
75 |
+
choices=["cpu", "npu", "cuda"],
|
76 |
+
help='accelerator for training')
|
77 |
+
# load deepspeed checkpoint
|
78 |
+
parser.add_argument('--load_dir',
|
79 |
+
type=str,
|
80 |
+
default=None)
|
81 |
+
parser.add_argument('--ckpt_id',
|
82 |
+
type=str,
|
83 |
+
default=None)
|
84 |
+
parser = add_model_args(parser)
|
85 |
+
parser = add_dataset_args(parser)
|
86 |
+
parser = add_ddp_args(parser)
|
87 |
+
parser = add_lora_args(parser)
|
88 |
+
parser = add_deepspeed_args(parser)
|
89 |
+
parser = add_fsdp_args(parser)
|
90 |
+
parser = add_trace_args(parser)
|
91 |
+
args = parser.parse_args()
|
92 |
+
if args.train_engine == "deepspeed":
|
93 |
+
args.deepspeed = True
|
94 |
+
assert args.deepspeed_config is not None
|
95 |
+
return args
|
96 |
+
|
97 |
+
|
98 |
+
# NOTE(xcsong): On worker errors, this recod tool will summarize the
|
99 |
+
# details of the error (e.g. time, rank, host, pid, traceback, etc).
|
100 |
+
@record
|
101 |
+
def main():
|
102 |
+
args = get_args()
|
103 |
+
logging.basicConfig(level=logging.DEBUG,
|
104 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
105 |
+
|
106 |
+
# Set random seed
|
107 |
+
torch.manual_seed(777)
|
108 |
+
random.seed(777)
|
109 |
+
np.random.seed(777)
|
110 |
+
utils_file.logging_info('开始严格seed')
|
111 |
+
seed_all(777)
|
112 |
+
utils_file.logging_info('结束严格seed')
|
113 |
+
logging.info('Random seed set to {}'.format(777))
|
114 |
+
|
115 |
+
# Read config
|
116 |
+
with open(args.config, 'r') as fin:
|
117 |
+
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
118 |
+
if len(args.override_config) > 0:
|
119 |
+
configs = override_config(configs, args.override_config)
|
120 |
+
|
121 |
+
# init tokenizer
|
122 |
+
tokenizer = init_tokenizer(configs)
|
123 |
+
|
124 |
+
# Init env for ddp OR deepspeed
|
125 |
+
_, _, rank = init_distributed(args)
|
126 |
+
|
127 |
+
# Init asr model from configs
|
128 |
+
model, configs = init_model(args, configs)
|
129 |
+
|
130 |
+
# Get dataset & dataloader
|
131 |
+
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
|
132 |
+
init_dataset_and_dataloader(args, configs, tokenizer)
|
133 |
+
|
134 |
+
# Do some sanity checks and save config to arsg.model_dir
|
135 |
+
configs = check_modify_and_save_config(args, configs,
|
136 |
+
tokenizer.symbol_table)
|
137 |
+
|
138 |
+
if hasattr(args, 'lora_reinit') and args.lora_reinit:
|
139 |
+
reinit_lora(model, args, configs, tokenizer)
|
140 |
+
|
141 |
+
# Check model is jitable & print model archtectures
|
142 |
+
trace_and_print_model(args, model)
|
143 |
+
|
144 |
+
# Tensorboard summary
|
145 |
+
writer = init_summarywriter(args)
|
146 |
+
|
147 |
+
# Dispatch model from cpu to gpu
|
148 |
+
model, device = wrap_cuda_model(args, model, configs)
|
149 |
+
|
150 |
+
# Get optimizer & scheduler
|
151 |
+
model, optimizer, scheduler = init_optimizer_and_scheduler(
|
152 |
+
args, configs, model)
|
153 |
+
|
154 |
+
# Load deepspeed checkpoint
|
155 |
+
if args.load_dir is not None and \
|
156 |
+
args.ckpt_id is not None:
|
157 |
+
_, client_sd = model.load_checkpoint(args.load_dir, args.ckpt_id)
|
158 |
+
|
159 |
+
# Save checkpoints
|
160 |
+
# save_model(model,
|
161 |
+
# info_dict={
|
162 |
+
# "save_time":
|
163 |
+
# datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'),
|
164 |
+
# "tag":
|
165 |
+
# "init",
|
166 |
+
# **configs
|
167 |
+
# })
|
168 |
+
|
169 |
+
# Get executor
|
170 |
+
tag = configs["init_infos"].get("tag", "init")
|
171 |
+
executor = Executor(global_step=configs["init_infos"].get('step', -1),
|
172 |
+
device=device)
|
173 |
+
|
174 |
+
# Init scaler, used for pytorch amp mixed precision training
|
175 |
+
scaler = init_scaler(args)
|
176 |
+
|
177 |
+
# Start training loop
|
178 |
+
start_epoch = configs["init_infos"].get('epoch', 0) + int("epoch_" in tag)
|
179 |
+
# if save_interval in configs, steps mode else epoch mode
|
180 |
+
end_epoch = configs.get('max_epoch', 100)
|
181 |
+
assert start_epoch <= end_epoch
|
182 |
+
configs.pop("init_infos", None)
|
183 |
+
final_epoch = None
|
184 |
+
for epoch in range(start_epoch, end_epoch):
|
185 |
+
configs['epoch'] = epoch
|
186 |
+
|
187 |
+
lrs = [group['lr'] for group in optimizer.param_groups]
|
188 |
+
logging.info('Epoch {} Step {} TRAIN info lr {} rank {}'.format(
|
189 |
+
epoch, executor.step, lrs_to_str(lrs), rank))
|
190 |
+
|
191 |
+
dist.barrier(
|
192 |
+
) # NOTE(xcsong): Ensure all ranks start Train at the same time.
|
193 |
+
# NOTE(xcsong): Why we need a new group? see `train_utils.py::wenet_join`
|
194 |
+
group_join = dist.new_group( # fix by zhaoyi for 多机训练
|
195 |
+
backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
|
196 |
+
# group_join = None
|
197 |
+
executor.train(model, optimizer, scheduler, train_data_loader,
|
198 |
+
cv_data_loader, writer, configs, scaler, group_join)
|
199 |
+
# dist.destroy_process_group(group_join)
|
200 |
+
|
201 |
+
dist.barrier(
|
202 |
+
) # NOTE(xcsong): Ensure all ranks start CV at the same time.
|
203 |
+
loss_dict = executor.cv(model, cv_data_loader, configs)
|
204 |
+
info_dict = {
|
205 |
+
'epoch': epoch,
|
206 |
+
'lrs': [group['lr'] for group in optimizer.param_groups],
|
207 |
+
'step': executor.step,
|
208 |
+
"loss_dict": loss_dict,
|
209 |
+
'save_time': datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'),
|
210 |
+
'tag': "epoch_{}".format(epoch),
|
211 |
+
'loss_dict': loss_dict,
|
212 |
+
**configs
|
213 |
+
}
|
214 |
+
# epoch cv: tensorboard && log
|
215 |
+
log_per_epoch(writer, info_dict=info_dict)
|
216 |
+
save_model(model, info_dict=info_dict)
|
217 |
+
|
218 |
+
final_epoch = epoch
|
219 |
+
|
220 |
+
if final_epoch is not None and rank == 0:
|
221 |
+
final_model_path = os.path.join(args.model_dir, 'final.pt')
|
222 |
+
os.remove(final_model_path) if os.path.exists(
|
223 |
+
final_model_path) else None
|
224 |
+
os.symlink('{}.pt'.format(final_epoch), final_model_path)
|
225 |
+
writer.close()
|
226 |
+
dist.barrier(
|
227 |
+
) # NOTE(yktian): Ensure all ranks end Train before destroy process group.
|
228 |
+
dist.destroy_process_group()
|
229 |
+
|
230 |
+
|
231 |
+
if __name__ == '__main__':
|
232 |
+
main()
|
wenet/branchformer/__init__.py
ADDED
File without changes
|
wenet/branchformer/cgmlp.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Yifan Peng (Carnegie Mellon University)
|
2 |
+
# 2023 Voicecomm Inc (Kai Li)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""MLP with convolutional gating (cgMLP) definition.
|
17 |
+
|
18 |
+
References:
|
19 |
+
https://openreview.net/forum?id=RA-zVvZLYIy
|
20 |
+
https://arxiv.org/abs/2105.08050
|
21 |
+
|
22 |
+
"""
|
23 |
+
|
24 |
+
from typing import Tuple
|
25 |
+
import torch
|
26 |
+
import torch.nn as nn
|
27 |
+
from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES
|
28 |
+
|
29 |
+
|
30 |
+
class ConvolutionalSpatialGatingUnit(torch.nn.Module):
|
31 |
+
"""Convolutional Spatial Gating Unit (CSGU)."""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
size: int,
|
36 |
+
kernel_size: int,
|
37 |
+
dropout_rate: float,
|
38 |
+
use_linear_after_conv: bool,
|
39 |
+
gate_activation: str,
|
40 |
+
causal: bool = True,
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
# split input channels
|
45 |
+
n_channels = size // 2
|
46 |
+
self.norm = nn.LayerNorm(n_channels)
|
47 |
+
# self.lorder is used to distinguish if it's a causal convolution,
|
48 |
+
# if self.lorder > 0: it's a causal convolution, the input will be
|
49 |
+
# padded with self.lorder frames on the left in forward.
|
50 |
+
# else: it's a symmetrical convolution
|
51 |
+
if causal:
|
52 |
+
padding = 0
|
53 |
+
self.lorder = kernel_size - 1
|
54 |
+
else:
|
55 |
+
# kernel_size should be an odd number for none causal convolution
|
56 |
+
assert (kernel_size - 1) % 2 == 0
|
57 |
+
padding = (kernel_size - 1) // 2
|
58 |
+
self.lorder = 0
|
59 |
+
self.conv = torch.nn.Conv1d(
|
60 |
+
n_channels,
|
61 |
+
n_channels,
|
62 |
+
kernel_size,
|
63 |
+
1,
|
64 |
+
padding,
|
65 |
+
groups=n_channels,
|
66 |
+
)
|
67 |
+
if use_linear_after_conv:
|
68 |
+
self.linear = torch.nn.Linear(n_channels, n_channels)
|
69 |
+
else:
|
70 |
+
self.linear = None
|
71 |
+
|
72 |
+
if gate_activation == "identity":
|
73 |
+
self.act = torch.nn.Identity()
|
74 |
+
else:
|
75 |
+
self.act = WENET_ACTIVATION_CLASSES[gate_activation]()
|
76 |
+
|
77 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
78 |
+
|
79 |
+
def espnet_initialization_fn(self):
|
80 |
+
torch.nn.init.normal_(self.conv.weight, std=1e-6)
|
81 |
+
torch.nn.init.ones_(self.conv.bias)
|
82 |
+
if self.linear is not None:
|
83 |
+
torch.nn.init.normal_(self.linear.weight, std=1e-6)
|
84 |
+
torch.nn.init.ones_(self.linear.bias)
|
85 |
+
|
86 |
+
def forward(
|
87 |
+
self, x: torch.Tensor, cache: torch.Tensor = torch.zeros((0, 0, 0))
|
88 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
89 |
+
"""Forward method
|
90 |
+
|
91 |
+
Args:
|
92 |
+
x (torch.Tensor): (batch, time, channels)
|
93 |
+
cache (torch.Tensor): left context cache, it is only
|
94 |
+
used in causal convolution (#batch, channels, cache_t),
|
95 |
+
(0, 0, 0) meas fake cache.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
out (torch.Tensor): (batch, time, channels/2)
|
99 |
+
"""
|
100 |
+
|
101 |
+
x_r, x_g = x.chunk(2, dim=-1)
|
102 |
+
# exchange the temporal dimension and the feature dimension
|
103 |
+
x_g = x_g.transpose(1, 2) # (#batch, channels, time)
|
104 |
+
|
105 |
+
if self.lorder > 0:
|
106 |
+
if cache.size(2) == 0: # cache_t == 0
|
107 |
+
x_g = nn.functional.pad(x_g, (self.lorder, 0), 'constant', 0.0)
|
108 |
+
else:
|
109 |
+
assert cache.size(0) == x_g.size(0) # equal batch
|
110 |
+
assert cache.size(1) == x_g.size(1) # equal channel
|
111 |
+
x_g = torch.cat((cache, x_g), dim=2)
|
112 |
+
assert (x_g.size(2) > self.lorder)
|
113 |
+
new_cache = x_g[:, :, -self.lorder:]
|
114 |
+
else:
|
115 |
+
# It's better we just return None if no cache is required,
|
116 |
+
# However, for JIT export, here we just fake one tensor instead of
|
117 |
+
# None.
|
118 |
+
new_cache = torch.zeros((0, 0, 0),
|
119 |
+
dtype=x_g.dtype,
|
120 |
+
device=x_g.device)
|
121 |
+
|
122 |
+
x_g = x_g.transpose(1, 2)
|
123 |
+
x_g = self.norm(x_g) # (N, T, D/2)
|
124 |
+
x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2) # (N, T, D/2)
|
125 |
+
if self.linear is not None:
|
126 |
+
x_g = self.linear(x_g)
|
127 |
+
|
128 |
+
x_g = self.act(x_g)
|
129 |
+
out = x_r * x_g # (N, T, D/2)
|
130 |
+
out = self.dropout(out)
|
131 |
+
return out, new_cache
|
132 |
+
|
133 |
+
|
134 |
+
class ConvolutionalGatingMLP(torch.nn.Module):
|
135 |
+
"""Convolutional Gating MLP (cgMLP)."""
|
136 |
+
|
137 |
+
def __init__(
|
138 |
+
self,
|
139 |
+
size: int,
|
140 |
+
linear_units: int,
|
141 |
+
kernel_size: int,
|
142 |
+
dropout_rate: float,
|
143 |
+
use_linear_after_conv: bool,
|
144 |
+
gate_activation: str,
|
145 |
+
causal: bool = True,
|
146 |
+
):
|
147 |
+
super().__init__()
|
148 |
+
|
149 |
+
self.channel_proj1 = torch.nn.Sequential(
|
150 |
+
torch.nn.Linear(size, linear_units), torch.nn.GELU())
|
151 |
+
self.csgu = ConvolutionalSpatialGatingUnit(
|
152 |
+
size=linear_units,
|
153 |
+
kernel_size=kernel_size,
|
154 |
+
dropout_rate=dropout_rate,
|
155 |
+
use_linear_after_conv=use_linear_after_conv,
|
156 |
+
gate_activation=gate_activation,
|
157 |
+
causal=causal,
|
158 |
+
)
|
159 |
+
self.channel_proj2 = torch.nn.Linear(linear_units // 2, size)
|
160 |
+
|
161 |
+
def forward(
|
162 |
+
self,
|
163 |
+
x: torch.Tensor,
|
164 |
+
mask: torch.Tensor,
|
165 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0))
|
166 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
167 |
+
"""Forward method
|
168 |
+
|
169 |
+
Args:
|
170 |
+
x (torch.Tensor): (batch, time, channels)
|
171 |
+
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
|
172 |
+
(0, 0, 0) means fake mask. Not used yet
|
173 |
+
cache (torch.Tensor): left context cache, it is only
|
174 |
+
used in causal convolution (#batch, channels, cache_t),
|
175 |
+
(0, 0, 0) meas fake cache.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
out (torch.Tensor): (batch, time, channels/2)
|
179 |
+
"""
|
180 |
+
|
181 |
+
xs_pad = x
|
182 |
+
|
183 |
+
# size -> linear_units
|
184 |
+
xs_pad = self.channel_proj1(xs_pad)
|
185 |
+
|
186 |
+
# linear_units -> linear_units/2
|
187 |
+
xs_pad, new_cnn_cache = self.csgu(xs_pad, cache)
|
188 |
+
|
189 |
+
# linear_units/2 -> size
|
190 |
+
xs_pad = self.channel_proj2(xs_pad)
|
191 |
+
|
192 |
+
out = xs_pad
|
193 |
+
|
194 |
+
return out, new_cnn_cache
|
wenet/branchformer/encoder.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Yifan Peng (Carnegie Mellon University)
|
2 |
+
# 2023 Voicecomm Inc (Kai Li)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""Encoder definition."""
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from typing import List, Optional, Union
|
21 |
+
|
22 |
+
from wenet.branchformer.encoder_layer import BranchformerEncoderLayer
|
23 |
+
from wenet.branchformer.cgmlp import ConvolutionalGatingMLP
|
24 |
+
from wenet.transformer.encoder import BaseEncoder
|
25 |
+
from wenet.utils.class_utils import (
|
26 |
+
WENET_ATTENTION_CLASSES, )
|
27 |
+
|
28 |
+
|
29 |
+
class BranchformerEncoder(BaseEncoder):
|
30 |
+
"""Branchformer encoder module."""
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
input_size: int,
|
35 |
+
output_size: int = 256,
|
36 |
+
use_attn: bool = True,
|
37 |
+
attention_heads: int = 4,
|
38 |
+
selfattention_layer_type: str = "rel_selfattn",
|
39 |
+
pos_enc_layer_type: str = "rel_pos",
|
40 |
+
use_cgmlp: bool = True,
|
41 |
+
cgmlp_linear_units: int = 2048,
|
42 |
+
cgmlp_conv_kernel: int = 31,
|
43 |
+
use_linear_after_conv: bool = False,
|
44 |
+
gate_activation: str = "identity",
|
45 |
+
merge_method: str = "concat",
|
46 |
+
cgmlp_weight: Union[float, List[float]] = 0.5,
|
47 |
+
attn_branch_drop_rate: Union[float, List[float]] = 0.0,
|
48 |
+
num_blocks: int = 12,
|
49 |
+
dropout_rate: float = 0.1,
|
50 |
+
positional_dropout_rate: float = 0.1,
|
51 |
+
attention_dropout_rate: float = 0.0,
|
52 |
+
input_layer: str = "conv2d",
|
53 |
+
stochastic_depth_rate: Union[float, List[float]] = 0.0,
|
54 |
+
static_chunk_size: int = 0,
|
55 |
+
use_dynamic_chunk: bool = False,
|
56 |
+
global_cmvn: torch.nn.Module = None,
|
57 |
+
use_dynamic_left_chunk: bool = False,
|
58 |
+
causal: bool = False,
|
59 |
+
query_bias: bool = True,
|
60 |
+
key_bias: bool = True,
|
61 |
+
value_bias: bool = True,
|
62 |
+
gradient_checkpointing: bool = False,
|
63 |
+
use_sdpa: bool = False,
|
64 |
+
layer_norm_type: str = 'layer_norm',
|
65 |
+
norm_eps: float = 1e-5,
|
66 |
+
n_kv_head: Optional[int] = None,
|
67 |
+
head_dim: Optional[int] = None,
|
68 |
+
):
|
69 |
+
super().__init__(input_size, output_size, attention_heads,
|
70 |
+
cgmlp_linear_units, num_blocks, dropout_rate,
|
71 |
+
positional_dropout_rate, attention_dropout_rate,
|
72 |
+
input_layer, pos_enc_layer_type, True,
|
73 |
+
static_chunk_size, use_dynamic_chunk, global_cmvn,
|
74 |
+
use_dynamic_left_chunk, gradient_checkpointing,
|
75 |
+
use_sdpa, layer_norm_type, norm_eps)
|
76 |
+
|
77 |
+
encoder_selfattn_layer_args = (
|
78 |
+
attention_heads,
|
79 |
+
output_size,
|
80 |
+
attention_dropout_rate,
|
81 |
+
query_bias,
|
82 |
+
key_bias,
|
83 |
+
value_bias,
|
84 |
+
use_sdpa,
|
85 |
+
n_kv_head,
|
86 |
+
head_dim,
|
87 |
+
)
|
88 |
+
|
89 |
+
cgmlp_layer = ConvolutionalGatingMLP
|
90 |
+
cgmlp_layer_args = (
|
91 |
+
output_size,
|
92 |
+
cgmlp_linear_units,
|
93 |
+
cgmlp_conv_kernel,
|
94 |
+
dropout_rate,
|
95 |
+
use_linear_after_conv,
|
96 |
+
gate_activation,
|
97 |
+
causal,
|
98 |
+
)
|
99 |
+
|
100 |
+
if isinstance(stochastic_depth_rate, float):
|
101 |
+
stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
|
102 |
+
if len(stochastic_depth_rate) != num_blocks:
|
103 |
+
raise ValueError(
|
104 |
+
f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
|
105 |
+
f"should be equal to num_blocks ({num_blocks})")
|
106 |
+
|
107 |
+
if isinstance(cgmlp_weight, float):
|
108 |
+
cgmlp_weight = [cgmlp_weight] * num_blocks
|
109 |
+
if len(cgmlp_weight) != num_blocks:
|
110 |
+
raise ValueError(
|
111 |
+
f"Length of cgmlp_weight ({len(cgmlp_weight)}) should be equal to "
|
112 |
+
f"num_blocks ({num_blocks})")
|
113 |
+
|
114 |
+
if isinstance(attn_branch_drop_rate, float):
|
115 |
+
attn_branch_drop_rate = [attn_branch_drop_rate] * num_blocks
|
116 |
+
if len(attn_branch_drop_rate) != num_blocks:
|
117 |
+
raise ValueError(
|
118 |
+
f"Length of attn_branch_drop_rate ({len(attn_branch_drop_rate)}) "
|
119 |
+
f"should be equal to num_blocks ({num_blocks})")
|
120 |
+
|
121 |
+
self.encoders = LayerDropModuleList(
|
122 |
+
p=stochastic_depth_rate,
|
123 |
+
modules=[
|
124 |
+
BranchformerEncoderLayer(
|
125 |
+
output_size,
|
126 |
+
WENET_ATTENTION_CLASSES[selfattention_layer_type](
|
127 |
+
*encoder_selfattn_layer_args) if use_attn else None,
|
128 |
+
cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None,
|
129 |
+
dropout_rate,
|
130 |
+
merge_method,
|
131 |
+
cgmlp_weight[lnum],
|
132 |
+
attn_branch_drop_rate[lnum],
|
133 |
+
stochastic_depth_rate[lnum],
|
134 |
+
) for lnum in range(num_blocks)
|
135 |
+
])
|
136 |
+
|
137 |
+
|
138 |
+
# modify from : https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/layer_drop.py # noqa
|
139 |
+
class LayerDropModuleList(torch.nn.ModuleList):
|
140 |
+
"""
|
141 |
+
A LayerDrop implementation based on :class:`torch.nn.ModuleList`.
|
142 |
+
|
143 |
+
We refresh the choice of which layers to drop every time we iterate
|
144 |
+
over the LayerDropModuleList instance. During evaluation we always
|
145 |
+
iterate over all layers.
|
146 |
+
|
147 |
+
Usage::
|
148 |
+
|
149 |
+
layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3])
|
150 |
+
for layer in layers: # this might iterate over layers 1 and 3
|
151 |
+
x = layer(x)
|
152 |
+
for layer in layers: # this might iterate over all layers
|
153 |
+
x = layer(x)
|
154 |
+
for layer in layers: # this might not iterate over any layers
|
155 |
+
x = layer(x)
|
156 |
+
|
157 |
+
Args:
|
158 |
+
p (float): probability of dropping out each layer
|
159 |
+
modules (iterable, optional): an iterable of modules to add
|
160 |
+
|
161 |
+
Limitations:
|
162 |
+
1 can work with ddp when layer's gradient checkpoint disabled
|
163 |
+
2 can't work with ddp when layer's gradient checkpoint enables
|
164 |
+
3 can work with fsdp
|
165 |
+
4 can work with deepspeed
|
166 |
+
"""
|
167 |
+
|
168 |
+
def __init__(self, p: List[float], modules=None):
|
169 |
+
super().__init__(modules)
|
170 |
+
assert len(p) == len(self)
|
171 |
+
self.p = p
|
172 |
+
|
173 |
+
def __iter__(self):
|
174 |
+
dropout_probs = torch.empty(len(self)).uniform_()
|
175 |
+
for i, m in enumerate(super().__iter__()):
|
176 |
+
if not self.training or (dropout_probs[i] > self.p[i]):
|
177 |
+
yield m
|
wenet/branchformer/encoder_layer.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Yifan Peng (Carnegie Mellon University)
|
2 |
+
# 2023 Voicecomm Inc (Kai Li)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""BranchformerEncoderLayer definition."""
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from typing import Optional, Tuple
|
21 |
+
|
22 |
+
from wenet.transformer.attention import T_CACHE
|
23 |
+
|
24 |
+
|
25 |
+
class BranchformerEncoderLayer(torch.nn.Module):
|
26 |
+
"""Branchformer encoder layer module.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
size (int): model dimension
|
30 |
+
attn: standard self-attention or efficient attention, optional
|
31 |
+
cgmlp: ConvolutionalGatingMLP, optional
|
32 |
+
dropout_rate (float): dropout probability
|
33 |
+
merge_method (str): concat, learned_ave, fixed_ave
|
34 |
+
cgmlp_weight (float): weight of the cgmlp branch, between 0 and 1,
|
35 |
+
used if merge_method is fixed_ave
|
36 |
+
attn_branch_drop_rate (float): probability of dropping the attn branch,
|
37 |
+
used if merge_method is learned_ave
|
38 |
+
stochastic_depth_rate (float): stochastic depth probability
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
size: int,
|
44 |
+
attn: Optional[torch.nn.Module],
|
45 |
+
cgmlp: Optional[torch.nn.Module],
|
46 |
+
dropout_rate: float,
|
47 |
+
merge_method: str,
|
48 |
+
cgmlp_weight: float = 0.5,
|
49 |
+
attn_branch_drop_rate: float = 0.0,
|
50 |
+
stochastic_depth_rate: float = 0.0,
|
51 |
+
):
|
52 |
+
super().__init__()
|
53 |
+
assert (attn is not None) or (
|
54 |
+
cgmlp is not None), "At least one branch should be valid"
|
55 |
+
|
56 |
+
self.size = size
|
57 |
+
self.attn = attn
|
58 |
+
self.cgmlp = cgmlp
|
59 |
+
self.merge_method = merge_method
|
60 |
+
self.cgmlp_weight = cgmlp_weight
|
61 |
+
self.attn_branch_drop_rate = attn_branch_drop_rate
|
62 |
+
self.stochastic_depth_rate = stochastic_depth_rate
|
63 |
+
self.use_two_branches = (attn is not None) and (cgmlp is not None)
|
64 |
+
|
65 |
+
if attn is not None:
|
66 |
+
self.norm_mha = nn.LayerNorm(size) # for the MHA module
|
67 |
+
if cgmlp is not None:
|
68 |
+
self.norm_mlp = nn.LayerNorm(size) # for the MLP module
|
69 |
+
self.norm_final = nn.LayerNorm(
|
70 |
+
size) # for the final output of the block
|
71 |
+
|
72 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
73 |
+
|
74 |
+
# # attention-based pooling for two branches
|
75 |
+
self.pooling_proj1 = torch.nn.Linear(size, 1)
|
76 |
+
self.pooling_proj2 = torch.nn.Linear(size, 1)
|
77 |
+
|
78 |
+
# # linear projections for calculating merging weights
|
79 |
+
self.weight_proj1 = torch.nn.Linear(size, 1)
|
80 |
+
self.weight_proj2 = torch.nn.Linear(size, 1)
|
81 |
+
|
82 |
+
if self.use_two_branches:
|
83 |
+
if self.merge_method == "concat":
|
84 |
+
self.merge_proj = torch.nn.Linear(size + size, size)
|
85 |
+
|
86 |
+
elif self.merge_method == "learned_ave":
|
87 |
+
# linear projection after weighted average
|
88 |
+
self.merge_proj = torch.nn.Linear(size, size)
|
89 |
+
|
90 |
+
elif self.merge_method == "fixed_ave":
|
91 |
+
assert (0.0 <= cgmlp_weight <=
|
92 |
+
1.0), "cgmlp weight should be between 0.0 and 1.0"
|
93 |
+
|
94 |
+
# remove the other branch if only one branch is used
|
95 |
+
if cgmlp_weight == 0.0:
|
96 |
+
self.use_two_branches = False
|
97 |
+
self.cgmlp = None
|
98 |
+
self.norm_mlp = None
|
99 |
+
elif cgmlp_weight == 1.0:
|
100 |
+
self.use_two_branches = False
|
101 |
+
self.attn = None
|
102 |
+
self.norm_mha = None
|
103 |
+
|
104 |
+
# linear projection after weighted average
|
105 |
+
self.merge_proj = torch.nn.Linear(size, size)
|
106 |
+
else:
|
107 |
+
raise ValueError(f"unknown merge method: {merge_method}")
|
108 |
+
else:
|
109 |
+
self.merge_proj = torch.nn.Identity()
|
110 |
+
|
111 |
+
def _forward(
|
112 |
+
self,
|
113 |
+
x: torch.Tensor,
|
114 |
+
mask: torch.Tensor,
|
115 |
+
pos_emb: torch.Tensor,
|
116 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
117 |
+
att_cache: T_CACHE = (torch.zeros(
|
118 |
+
(0, 0, 0, 0)), torch.zeros(0, 0, 0, 0)),
|
119 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
120 |
+
stoch_layer_coeff: float = 1.0
|
121 |
+
) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE, torch.Tensor]:
|
122 |
+
# Two branches
|
123 |
+
x1 = x
|
124 |
+
x2 = x
|
125 |
+
|
126 |
+
# Branch 1: multi-headed attention module
|
127 |
+
if self.attn is not None:
|
128 |
+
x1 = self.norm_mha(x1)
|
129 |
+
x_att, new_att_cache = self.attn(x1, x1, x1, mask, pos_emb,
|
130 |
+
att_cache)
|
131 |
+
x1 = self.dropout(x_att)
|
132 |
+
|
133 |
+
# Branch 2: convolutional gating mlp
|
134 |
+
# Fake new cnn cache here, and then change it in conv_module
|
135 |
+
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
136 |
+
if self.cgmlp is not None:
|
137 |
+
x2 = self.norm_mlp(x2)
|
138 |
+
x2, new_cnn_cache = self.cgmlp(x2, mask_pad, cnn_cache)
|
139 |
+
x2 = self.dropout(x2)
|
140 |
+
|
141 |
+
# Merge two branches
|
142 |
+
if self.use_two_branches:
|
143 |
+
if self.merge_method == "concat":
|
144 |
+
x = x + stoch_layer_coeff * self.dropout(
|
145 |
+
self.merge_proj(torch.cat([x1, x2], dim=-1)))
|
146 |
+
elif self.merge_method == "learned_ave":
|
147 |
+
if (self.training and self.attn_branch_drop_rate > 0
|
148 |
+
and torch.rand(1).item() < self.attn_branch_drop_rate):
|
149 |
+
# Drop the attn branch
|
150 |
+
w1, w2 = torch.tensor(0.0), torch.tensor(1.0)
|
151 |
+
else:
|
152 |
+
# branch1
|
153 |
+
score1 = (self.pooling_proj1(x1).transpose(1, 2) /
|
154 |
+
self.size**0.5)
|
155 |
+
score1 = score1.masked_fill(mask_pad.eq(0), -float('inf'))
|
156 |
+
score1 = torch.softmax(score1, dim=-1).masked_fill(
|
157 |
+
mask_pad.eq(0), 0.0)
|
158 |
+
|
159 |
+
pooled1 = torch.matmul(score1,
|
160 |
+
x1).squeeze(1) # (batch, size)
|
161 |
+
weight1 = self.weight_proj1(pooled1) # (batch, 1)
|
162 |
+
|
163 |
+
# branch2
|
164 |
+
score2 = (self.pooling_proj2(x2).transpose(1, 2) /
|
165 |
+
self.size**0.5)
|
166 |
+
score2 = score2.masked_fill(mask_pad.eq(0), -float('inf'))
|
167 |
+
score2 = torch.softmax(score2, dim=-1).masked_fill(
|
168 |
+
mask_pad.eq(0), 0.0)
|
169 |
+
|
170 |
+
pooled2 = torch.matmul(score2,
|
171 |
+
x2).squeeze(1) # (batch, size)
|
172 |
+
weight2 = self.weight_proj2(pooled2) # (batch, 1)
|
173 |
+
|
174 |
+
# normalize weights of two branches
|
175 |
+
merge_weights = torch.softmax(torch.cat([weight1, weight2],
|
176 |
+
dim=-1),
|
177 |
+
dim=-1) # (batch, 2)
|
178 |
+
merge_weights = merge_weights.unsqueeze(-1).unsqueeze(
|
179 |
+
-1) # (batch, 2, 1, 1)
|
180 |
+
w1, w2 = merge_weights[:,
|
181 |
+
0], merge_weights[:,
|
182 |
+
1] # (batch, 1, 1)
|
183 |
+
|
184 |
+
x = x + stoch_layer_coeff * self.dropout(
|
185 |
+
self.merge_proj(w1 * x1 + w2 * x2))
|
186 |
+
elif self.merge_method == "fixed_ave":
|
187 |
+
x = x + stoch_layer_coeff * self.dropout(
|
188 |
+
self.merge_proj((1.0 - self.cgmlp_weight) * x1 +
|
189 |
+
self.cgmlp_weight * x2))
|
190 |
+
else:
|
191 |
+
raise RuntimeError(
|
192 |
+
f"unknown merge method: {self.merge_method}")
|
193 |
+
else:
|
194 |
+
if self.attn is None:
|
195 |
+
x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x2))
|
196 |
+
elif self.cgmlp is None:
|
197 |
+
x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x1))
|
198 |
+
else:
|
199 |
+
# This should not happen
|
200 |
+
raise RuntimeError(
|
201 |
+
"Both branches are not None, which is unexpected.")
|
202 |
+
|
203 |
+
x = self.norm_final(x)
|
204 |
+
|
205 |
+
return x, mask, new_att_cache, new_cnn_cache
|
206 |
+
|
207 |
+
def forward(
|
208 |
+
self,
|
209 |
+
x: torch.Tensor,
|
210 |
+
mask: torch.Tensor,
|
211 |
+
pos_emb: torch.Tensor,
|
212 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
213 |
+
att_cache: T_CACHE = (torch.zeros(
|
214 |
+
(0, 0, 0, 0)), torch.zeros(0, 0, 0, 0)),
|
215 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
216 |
+
) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE, torch.Tensor]:
|
217 |
+
"""Compute encoded features.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
x (Union[Tuple, torch.Tensor]): Input tensor (#batch, time, size).
|
221 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time, time).
|
222 |
+
pos_emb (torch.Tensor): positional encoding, must not be None
|
223 |
+
for BranchformerEncoderLayer.
|
224 |
+
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
225 |
+
(#batch, 1,time), (0, 0, 0) means fake mask.
|
226 |
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
227 |
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
228 |
+
cnn_cache (torch.Tensor): Convolution cache in cgmlp layer
|
229 |
+
(#batch=1, size, cache_t2)
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
233 |
+
torch.Tensor: Mask tensor (#batch, time, time.
|
234 |
+
torch.Tensor: att_cache tensor,
|
235 |
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
236 |
+
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
237 |
+
"""
|
238 |
+
|
239 |
+
stoch_layer_coeff = 1.0
|
240 |
+
# with stochastic depth, residual connection `x + f(x)` becomes
|
241 |
+
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
242 |
+
if self.training:
|
243 |
+
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
244 |
+
return self._forward(x, mask, pos_emb, mask_pad, att_cache, cnn_cache,
|
245 |
+
stoch_layer_coeff)
|
wenet/cli/__init__.py
ADDED
File without changes
|
wenet/cli/hub.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Mddct([email protected])
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import requests
|
17 |
+
import sys
|
18 |
+
import tarfile
|
19 |
+
from pathlib import Path
|
20 |
+
from urllib.request import urlretrieve
|
21 |
+
|
22 |
+
import tqdm
|
23 |
+
|
24 |
+
|
25 |
+
def download(url: str, dest: str, only_child=True):
|
26 |
+
""" download from url to dest
|
27 |
+
"""
|
28 |
+
assert os.path.exists(dest)
|
29 |
+
print('Downloading {} to {}'.format(url, dest))
|
30 |
+
|
31 |
+
def progress_hook(t):
|
32 |
+
last_b = [0]
|
33 |
+
|
34 |
+
def update_to(b=1, bsize=1, tsize=None):
|
35 |
+
if tsize not in (None, -1):
|
36 |
+
t.total = tsize
|
37 |
+
displayed = t.update((b - last_b[0]) * bsize)
|
38 |
+
last_b[0] = b
|
39 |
+
return displayed
|
40 |
+
|
41 |
+
return update_to
|
42 |
+
|
43 |
+
# *.tar.gz
|
44 |
+
name = url.split('?')[0].split('/')[-1]
|
45 |
+
tar_path = os.path.join(dest, name)
|
46 |
+
with tqdm.tqdm(unit='B',
|
47 |
+
unit_scale=True,
|
48 |
+
unit_divisor=1024,
|
49 |
+
miniters=1,
|
50 |
+
desc=(name)) as t:
|
51 |
+
urlretrieve(url,
|
52 |
+
filename=tar_path,
|
53 |
+
reporthook=progress_hook(t),
|
54 |
+
data=None)
|
55 |
+
t.total = t.n
|
56 |
+
|
57 |
+
with tarfile.open(tar_path) as f:
|
58 |
+
if not only_child:
|
59 |
+
f.extractall(dest)
|
60 |
+
else:
|
61 |
+
for tarinfo in f:
|
62 |
+
if "/" not in tarinfo.name:
|
63 |
+
continue
|
64 |
+
name = os.path.basename(tarinfo.name)
|
65 |
+
fileobj = f.extractfile(tarinfo)
|
66 |
+
with open(os.path.join(dest, name), "wb") as writer:
|
67 |
+
writer.write(fileobj.read())
|
68 |
+
|
69 |
+
|
70 |
+
class Hub(object):
|
71 |
+
"""Hub for wenet pretrain runtime model
|
72 |
+
"""
|
73 |
+
# TODO(Mddct): make assets class to support other language
|
74 |
+
Assets = {
|
75 |
+
# wenetspeech
|
76 |
+
"chinese": "wenetspeech_u2pp_conformer_libtorch.tar.gz",
|
77 |
+
# gigaspeech
|
78 |
+
"english": "gigaspeech_u2pp_conformer_libtorch.tar.gz",
|
79 |
+
# paraformer
|
80 |
+
"paraformer": "paraformer.tar.gz"
|
81 |
+
}
|
82 |
+
|
83 |
+
def __init__(self) -> None:
|
84 |
+
pass
|
85 |
+
|
86 |
+
@staticmethod
|
87 |
+
def get_model_by_lang(lang: str) -> str:
|
88 |
+
if lang not in Hub.Assets.keys():
|
89 |
+
print('ERROR: Unsupported language {} !!!'.format(lang))
|
90 |
+
sys.exit(1)
|
91 |
+
|
92 |
+
# NOTE(Mddct): model_dir structure
|
93 |
+
# Path.Home()/.wenet
|
94 |
+
# - chs
|
95 |
+
# - units.txt
|
96 |
+
# - final.zip
|
97 |
+
# - en
|
98 |
+
# - units.txt
|
99 |
+
# - final.zip
|
100 |
+
model = Hub.Assets[lang]
|
101 |
+
model_dir = os.path.join(Path.home(), ".wenet", lang)
|
102 |
+
if not os.path.exists(model_dir):
|
103 |
+
os.makedirs(model_dir)
|
104 |
+
# TODO(Mddct): model metadata
|
105 |
+
if set(["final.zip",
|
106 |
+
"units.txt"]).issubset(set(os.listdir(model_dir))):
|
107 |
+
return model_dir
|
108 |
+
# If not exist, download
|
109 |
+
response = requests.get(
|
110 |
+
"https://modelscope.cn/api/v1/datasets/wenet/wenet_pretrained_models/oss/tree" # noqa
|
111 |
+
)
|
112 |
+
model_info = next(data for data in response.json()["Data"]
|
113 |
+
if data["Key"] == model)
|
114 |
+
model_url = model_info['Url']
|
115 |
+
download(model_url, model_dir, only_child=True)
|
116 |
+
return model_dir
|
wenet/cli/model.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Binbin Zhang ([email protected])
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torchaudio
|
19 |
+
import torchaudio.compliance.kaldi as kaldi
|
20 |
+
|
21 |
+
from wenet.cli.hub import Hub
|
22 |
+
from wenet.utils.ctc_utils import (force_align, gen_ctc_peak_time,
|
23 |
+
gen_timestamps_from_peak)
|
24 |
+
from wenet.utils.file_utils import read_symbol_table
|
25 |
+
from wenet.transformer.search import (attention_rescoring,
|
26 |
+
ctc_prefix_beam_search, DecodeResult)
|
27 |
+
from wenet.utils.context_graph import ContextGraph
|
28 |
+
from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu
|
29 |
+
|
30 |
+
|
31 |
+
class Model:
|
32 |
+
|
33 |
+
def __init__(self,
|
34 |
+
model_dir: str,
|
35 |
+
gpu: int = -1,
|
36 |
+
beam: int = 5,
|
37 |
+
context_path: str = None,
|
38 |
+
context_score: float = 6.0,
|
39 |
+
resample_rate: int = 16000):
|
40 |
+
model_path = os.path.join(model_dir, 'final.zip')
|
41 |
+
units_path = os.path.join(model_dir, 'units.txt')
|
42 |
+
self.model = torch.jit.load(model_path)
|
43 |
+
self.resample_rate = resample_rate
|
44 |
+
self.model.eval()
|
45 |
+
if gpu >= 0:
|
46 |
+
device = 'cuda:{}'.format(gpu)
|
47 |
+
else:
|
48 |
+
device = 'cpu'
|
49 |
+
self.device = torch.device(device)
|
50 |
+
self.model.to(device)
|
51 |
+
self.symbol_table = read_symbol_table(units_path)
|
52 |
+
self.char_dict = {v: k for k, v in self.symbol_table.items()}
|
53 |
+
self.beam = beam
|
54 |
+
if context_path is not None:
|
55 |
+
self.context_graph = ContextGraph(context_path,
|
56 |
+
self.symbol_table,
|
57 |
+
context_score=context_score)
|
58 |
+
else:
|
59 |
+
self.context_graph = None
|
60 |
+
|
61 |
+
def compute_feats(self, audio_file: str) -> torch.Tensor:
|
62 |
+
waveform, sample_rate = torchaudio.load(audio_file, normalize=False)
|
63 |
+
waveform = waveform.to(torch.float)
|
64 |
+
if sample_rate != self.resample_rate:
|
65 |
+
waveform = torchaudio.transforms.Resample(
|
66 |
+
orig_freq=sample_rate, new_freq=self.resample_rate)(waveform)
|
67 |
+
# NOTE (MengqingCao): complex dtype not supported in torch_npu.abs() now,
|
68 |
+
# thus, delay placing data on NPU after the calculation of fbank.
|
69 |
+
# revert me after complex dtype is supported.
|
70 |
+
if "npu" not in self.device.__str__():
|
71 |
+
waveform = waveform.to(self.device)
|
72 |
+
feats = kaldi.fbank(waveform,
|
73 |
+
num_mel_bins=80,
|
74 |
+
frame_length=25,
|
75 |
+
frame_shift=10,
|
76 |
+
energy_floor=0.0,
|
77 |
+
sample_frequency=self.resample_rate)
|
78 |
+
if "npu" in self.device.__str__():
|
79 |
+
feats = feats.to(self.device)
|
80 |
+
feats = feats.unsqueeze(0)
|
81 |
+
return feats
|
82 |
+
|
83 |
+
@torch.no_grad()
|
84 |
+
def _decode(self,
|
85 |
+
audio_file: str,
|
86 |
+
tokens_info: bool = False,
|
87 |
+
label: str = None) -> dict:
|
88 |
+
feats = self.compute_feats(audio_file)
|
89 |
+
encoder_out, _, _ = self.model.forward_encoder_chunk(feats, 0, -1)
|
90 |
+
encoder_lens = torch.tensor([encoder_out.size(1)],
|
91 |
+
dtype=torch.long,
|
92 |
+
device=encoder_out.device)
|
93 |
+
ctc_probs = self.model.ctc_activation(encoder_out)
|
94 |
+
if label is None:
|
95 |
+
ctc_prefix_results = ctc_prefix_beam_search(
|
96 |
+
ctc_probs,
|
97 |
+
encoder_lens,
|
98 |
+
self.beam,
|
99 |
+
context_graph=self.context_graph)
|
100 |
+
else: # force align mode, construct ctc prefix result from alignment
|
101 |
+
label_t = self.tokenize(label)
|
102 |
+
alignment = force_align(ctc_probs.squeeze(0),
|
103 |
+
torch.tensor(label_t, dtype=torch.long))
|
104 |
+
peaks = gen_ctc_peak_time(alignment)
|
105 |
+
ctc_prefix_results = [
|
106 |
+
DecodeResult(tokens=label_t,
|
107 |
+
score=0.0,
|
108 |
+
times=peaks,
|
109 |
+
nbest=[label_t],
|
110 |
+
nbest_scores=[0.0],
|
111 |
+
nbest_times=[peaks])
|
112 |
+
]
|
113 |
+
rescoring_results = attention_rescoring(self.model, ctc_prefix_results,
|
114 |
+
encoder_out, encoder_lens, 0.3,
|
115 |
+
0.5)
|
116 |
+
res = rescoring_results[0]
|
117 |
+
result = {}
|
118 |
+
result['text'] = ''.join([self.char_dict[x] for x in res.tokens])
|
119 |
+
result['confidence'] = res.confidence
|
120 |
+
|
121 |
+
if tokens_info:
|
122 |
+
frame_rate = self.model.subsampling_rate(
|
123 |
+
) * 0.01 # 0.01 seconds per frame
|
124 |
+
max_duration = encoder_out.size(1) * frame_rate
|
125 |
+
times = gen_timestamps_from_peak(res.times, max_duration,
|
126 |
+
frame_rate, 1.0)
|
127 |
+
tokens_info = []
|
128 |
+
for i, x in enumerate(res.tokens):
|
129 |
+
tokens_info.append({
|
130 |
+
'token': self.char_dict[x],
|
131 |
+
'start': round(times[i][0], 3),
|
132 |
+
'end': round(times[i][1], 3),
|
133 |
+
'confidence': round(res.tokens_confidence[i], 2)
|
134 |
+
})
|
135 |
+
result['tokens'] = tokens_info
|
136 |
+
return result
|
137 |
+
|
138 |
+
def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict:
|
139 |
+
return self._decode(audio_file, tokens_info)
|
140 |
+
|
141 |
+
def tokenize(self, label: str):
|
142 |
+
# TODO(Binbin Zhang): Support BPE
|
143 |
+
tokens = []
|
144 |
+
for c in label:
|
145 |
+
if c == ' ':
|
146 |
+
c = "▁"
|
147 |
+
tokens.append(c)
|
148 |
+
token_list = []
|
149 |
+
for c in tokens:
|
150 |
+
if c in self.symbol_table:
|
151 |
+
token_list.append(self.symbol_table[c])
|
152 |
+
elif '<unk>' in self.symbol_table:
|
153 |
+
token_list.append(self.symbol_table['<unk>'])
|
154 |
+
return token_list
|
155 |
+
|
156 |
+
def align(self, audio_file: str, label: str) -> dict:
|
157 |
+
return self._decode(audio_file, True, label)
|
158 |
+
|
159 |
+
|
160 |
+
def load_model(language: str = None,
|
161 |
+
model_dir: str = None,
|
162 |
+
gpu: int = -1,
|
163 |
+
beam: int = 5,
|
164 |
+
context_path: str = None,
|
165 |
+
context_score: float = 6.0,
|
166 |
+
device: str = "cpu") -> Model:
|
167 |
+
if model_dir is None:
|
168 |
+
model_dir = Hub.get_model_by_lang(language)
|
169 |
+
|
170 |
+
if gpu != -1:
|
171 |
+
# remain the original usage of gpu
|
172 |
+
device = "cuda"
|
173 |
+
model = Model(model_dir, gpu, beam, context_path, context_score)
|
174 |
+
model.device = torch.device(device)
|
175 |
+
model.model.to(device)
|
176 |
+
return model
|
wenet/cli/paraformer_model.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
import torchaudio.compliance.kaldi as kaldi
|
6 |
+
|
7 |
+
from wenet.cli.hub import Hub
|
8 |
+
from wenet.paraformer.search import (gen_timestamps_from_peak,
|
9 |
+
paraformer_greedy_search)
|
10 |
+
from wenet.text.paraformer_tokenizer import ParaformerTokenizer
|
11 |
+
from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu
|
12 |
+
|
13 |
+
|
14 |
+
class Paraformer:
|
15 |
+
|
16 |
+
def __init__(self, model_dir: str, resample_rate: int = 16000) -> None:
|
17 |
+
|
18 |
+
model_path = os.path.join(model_dir, 'final.zip')
|
19 |
+
units_path = os.path.join(model_dir, 'units.txt')
|
20 |
+
self.model = torch.jit.load(model_path)
|
21 |
+
self.resample_rate = resample_rate
|
22 |
+
self.device = torch.device("cpu")
|
23 |
+
self.tokenizer = ParaformerTokenizer(symbol_table=units_path)
|
24 |
+
|
25 |
+
def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict:
|
26 |
+
waveform, sample_rate = torchaudio.load(audio_file, normalize=False)
|
27 |
+
waveform = waveform.to(torch.float).to(self.device)
|
28 |
+
if sample_rate != self.resample_rate:
|
29 |
+
waveform = torchaudio.transforms.Resample(
|
30 |
+
orig_freq=sample_rate, new_freq=self.resample_rate)(waveform)
|
31 |
+
feats = kaldi.fbank(waveform,
|
32 |
+
num_mel_bins=80,
|
33 |
+
frame_length=25,
|
34 |
+
frame_shift=10,
|
35 |
+
energy_floor=0.0,
|
36 |
+
sample_frequency=self.resample_rate,
|
37 |
+
window_type="hamming")
|
38 |
+
feats = feats.unsqueeze(0)
|
39 |
+
feats_lens = torch.tensor([feats.size(1)],
|
40 |
+
dtype=torch.int64,
|
41 |
+
device=feats.device)
|
42 |
+
|
43 |
+
decoder_out, token_num, tp_alphas = self.model.forward_paraformer(
|
44 |
+
feats, feats_lens)
|
45 |
+
cif_peaks = self.model.forward_cif_peaks(tp_alphas, token_num)
|
46 |
+
res = paraformer_greedy_search(decoder_out, token_num, cif_peaks)[0]
|
47 |
+
result = {}
|
48 |
+
result['confidence'] = res.confidence
|
49 |
+
result['text'] = self.tokenizer.detokenize(res.tokens)[0]
|
50 |
+
if tokens_info:
|
51 |
+
tokens_info = []
|
52 |
+
times = gen_timestamps_from_peak(res.times,
|
53 |
+
num_frames=tp_alphas.size(1),
|
54 |
+
frame_rate=0.02)
|
55 |
+
|
56 |
+
for i, x in enumerate(res.tokens):
|
57 |
+
tokens_info.append({
|
58 |
+
'token': self.tokenizer.char_dict[x],
|
59 |
+
'start': round(times[i][0], 3),
|
60 |
+
'end': round(times[i][1], 3),
|
61 |
+
'confidence': round(res.tokens_confidence[i], 2)
|
62 |
+
})
|
63 |
+
result['tokens'] = tokens_info
|
64 |
+
|
65 |
+
return result
|
66 |
+
|
67 |
+
def align(self, audio_file: str, label: str) -> dict:
|
68 |
+
raise NotImplementedError("Align is currently not supported")
|
69 |
+
|
70 |
+
|
71 |
+
def load_model(model_dir: str = None,
|
72 |
+
gpu: int = -1,
|
73 |
+
device: str = "cpu") -> Paraformer:
|
74 |
+
if model_dir is None:
|
75 |
+
model_dir = Hub.get_model_by_lang('paraformer')
|
76 |
+
if gpu != -1:
|
77 |
+
# remain the original usage of gpu
|
78 |
+
device = "cuda"
|
79 |
+
paraformer = Paraformer(model_dir)
|
80 |
+
paraformer.device = torch.device(device)
|
81 |
+
paraformer.model.to(device)
|
82 |
+
return paraformer
|
wenet/cli/transcribe.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Binbin Zhang ([email protected])
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
|
17 |
+
from wenet.cli.paraformer_model import load_model as load_paraformer
|
18 |
+
from wenet.cli.model import load_model
|
19 |
+
|
20 |
+
|
21 |
+
def get_args():
|
22 |
+
parser = argparse.ArgumentParser(description='')
|
23 |
+
parser.add_argument('audio_file', help='audio file to transcribe')
|
24 |
+
parser.add_argument('-l',
|
25 |
+
'--language',
|
26 |
+
choices=[
|
27 |
+
'chinese',
|
28 |
+
'english',
|
29 |
+
],
|
30 |
+
default='chinese',
|
31 |
+
help='language type')
|
32 |
+
parser.add_argument('-m',
|
33 |
+
'--model_dir',
|
34 |
+
default=None,
|
35 |
+
help='specify your own model dir')
|
36 |
+
parser.add_argument('-g',
|
37 |
+
'--gpu',
|
38 |
+
type=int,
|
39 |
+
default='-1',
|
40 |
+
help='gpu id to decode, default is cpu.')
|
41 |
+
parser.add_argument('--device',
|
42 |
+
type=str,
|
43 |
+
default='cpu',
|
44 |
+
choices=["cpu", "npu", "cuda"],
|
45 |
+
help='accelerator to use')
|
46 |
+
parser.add_argument('-t',
|
47 |
+
'--show_tokens_info',
|
48 |
+
action='store_true',
|
49 |
+
help='whether to output token(word) level information'
|
50 |
+
', such times/confidence')
|
51 |
+
parser.add_argument('--align',
|
52 |
+
action='store_true',
|
53 |
+
help='force align the input audio and transcript')
|
54 |
+
parser.add_argument('--label', type=str, help='the input label to align')
|
55 |
+
parser.add_argument('--paraformer',
|
56 |
+
action='store_true',
|
57 |
+
help='whether to use the best chinese model')
|
58 |
+
parser.add_argument('--beam', type=int, default=5, help="beam size")
|
59 |
+
parser.add_argument('--context_path',
|
60 |
+
type=str,
|
61 |
+
default=None,
|
62 |
+
help='context list file')
|
63 |
+
parser.add_argument('--context_score',
|
64 |
+
type=float,
|
65 |
+
default=6.0,
|
66 |
+
help='context score')
|
67 |
+
args = parser.parse_args()
|
68 |
+
return args
|
69 |
+
|
70 |
+
|
71 |
+
def main():
|
72 |
+
args = get_args()
|
73 |
+
|
74 |
+
if args.paraformer:
|
75 |
+
model = load_paraformer(args.model_dir, args.gpu, args.device)
|
76 |
+
else:
|
77 |
+
model = load_model(args.language, args.model_dir, args.gpu, args.beam,
|
78 |
+
args.context_path, args.context_score, args.device)
|
79 |
+
if args.align:
|
80 |
+
result = model.align(args.audio_file, args.label)
|
81 |
+
else:
|
82 |
+
result = model.transcribe(args.audio_file, args.show_tokens_info)
|
83 |
+
print(result)
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
main()
|
wenet/ctl_model/asr_model_ctl.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
2 |
+
# 2023 NetEase Inc. (authors: Yuting Yang)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet) and
|
16 |
+
# fairseq(https://github.com/facebookresearch/fairseq)
|
17 |
+
|
18 |
+
from typing import Dict, Optional
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn.functional as F
|
22 |
+
from wenet.transformer.ctc import CTC
|
23 |
+
from wenet.transformer.decoder import TransformerDecoder
|
24 |
+
from wenet.ctl_model.encoder import TransformerEncoder
|
25 |
+
from wenet.transformer.asr_model import ASRModel
|
26 |
+
from wenet.utils.common import IGNORE_ID
|
27 |
+
|
28 |
+
|
29 |
+
class CTLModel(ASRModel):
|
30 |
+
"""
|
31 |
+
Implementation of Interspeecch 2023 paper:
|
32 |
+
'Enhancing the Unified Streaming and Non-streaming Model
|
33 |
+
with Contrastive Learning'
|
34 |
+
https://arxiv.org/abs/2306.00755
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
vocab_size: int,
|
40 |
+
encoder: TransformerEncoder,
|
41 |
+
decoder: TransformerDecoder,
|
42 |
+
ctc: CTC,
|
43 |
+
ctc_weight: float = 0.5,
|
44 |
+
ignore_id: int = IGNORE_ID,
|
45 |
+
reverse_weight: float = 0.0,
|
46 |
+
lsm_weight: float = 0.0,
|
47 |
+
length_normalized_loss: bool = False,
|
48 |
+
logit_temp: float = 0.1,
|
49 |
+
n_negatives: int = 0,
|
50 |
+
ctl_weight: float = 1,
|
51 |
+
special_tokens: dict = None,
|
52 |
+
):
|
53 |
+
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
54 |
+
super().__init__(vocab_size,
|
55 |
+
encoder,
|
56 |
+
decoder,
|
57 |
+
ctc,
|
58 |
+
ctc_weight,
|
59 |
+
ignore_id,
|
60 |
+
reverse_weight,
|
61 |
+
lsm_weight,
|
62 |
+
length_normalized_loss,
|
63 |
+
special_tokens=special_tokens)
|
64 |
+
|
65 |
+
# For CTL Loss
|
66 |
+
self.n_negatives = n_negatives
|
67 |
+
self.ctl_weight = ctl_weight
|
68 |
+
self.logit_temp = logit_temp
|
69 |
+
|
70 |
+
@torch.jit.unused
|
71 |
+
def forward(
|
72 |
+
self,
|
73 |
+
batch: dict,
|
74 |
+
device: torch.device,
|
75 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
76 |
+
|
77 |
+
speech = batch['feats'].to(device)
|
78 |
+
speech_lengths = batch['feats_lengths'].to(device)
|
79 |
+
text = batch['target'].to(device)
|
80 |
+
text_lengths = batch['target_lengths'].to(device)
|
81 |
+
loss_full, encoder_out_full, _, _ = self.forward_full(
|
82 |
+
speech, speech_lengths, text, text_lengths)
|
83 |
+
loss_chunk, encoder_out, lens_chunk, encoder_mask = self.forward_chunk(
|
84 |
+
speech, speech_lengths, text, text_lengths)
|
85 |
+
|
86 |
+
ctl_loss = 0.0
|
87 |
+
if self.ctl_weight > 0 and self.n_negatives > 0:
|
88 |
+
num = encoder_out_full.size(1)
|
89 |
+
targets = encoder_out_full
|
90 |
+
src = encoder_out
|
91 |
+
negs, negs_idxs = self.sample_negatives(targets,
|
92 |
+
targets.size(1),
|
93 |
+
speech_lengths=lens_chunk)
|
94 |
+
ctl_loss = self.CTL(src, targets, negs, encoder_mask)
|
95 |
+
|
96 |
+
loss = loss_full + loss_chunk + self.ctl_weight * ctl_loss
|
97 |
+
return {
|
98 |
+
"loss": loss,
|
99 |
+
"loss_full": loss_full,
|
100 |
+
"loss_chunk": loss_chunk,
|
101 |
+
"loss_ctl": ctl_loss
|
102 |
+
}
|
103 |
+
|
104 |
+
def forward_full(
|
105 |
+
self,
|
106 |
+
speech: torch.Tensor,
|
107 |
+
speech_lengths: torch.Tensor,
|
108 |
+
text: torch.Tensor,
|
109 |
+
text_lengths: torch.Tensor,
|
110 |
+
):
|
111 |
+
"""Full context mode
|
112 |
+
Frontend + Encoder + Decoder + Calc loss
|
113 |
+
|
114 |
+
Args:
|
115 |
+
speech: (Batch, Length, ...)
|
116 |
+
speech_lengths: (Batch, )
|
117 |
+
text: (Batch, Length)
|
118 |
+
text_lengths: (Batch,)
|
119 |
+
"""
|
120 |
+
|
121 |
+
assert text_lengths.dim() == 1, text_lengths.shape
|
122 |
+
# Check that batch_size is unified
|
123 |
+
assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
|
124 |
+
text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
|
125 |
+
text.shape, text_lengths.shape)
|
126 |
+
# 1. Encoder
|
127 |
+
encoder_out, encoder_mask = self.encoder.forward_full(
|
128 |
+
speech, speech_lengths)
|
129 |
+
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
|
130 |
+
|
131 |
+
# 2a. Attention-decoder branch
|
132 |
+
if self.ctc_weight != 1.0:
|
133 |
+
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
|
134 |
+
text, text_lengths)
|
135 |
+
else:
|
136 |
+
loss_att = None
|
137 |
+
|
138 |
+
# 2b. CTC branch
|
139 |
+
if self.ctc_weight != 0.0:
|
140 |
+
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text,
|
141 |
+
text_lengths)
|
142 |
+
else:
|
143 |
+
loss_ctc = None
|
144 |
+
|
145 |
+
if loss_ctc is None:
|
146 |
+
loss = loss_att
|
147 |
+
elif loss_att is None:
|
148 |
+
loss = loss_ctc
|
149 |
+
else:
|
150 |
+
loss = self.ctc_weight * loss_ctc[0] + (1 -
|
151 |
+
self.ctc_weight) * loss_att
|
152 |
+
return loss, encoder_out, encoder_out_lens, encoder_mask
|
153 |
+
|
154 |
+
def forward_chunk(
|
155 |
+
self,
|
156 |
+
speech: torch.Tensor,
|
157 |
+
speech_lengths: torch.Tensor,
|
158 |
+
text: torch.Tensor,
|
159 |
+
text_lengths: torch.Tensor,
|
160 |
+
):
|
161 |
+
"""Chunk-based context mode
|
162 |
+
Frontend + Encoder + Decoder + Calc loss
|
163 |
+
|
164 |
+
Args:
|
165 |
+
speech: (Batch, Length, ...)
|
166 |
+
speech_lengths: (Batch, )
|
167 |
+
text: (Batch, Length)
|
168 |
+
text_lengths: (Batch,)
|
169 |
+
"""
|
170 |
+
|
171 |
+
assert text_lengths.dim() == 1, text_lengths.shape
|
172 |
+
# Check that batch_size is unified
|
173 |
+
assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
|
174 |
+
text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
|
175 |
+
text.shape, text_lengths.shape)
|
176 |
+
# 1. Encoder
|
177 |
+
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
|
178 |
+
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
|
179 |
+
|
180 |
+
# 2a. Attention-decoder branch
|
181 |
+
if self.ctc_weight != 1.0:
|
182 |
+
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
|
183 |
+
text, text_lengths)
|
184 |
+
else:
|
185 |
+
loss_att = None
|
186 |
+
|
187 |
+
# 2b. CTC branch
|
188 |
+
if self.ctc_weight != 0.0:
|
189 |
+
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text,
|
190 |
+
text_lengths)
|
191 |
+
else:
|
192 |
+
loss_ctc = None
|
193 |
+
|
194 |
+
if loss_ctc is None:
|
195 |
+
loss = loss_att
|
196 |
+
elif loss_att is None:
|
197 |
+
loss = loss_ctc
|
198 |
+
else:
|
199 |
+
loss = self.ctc_weight * loss_ctc[0] + (1 -
|
200 |
+
self.ctc_weight) * loss_att
|
201 |
+
return loss, encoder_out, encoder_out_lens, encoder_mask
|
202 |
+
|
203 |
+
def sample_negatives(self, y, num, padding_count=0, speech_lengths=None):
|
204 |
+
if self.n_negatives == 0:
|
205 |
+
return y.new(0)
|
206 |
+
bsz, tsz, fsz = y.shape
|
207 |
+
y = y.reshape(-1, fsz) # BTC => (BxT)C
|
208 |
+
|
209 |
+
# FIXME: what happens if padding_count is specified?
|
210 |
+
high = tsz - (padding_count or 0)
|
211 |
+
with torch.no_grad():
|
212 |
+
assert high > 1, f"{bsz,tsz,fsz}"
|
213 |
+
|
214 |
+
if self.n_negatives > 0:
|
215 |
+
tszs = (torch.arange(num).unsqueeze(-1).expand(
|
216 |
+
-1, self.n_negatives).flatten())
|
217 |
+
if speech_lengths is not None:
|
218 |
+
neg_idxs = [
|
219 |
+
torch.randint(low=0,
|
220 |
+
high=speech_lengths[i].item() - 1,
|
221 |
+
size=(1, self.n_negatives * tsz))
|
222 |
+
for i in range(len(speech_lengths))
|
223 |
+
]
|
224 |
+
neg_idxs = torch.cat(neg_idxs).reshape(
|
225 |
+
bsz, self.n_negatives * tsz)
|
226 |
+
else:
|
227 |
+
neg_idxs = torch.randint(low=0,
|
228 |
+
high=num - 1,
|
229 |
+
size=(bsz,
|
230 |
+
self.n_negatives * tsz))
|
231 |
+
neg_idxs[neg_idxs >= tszs] += 1
|
232 |
+
|
233 |
+
if self.n_negatives > 0:
|
234 |
+
neg_idxs = neg_idxs + (torch.arange(bsz).unsqueeze(1) * high)
|
235 |
+
|
236 |
+
negs = y[neg_idxs.view(-1)]
|
237 |
+
negs = negs.contiguous().view(bsz, num, self.n_negatives,
|
238 |
+
fsz).permute(2, 0, 1, 3) # to NxBxTxC
|
239 |
+
return negs, neg_idxs
|
240 |
+
|
241 |
+
def compute_preds(self, x, y, negatives):
|
242 |
+
neg_is_pos = (y == negatives).all(-1)
|
243 |
+
y = y.unsqueeze(0)
|
244 |
+
targets = torch.cat([y, negatives], dim=0)
|
245 |
+
|
246 |
+
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1)
|
247 |
+
logits = logits / self.logit_temp
|
248 |
+
logits = logits.type_as(x)
|
249 |
+
|
250 |
+
if neg_is_pos.any():
|
251 |
+
if not hasattr(self, "_inftensor"):
|
252 |
+
self._inftensor = float("-inf")
|
253 |
+
# logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor)
|
254 |
+
logits[1:][neg_is_pos] = self._inftensor
|
255 |
+
logits = logits.transpose(0, 2)
|
256 |
+
logits = logits.transpose(0, 1)
|
257 |
+
logits = logits.reshape(-1, logits.size(-1))
|
258 |
+
return logits
|
259 |
+
|
260 |
+
def CTL(self, x, y, negs, mask=None):
|
261 |
+
# Step1: compute cosine similarity, shape [B*T, n_negatives+1]
|
262 |
+
logits = self.compute_preds(x, y, negs)
|
263 |
+
|
264 |
+
# Step2: target shape [B*T]
|
265 |
+
target = x.new_zeros(x.size(0) * x.size(1), dtype=torch.long)
|
266 |
+
|
267 |
+
# Step3: compute CTL loss
|
268 |
+
if mask is not None:
|
269 |
+
normalize_length = mask.sum()
|
270 |
+
bz, sz = mask.size(0), mask.size(-1)
|
271 |
+
mask = mask.squeeze(1).reshape(bz * sz).eq(0)
|
272 |
+
ce = F.cross_entropy(logits, target, reduction='none')
|
273 |
+
loss = ce.masked_fill(mask, 0).sum() / normalize_length
|
274 |
+
else:
|
275 |
+
loss = F.cross_entropy(logits, target)
|
276 |
+
|
277 |
+
return loss
|
wenet/ctl_model/encoder.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
2 |
+
# 2022 Xingchen Song ([email protected])
|
3 |
+
# 2023 NetEase Inc
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
17 |
+
"""Encoder definition."""
|
18 |
+
from typing import Optional, Tuple
|
19 |
+
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from wenet.utils.mask import make_pad_mask
|
23 |
+
from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder
|
24 |
+
|
25 |
+
|
26 |
+
class DualTransformerEncoder(TransformerEncoder):
|
27 |
+
"""Transformer encoder module."""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
input_size: int,
|
32 |
+
output_size: int = 256,
|
33 |
+
attention_heads: int = 4,
|
34 |
+
linear_units: int = 2048,
|
35 |
+
num_blocks: int = 6,
|
36 |
+
dropout_rate: float = 0.1,
|
37 |
+
positional_dropout_rate: float = 0.1,
|
38 |
+
attention_dropout_rate: float = 0.0,
|
39 |
+
input_layer: str = "conv2d",
|
40 |
+
pos_enc_layer_type: str = "abs_pos",
|
41 |
+
normalize_before: bool = True,
|
42 |
+
static_chunk_size: int = 0,
|
43 |
+
use_dynamic_chunk: bool = False,
|
44 |
+
global_cmvn: torch.nn.Module = None,
|
45 |
+
use_dynamic_left_chunk: bool = False,
|
46 |
+
query_bias: bool = True,
|
47 |
+
key_bias: bool = True,
|
48 |
+
value_bias: bool = True,
|
49 |
+
activation_type: str = "relu",
|
50 |
+
gradient_checkpointing: bool = False,
|
51 |
+
use_sdpa: bool = False,
|
52 |
+
layer_norm_type: str = 'layer_norm',
|
53 |
+
norm_eps: float = 1e-5,
|
54 |
+
n_kv_head: Optional[int] = None,
|
55 |
+
head_dim: Optional[int] = None,
|
56 |
+
selfattention_layer_type: str = "selfattn",
|
57 |
+
mlp_type: str = 'position_wise_feed_forward',
|
58 |
+
mlp_bias: bool = True,
|
59 |
+
n_expert: int = 8,
|
60 |
+
n_expert_activated: int = 2,
|
61 |
+
):
|
62 |
+
""" Construct DualTransformerEncoder
|
63 |
+
Support both the full context mode and the streaming mode separately
|
64 |
+
"""
|
65 |
+
super().__init__(input_size, output_size, attention_heads,
|
66 |
+
linear_units, num_blocks, dropout_rate,
|
67 |
+
positional_dropout_rate, attention_dropout_rate,
|
68 |
+
input_layer, pos_enc_layer_type, normalize_before,
|
69 |
+
static_chunk_size, use_dynamic_chunk, global_cmvn,
|
70 |
+
use_dynamic_left_chunk, query_bias, key_bias,
|
71 |
+
value_bias, activation_type, gradient_checkpointing,
|
72 |
+
use_sdpa, layer_norm_type, norm_eps, n_kv_head,
|
73 |
+
head_dim, selfattention_layer_type, mlp_type,
|
74 |
+
mlp_bias, n_expert, n_expert_activated)
|
75 |
+
|
76 |
+
def forward_full(
|
77 |
+
self,
|
78 |
+
xs: torch.Tensor,
|
79 |
+
xs_lens: torch.Tensor,
|
80 |
+
decoding_chunk_size: int = 0,
|
81 |
+
num_decoding_left_chunks: int = -1,
|
82 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
83 |
+
T = xs.size(1)
|
84 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
85 |
+
if self.global_cmvn is not None:
|
86 |
+
xs = self.global_cmvn(xs)
|
87 |
+
xs, pos_emb, masks = self.embed(xs, masks)
|
88 |
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
89 |
+
for layer in self.encoders:
|
90 |
+
xs, masks, _, _ = layer(xs, masks, pos_emb, mask_pad)
|
91 |
+
if self.normalize_before:
|
92 |
+
xs = self.after_norm(xs)
|
93 |
+
return xs, masks
|
94 |
+
|
95 |
+
|
96 |
+
class DualConformerEncoder(ConformerEncoder):
|
97 |
+
"""Conformer encoder module."""
|
98 |
+
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
input_size: int,
|
102 |
+
output_size: int = 256,
|
103 |
+
attention_heads: int = 4,
|
104 |
+
linear_units: int = 2048,
|
105 |
+
num_blocks: int = 6,
|
106 |
+
dropout_rate: float = 0.1,
|
107 |
+
positional_dropout_rate: float = 0.1,
|
108 |
+
attention_dropout_rate: float = 0.0,
|
109 |
+
input_layer: str = "conv2d",
|
110 |
+
pos_enc_layer_type: str = "rel_pos",
|
111 |
+
normalize_before: bool = True,
|
112 |
+
static_chunk_size: int = 0,
|
113 |
+
use_dynamic_chunk: bool = False,
|
114 |
+
global_cmvn: torch.nn.Module = None,
|
115 |
+
use_dynamic_left_chunk: bool = False,
|
116 |
+
positionwise_conv_kernel_size: int = 1,
|
117 |
+
macaron_style: bool = True,
|
118 |
+
selfattention_layer_type: str = "rel_selfattn",
|
119 |
+
activation_type: str = "swish",
|
120 |
+
use_cnn_module: bool = True,
|
121 |
+
cnn_module_kernel: int = 15,
|
122 |
+
causal: bool = False,
|
123 |
+
cnn_module_norm: str = "batch_norm",
|
124 |
+
query_bias: bool = True,
|
125 |
+
key_bias: bool = True,
|
126 |
+
value_bias: bool = True,
|
127 |
+
conv_bias: bool = True,
|
128 |
+
gradient_checkpointing: bool = False,
|
129 |
+
use_sdpa: bool = False,
|
130 |
+
layer_norm_type: str = 'layer_norm',
|
131 |
+
norm_eps: float = 1e-5,
|
132 |
+
n_kv_head: Optional[int] = None,
|
133 |
+
head_dim: Optional[int] = None,
|
134 |
+
mlp_type: str = 'position_wise_feed_forward',
|
135 |
+
mlp_bias: bool = True,
|
136 |
+
n_expert: int = 8,
|
137 |
+
n_expert_activated: int = 2,
|
138 |
+
):
|
139 |
+
""" Construct DualConformerEncoder
|
140 |
+
Support both the full context mode and the streaming mode separately
|
141 |
+
"""
|
142 |
+
super().__init__(
|
143 |
+
input_size, output_size, attention_heads, linear_units, num_blocks,
|
144 |
+
dropout_rate, positional_dropout_rate, attention_dropout_rate,
|
145 |
+
input_layer, pos_enc_layer_type, normalize_before,
|
146 |
+
static_chunk_size, use_dynamic_chunk, global_cmvn,
|
147 |
+
use_dynamic_left_chunk, positionwise_conv_kernel_size,
|
148 |
+
macaron_style, selfattention_layer_type, activation_type,
|
149 |
+
use_cnn_module, cnn_module_kernel, causal, cnn_module_norm,
|
150 |
+
query_bias, key_bias, value_bias, conv_bias,
|
151 |
+
gradient_checkpointing, use_sdpa, layer_norm_type, norm_eps,
|
152 |
+
n_kv_head, head_dim, mlp_type, mlp_bias, n_expert,
|
153 |
+
n_expert_activated)
|
154 |
+
|
155 |
+
def forward_full(
|
156 |
+
self,
|
157 |
+
xs: torch.Tensor,
|
158 |
+
xs_lens: torch.Tensor,
|
159 |
+
decoding_chunk_size: int = 0,
|
160 |
+
num_decoding_left_chunks: int = -1,
|
161 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
162 |
+
T = xs.size(1)
|
163 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
164 |
+
if self.global_cmvn is not None:
|
165 |
+
xs = self.global_cmvn(xs)
|
166 |
+
xs, pos_emb, masks = self.embed(xs, masks)
|
167 |
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
168 |
+
for layer in self.encoders:
|
169 |
+
xs, masks, _, _ = layer(xs, masks, pos_emb, mask_pad)
|
170 |
+
if self.normalize_before:
|
171 |
+
xs = self.after_norm(xs)
|
172 |
+
return xs, masks
|
wenet/dataset/__init__.py
ADDED
File without changes
|
wenet/dataset/datapipes.py
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Wenet Community. (authors: Dinghao Zhou)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import collections
|
16 |
+
from collections.abc import Callable
|
17 |
+
import copy
|
18 |
+
import sys
|
19 |
+
import tarfile
|
20 |
+
import logging
|
21 |
+
from typing import List, Optional
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
from torch.utils.data import IterDataPipe, functional_datapipe
|
25 |
+
from torch.utils.data import datapipes
|
26 |
+
from torch.utils.data.datapipes.iter import Mapper
|
27 |
+
from torch.utils.data.datapipes.iter.sharding import (
|
28 |
+
SHARDING_PRIORITIES, ShardingFilterIterDataPipe)
|
29 |
+
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
|
30 |
+
|
31 |
+
from wenet.dataset.processor import parse_url
|
32 |
+
|
33 |
+
|
34 |
+
@functional_datapipe("map_ignore_error")
|
35 |
+
class MapperIgnoreErrorDataPipe(Mapper):
|
36 |
+
|
37 |
+
def __init__(self,
|
38 |
+
dataset: IterDataPipe,
|
39 |
+
fn: Callable,
|
40 |
+
input_col=None,
|
41 |
+
output_col=None,
|
42 |
+
log_error: bool = True) -> None:
|
43 |
+
super().__init__(dataset, fn, input_col, output_col)
|
44 |
+
self._iter = None
|
45 |
+
self.log_error = log_error
|
46 |
+
|
47 |
+
def __iter__(self):
|
48 |
+
if self._iter is None:
|
49 |
+
self._iter = iter(self.datapipe)
|
50 |
+
|
51 |
+
while True:
|
52 |
+
try:
|
53 |
+
elem = next(self._iter)
|
54 |
+
yield self._apply_fn(elem)
|
55 |
+
except StopIteration:
|
56 |
+
self._iter = None
|
57 |
+
return
|
58 |
+
except Exception as ex:
|
59 |
+
if self.log_error:
|
60 |
+
logging.warning(str(ex))
|
61 |
+
|
62 |
+
|
63 |
+
@functional_datapipe('bucket_by_sequence_length')
|
64 |
+
class BucketBySequenceLengthDataPipe(IterDataPipe):
|
65 |
+
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
dataset: IterDataPipe,
|
69 |
+
elem_length_func,
|
70 |
+
bucket_boundaries: List[int],
|
71 |
+
bucket_batch_sizes: List[int],
|
72 |
+
wrapper_class=None,
|
73 |
+
) -> None:
|
74 |
+
super().__init__()
|
75 |
+
_check_unpickable_fn(elem_length_func)
|
76 |
+
assert len(bucket_batch_sizes) == len(bucket_boundaries) + 1
|
77 |
+
self.bucket_batch_sizes = bucket_batch_sizes
|
78 |
+
self.bucket_boundaries = bucket_boundaries + [sys.maxsize]
|
79 |
+
self.elem_length_func = elem_length_func
|
80 |
+
|
81 |
+
self._group_dp = GroupByWindowDataPipe(dataset,
|
82 |
+
self._element_to_bucket_id,
|
83 |
+
self._window_size_func,
|
84 |
+
wrapper_class=wrapper_class)
|
85 |
+
|
86 |
+
def __iter__(self):
|
87 |
+
yield from self._group_dp
|
88 |
+
|
89 |
+
def _element_to_bucket_id(self, elem):
|
90 |
+
seq_len = self.elem_length_func(elem)
|
91 |
+
bucket_id = 0
|
92 |
+
for (i, b) in enumerate(self.bucket_boundaries):
|
93 |
+
if seq_len < b:
|
94 |
+
bucket_id = i
|
95 |
+
break
|
96 |
+
return bucket_id
|
97 |
+
|
98 |
+
def _window_size_func(self, bucket_id):
|
99 |
+
return self.bucket_batch_sizes[bucket_id]
|
100 |
+
|
101 |
+
|
102 |
+
@functional_datapipe("group_by_window")
|
103 |
+
class GroupByWindowDataPipe(datapipes.iter.Grouper):
|
104 |
+
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
dataset: IterDataPipe,
|
108 |
+
key_func,
|
109 |
+
window_size_func,
|
110 |
+
wrapper_class=None,
|
111 |
+
):
|
112 |
+
super().__init__(dataset,
|
113 |
+
key_func,
|
114 |
+
keep_key=False,
|
115 |
+
group_size=None,
|
116 |
+
drop_remaining=False)
|
117 |
+
_check_unpickable_fn(window_size_func)
|
118 |
+
self.dp = dataset
|
119 |
+
self.window_size_func = window_size_func
|
120 |
+
if wrapper_class is not None:
|
121 |
+
_check_unpickable_fn(wrapper_class)
|
122 |
+
del self.wrapper_class
|
123 |
+
self.wrapper_class = wrapper_class
|
124 |
+
|
125 |
+
def __iter__(self):
|
126 |
+
for x in self.datapipe:
|
127 |
+
key = self.group_key_fn(x)
|
128 |
+
|
129 |
+
self.buffer_elements[key].append(x)
|
130 |
+
self.curr_buffer_size += 1
|
131 |
+
|
132 |
+
group_size = self.window_size_func(key)
|
133 |
+
if group_size == len(self.buffer_elements[key]):
|
134 |
+
result = self.wrapper_class(self.buffer_elements[key])
|
135 |
+
yield result
|
136 |
+
self.curr_buffer_size -= len(self.buffer_elements[key])
|
137 |
+
del self.buffer_elements[key]
|
138 |
+
|
139 |
+
if self.curr_buffer_size == self.max_buffer_size:
|
140 |
+
result_to_yield = self._remove_biggest_key()
|
141 |
+
if result_to_yield is not None:
|
142 |
+
result = self.wrapper_class(result_to_yield)
|
143 |
+
yield result
|
144 |
+
|
145 |
+
for key in tuple(self.buffer_elements.keys()):
|
146 |
+
result = self.wrapper_class(self.buffer_elements.pop(key))
|
147 |
+
self.curr_buffer_size -= len(result)
|
148 |
+
yield result
|
149 |
+
|
150 |
+
|
151 |
+
@functional_datapipe("sort")
|
152 |
+
class SortDataPipe(IterDataPipe):
|
153 |
+
|
154 |
+
def __init__(self,
|
155 |
+
dataset: IterDataPipe,
|
156 |
+
buffer_size: int = 500,
|
157 |
+
key_func=None,
|
158 |
+
reverse=False) -> None:
|
159 |
+
if key_func is not None:
|
160 |
+
_check_unpickable_fn(key_func)
|
161 |
+
self.buffer_size = buffer_size
|
162 |
+
super().__init__()
|
163 |
+
self.dp = dataset
|
164 |
+
self._buffer = []
|
165 |
+
self.key_func = key_func
|
166 |
+
self.reverse = reverse
|
167 |
+
|
168 |
+
def __iter__(self):
|
169 |
+
for elem in self.dp:
|
170 |
+
self._buffer.append(elem)
|
171 |
+
if len(self._buffer) >= self.buffer_size:
|
172 |
+
self._buffer.sort(key=self.key_func, reverse=self.reverse)
|
173 |
+
for x in self._buffer:
|
174 |
+
yield x
|
175 |
+
del self._buffer
|
176 |
+
self._buffer = []
|
177 |
+
# The sample left over
|
178 |
+
self._buffer.sort(key=self.key_func, reverse=self.reverse)
|
179 |
+
for x in self._buffer:
|
180 |
+
yield x
|
181 |
+
del self._buffer
|
182 |
+
self._buffer = []
|
183 |
+
|
184 |
+
|
185 |
+
@functional_datapipe("dynamic_batch")
|
186 |
+
class DynamicBatchDataPipe(IterDataPipe):
|
187 |
+
|
188 |
+
def __init__(self, dataset: IterDataPipe, window_class,
|
189 |
+
wrapper_class) -> None:
|
190 |
+
_check_unpickable_fn(window_class)
|
191 |
+
_check_unpickable_fn(wrapper_class)
|
192 |
+
super().__init__()
|
193 |
+
self.dp = dataset
|
194 |
+
assert window_class is not None
|
195 |
+
assert wrapper_class is not None
|
196 |
+
self.window_class = window_class
|
197 |
+
self._buffer = []
|
198 |
+
self._wrappr_class = wrapper_class
|
199 |
+
|
200 |
+
def __iter__(self):
|
201 |
+
for elem in self.dp:
|
202 |
+
if not self.window_class(elem, len(self._buffer)):
|
203 |
+
self._buffer.append(elem)
|
204 |
+
else:
|
205 |
+
if len(self._buffer) > 0:
|
206 |
+
yield self._wrappr_class(self._buffer)
|
207 |
+
del self._buffer
|
208 |
+
self._buffer = [elem]
|
209 |
+
if len(self._buffer) > 0:
|
210 |
+
yield self._wrappr_class(self._buffer)
|
211 |
+
del self._buffer
|
212 |
+
self._buffer = []
|
213 |
+
|
214 |
+
|
215 |
+
@functional_datapipe("prefetch")
|
216 |
+
class PrefetchDataPipe(IterDataPipe):
|
217 |
+
"""Performs prefetching"""
|
218 |
+
|
219 |
+
def __init__(
|
220 |
+
self,
|
221 |
+
dataset: IterDataPipe,
|
222 |
+
buffer_size: int = 500,
|
223 |
+
):
|
224 |
+
# TODO(Mddct): support multiprocessing pool with shared-memory to
|
225 |
+
# prefetch
|
226 |
+
super().__init__()
|
227 |
+
self.dp = dataset
|
228 |
+
self._iter = None
|
229 |
+
self._prefetch_buffer_size = buffer_size
|
230 |
+
self._buffer = None
|
231 |
+
if self._prefetch_buffer_size > 0:
|
232 |
+
self._buffer = collections.deque(maxlen=self._prefetch_buffer_size)
|
233 |
+
|
234 |
+
def __iter__(self):
|
235 |
+
if self._prefetch_buffer_size > 0:
|
236 |
+
if self._iter is None:
|
237 |
+
self._iter = iter(self.dp)
|
238 |
+
assert self._buffer is not None
|
239 |
+
|
240 |
+
while True:
|
241 |
+
if len(self._buffer) <= self._prefetch_buffer_size // 2:
|
242 |
+
while len(self._buffer) < self._prefetch_buffer_size:
|
243 |
+
try:
|
244 |
+
self._buffer.append(next(self._iter))
|
245 |
+
except StopIteration:
|
246 |
+
if len(self._buffer) != 0:
|
247 |
+
while len(self._buffer) > 0:
|
248 |
+
yield self._buffer.popleft()
|
249 |
+
self._iter = None
|
250 |
+
return
|
251 |
+
while len(self._buffer) > self._prefetch_buffer_size // 2:
|
252 |
+
elem = self._buffer.popleft()
|
253 |
+
yield elem
|
254 |
+
|
255 |
+
else:
|
256 |
+
yield from self.dp
|
257 |
+
|
258 |
+
|
259 |
+
@functional_datapipe("repeat")
|
260 |
+
class RepeatDatapipe(IterDataPipe):
|
261 |
+
|
262 |
+
def __init__(self, dataset: IterDataPipe, count: int = -1):
|
263 |
+
super().__init__()
|
264 |
+
self.dp = dataset
|
265 |
+
self.count = count
|
266 |
+
|
267 |
+
def __iter__(self):
|
268 |
+
if self.count == 1:
|
269 |
+
yield from self.dp
|
270 |
+
return
|
271 |
+
i = 0
|
272 |
+
while self.count < 0 or i < self.count:
|
273 |
+
for elem in self.dp:
|
274 |
+
new_elem = copy.copy(elem)
|
275 |
+
yield new_elem
|
276 |
+
i += 1
|
277 |
+
|
278 |
+
|
279 |
+
@functional_datapipe("shard")
|
280 |
+
class ShardDataPipe(ShardingFilterIterDataPipe):
|
281 |
+
|
282 |
+
def __init__(self, dataset: IterDataPipe, partition: bool = False):
|
283 |
+
super().__init__(dataset, None)
|
284 |
+
self.partition = partition
|
285 |
+
self.dp = dataset
|
286 |
+
|
287 |
+
def apply_sharding(self, num_of_instances: int, instance_id: int,
|
288 |
+
sharding_group: SHARDING_PRIORITIES):
|
289 |
+
if self.partition:
|
290 |
+
return super().apply_sharding(num_of_instances, instance_id,
|
291 |
+
sharding_group)
|
292 |
+
else:
|
293 |
+
# We can not handle uneven data for CV on DDP, so we don't
|
294 |
+
# sample data by rank, that means every GPU gets the same
|
295 |
+
# and all the CV data
|
296 |
+
info = torch.utils.data.get_worker_info()
|
297 |
+
if info is None:
|
298 |
+
self.num_of_instances = 1
|
299 |
+
self.instance_id = 0
|
300 |
+
else:
|
301 |
+
n_workers_per_device = info.num_workers
|
302 |
+
self.num_of_instances = n_workers_per_device
|
303 |
+
self.instance_id = info.id
|
304 |
+
|
305 |
+
|
306 |
+
@functional_datapipe("interleave")
|
307 |
+
class InterlaveDataPipe(IterDataPipe):
|
308 |
+
|
309 |
+
def __init__(
|
310 |
+
self,
|
311 |
+
source_datapipes: List[IterDataPipe],
|
312 |
+
weights: Optional[List[float]] = None,
|
313 |
+
seed=2027,
|
314 |
+
):
|
315 |
+
super().__init__()
|
316 |
+
self.rng = np.random.default_rng(seed)
|
317 |
+
self.source_datapipes = source_datapipes
|
318 |
+
self.weights = weights
|
319 |
+
if weights is None:
|
320 |
+
self.weights = [1 / len(self.source_datapipes)] * len(
|
321 |
+
self.source_datapipes)
|
322 |
+
else:
|
323 |
+
self.weights = [weight / sum(weights) for weight in weights]
|
324 |
+
self.iters = None
|
325 |
+
|
326 |
+
def __iter__(self):
|
327 |
+
weights = copy.deepcopy(self.weights)
|
328 |
+
exhausted = len(self.source_datapipes) * [False]
|
329 |
+
if self.iters is None:
|
330 |
+
self.iters = [(i, iter(d))
|
331 |
+
for i, d in enumerate(self.source_datapipes)]
|
332 |
+
while True:
|
333 |
+
# TODO(Mddct): rng
|
334 |
+
index_iter = self.rng.choice(self.iters, p=weights)
|
335 |
+
i, ite = index_iter
|
336 |
+
try:
|
337 |
+
elem = next(ite)
|
338 |
+
yield elem
|
339 |
+
except StopIteration:
|
340 |
+
weights[i] = 0.
|
341 |
+
exhausted[i] = True
|
342 |
+
if all(exhausted):
|
343 |
+
return
|
344 |
+
weights = [weight / sum(weights) for weight in weights]
|
345 |
+
|
346 |
+
|
347 |
+
class TextLineDataPipe(IterDataPipe):
|
348 |
+
""" Streamming Text line
|
349 |
+
"""
|
350 |
+
|
351 |
+
def __init__(self, filenames, mode='r'):
|
352 |
+
super().__init__()
|
353 |
+
_dp = datapipes.iter.FileLister(filenames)
|
354 |
+
_dp = datapipes.iter.FileOpener(_dp, mode=mode)
|
355 |
+
self.dp = _dp
|
356 |
+
|
357 |
+
def __iter__(self):
|
358 |
+
for fname, stream in self.dp:
|
359 |
+
for line in stream:
|
360 |
+
line = line.strip('\n')
|
361 |
+
yield {"file_name": fname, "line": line}
|
362 |
+
stream.close()
|
363 |
+
|
364 |
+
|
365 |
+
@functional_datapipe("tar_file_and_group")
|
366 |
+
class TarsDataPipe(IterDataPipe):
|
367 |
+
""" Decode wenet's tar , yield {'txt': "...", "raw": "..."}
|
368 |
+
"""
|
369 |
+
|
370 |
+
def __init__(self, dataset: IterDataPipe) -> None:
|
371 |
+
super().__init__()
|
372 |
+
self.dp = dataset
|
373 |
+
|
374 |
+
def __iter__(self):
|
375 |
+
from wenet.dataset.processor import AUDIO_FORMAT_SETS
|
376 |
+
for sample in self.dp:
|
377 |
+
assert 'file_name' in sample
|
378 |
+
assert 'line' in sample
|
379 |
+
assert 'stream' in sample
|
380 |
+
try:
|
381 |
+
with tarfile.open(fileobj=sample['stream'],
|
382 |
+
mode="r:*") as stream:
|
383 |
+
prev_prefix = None
|
384 |
+
example = {
|
385 |
+
'file_name': sample['file_name'],
|
386 |
+
'tar_file_name': sample['line']
|
387 |
+
}
|
388 |
+
valid = True
|
389 |
+
for tarinfo in stream:
|
390 |
+
name = tarinfo.name
|
391 |
+
pos = name.rfind('.')
|
392 |
+
assert pos > 0
|
393 |
+
prefix, postfix = name[:pos], name[pos + 1:]
|
394 |
+
if prev_prefix is not None and prefix != prev_prefix:
|
395 |
+
example['key'] = prev_prefix
|
396 |
+
if valid:
|
397 |
+
yield example
|
398 |
+
example = {
|
399 |
+
'file_name': sample['file_name'],
|
400 |
+
'tar_file_name': sample['line']
|
401 |
+
}
|
402 |
+
valid = True
|
403 |
+
with stream.extractfile(tarinfo) as file_obj:
|
404 |
+
try:
|
405 |
+
if postfix == 'txt':
|
406 |
+
example['txt'] = file_obj.read().decode(
|
407 |
+
'utf8').strip()
|
408 |
+
elif postfix in AUDIO_FORMAT_SETS:
|
409 |
+
example['wav'] = file_obj.read()
|
410 |
+
else:
|
411 |
+
example[postfix] = file_obj.read()
|
412 |
+
except Exception as ex:
|
413 |
+
valid = False
|
414 |
+
logging.warning(
|
415 |
+
'error to parse {}'.format(name))
|
416 |
+
prev_prefix = prefix
|
417 |
+
if prev_prefix is not None:
|
418 |
+
example['key'] = prev_prefix
|
419 |
+
yield example
|
420 |
+
except Exception as ex:
|
421 |
+
msg = 'In tar_file_and_group: {} when processing {}'.format(
|
422 |
+
ex, sample['line'])
|
423 |
+
logging.warning(msg)
|
424 |
+
finally:
|
425 |
+
if 'process' in sample:
|
426 |
+
sample['process'].communicate()
|
427 |
+
sample['stream'].close()
|
428 |
+
|
429 |
+
|
430 |
+
class WenetRawDatasetSource(IterDataPipe):
|
431 |
+
|
432 |
+
def __init__(self,
|
433 |
+
filenames: str,
|
434 |
+
prefetch: int = 500,
|
435 |
+
partition: bool = True,
|
436 |
+
shuffle: bool = False,
|
437 |
+
shuffle_size: int = 10000,
|
438 |
+
cycle: int = 1) -> None:
|
439 |
+
super().__init__()
|
440 |
+
self.dp = TextLineDataPipe(filenames)
|
441 |
+
if shuffle:
|
442 |
+
self.dp = self.dp.shuffle(buffer_size=shuffle_size)
|
443 |
+
self.dp = self.dp.repeat(cycle).prefetch(prefetch)
|
444 |
+
self.dp = self.dp.shard(partition)
|
445 |
+
|
446 |
+
def __iter__(self):
|
447 |
+
for d in self.dp:
|
448 |
+
yield d
|
449 |
+
|
450 |
+
|
451 |
+
class WenetTarShardDatasetSource(IterDataPipe):
|
452 |
+
|
453 |
+
def __init__(self,
|
454 |
+
filenames: str,
|
455 |
+
prefetch: int = 500,
|
456 |
+
partition: bool = True,
|
457 |
+
shuffle: bool = False,
|
458 |
+
shuffle_size: int = 10000,
|
459 |
+
cycle: int = 1) -> None:
|
460 |
+
super().__init__()
|
461 |
+
self.dp = TextLineDataPipe(filenames)
|
462 |
+
if shuffle:
|
463 |
+
self.dp = self.dp.shuffle(buffer_size=shuffle_size)
|
464 |
+
self.dp = self.dp.repeat(cycle)
|
465 |
+
self.dp = self.dp.shard(partition).map_ignore_error(
|
466 |
+
parse_url).tar_file_and_group().prefetch(prefetch)
|
467 |
+
|
468 |
+
def __iter__(self):
|
469 |
+
for d in self.dp:
|
470 |
+
yield d
|
wenet/dataset/dataset.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import random
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.distributed as dist
|
19 |
+
from torch.utils.data import IterableDataset
|
20 |
+
|
21 |
+
import wenet.dataset.deprecated.processor as processor
|
22 |
+
from wenet.text.base_tokenizer import BaseTokenizer
|
23 |
+
from wenet.utils.file_utils import read_lists
|
24 |
+
|
25 |
+
|
26 |
+
class Processor(IterableDataset):
|
27 |
+
|
28 |
+
def __init__(self, source, f, *args, **kw):
|
29 |
+
assert callable(f)
|
30 |
+
self.source = source
|
31 |
+
self.f = f
|
32 |
+
self.args = args
|
33 |
+
self.kw = kw
|
34 |
+
|
35 |
+
def set_epoch(self, epoch):
|
36 |
+
self.source.set_epoch(epoch)
|
37 |
+
|
38 |
+
def __iter__(self):
|
39 |
+
""" Return an iterator over the source dataset processed by the
|
40 |
+
given processor.
|
41 |
+
"""
|
42 |
+
assert self.source is not None
|
43 |
+
assert callable(self.f)
|
44 |
+
return self.f(iter(self.source), *self.args, **self.kw)
|
45 |
+
|
46 |
+
def apply(self, f):
|
47 |
+
assert callable(f)
|
48 |
+
return Processor(self, f, *self.args, **self.kw)
|
49 |
+
|
50 |
+
|
51 |
+
class DistributedSampler:
|
52 |
+
|
53 |
+
def __init__(self, shuffle=True, partition=True, split_num=1):
|
54 |
+
self.epoch = -1
|
55 |
+
self.update()
|
56 |
+
self.shuffle = shuffle
|
57 |
+
self.partition = partition
|
58 |
+
self.split_num = split_num
|
59 |
+
|
60 |
+
def update(self):
|
61 |
+
assert dist.is_available()
|
62 |
+
if dist.is_initialized():
|
63 |
+
self.rank = dist.get_rank()
|
64 |
+
self.world_size = dist.get_world_size()
|
65 |
+
else:
|
66 |
+
self.rank = 0
|
67 |
+
self.world_size = 1
|
68 |
+
worker_info = torch.utils.data.get_worker_info()
|
69 |
+
if worker_info is None:
|
70 |
+
self.worker_id = 0
|
71 |
+
self.num_workers = 1
|
72 |
+
else:
|
73 |
+
self.worker_id = worker_info.id
|
74 |
+
self.num_workers = worker_info.num_workers
|
75 |
+
return dict(rank=self.rank,
|
76 |
+
world_size=self.world_size,
|
77 |
+
worker_id=self.worker_id,
|
78 |
+
num_workers=self.num_workers)
|
79 |
+
|
80 |
+
def set_epoch(self, epoch):
|
81 |
+
self.epoch = epoch
|
82 |
+
|
83 |
+
def split_data(self, total_num):
|
84 |
+
data = list(range(total_num))
|
85 |
+
sub_epoch = self.epoch + 1
|
86 |
+
full_epoch = sub_epoch // self.split_num
|
87 |
+
num_per_sub_epochs = total_num // self.split_num
|
88 |
+
random.Random(full_epoch).shuffle(data)
|
89 |
+
|
90 |
+
split_index = sub_epoch - full_epoch * self.split_num
|
91 |
+
begin = split_index * num_per_sub_epochs
|
92 |
+
end = (begin + num_per_sub_epochs
|
93 |
+
if (split_index + 1) < self.split_num else
|
94 |
+
total_num)
|
95 |
+
|
96 |
+
# print(f'begin: {begin}, end: {end}, world_size: {self.world_size}')
|
97 |
+
return data[begin:end]
|
98 |
+
|
99 |
+
def sample(self, data, split_num=1):
|
100 |
+
""" Sample data according to rank/world_size/num_workers
|
101 |
+
|
102 |
+
Args:
|
103 |
+
data(List): input data list
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
List: data list after sample
|
107 |
+
"""
|
108 |
+
if self.split_num == 1:
|
109 |
+
data = list(range(len(data)))
|
110 |
+
else:
|
111 |
+
data = self.split_data(len(data))
|
112 |
+
# TODO(Binbin Zhang): fix this
|
113 |
+
# We can not handle uneven data for CV on DDP, so we don't
|
114 |
+
# sample data by rank, that means every GPU gets the same
|
115 |
+
# and all the CV data
|
116 |
+
if self.partition:
|
117 |
+
if self.shuffle:
|
118 |
+
random.Random(self.epoch).shuffle(data)
|
119 |
+
data = data[self.rank::self.world_size]
|
120 |
+
# print(f'num dataset: {len(data)}')
|
121 |
+
data = data[self.worker_id::self.num_workers]
|
122 |
+
self.epoch += 1
|
123 |
+
return data
|
124 |
+
|
125 |
+
|
126 |
+
class DataList(IterableDataset):
|
127 |
+
|
128 |
+
def __init__(self, lists, shuffle=True, partition=True, split_num=1):
|
129 |
+
self.lists = lists
|
130 |
+
self.sampler = DistributedSampler(shuffle, partition, split_num)
|
131 |
+
|
132 |
+
def set_epoch(self, epoch):
|
133 |
+
self.sampler.set_epoch(epoch)
|
134 |
+
|
135 |
+
def __iter__(self):
|
136 |
+
sampler_info = self.sampler.update()
|
137 |
+
indexes = self.sampler.sample(self.lists)
|
138 |
+
for index in indexes:
|
139 |
+
# yield dict(src=src)
|
140 |
+
data = dict(src=self.lists[index])
|
141 |
+
data.update(sampler_info)
|
142 |
+
yield data
|
143 |
+
|
144 |
+
|
145 |
+
def Dataset(data_type,
|
146 |
+
data_list_file,
|
147 |
+
tokenizer: BaseTokenizer,
|
148 |
+
conf,
|
149 |
+
partition=True):
|
150 |
+
""" Construct dataset from arguments
|
151 |
+
|
152 |
+
We have two shuffle stage in the Dataset. The first is global
|
153 |
+
shuffle at shards tar/raw file level. The second is global shuffle
|
154 |
+
at training samples level.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
data_type(str): raw/shard
|
158 |
+
bpe_model(str): model for english bpe part
|
159 |
+
partition(bool): whether to do data partition in terms of rank
|
160 |
+
"""
|
161 |
+
assert data_type in ['raw', 'shard', 'shard_full_data']
|
162 |
+
lists = read_lists(data_list_file)
|
163 |
+
shuffle = conf.get('shuffle', True)
|
164 |
+
split_num = conf.get('split_num', 1)
|
165 |
+
dataset = DataList(lists, shuffle=shuffle, partition=partition, split_num=split_num)
|
166 |
+
if data_type == 'shard':
|
167 |
+
dataset = Processor(dataset, processor.url_opener)
|
168 |
+
dataset = Processor(dataset, processor.tar_file_and_group)
|
169 |
+
elif data_type == 'shard_full_data':
|
170 |
+
dataset = Processor(dataset, processor.url_opener)
|
171 |
+
dataset = Processor(dataset, processor.tar_file_and_group_full_data)
|
172 |
+
else:
|
173 |
+
dataset = Processor(dataset, processor.parse_raw)
|
174 |
+
|
175 |
+
speaker_conf = conf.get('speaker_conf', None)
|
176 |
+
if speaker_conf is not None:
|
177 |
+
dataset = Processor(dataset, processor.parse_speaker, **speaker_conf)
|
178 |
+
|
179 |
+
if conf.get('eod_id', None) is not None:
|
180 |
+
tokenizer.eod_id = conf['eod_id']
|
181 |
+
# prompt dict
|
182 |
+
from gxl_ai_utils.utils import utils_file
|
183 |
+
global_prompt_dict = utils_file.load_dict_from_yaml('conf/prompt_stage4.yaml')
|
184 |
+
dataset = Processor(dataset, processor.tokenize, tokenizer,
|
185 |
+
global_prompt_dict=global_prompt_dict)
|
186 |
+
filter_conf = conf.get('filter_conf', {})
|
187 |
+
dataset = Processor(dataset, processor.filter, **filter_conf)
|
188 |
+
|
189 |
+
resample_conf = conf.get('resample_conf', {})
|
190 |
+
dataset = Processor(dataset, processor.resample, **resample_conf)
|
191 |
+
|
192 |
+
speed_perturb = conf.get('speed_perturb', False)
|
193 |
+
if speed_perturb:
|
194 |
+
dataset = Processor(dataset, processor.speed_perturb)
|
195 |
+
|
196 |
+
feats_type = conf.get('feats_type', 'fbank')
|
197 |
+
assert feats_type in ['fbank', 'mfcc', 'log_mel_spectrogram']
|
198 |
+
if feats_type == 'fbank':
|
199 |
+
fbank_conf = conf.get('fbank_conf', {})
|
200 |
+
dataset = Processor(dataset, processor.compute_fbank, **fbank_conf)
|
201 |
+
elif feats_type == 'mfcc':
|
202 |
+
mfcc_conf = conf.get('mfcc_conf', {})
|
203 |
+
dataset = Processor(dataset, processor.compute_mfcc, **mfcc_conf)
|
204 |
+
elif feats_type == 'log_mel_spectrogram':
|
205 |
+
log_mel_spectrogram_conf = conf.get('log_mel_spectrogram_conf', {})
|
206 |
+
dataset = Processor(dataset, processor.compute_log_mel_spectrogram,
|
207 |
+
**log_mel_spectrogram_conf)
|
208 |
+
|
209 |
+
spec_aug = conf.get('spec_aug', True)
|
210 |
+
spec_sub = conf.get('spec_sub', False)
|
211 |
+
spec_trim = conf.get('spec_trim', False)
|
212 |
+
if spec_aug:
|
213 |
+
spec_aug_conf = conf.get('spec_aug_conf', {})
|
214 |
+
dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf)
|
215 |
+
if spec_sub:
|
216 |
+
spec_sub_conf = conf.get('spec_sub_conf', {})
|
217 |
+
dataset = Processor(dataset, processor.spec_sub, **spec_sub_conf)
|
218 |
+
if spec_trim:
|
219 |
+
spec_trim_conf = conf.get('spec_trim_conf', {})
|
220 |
+
dataset = Processor(dataset, processor.spec_trim, **spec_trim_conf)
|
221 |
+
|
222 |
+
if shuffle:
|
223 |
+
shuffle_conf = conf.get('shuffle_conf', {})
|
224 |
+
dataset = Processor(dataset, processor.shuffle, **shuffle_conf)
|
225 |
+
|
226 |
+
sort = conf.get('sort', True)
|
227 |
+
if sort:
|
228 |
+
sort_conf = conf.get('sort_conf', {})
|
229 |
+
dataset = Processor(dataset, processor.sort, **sort_conf)
|
230 |
+
|
231 |
+
batch_conf = conf.get('batch_conf', {})
|
232 |
+
dataset = Processor(dataset, processor.batch, **batch_conf)
|
233 |
+
dataset = Processor(dataset, processor.padding)
|
234 |
+
return dataset
|
wenet/dataset/deprecated/dataset.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import random
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.distributed as dist
|
19 |
+
from torch.utils.data import IterableDataset
|
20 |
+
|
21 |
+
import wenet.dataset.deprecated.processor as processor
|
22 |
+
from wenet.text.base_tokenizer import BaseTokenizer
|
23 |
+
from wenet.utils.file_utils import read_lists
|
24 |
+
|
25 |
+
|
26 |
+
class Processor(IterableDataset):
|
27 |
+
|
28 |
+
def __init__(self, source, f, *args, **kw):
|
29 |
+
assert callable(f)
|
30 |
+
self.source = source
|
31 |
+
self.f = f
|
32 |
+
self.args = args
|
33 |
+
self.kw = kw
|
34 |
+
|
35 |
+
def set_epoch(self, epoch):
|
36 |
+
self.source.set_epoch(epoch)
|
37 |
+
|
38 |
+
def __iter__(self):
|
39 |
+
""" Return an iterator over the source dataset processed by the
|
40 |
+
given processor.
|
41 |
+
"""
|
42 |
+
assert self.source is not None
|
43 |
+
assert callable(self.f)
|
44 |
+
return self.f(iter(self.source), *self.args, **self.kw)
|
45 |
+
|
46 |
+
def apply(self, f):
|
47 |
+
assert callable(f)
|
48 |
+
return Processor(self, f, *self.args, **self.kw)
|
49 |
+
|
50 |
+
|
51 |
+
class DistributedSampler:
|
52 |
+
|
53 |
+
def __init__(self, shuffle=True, partition=True):
|
54 |
+
self.epoch = -1
|
55 |
+
self.update()
|
56 |
+
self.shuffle = shuffle
|
57 |
+
self.partition = partition
|
58 |
+
|
59 |
+
def update(self):
|
60 |
+
assert dist.is_available()
|
61 |
+
if dist.is_initialized():
|
62 |
+
self.rank = dist.get_rank()
|
63 |
+
self.world_size = dist.get_world_size()
|
64 |
+
else:
|
65 |
+
self.rank = 0
|
66 |
+
self.world_size = 1
|
67 |
+
worker_info = torch.utils.data.get_worker_info()
|
68 |
+
if worker_info is None:
|
69 |
+
self.worker_id = 0
|
70 |
+
self.num_workers = 1
|
71 |
+
else:
|
72 |
+
self.worker_id = worker_info.id
|
73 |
+
self.num_workers = worker_info.num_workers
|
74 |
+
return dict(rank=self.rank,
|
75 |
+
world_size=self.world_size,
|
76 |
+
worker_id=self.worker_id,
|
77 |
+
num_workers=self.num_workers)
|
78 |
+
|
79 |
+
def set_epoch(self, epoch):
|
80 |
+
self.epoch = epoch
|
81 |
+
|
82 |
+
def sample(self, data):
|
83 |
+
""" Sample data according to rank/world_size/num_workers
|
84 |
+
|
85 |
+
Args:
|
86 |
+
data(List): input data list
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
List: data list after sample
|
90 |
+
"""
|
91 |
+
data = list(range(len(data)))
|
92 |
+
# TODO(Binbin Zhang): fix this
|
93 |
+
# We can not handle uneven data for CV on DDP, so we don't
|
94 |
+
# sample data by rank, that means every GPU gets the same
|
95 |
+
# and all the CV data
|
96 |
+
if self.partition:
|
97 |
+
if self.shuffle:
|
98 |
+
random.Random(self.epoch).shuffle(data)
|
99 |
+
data = data[self.rank::self.world_size]
|
100 |
+
data = data[self.worker_id::self.num_workers]
|
101 |
+
return data
|
102 |
+
|
103 |
+
|
104 |
+
class DataList(IterableDataset):
|
105 |
+
|
106 |
+
def __init__(self, lists, shuffle=True, partition=True):
|
107 |
+
self.lists = lists
|
108 |
+
self.sampler = DistributedSampler(shuffle, partition)
|
109 |
+
|
110 |
+
def set_epoch(self, epoch):
|
111 |
+
self.sampler.set_epoch(epoch)
|
112 |
+
|
113 |
+
def __iter__(self):
|
114 |
+
sampler_info = self.sampler.update()
|
115 |
+
indexes = self.sampler.sample(self.lists)
|
116 |
+
for index in indexes:
|
117 |
+
# yield dict(src=src)
|
118 |
+
data = dict(src=self.lists[index])
|
119 |
+
data.update(sampler_info)
|
120 |
+
yield data
|
121 |
+
|
122 |
+
|
123 |
+
def Dataset(data_type,
|
124 |
+
data_list_file,
|
125 |
+
tokenizer: BaseTokenizer,
|
126 |
+
conf,
|
127 |
+
partition=True):
|
128 |
+
""" Construct dataset from arguments
|
129 |
+
|
130 |
+
We have two shuffle stage in the Dataset. The first is global
|
131 |
+
shuffle at shards tar/raw file level. The second is global shuffle
|
132 |
+
at training samples level.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
data_type(str): raw/shard
|
136 |
+
bpe_model(str): model for english bpe part
|
137 |
+
partition(bool): whether to do data partition in terms of rank
|
138 |
+
"""
|
139 |
+
assert data_type in ['raw', 'shard']
|
140 |
+
lists = read_lists(data_list_file)
|
141 |
+
shuffle = conf.get('shuffle', True)
|
142 |
+
dataset = DataList(lists, shuffle=shuffle, partition=partition)
|
143 |
+
if data_type == 'shard':
|
144 |
+
dataset = Processor(dataset, processor.url_opener)
|
145 |
+
dataset = Processor(dataset, processor.tar_file_and_group)
|
146 |
+
else:
|
147 |
+
dataset = Processor(dataset, processor.parse_raw)
|
148 |
+
|
149 |
+
speaker_conf = conf.get('speaker_conf', None)
|
150 |
+
if speaker_conf is not None:
|
151 |
+
dataset = Processor(dataset, processor.parse_speaker, **speaker_conf)
|
152 |
+
|
153 |
+
dataset = Processor(dataset, processor.tokenize, tokenizer)
|
154 |
+
filter_conf = conf.get('filter_conf', {})
|
155 |
+
dataset = Processor(dataset, processor.filter, **filter_conf)
|
156 |
+
|
157 |
+
resample_conf = conf.get('resample_conf', {})
|
158 |
+
dataset = Processor(dataset, processor.resample, **resample_conf)
|
159 |
+
|
160 |
+
speed_perturb = conf.get('speed_perturb', False)
|
161 |
+
if speed_perturb:
|
162 |
+
dataset = Processor(dataset, processor.speed_perturb)
|
163 |
+
|
164 |
+
feats_type = conf.get('feats_type', 'fbank')
|
165 |
+
assert feats_type in ['fbank', 'mfcc', 'log_mel_spectrogram']
|
166 |
+
if feats_type == 'fbank':
|
167 |
+
fbank_conf = conf.get('fbank_conf', {})
|
168 |
+
dataset = Processor(dataset, processor.compute_fbank, **fbank_conf)
|
169 |
+
elif feats_type == 'mfcc':
|
170 |
+
mfcc_conf = conf.get('mfcc_conf', {})
|
171 |
+
dataset = Processor(dataset, processor.compute_mfcc, **mfcc_conf)
|
172 |
+
elif feats_type == 'log_mel_spectrogram':
|
173 |
+
log_mel_spectrogram_conf = conf.get('log_mel_spectrogram_conf', {})
|
174 |
+
dataset = Processor(dataset, processor.compute_log_mel_spectrogram,
|
175 |
+
**log_mel_spectrogram_conf)
|
176 |
+
|
177 |
+
spec_aug = conf.get('spec_aug', True)
|
178 |
+
spec_sub = conf.get('spec_sub', False)
|
179 |
+
spec_trim = conf.get('spec_trim', False)
|
180 |
+
if spec_aug:
|
181 |
+
spec_aug_conf = conf.get('spec_aug_conf', {})
|
182 |
+
dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf)
|
183 |
+
if spec_sub:
|
184 |
+
spec_sub_conf = conf.get('spec_sub_conf', {})
|
185 |
+
dataset = Processor(dataset, processor.spec_sub, **spec_sub_conf)
|
186 |
+
if spec_trim:
|
187 |
+
spec_trim_conf = conf.get('spec_trim_conf', {})
|
188 |
+
dataset = Processor(dataset, processor.spec_trim, **spec_trim_conf)
|
189 |
+
|
190 |
+
if shuffle:
|
191 |
+
shuffle_conf = conf.get('shuffle_conf', {})
|
192 |
+
dataset = Processor(dataset, processor.shuffle, **shuffle_conf)
|
193 |
+
|
194 |
+
sort = conf.get('sort', True)
|
195 |
+
if sort:
|
196 |
+
sort_conf = conf.get('sort_conf', {})
|
197 |
+
dataset = Processor(dataset, processor.sort, **sort_conf)
|
198 |
+
|
199 |
+
batch_conf = conf.get('batch_conf', {})
|
200 |
+
dataset = Processor(dataset, processor.batch, **batch_conf)
|
201 |
+
dataset = Processor(dataset, processor.padding)
|
202 |
+
return dataset
|
wenet/dataset/deprecated/processor.py
ADDED
@@ -0,0 +1,1023 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import copy
|
16 |
+
import librosa
|
17 |
+
import logging
|
18 |
+
import json
|
19 |
+
import random
|
20 |
+
import tarfile
|
21 |
+
from subprocess import PIPE, Popen
|
22 |
+
from urllib.parse import urlparse
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import torchaudio
|
26 |
+
import torchaudio.compliance.kaldi as kaldi
|
27 |
+
import torch.nn.functional as F
|
28 |
+
from gxl_ai_utils.utils import utils_file
|
29 |
+
from torch.nn.utils.rnn import pad_sequence
|
30 |
+
from wenet.text.base_tokenizer import BaseTokenizer
|
31 |
+
|
32 |
+
# torchaudio.utils.sox_utils.set_buffer_size(16500)
|
33 |
+
torchaudio.set_audio_backend("soundfile")
|
34 |
+
|
35 |
+
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
|
36 |
+
|
37 |
+
|
38 |
+
def url_opener(data):
|
39 |
+
""" Give url or local file, return file descriptor
|
40 |
+
Inplace operation.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
data(Iterable[str]): url or local file list
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
Iterable[{src, stream}]
|
47 |
+
"""
|
48 |
+
for sample in data:
|
49 |
+
assert 'src' in sample
|
50 |
+
# TODO(Binbin Zhang): support HTTP
|
51 |
+
url = sample['src']
|
52 |
+
try:
|
53 |
+
pr = urlparse(url)
|
54 |
+
# local file
|
55 |
+
if pr.scheme == '' or pr.scheme == 'file':
|
56 |
+
stream = open(url, 'rb')
|
57 |
+
# network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP
|
58 |
+
else:
|
59 |
+
cmd = f'wget -q -O - {url}'
|
60 |
+
process = Popen(cmd, shell=True, stdout=PIPE)
|
61 |
+
sample.update(process=process)
|
62 |
+
stream = process.stdout
|
63 |
+
sample.update(stream=stream)
|
64 |
+
yield sample
|
65 |
+
except Exception as ex:
|
66 |
+
logging.warning('Failed to open {}'.format(url))
|
67 |
+
|
68 |
+
|
69 |
+
def tar_file_and_group(data):
|
70 |
+
""" Expand a stream of open tar files into a stream of tar file contents.
|
71 |
+
And groups the file with same prefix
|
72 |
+
|
73 |
+
Args:
|
74 |
+
data: Iterable[{src, stream}]
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
Iterable[{key, wav, txt, sample_rate}]
|
78 |
+
"""
|
79 |
+
for sample in data:
|
80 |
+
assert 'stream' in sample
|
81 |
+
stream = None
|
82 |
+
try:
|
83 |
+
stream = tarfile.open(fileobj=sample['stream'], mode="r:*")
|
84 |
+
prev_prefix = None
|
85 |
+
example = {}
|
86 |
+
valid = True
|
87 |
+
for tarinfo in stream:
|
88 |
+
name = tarinfo.name
|
89 |
+
pos = name.rfind('.')
|
90 |
+
assert pos > 0
|
91 |
+
prefix, postfix = name[:pos], name[pos + 1:]
|
92 |
+
if prev_prefix is not None and prefix != prev_prefix:
|
93 |
+
example['key'] = prev_prefix
|
94 |
+
if valid:
|
95 |
+
yield example
|
96 |
+
example = {}
|
97 |
+
valid = True
|
98 |
+
with stream.extractfile(tarinfo) as file_obj:
|
99 |
+
try:
|
100 |
+
if postfix == 'txt':
|
101 |
+
example['txt'] = file_obj.read().decode(
|
102 |
+
'utf8').strip()
|
103 |
+
elif postfix in AUDIO_FORMAT_SETS:
|
104 |
+
waveform, sample_rate = torchaudio.load(file_obj)
|
105 |
+
example['wav'] = waveform
|
106 |
+
example['sample_rate'] = sample_rate
|
107 |
+
else:
|
108 |
+
example[postfix] = file_obj.read()
|
109 |
+
except Exception as ex:
|
110 |
+
valid = False
|
111 |
+
logging.warning('error to parse {}'.format(name))
|
112 |
+
prev_prefix = prefix
|
113 |
+
if prev_prefix is not None:
|
114 |
+
example['key'] = prev_prefix
|
115 |
+
yield example
|
116 |
+
except Exception as ex:
|
117 |
+
logging.warning(
|
118 |
+
'In tar_file_and_group: {} when processing {}'.format(
|
119 |
+
ex, sample['src']))
|
120 |
+
finally:
|
121 |
+
if stream is not None:
|
122 |
+
stream.close()
|
123 |
+
if 'process' in sample:
|
124 |
+
sample['process'].communicate()
|
125 |
+
sample['stream'].close()
|
126 |
+
|
127 |
+
|
128 |
+
def tar_file_and_group_full_data(data):
|
129 |
+
""" Expand a stream of open tar files into a stream of tar file contents.
|
130 |
+
And groups the file with same prefix
|
131 |
+
|
132 |
+
Args:
|
133 |
+
data: Iterable[{src, stream}]
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
Iterable[{key, wav, txt, sample_rate}]
|
137 |
+
"""
|
138 |
+
for sample in data:
|
139 |
+
assert 'stream' in sample
|
140 |
+
stream = None
|
141 |
+
try:
|
142 |
+
stream = tarfile.open(fileobj=sample['stream'], mode="r:*")
|
143 |
+
prev_prefix = None
|
144 |
+
example = {}
|
145 |
+
valid = True
|
146 |
+
for tarinfo in stream:
|
147 |
+
name = tarinfo.name
|
148 |
+
pos = name.rfind('.')
|
149 |
+
assert pos > 0
|
150 |
+
prefix, postfix = name[:pos], name[pos + 1:]
|
151 |
+
if prev_prefix is not None and prefix != prev_prefix:
|
152 |
+
example['key'] = prev_prefix
|
153 |
+
if valid:
|
154 |
+
# assert 'txt' in example
|
155 |
+
if 'txt' not in example:
|
156 |
+
example['txt'] = ''
|
157 |
+
yield example
|
158 |
+
example = {}
|
159 |
+
valid = True
|
160 |
+
with stream.extractfile(tarinfo) as file_obj:
|
161 |
+
try:
|
162 |
+
if postfix == 'txt':
|
163 |
+
example['txt'] = file_obj.read().decode(
|
164 |
+
'utf8').strip()
|
165 |
+
elif postfix == 'lang':
|
166 |
+
example['lang'] = file_obj.read().decode(
|
167 |
+
'utf8').strip()
|
168 |
+
elif postfix == 'speaker':
|
169 |
+
try:
|
170 |
+
example['speaker'] = file_obj.read().decode(
|
171 |
+
'utf8').strip()
|
172 |
+
except Exception as ex:
|
173 |
+
example['speaker'] = "none"
|
174 |
+
elif postfix == 'emotion':
|
175 |
+
example['emotion'] = file_obj.read().decode(
|
176 |
+
'utf8').strip()
|
177 |
+
elif postfix == 'gender':
|
178 |
+
example['gender'] = file_obj.read().decode(
|
179 |
+
'utf8').strip()
|
180 |
+
elif postfix == 'task':
|
181 |
+
example['task'] = file_obj.read().decode(
|
182 |
+
'utf8').strip()
|
183 |
+
elif postfix == 'speech_token':
|
184 |
+
example['speech_token'] = file_obj.read()
|
185 |
+
elif postfix == 'duration':
|
186 |
+
duration_str = file_obj.read().decode(
|
187 |
+
'utf8').strip()
|
188 |
+
try:
|
189 |
+
duration_float = float(duration_str)
|
190 |
+
example['duration'] = duration_float
|
191 |
+
except Exception as ex:
|
192 |
+
logging.warning(f'error to parse duration {duration_str}')
|
193 |
+
example['duration'] = 0
|
194 |
+
|
195 |
+
elif postfix in AUDIO_FORMAT_SETS:
|
196 |
+
waveform, sample_rate = torchaudio.load(file_obj)
|
197 |
+
# 检查音频的维度
|
198 |
+
num_channels = waveform.shape[0]
|
199 |
+
# 如果音频是多通道的,则进行通道平均
|
200 |
+
if num_channels > 1:
|
201 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
202 |
+
example['wav'] = waveform
|
203 |
+
example['sample_rate'] = sample_rate
|
204 |
+
else:
|
205 |
+
example[postfix] = file_obj.read()
|
206 |
+
except Exception as ex:
|
207 |
+
valid = False
|
208 |
+
# logging.warning('error to parse {}'.format(name))
|
209 |
+
prev_prefix = prefix
|
210 |
+
if prev_prefix is not None:
|
211 |
+
example['key'] = prev_prefix
|
212 |
+
if 'txt' in example:
|
213 |
+
yield example
|
214 |
+
|
215 |
+
except Exception as ex:
|
216 |
+
logging.warning(
|
217 |
+
'In tar_file_and_group: {} when processing {}'.format(
|
218 |
+
ex, sample['src']))
|
219 |
+
finally:
|
220 |
+
if stream is not None:
|
221 |
+
stream.close()
|
222 |
+
if 'process' in sample:
|
223 |
+
sample['process'].communicate()
|
224 |
+
sample['stream'].close()
|
225 |
+
|
226 |
+
|
227 |
+
def parse_raw(data):
|
228 |
+
""" Parse key/wav/txt from json line
|
229 |
+
|
230 |
+
Args:
|
231 |
+
data: Iterable[str], str is a json line has key/wav/txt
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
Iterable[{key, wav, txt, sample_rate}]
|
235 |
+
"""
|
236 |
+
for sample in data:
|
237 |
+
assert 'src' in sample
|
238 |
+
json_line = sample['src']
|
239 |
+
obj = json.loads(json_line)
|
240 |
+
assert 'key' in obj
|
241 |
+
assert 'wav' in obj
|
242 |
+
assert 'txt' in obj
|
243 |
+
key = obj['key']
|
244 |
+
wav_file = obj['wav']
|
245 |
+
txt = obj['txt']
|
246 |
+
try:
|
247 |
+
if 'start' in obj:
|
248 |
+
assert 'end' in obj
|
249 |
+
sample_rate = torchaudio.info(wav_file).sample_rate
|
250 |
+
start_frame = int(obj['start'] * sample_rate)
|
251 |
+
end_frame = int(obj['end'] * sample_rate)
|
252 |
+
waveform, _ = torchaudio.load(filepath=wav_file,
|
253 |
+
num_frames=end_frame -
|
254 |
+
start_frame,
|
255 |
+
frame_offset=start_frame)
|
256 |
+
else:
|
257 |
+
waveform, sample_rate = torchaudio.load(wav_file)
|
258 |
+
# 检查音频的维度
|
259 |
+
num_channels = waveform.shape[0]
|
260 |
+
# 如果音频是多通道的,则进行通道平均
|
261 |
+
if num_channels > 1:
|
262 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
263 |
+
example = copy.deepcopy(obj) # copy and keep all the fields
|
264 |
+
example['wav'] = waveform # overwrite wav
|
265 |
+
example['sample_rate'] = sample_rate
|
266 |
+
yield example
|
267 |
+
except Exception as ex:
|
268 |
+
logging.warning('Failed to read {}'.format(wav_file))
|
269 |
+
|
270 |
+
|
271 |
+
def parse_speaker(data, speaker_table_path):
|
272 |
+
speaker_dict = {}
|
273 |
+
with open(speaker_table_path, 'r', encoding='utf8') as fin:
|
274 |
+
for line in fin:
|
275 |
+
arr = line.strip().split()
|
276 |
+
speaker_dict[arr[0]] = int(arr[1])
|
277 |
+
for sample in data:
|
278 |
+
assert 'speaker' in sample
|
279 |
+
speaker = sample['speaker']
|
280 |
+
sample['speaker'] = speaker_dict.get(speaker, 0)
|
281 |
+
yield sample
|
282 |
+
|
283 |
+
|
284 |
+
def filter(data,
|
285 |
+
max_length=1200,
|
286 |
+
min_length=10,
|
287 |
+
token_max_length=250,
|
288 |
+
token_min_length=1,
|
289 |
+
min_output_input_ratio=0.00005,
|
290 |
+
max_output_input_ratio=1,
|
291 |
+
filter_no_extra_info: bool = False,
|
292 |
+
max_seq_len=1000):
|
293 |
+
""" Filter sample according to feature and label length
|
294 |
+
Inplace operation.
|
295 |
+
|
296 |
+
Args::
|
297 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
298 |
+
max_length: drop utterance which is greater than max_length(10ms)
|
299 |
+
min_length: drop utterance which is less than min_length(10ms)
|
300 |
+
token_max_length: drop utterance which is greater than
|
301 |
+
token_max_length, especially when use char unit for
|
302 |
+
english modeling
|
303 |
+
token_min_length: drop utterance which is
|
304 |
+
less than token_max_length
|
305 |
+
min_output_input_ratio: minimal ration of
|
306 |
+
token_length / feats_length(10ms)
|
307 |
+
max_output_input_ratio: maximum ration of
|
308 |
+
token_length / feats_length(10ms)
|
309 |
+
|
310 |
+
Returns:
|
311 |
+
Iterable[{key, wav, label, sample_rate}]
|
312 |
+
"""
|
313 |
+
for sample in data:
|
314 |
+
try:
|
315 |
+
assert 'sample_rate' in sample
|
316 |
+
assert 'wav' in sample
|
317 |
+
assert 'label' in sample
|
318 |
+
except:
|
319 |
+
continue
|
320 |
+
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
321 |
+
num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100
|
322 |
+
|
323 |
+
# filter for shard_in_common
|
324 |
+
if filter_no_extra_info:
|
325 |
+
if 'lang' not in sample:
|
326 |
+
continue
|
327 |
+
if 'task' not in sample:
|
328 |
+
continue
|
329 |
+
|
330 |
+
if num_frames < min_length:
|
331 |
+
continue
|
332 |
+
|
333 |
+
# if "output_type" in sample and sample["output_type"] == "speech2text_token":
|
334 |
+
# max_length = int(max_length / 2)
|
335 |
+
# if "output_type" in sample and sample["output_type"] == "text2token":
|
336 |
+
# max_length = int(max_length / 1.5)
|
337 |
+
if num_frames > max_length:
|
338 |
+
# continue
|
339 |
+
if 'task' in sample and sample['task'] == '<CAPTION>':
|
340 |
+
# utils_file.logging_limit_print('进行了随机剪裁')
|
341 |
+
# 随机选择一个起始点进行裁剪
|
342 |
+
start_frame = random.randint(0, int(num_frames - max_length))
|
343 |
+
end_frame = start_frame + max_length
|
344 |
+
sample['wav'] = sample['wav'][:, int(start_frame / 100 * sample['sample_rate']): int(
|
345 |
+
end_frame / 100 * sample['sample_rate'])]
|
346 |
+
# print('sample[', sample['wav'].shape)
|
347 |
+
else:
|
348 |
+
continue
|
349 |
+
if len(sample['label']) < token_min_length:
|
350 |
+
continue
|
351 |
+
if len(sample['label']) > token_max_length:
|
352 |
+
continue
|
353 |
+
# if num_frames != 0:
|
354 |
+
# if len(sample['label']) / num_frames < min_output_input_ratio:
|
355 |
+
# continue
|
356 |
+
# if len(sample['label']) / num_frames > max_output_input_ratio:
|
357 |
+
# continue
|
358 |
+
|
359 |
+
if sample["output_type"] == "speech2text_token":
|
360 |
+
seq_len = len(sample['prompt']) + num_frames / 8 + len(sample['label']) + len(sample['speech_token'])
|
361 |
+
elif sample["output_type"] == "text2token":
|
362 |
+
seq_len = len(sample['prompt']) + len(sample['label']) + len(sample['speech_token'])
|
363 |
+
else:
|
364 |
+
seq_len = len(sample['prompt']) + num_frames / 8 + len(sample['label'])
|
365 |
+
utils_file.logging_limit_print(f'seqlen: {seq_len}, output_type:{sample["output_type"]},len(sample["prompt"]):{len(sample["prompt"])},num_frames / 8:{num_frames / 8},len(sample["label"]):{len(sample["label"])},len(sample["speech_token"]):{len(sample["speech_token"])} ')
|
366 |
+
if max_seq_len > 0 and max_seq_len < seq_len:
|
367 |
+
utils_file.logging_limit_print(f"seqlen: {seq_len} 超过了最大长度:{max_seq_len},contiune")
|
368 |
+
continue
|
369 |
+
yield sample
|
370 |
+
|
371 |
+
|
372 |
+
def resample(data, resample_rate=16000):
|
373 |
+
""" Resample data.
|
374 |
+
Inplace operation.
|
375 |
+
|
376 |
+
Args:
|
377 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
378 |
+
resample_rate: target resample rate
|
379 |
+
|
380 |
+
Returns:
|
381 |
+
Iterable[{key, wav, label, sample_rate}]
|
382 |
+
"""
|
383 |
+
for sample in data:
|
384 |
+
assert 'sample_rate' in sample
|
385 |
+
assert 'wav' in sample
|
386 |
+
sample_rate = sample['sample_rate']
|
387 |
+
waveform = sample['wav']
|
388 |
+
if sample_rate != resample_rate:
|
389 |
+
sample['sample_rate'] = resample_rate
|
390 |
+
sample['wav'] = torchaudio.transforms.Resample(
|
391 |
+
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
392 |
+
yield sample
|
393 |
+
|
394 |
+
|
395 |
+
def speed_perturb(data, speeds=None):
|
396 |
+
""" Apply speed perturb to the data.
|
397 |
+
Inplace operation.
|
398 |
+
|
399 |
+
Args:
|
400 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
401 |
+
speeds(List[float]): optional speed
|
402 |
+
|
403 |
+
Returns:
|
404 |
+
Iterable[{key, wav, label, sample_rate}]
|
405 |
+
"""
|
406 |
+
if speeds is None:
|
407 |
+
speeds = [0.9, 1.0, 1.1]
|
408 |
+
for sample in data:
|
409 |
+
assert 'sample_rate' in sample
|
410 |
+
assert 'wav' in sample
|
411 |
+
sample_rate = sample['sample_rate']
|
412 |
+
waveform = sample['wav']
|
413 |
+
speed = random.choice(speeds)
|
414 |
+
if speed != 1.0:
|
415 |
+
wav, _ = torchaudio.sox_effects.apply_effects_tensor(
|
416 |
+
waveform, sample_rate,
|
417 |
+
[['speed', str(speed)], ['rate', str(sample_rate)]])
|
418 |
+
sample['wav'] = wav
|
419 |
+
|
420 |
+
yield sample
|
421 |
+
|
422 |
+
|
423 |
+
def compute_fbank(data,
|
424 |
+
num_mel_bins=23,
|
425 |
+
frame_length=25,
|
426 |
+
frame_shift=10,
|
427 |
+
dither=0.0):
|
428 |
+
""" Extract fbank
|
429 |
+
|
430 |
+
Args:
|
431 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
432 |
+
|
433 |
+
Returns:
|
434 |
+
Iterable[{key, feat, label}]
|
435 |
+
"""
|
436 |
+
for sample in data:
|
437 |
+
assert 'sample_rate' in sample
|
438 |
+
assert 'wav' in sample
|
439 |
+
assert 'key' in sample
|
440 |
+
assert 'label' in sample
|
441 |
+
sample_rate = sample['sample_rate']
|
442 |
+
waveform = sample['wav']
|
443 |
+
waveform = waveform * (1 << 15)
|
444 |
+
# Only keep key, feat, label
|
445 |
+
mat = kaldi.fbank(waveform,
|
446 |
+
num_mel_bins=num_mel_bins,
|
447 |
+
frame_length=frame_length,
|
448 |
+
frame_shift=frame_shift,
|
449 |
+
dither=dither,
|
450 |
+
energy_floor=0.0,
|
451 |
+
sample_frequency=sample_rate)
|
452 |
+
sample['feat'] = mat
|
453 |
+
yield sample
|
454 |
+
|
455 |
+
|
456 |
+
def compute_mfcc(data,
|
457 |
+
num_mel_bins=23,
|
458 |
+
frame_length=25,
|
459 |
+
frame_shift=10,
|
460 |
+
dither=0.0,
|
461 |
+
num_ceps=40,
|
462 |
+
high_freq=0.0,
|
463 |
+
low_freq=20.0):
|
464 |
+
""" Extract mfcc
|
465 |
+
|
466 |
+
Args:
|
467 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
468 |
+
|
469 |
+
Returns:
|
470 |
+
Iterable[{key, feat, label}]
|
471 |
+
"""
|
472 |
+
for sample in data:
|
473 |
+
assert 'sample_rate' in sample
|
474 |
+
assert 'wav' in sample
|
475 |
+
assert 'key' in sample
|
476 |
+
assert 'label' in sample
|
477 |
+
sample_rate = sample['sample_rate']
|
478 |
+
waveform = sample['wav']
|
479 |
+
waveform = waveform * (1 << 15)
|
480 |
+
# Only keep key, feat, label
|
481 |
+
mat = kaldi.mfcc(waveform,
|
482 |
+
num_mel_bins=num_mel_bins,
|
483 |
+
frame_length=frame_length,
|
484 |
+
frame_shift=frame_shift,
|
485 |
+
dither=dither,
|
486 |
+
num_ceps=num_ceps,
|
487 |
+
high_freq=high_freq,
|
488 |
+
low_freq=low_freq,
|
489 |
+
sample_frequency=sample_rate)
|
490 |
+
sample['feat'] = mat
|
491 |
+
yield sample
|
492 |
+
|
493 |
+
|
494 |
+
def compute_log_mel_spectrogram(data,
|
495 |
+
n_fft=400,
|
496 |
+
hop_length=160,
|
497 |
+
num_mel_bins=80,
|
498 |
+
padding=0):
|
499 |
+
""" Extract log mel spectrogram, modified from openai-whisper, see:
|
500 |
+
- https://github.com/openai/whisper/blob/main/whisper/audio.py
|
501 |
+
- https://github.com/wenet-e2e/wenet/pull/2141#issuecomment-1811765040
|
502 |
+
|
503 |
+
Args:
|
504 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
505 |
+
|
506 |
+
Returns:
|
507 |
+
Iterable[{key, feat, label}]
|
508 |
+
"""
|
509 |
+
for sample in data:
|
510 |
+
assert 'sample_rate' in sample
|
511 |
+
assert 'wav' in sample
|
512 |
+
assert 'key' in sample
|
513 |
+
assert 'label' in sample
|
514 |
+
sample_rate = sample['sample_rate']
|
515 |
+
waveform = sample['wav'].squeeze(0) # (channel=1, sample) -> (sample,)
|
516 |
+
# print(f'wavform shape: {waveform.shape}')
|
517 |
+
if padding > 0:
|
518 |
+
waveform = F.pad(waveform, (0, padding))
|
519 |
+
window = torch.hann_window(n_fft)
|
520 |
+
stft = torch.stft(waveform,
|
521 |
+
n_fft,
|
522 |
+
hop_length,
|
523 |
+
window=window,
|
524 |
+
return_complex=True)
|
525 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
526 |
+
|
527 |
+
filters = torch.from_numpy(
|
528 |
+
librosa.filters.mel(sr=sample_rate,
|
529 |
+
n_fft=n_fft,
|
530 |
+
n_mels=num_mel_bins))
|
531 |
+
mel_spec = filters @ magnitudes
|
532 |
+
|
533 |
+
# NOTE(xcsong): https://github.com/openai/whisper/discussions/269
|
534 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
535 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
536 |
+
log_spec = (log_spec + 4.0) / 4.0
|
537 |
+
sample['feat'] = log_spec.transpose(0, 1)
|
538 |
+
yield sample
|
539 |
+
|
540 |
+
|
541 |
+
import re
|
542 |
+
|
543 |
+
|
544 |
+
def process_text(text):
|
545 |
+
# 1. 删除汉字左右两侧的空格
|
546 |
+
text = re.sub(r'\s*([\u4e00-\u9fff])\s*', r'\1', text)
|
547 |
+
# 2. 将英文转成小写
|
548 |
+
text = text.lower()
|
549 |
+
# 3. 删除 < 和 > 符号两侧的空格
|
550 |
+
text = re.sub(r'\s*<\s*', '<', text)
|
551 |
+
text = re.sub(r'\s*>\s*', '>', text)
|
552 |
+
return text
|
553 |
+
|
554 |
+
|
555 |
+
global_style_dict = {
|
556 |
+
"朗读": "新闻科普",
|
557 |
+
"科普百科": "新闻科普",
|
558 |
+
"悬疑恐怖": "恐怖故事",
|
559 |
+
"童话故事": "童话故事",
|
560 |
+
"客服": "客服",
|
561 |
+
"诗歌": "诗歌散文",
|
562 |
+
"散文": "诗歌散文",
|
563 |
+
"武侠评书": "有声书",
|
564 |
+
"小说": "有声书",
|
565 |
+
"历史": "有声书",
|
566 |
+
"科幻": "有声书",
|
567 |
+
"对话": "日常口语",
|
568 |
+
"口语": "日常口语",
|
569 |
+
"幽默": "其他",
|
570 |
+
"其他": "其他",
|
571 |
+
}
|
572 |
+
|
573 |
+
|
574 |
+
def replace_keys_in_brackets(input_str, key_value_dict):
|
575 |
+
for key, value in key_value_dict.items():
|
576 |
+
# 构造匹配 <key> 形式的正则表达式模式
|
577 |
+
pattern = re.compile(r'<{}>'.format(key))
|
578 |
+
input_str = pattern.sub(f"<{value}>", input_str)
|
579 |
+
return input_str
|
580 |
+
|
581 |
+
|
582 |
+
def tokenize(data, tokenizer: BaseTokenizer, global_prompt_dict=None):
|
583 |
+
""" Decode text to chars or BPE
|
584 |
+
Inplace operation
|
585 |
+
|
586 |
+
Args:
|
587 |
+
data: Iterable[{key, wav, txt, sample_rate}]
|
588 |
+
|
589 |
+
Returns:
|
590 |
+
Iterable[{key, wav, txt, tokens, label, sample_rate}]
|
591 |
+
"""
|
592 |
+
for sample in data:
|
593 |
+
try:
|
594 |
+
assert 'txt' in sample
|
595 |
+
except:
|
596 |
+
print(f'tokenize: {sample}')
|
597 |
+
exit()
|
598 |
+
if 'task' in sample:
|
599 |
+
task_name = sample['task']
|
600 |
+
# if "<AGE>" in task_name:
|
601 |
+
# txt = sample['txt'].replace("<YOUTH>", "<ADULT>").replace("<MIDDLE_AGE>", "<ADULT>").replace("<MIDDLE>", "<ADULT>")
|
602 |
+
if "<STYLE>" in sample['task']:
|
603 |
+
txt = replace_keys_in_brackets(sample['txt'], global_style_dict)
|
604 |
+
else:
|
605 |
+
txt = sample['txt']
|
606 |
+
else:
|
607 |
+
txt = sample['txt']
|
608 |
+
|
609 |
+
tokens, label = tokenizer.tokenize(process_text(txt))
|
610 |
+
sample['tokens'] = tokens # token是字符, label是数字
|
611 |
+
sample['label'] = label + [tokenizer.eod_id]
|
612 |
+
if 'task' in sample:
|
613 |
+
task_name = sample['task']
|
614 |
+
try:
|
615 |
+
random_index = random.randint(0, len(global_prompt_dict[task_name]) - 1)
|
616 |
+
prompt = global_prompt_dict[task_name][random_index]
|
617 |
+
sample['prompt'] = tokenizer.tokenize(prompt)[1] # labels
|
618 |
+
except:
|
619 |
+
pass
|
620 |
+
else:
|
621 |
+
task_name = '<TRANSCRIBE>'
|
622 |
+
try:
|
623 |
+
random_index = random.randint(0, len(global_prompt_dict[task_name]) - 1)
|
624 |
+
prompt = global_prompt_dict[task_name][random_index]
|
625 |
+
sample['prompt'] = tokenizer.tokenize(prompt)[1] # labels
|
626 |
+
except:
|
627 |
+
pass
|
628 |
+
|
629 |
+
if 'speech_token' in sample:
|
630 |
+
old_task_name = sample['task']
|
631 |
+
if old_task_name == "<TRANSCRIBE>":
|
632 |
+
task_name = '<TEXT2SPEECH_TOKEN>'
|
633 |
+
sample['output_type'] = 'text2token'
|
634 |
+
elif old_task_name == "<S2TCHAT>":
|
635 |
+
task_name = '<SPEECH2TEXT_SPEECH_TOKEN>'
|
636 |
+
sample['output_type'] = 'speech2text_token'
|
637 |
+
else:
|
638 |
+
task_name = old_task_name
|
639 |
+
try:
|
640 |
+
random_index = random.randint(0, len(global_prompt_dict[task_name]) - 1)
|
641 |
+
prompt = global_prompt_dict[task_name][random_index]
|
642 |
+
sample['prompt'] = tokenizer.tokenize(prompt)[1] # labels
|
643 |
+
except:
|
644 |
+
pass
|
645 |
+
# 报错修改 from sywang ,只有推理的时候才会需要(raw格式),tar格式会自动转int list
|
646 |
+
# try:
|
647 |
+
# utils_file.logging_limit_print("type of sample['speech_token']: ", type(sample['speech_token']))
|
648 |
+
# speech_tokens = ast.literal_eval(sample['speech_token']) # 解析字符串为列表
|
649 |
+
# except (ValueError, SyntaxError) as e:
|
650 |
+
# print(f"解析错误: {e}在{speech_tokens}")
|
651 |
+
# speech_tokens = []
|
652 |
+
# speech_token = [int(x) for x in speech_tokens]
|
653 |
+
speech_token = [int(x) for x in sample['speech_token']]
|
654 |
+
sample['speech_token'] = speech_token + [4096]
|
655 |
+
else:
|
656 |
+
sample['output_type'] = 'text'
|
657 |
+
sample['speech_token'] = [4096]
|
658 |
+
yield sample
|
659 |
+
|
660 |
+
|
661 |
+
def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80):
|
662 |
+
""" Do spec augmentation
|
663 |
+
Inplace operation
|
664 |
+
|
665 |
+
Args:
|
666 |
+
data: Iterable[{key, feat, label}]
|
667 |
+
num_t_mask: number of time mask to apply
|
668 |
+
num_f_mask: number of freq mask to apply
|
669 |
+
max_t: max width of time mask
|
670 |
+
max_f: max width of freq mask
|
671 |
+
max_w: max width of time warp
|
672 |
+
|
673 |
+
Returns
|
674 |
+
Iterable[{key, feat, label}]
|
675 |
+
"""
|
676 |
+
for sample in data:
|
677 |
+
assert 'feat' in sample
|
678 |
+
x = sample['feat']
|
679 |
+
assert isinstance(x, torch.Tensor)
|
680 |
+
y = x.clone().detach()
|
681 |
+
max_frames = y.size(0)
|
682 |
+
max_freq = y.size(1)
|
683 |
+
# time mask
|
684 |
+
for i in range(num_t_mask):
|
685 |
+
start = random.randint(0, max_frames - 1)
|
686 |
+
length = random.randint(1, max_t)
|
687 |
+
end = min(max_frames, start + length)
|
688 |
+
y[start:end, :] = 0
|
689 |
+
# freq mask
|
690 |
+
for i in range(num_f_mask):
|
691 |
+
start = random.randint(0, max_freq - 1)
|
692 |
+
length = random.randint(1, max_f)
|
693 |
+
end = min(max_freq, start + length)
|
694 |
+
y[:, start:end] = 0
|
695 |
+
sample['feat'] = y
|
696 |
+
yield sample
|
697 |
+
|
698 |
+
|
699 |
+
def spec_sub(data, max_t=20, num_t_sub=3):
|
700 |
+
""" Do spec substitute
|
701 |
+
Inplace operation
|
702 |
+
ref: U2++, section 3.2.3 [https://arxiv.org/abs/2106.05642]
|
703 |
+
|
704 |
+
Args:
|
705 |
+
data: Iterable[{key, feat, label}]
|
706 |
+
max_t: max width of time substitute
|
707 |
+
num_t_sub: number of time substitute to apply
|
708 |
+
|
709 |
+
Returns
|
710 |
+
Iterable[{key, feat, label}]
|
711 |
+
"""
|
712 |
+
for sample in data:
|
713 |
+
assert 'feat' in sample
|
714 |
+
x = sample['feat']
|
715 |
+
assert isinstance(x, torch.Tensor)
|
716 |
+
y = x.clone().detach()
|
717 |
+
max_frames = y.size(0)
|
718 |
+
for i in range(num_t_sub):
|
719 |
+
start = random.randint(0, max_frames - 1)
|
720 |
+
length = random.randint(1, max_t)
|
721 |
+
end = min(max_frames, start + length)
|
722 |
+
# only substitute the earlier time chosen randomly for current time
|
723 |
+
pos = random.randint(0, start)
|
724 |
+
y[start:end, :] = x[start - pos:end - pos, :]
|
725 |
+
sample['feat'] = y
|
726 |
+
yield sample
|
727 |
+
|
728 |
+
|
729 |
+
def spec_trim(data, max_t=20):
|
730 |
+
""" Trim tailing frames. Inplace operation.
|
731 |
+
ref: TrimTail [https://arxiv.org/abs/2211.00522]
|
732 |
+
|
733 |
+
Args:
|
734 |
+
data: Iterable[{key, feat, label}]
|
735 |
+
max_t: max width of length trimming
|
736 |
+
|
737 |
+
Returns
|
738 |
+
Iterable[{key, feat, label}]
|
739 |
+
"""
|
740 |
+
for sample in data:
|
741 |
+
assert 'feat' in sample
|
742 |
+
x = sample['feat']
|
743 |
+
assert isinstance(x, torch.Tensor)
|
744 |
+
max_frames = x.size(0)
|
745 |
+
length = random.randint(1, max_t)
|
746 |
+
if length < max_frames / 2:
|
747 |
+
y = x.clone().detach()[:max_frames - length]
|
748 |
+
sample['feat'] = y
|
749 |
+
yield sample
|
750 |
+
|
751 |
+
|
752 |
+
def shuffle(data, shuffle_size=10000):
|
753 |
+
""" Local shuffle the data
|
754 |
+
|
755 |
+
Args:
|
756 |
+
data: Iterable[{key, feat, label}]
|
757 |
+
shuffle_size: buffer size for shuffle
|
758 |
+
|
759 |
+
Returns:
|
760 |
+
Iterable[{key, feat, label}]
|
761 |
+
"""
|
762 |
+
buf = []
|
763 |
+
for sample in data:
|
764 |
+
buf.append(sample)
|
765 |
+
if len(buf) >= shuffle_size:
|
766 |
+
random.shuffle(buf)
|
767 |
+
for x in buf:
|
768 |
+
yield x
|
769 |
+
buf = []
|
770 |
+
# The sample left over
|
771 |
+
random.shuffle(buf)
|
772 |
+
for x in buf:
|
773 |
+
yield x
|
774 |
+
|
775 |
+
|
776 |
+
def sort(data, sort_size=500):
|
777 |
+
""" Sort the data by feature length.
|
778 |
+
Sort is used after shuffle and before batch, so we can group
|
779 |
+
utts with similar lengths into a batch, and `sort_size` should
|
780 |
+
be less than `shuffle_size`
|
781 |
+
|
782 |
+
Args:
|
783 |
+
data: Iterable[{key, feat, label}]
|
784 |
+
sort_size: buffer size for sort
|
785 |
+
|
786 |
+
Returns:
|
787 |
+
Iterable[{key, feat, label}]
|
788 |
+
"""
|
789 |
+
|
790 |
+
buf = []
|
791 |
+
for sample in data:
|
792 |
+
buf.append(sample)
|
793 |
+
if len(buf) >= sort_size:
|
794 |
+
buf.sort(key=lambda x: x['feat'].size(0))
|
795 |
+
for x in buf:
|
796 |
+
yield x
|
797 |
+
buf = []
|
798 |
+
# The sample left over
|
799 |
+
buf.sort(key=lambda x: x['feat'].size(0))
|
800 |
+
for x in buf:
|
801 |
+
yield x
|
802 |
+
|
803 |
+
|
804 |
+
def static_batch(data, batch_size=16):
|
805 |
+
""" Static batch the data by `batch_size`
|
806 |
+
|
807 |
+
Args:
|
808 |
+
data: Iterable[{key, feat, label}]
|
809 |
+
batch_size: batch size
|
810 |
+
|
811 |
+
Returns:
|
812 |
+
Iterable[List[{key, feat, label}]]
|
813 |
+
"""
|
814 |
+
buf = []
|
815 |
+
for sample in data:
|
816 |
+
buf.append(sample)
|
817 |
+
if len(buf) >= batch_size:
|
818 |
+
yield buf
|
819 |
+
buf = []
|
820 |
+
if len(buf) > 0:
|
821 |
+
yield buf
|
822 |
+
|
823 |
+
|
824 |
+
def dynamic_batch(data, max_frames_in_batch=12000, max_seq_in_batch=10000000):
|
825 |
+
""" Dynamic batch the data until the total frames in batch
|
826 |
+
reach `max_frames_in_batch`
|
827 |
+
|
828 |
+
Args:
|
829 |
+
data: Iterable[{key, feat, label}]
|
830 |
+
max_frames_in_batch: max_frames in one batch
|
831 |
+
|
832 |
+
Returns:
|
833 |
+
Iterable[List[{key, feat, label}]]
|
834 |
+
"""
|
835 |
+
buf = []
|
836 |
+
longest_frames = 0
|
837 |
+
longest_seq = 0
|
838 |
+
max_frames_in_batch = max_frames_in_batch
|
839 |
+
|
840 |
+
buf_speech_token = []
|
841 |
+
longest_frames_token = 0
|
842 |
+
longest_seq_token = 0
|
843 |
+
max_frames_in_batch_token = int(max_frames_in_batch)
|
844 |
+
|
845 |
+
buf_speech_token_with_text = []
|
846 |
+
longest_frames_token_with_text = 0
|
847 |
+
longest_seq_token_with_text = 0
|
848 |
+
max_frames_in_batch_token_with_text = int(max_frames_in_batch / 2.5)
|
849 |
+
|
850 |
+
for sample in data:
|
851 |
+
assert 'feat' in sample
|
852 |
+
assert isinstance(sample['feat'], torch.Tensor)
|
853 |
+
new_sample_frames = sample['feat'].size(0)
|
854 |
+
if "output_type" in sample and sample["output_type"] == "speech2text_token":
|
855 |
+
new_seq = sample['feat'].size(0) / 8 + len(sample['label']) + len(sample.get('prompt', [])) + len(
|
856 |
+
sample.get('speech_token', []))
|
857 |
+
longest_seq_token = max(longest_seq_token, new_seq)
|
858 |
+
utils_file.logging_limit_print(
|
859 |
+
f'batchf fuc,当前条目new_seq为: {new_seq},longest_seq_token为: {longest_seq_token}')
|
860 |
+
longest_frames_token = max(longest_frames_token, new_sample_frames)
|
861 |
+
frames_after_padding_token = longest_frames_token * (len(buf_speech_token)+1)
|
862 |
+
seq_after_padding_token = longest_seq_token * (len(buf_speech_token)+1)
|
863 |
+
utils_file.logging_limit_print(
|
864 |
+
f'batchf fuc,当前条目new_seq为: {new_seq},longest_seq_token为: {longest_seq_token},seq_after_padding_token: {seq_after_padding_token}')
|
865 |
+
utils_file.logging_limit_print(
|
866 |
+
f'batchf fuc,当前条目 new_sample_frames 为: {new_sample_frames},longest_frames_token: {longest_frames_token},frames_after_padding_token: {frames_after_padding_token}')
|
867 |
+
if frames_after_padding_token > max_frames_in_batch_token or seq_after_padding_token > max_seq_in_batch:
|
868 |
+
yield buf_speech_token
|
869 |
+
buf_speech_token = [sample]
|
870 |
+
longest_frames_token = new_sample_frames
|
871 |
+
longest_seq_token = new_seq
|
872 |
+
else:
|
873 |
+
buf_speech_token.append(sample)
|
874 |
+
elif "output_type" in sample and sample["output_type"] == "text2token":
|
875 |
+
new_seq = len(sample['label']) + len(sample.get('prompt', [])) + len(
|
876 |
+
sample.get('speech_token', []))
|
877 |
+
longest_seq_token_with_text = max(longest_seq_token_with_text, new_seq)
|
878 |
+
longest_frames_token_with_text = max(longest_frames_token_with_text, new_sample_frames)
|
879 |
+
frames_after_padding_token_with_text = longest_frames_token_with_text * (len(buf_speech_token_with_text)+1)
|
880 |
+
seq_after_padding_token_with_text = longest_seq_token_with_text * (len(buf_speech_token_with_text)+1)
|
881 |
+
if frames_after_padding_token_with_text > max_frames_in_batch_token_with_text or seq_after_padding_token_with_text > max_seq_in_batch:
|
882 |
+
yield buf_speech_token_with_text
|
883 |
+
buf_speech_token_with_text = [sample]
|
884 |
+
longest_frames_token_with_text = new_sample_frames
|
885 |
+
longest_seq_token_with_text = new_seq
|
886 |
+
else:
|
887 |
+
buf_speech_token_with_text.append(sample)
|
888 |
+
else:
|
889 |
+
new_seq = sample['feat'].size(0) / 8 + len(sample['label']) + len(sample.get('prompt', []))
|
890 |
+
longest_seq = max(longest_seq, new_seq)
|
891 |
+
longest_frames = max(longest_frames, new_sample_frames)
|
892 |
+
frames_after_padding = longest_frames * (len(buf)+1)
|
893 |
+
seq_after_padding = longest_seq * (len(buf)+1)
|
894 |
+
if frames_after_padding > max_frames_in_batch or seq_after_padding > max_seq_in_batch:
|
895 |
+
yield buf
|
896 |
+
buf = [sample]
|
897 |
+
longest_frames = new_sample_frames
|
898 |
+
longest_seq = new_seq
|
899 |
+
else:
|
900 |
+
buf.append(sample)
|
901 |
+
if len(buf) > 0:
|
902 |
+
yield buf
|
903 |
+
|
904 |
+
|
905 |
+
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, max_seq_in_batch=10000000):
|
906 |
+
""" Wrapper for static/dynamic batch
|
907 |
+
"""
|
908 |
+
if batch_type == 'static':
|
909 |
+
return static_batch(data, batch_size)
|
910 |
+
elif batch_type == 'dynamic':
|
911 |
+
return dynamic_batch(data, max_frames_in_batch, max_seq_in_batch=max_seq_in_batch)
|
912 |
+
else:
|
913 |
+
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
914 |
+
|
915 |
+
|
916 |
+
def padding(data):
|
917 |
+
""" Padding the data into training data
|
918 |
+
|
919 |
+
Args:
|
920 |
+
data: Iterable[List[{key, feat, label}]]
|
921 |
+
|
922 |
+
Returns:
|
923 |
+
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
|
924 |
+
"""
|
925 |
+
for sample in data:
|
926 |
+
assert isinstance(sample, list)
|
927 |
+
feats_length = torch.tensor([x['feat'].size(0) for x in sample],
|
928 |
+
dtype=torch.int32)
|
929 |
+
order = torch.argsort(feats_length, descending=True)
|
930 |
+
feats_lengths = torch.tensor(
|
931 |
+
[sample[i]['feat'].size(0) for i in order], dtype=torch.int32)
|
932 |
+
sorted_feats = [sample[i]['feat'] for i in order]
|
933 |
+
sorted_keys = [sample[i]['key'] for i in order]
|
934 |
+
sorted_labels = [
|
935 |
+
torch.tensor(sample[i]['label'], dtype=torch.int64) for i in order
|
936 |
+
]
|
937 |
+
sorted_speech_tokens = [
|
938 |
+
torch.tensor(sample[i]['speech_token'], dtype=torch.int64) for i in order
|
939 |
+
]
|
940 |
+
|
941 |
+
sorted_wavs = [sample[i]['wav'].squeeze(0) for i in order]
|
942 |
+
label_lengths = torch.tensor([x.size(0) for x in sorted_labels],
|
943 |
+
dtype=torch.int32)
|
944 |
+
speech_token_lengths = torch.tensor([x.size(0) for x in sorted_speech_tokens],
|
945 |
+
dtype=torch.int32)
|
946 |
+
wav_lengths = torch.tensor([x.size(0) for x in sorted_wavs],
|
947 |
+
dtype=torch.int32)
|
948 |
+
# print('------------------')
|
949 |
+
# for feat_item in sorted_feats:
|
950 |
+
# print(feat_item.shape)
|
951 |
+
# print('------------------')
|
952 |
+
|
953 |
+
padded_feats = pad_sequence(sorted_feats,
|
954 |
+
batch_first=True,
|
955 |
+
padding_value=0)
|
956 |
+
padding_labels = pad_sequence(sorted_labels,
|
957 |
+
batch_first=True,
|
958 |
+
padding_value=-100)
|
959 |
+
|
960 |
+
padding_speech_tokens = pad_sequence(sorted_speech_tokens,
|
961 |
+
batch_first=True,
|
962 |
+
padding_value=-100)
|
963 |
+
padded_wavs = pad_sequence(sorted_wavs,
|
964 |
+
batch_first=True,
|
965 |
+
padding_value=0)
|
966 |
+
|
967 |
+
sorted_lang = [
|
968 |
+
sample[i].get('lang', 'cn') for i in order
|
969 |
+
]
|
970 |
+
|
971 |
+
sorted_speaker = [
|
972 |
+
sample[i].get('speaker', 'None') for i in order
|
973 |
+
]
|
974 |
+
|
975 |
+
sorted_emotion = [
|
976 |
+
sample[i].get('emotion', 'None') for i in order
|
977 |
+
]
|
978 |
+
sorted_gender = [
|
979 |
+
sample[i].get('gender', 'None') for i in order
|
980 |
+
]
|
981 |
+
# sorted_duration = [
|
982 |
+
# sample[i]['duration'] for i in order
|
983 |
+
# ]
|
984 |
+
sorted_task = [
|
985 |
+
sample[i].get('task', '<TRANSCRIBE>') for i in order
|
986 |
+
]
|
987 |
+
|
988 |
+
batch = {
|
989 |
+
"keys": sorted_keys,
|
990 |
+
"feats": padded_feats,
|
991 |
+
"target": padding_labels,
|
992 |
+
"feats_lengths": feats_lengths,
|
993 |
+
"target_lengths": label_lengths,
|
994 |
+
"pcm": padded_wavs,
|
995 |
+
"pcm_length": wav_lengths,
|
996 |
+
"speech_tokens": padding_speech_tokens,
|
997 |
+
"speech_tokens_length": speech_token_lengths,
|
998 |
+
"lang": sorted_lang,
|
999 |
+
"speaker": sorted_speaker,
|
1000 |
+
"emotion": sorted_emotion,
|
1001 |
+
"gender": sorted_gender,
|
1002 |
+
"task": sorted_task
|
1003 |
+
}
|
1004 |
+
if 'prompt' in sample[0]:
|
1005 |
+
sorted_prompts = [
|
1006 |
+
torch.tensor(sample[i]['prompt'], dtype=torch.int64
|
1007 |
+
) for i in order
|
1008 |
+
]
|
1009 |
+
prompt_lengths = torch.tensor([x.size(0) for x in
|
1010 |
+
sorted_prompts], dtype=torch.int32)
|
1011 |
+
padding_prompts = pad_sequence(sorted_prompts,
|
1012 |
+
batch_first=True,
|
1013 |
+
padding_value=-1)
|
1014 |
+
batch['prompt'] = padding_prompts
|
1015 |
+
batch['prompt_lengths'] = prompt_lengths
|
1016 |
+
|
1017 |
+
if 'output_type' in sample[0] and sample[0]['output_type'] == 'speech2text_token':
|
1018 |
+
batch['output_type'] = 'speech2text_token'
|
1019 |
+
elif 'output_type' in sample[0] and sample[0]['output_type'] == 'text2token':
|
1020 |
+
batch['output_type'] = 'text2token'
|
1021 |
+
else:
|
1022 |
+
batch['output_type'] = 'text'
|
1023 |
+
yield batch
|
wenet/dataset/kaldi_io.py
ADDED
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2014-2016 Brno University of Technology (author: Karel Vesely)
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License")
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import sys, os, re, gzip, struct
|
9 |
+
|
10 |
+
#################################################
|
11 |
+
# Adding kaldi tools to shell path,
|
12 |
+
|
13 |
+
# Select kaldi,
|
14 |
+
if not 'KALDI_ROOT' in os.environ:
|
15 |
+
# Default! To change run python with 'export KALDI_ROOT=/some_dir python'
|
16 |
+
os.environ['KALDI_ROOT'] = '/mnt/matylda5/iveselyk/Tools/kaldi-trunk'
|
17 |
+
|
18 |
+
# Add kaldi tools to path,
|
19 |
+
os.environ['PATH'] = os.popen(
|
20 |
+
'echo $KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lm/:$KALDI_ROOT/src/sgmmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$KALDI_ROOT/src/nnetbin:$KALDI_ROOT/src/nnet2bin:$KALDI_ROOT/src/nnet3bin:$KALDI_ROOT/src/online2bin/:$KALDI_ROOT/src/ivectorbin/:$KALDI_ROOT/src/lmbin/'
|
21 |
+
).readline().strip() + ':' + os.environ['PATH']
|
22 |
+
|
23 |
+
|
24 |
+
#################################################
|
25 |
+
# Define all custom exceptions,
|
26 |
+
class UnsupportedDataType(Exception):
|
27 |
+
pass
|
28 |
+
|
29 |
+
|
30 |
+
class UnknownVectorHeader(Exception):
|
31 |
+
pass
|
32 |
+
|
33 |
+
|
34 |
+
class UnknownMatrixHeader(Exception):
|
35 |
+
pass
|
36 |
+
|
37 |
+
|
38 |
+
class BadSampleSize(Exception):
|
39 |
+
pass
|
40 |
+
|
41 |
+
|
42 |
+
class BadInputFormat(Exception):
|
43 |
+
pass
|
44 |
+
|
45 |
+
|
46 |
+
class SubprocessFailed(Exception):
|
47 |
+
pass
|
48 |
+
|
49 |
+
|
50 |
+
#################################################
|
51 |
+
# Data-type independent helper functions,
|
52 |
+
|
53 |
+
|
54 |
+
def open_or_fd(file, mode='rb'):
|
55 |
+
""" fd = open_or_fd(file)
|
56 |
+
Open file, gzipped file, pipe, or forward the file-descriptor.
|
57 |
+
Eventually seeks in the 'file' argument contains ':offset' suffix.
|
58 |
+
"""
|
59 |
+
offset = None
|
60 |
+
try:
|
61 |
+
# strip 'ark:' prefix from r{x,w}filename (optional),
|
62 |
+
if re.search('^(ark|scp)(,scp|,b|,t|,n?f|,n?p|,b?o|,n?s|,n?cs)*:',
|
63 |
+
file):
|
64 |
+
(prefix, file) = file.split(':', 1)
|
65 |
+
# separate offset from filename (optional),
|
66 |
+
if re.search(':[0-9]+$', file):
|
67 |
+
(file, offset) = file.rsplit(':', 1)
|
68 |
+
# input pipe?
|
69 |
+
if file[-1] == '|':
|
70 |
+
fd = popen(file[:-1], 'rb') # custom,
|
71 |
+
# output pipe?
|
72 |
+
elif file[0] == '|':
|
73 |
+
fd = popen(file[1:], 'wb') # custom,
|
74 |
+
# is it gzipped?
|
75 |
+
elif file.split('.')[-1] == 'gz':
|
76 |
+
fd = gzip.open(file, mode)
|
77 |
+
# a normal file...
|
78 |
+
else:
|
79 |
+
fd = open(file, mode)
|
80 |
+
except TypeError:
|
81 |
+
# 'file' is opened file descriptor,
|
82 |
+
fd = file
|
83 |
+
# Eventually seek to offset,
|
84 |
+
if offset != None: fd.seek(int(offset))
|
85 |
+
return fd
|
86 |
+
|
87 |
+
|
88 |
+
# based on '/usr/local/lib/python3.4/os.py'
|
89 |
+
def popen(cmd, mode="rb"):
|
90 |
+
if not isinstance(cmd, str):
|
91 |
+
raise TypeError("invalid cmd type (%s, expected string)" % type(cmd))
|
92 |
+
|
93 |
+
import subprocess, io, threading
|
94 |
+
|
95 |
+
# cleanup function for subprocesses,
|
96 |
+
def cleanup(proc, cmd):
|
97 |
+
ret = proc.wait()
|
98 |
+
if ret > 0:
|
99 |
+
raise SubprocessFailed('cmd %s returned %d !' % (cmd, ret))
|
100 |
+
return
|
101 |
+
|
102 |
+
# text-mode,
|
103 |
+
if mode == "r":
|
104 |
+
proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
|
105 |
+
threading.Thread(target=cleanup,
|
106 |
+
args=(proc, cmd)).start() # clean-up thread,
|
107 |
+
return io.TextIOWrapper(proc.stdout)
|
108 |
+
elif mode == "w":
|
109 |
+
proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE)
|
110 |
+
threading.Thread(target=cleanup,
|
111 |
+
args=(proc, cmd)).start() # clean-up thread,
|
112 |
+
return io.TextIOWrapper(proc.stdin)
|
113 |
+
# binary,
|
114 |
+
elif mode == "rb":
|
115 |
+
proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
|
116 |
+
threading.Thread(target=cleanup,
|
117 |
+
args=(proc, cmd)).start() # clean-up thread,
|
118 |
+
return proc.stdout
|
119 |
+
elif mode == "wb":
|
120 |
+
proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE)
|
121 |
+
threading.Thread(target=cleanup,
|
122 |
+
args=(proc, cmd)).start() # clean-up thread,
|
123 |
+
return proc.stdin
|
124 |
+
# sanity,
|
125 |
+
else:
|
126 |
+
raise ValueError("invalid mode %s" % mode)
|
127 |
+
|
128 |
+
|
129 |
+
def read_key(fd):
|
130 |
+
""" [key] = read_key(fd)
|
131 |
+
Read the utterance-key from the opened ark/stream descriptor 'fd'.
|
132 |
+
"""
|
133 |
+
key = ''
|
134 |
+
while 1:
|
135 |
+
char = fd.read(1).decode("latin1")
|
136 |
+
if char == '': break
|
137 |
+
if char == ' ': break
|
138 |
+
key += char
|
139 |
+
key = key.strip()
|
140 |
+
if key == '': return None # end of file,
|
141 |
+
assert (re.match('^\S+$', key) != None) # check format (no whitespace!)
|
142 |
+
return key
|
143 |
+
|
144 |
+
|
145 |
+
#################################################
|
146 |
+
# Integer vectors (alignments, ...),
|
147 |
+
|
148 |
+
|
149 |
+
def read_ali_ark(file_or_fd):
|
150 |
+
""" Alias to 'read_vec_int_ark()' """
|
151 |
+
return read_vec_int_ark(file_or_fd)
|
152 |
+
|
153 |
+
|
154 |
+
def read_vec_int_ark(file_or_fd):
|
155 |
+
""" generator(key,vec) = read_vec_int_ark(file_or_fd)
|
156 |
+
Create generator of (key,vector<int>) tuples, which reads from the ark file/stream.
|
157 |
+
file_or_fd : ark, gzipped ark, pipe or opened file descriptor.
|
158 |
+
|
159 |
+
Read ark to a 'dictionary':
|
160 |
+
d = { u:d for u,d in kaldi_io.read_vec_int_ark(file) }
|
161 |
+
"""
|
162 |
+
fd = open_or_fd(file_or_fd)
|
163 |
+
try:
|
164 |
+
key = read_key(fd)
|
165 |
+
while key:
|
166 |
+
ali = read_vec_int(fd)
|
167 |
+
yield key, ali
|
168 |
+
key = read_key(fd)
|
169 |
+
finally:
|
170 |
+
if fd is not file_or_fd: fd.close()
|
171 |
+
|
172 |
+
|
173 |
+
def read_vec_int_scp(file_or_fd):
|
174 |
+
""" generator(key,vec) = read_vec_int_scp(file_or_fd)
|
175 |
+
Returns generator of (key,vector<int>) tuples, read according to kaldi scp.
|
176 |
+
file_or_fd : scp, gzipped scp, pipe or opened file descriptor.
|
177 |
+
|
178 |
+
Iterate the scp:
|
179 |
+
for key,vec in kaldi_io.read_vec_int_scp(file):
|
180 |
+
...
|
181 |
+
|
182 |
+
Read scp to a 'dictionary':
|
183 |
+
d = { key:vec for key,mat in kaldi_io.read_vec_int_scp(file) }
|
184 |
+
"""
|
185 |
+
fd = open_or_fd(file_or_fd)
|
186 |
+
try:
|
187 |
+
for line in fd:
|
188 |
+
(key, rxfile) = line.decode().split(' ')
|
189 |
+
vec = read_vec_int(rxfile)
|
190 |
+
yield key, vec
|
191 |
+
finally:
|
192 |
+
if fd is not file_or_fd: fd.close()
|
193 |
+
|
194 |
+
|
195 |
+
def read_vec_int(file_or_fd):
|
196 |
+
""" [int-vec] = read_vec_int(file_or_fd)
|
197 |
+
Read kaldi integer vector, ascii or binary input,
|
198 |
+
"""
|
199 |
+
fd = open_or_fd(file_or_fd)
|
200 |
+
binary = fd.read(2).decode()
|
201 |
+
if binary == '\0B': # binary flag
|
202 |
+
assert (fd.read(1).decode() == '\4')
|
203 |
+
# int-size
|
204 |
+
vec_size = np.frombuffer(fd.read(4), dtype='int32',
|
205 |
+
count=1)[0] # vector dim
|
206 |
+
# Elements from int32 vector are sored in tuples: (sizeof(int32), value),
|
207 |
+
vec = np.frombuffer(fd.read(vec_size * 5),
|
208 |
+
dtype=[('size', 'int8'), ('value', 'int32')],
|
209 |
+
count=vec_size)
|
210 |
+
assert (vec[0]['size'] == 4) # int32 size,
|
211 |
+
ans = vec[:]['value'] # values are in 2nd column,
|
212 |
+
else: # ascii,
|
213 |
+
arr = (binary + fd.readline().decode()).strip().split()
|
214 |
+
try:
|
215 |
+
arr.remove('[')
|
216 |
+
arr.remove(']') # optionally
|
217 |
+
except ValueError:
|
218 |
+
pass
|
219 |
+
ans = np.array(arr, dtype=int)
|
220 |
+
if fd is not file_or_fd: fd.close() # cleanup
|
221 |
+
return ans
|
222 |
+
|
223 |
+
|
224 |
+
# Writing,
|
225 |
+
def write_vec_int(file_or_fd, v, key=''):
|
226 |
+
""" write_vec_int(f, v, key='')
|
227 |
+
Write a binary kaldi integer vector to filename or stream.
|
228 |
+
Arguments:
|
229 |
+
file_or_fd : filename or opened file descriptor for writing,
|
230 |
+
v : the vector to be stored,
|
231 |
+
key (optional) : used for writing ark-file, the utterance-id gets written before the vector.
|
232 |
+
|
233 |
+
Example of writing single vector:
|
234 |
+
kaldi_io.write_vec_int(filename, vec)
|
235 |
+
|
236 |
+
Example of writing arkfile:
|
237 |
+
with open(ark_file,'w') as f:
|
238 |
+
for key,vec in dict.iteritems():
|
239 |
+
kaldi_io.write_vec_flt(f, vec, key=key)
|
240 |
+
"""
|
241 |
+
fd = open_or_fd(file_or_fd, mode='wb')
|
242 |
+
if sys.version_info[0] == 3: assert (fd.mode == 'wb')
|
243 |
+
try:
|
244 |
+
if key != '':
|
245 |
+
fd.write(
|
246 |
+
(key +
|
247 |
+
' ').encode("latin1")) # ark-files have keys (utterance-id),
|
248 |
+
fd.write('\0B'.encode()) # we write binary!
|
249 |
+
# dim,
|
250 |
+
fd.write('\4'.encode()) # int32 type,
|
251 |
+
fd.write(struct.pack(np.dtype('int32').char, v.shape[0]))
|
252 |
+
# data,
|
253 |
+
for i in range(len(v)):
|
254 |
+
fd.write('\4'.encode()) # int32 type,
|
255 |
+
fd.write(struct.pack(np.dtype('int32').char, v[i])) # binary,
|
256 |
+
finally:
|
257 |
+
if fd is not file_or_fd: fd.close()
|
258 |
+
|
259 |
+
|
260 |
+
#################################################
|
261 |
+
# Float vectors (confidences, ivectors, ...),
|
262 |
+
|
263 |
+
|
264 |
+
# Reading,
|
265 |
+
def read_vec_flt_scp(file_or_fd):
|
266 |
+
""" generator(key,mat) = read_vec_flt_scp(file_or_fd)
|
267 |
+
Returns generator of (key,vector) tuples, read according to kaldi scp.
|
268 |
+
file_or_fd : scp, gzipped scp, pipe or opened file descriptor.
|
269 |
+
|
270 |
+
Iterate the scp:
|
271 |
+
for key,vec in kaldi_io.read_vec_flt_scp(file):
|
272 |
+
...
|
273 |
+
|
274 |
+
Read scp to a 'dictionary':
|
275 |
+
d = { key:mat for key,mat in kaldi_io.read_mat_scp(file) }
|
276 |
+
"""
|
277 |
+
fd = open_or_fd(file_or_fd)
|
278 |
+
try:
|
279 |
+
for line in fd:
|
280 |
+
(key, rxfile) = line.decode().split(' ')
|
281 |
+
vec = read_vec_flt(rxfile)
|
282 |
+
yield key, vec
|
283 |
+
finally:
|
284 |
+
if fd is not file_or_fd: fd.close()
|
285 |
+
|
286 |
+
|
287 |
+
def read_vec_flt_ark(file_or_fd):
|
288 |
+
""" generator(key,vec) = read_vec_flt_ark(file_or_fd)
|
289 |
+
Create generator of (key,vector<float>) tuples, reading from an ark file/stream.
|
290 |
+
file_or_fd : ark, gzipped ark, pipe or opened file descriptor.
|
291 |
+
|
292 |
+
Read ark to a 'dictionary':
|
293 |
+
d = { u:d for u,d in kaldi_io.read_vec_flt_ark(file) }
|
294 |
+
"""
|
295 |
+
fd = open_or_fd(file_or_fd)
|
296 |
+
try:
|
297 |
+
key = read_key(fd)
|
298 |
+
while key:
|
299 |
+
ali = read_vec_flt(fd)
|
300 |
+
yield key, ali
|
301 |
+
key = read_key(fd)
|
302 |
+
finally:
|
303 |
+
if fd is not file_or_fd: fd.close()
|
304 |
+
|
305 |
+
|
306 |
+
def read_vec_flt(file_or_fd):
|
307 |
+
""" [flt-vec] = read_vec_flt(file_or_fd)
|
308 |
+
Read kaldi float vector, ascii or binary input,
|
309 |
+
"""
|
310 |
+
fd = open_or_fd(file_or_fd)
|
311 |
+
binary = fd.read(2).decode()
|
312 |
+
if binary == '\0B': # binary flag
|
313 |
+
# Data type,
|
314 |
+
header = fd.read(3).decode()
|
315 |
+
if header == 'FV ': sample_size = 4 # floats
|
316 |
+
elif header == 'DV ': sample_size = 8 # doubles
|
317 |
+
else: raise UnknownVectorHeader("The header contained '%s'" % header)
|
318 |
+
assert (sample_size > 0)
|
319 |
+
# Dimension,
|
320 |
+
assert (fd.read(1).decode() == '\4')
|
321 |
+
# int-size
|
322 |
+
vec_size = np.frombuffer(fd.read(4), dtype='int32',
|
323 |
+
count=1)[0] # vector dim
|
324 |
+
# Read whole vector,
|
325 |
+
buf = fd.read(vec_size * sample_size)
|
326 |
+
if sample_size == 4: ans = np.frombuffer(buf, dtype='float32')
|
327 |
+
elif sample_size == 8: ans = np.frombuffer(buf, dtype='float64')
|
328 |
+
else: raise BadSampleSize
|
329 |
+
return ans
|
330 |
+
else: # ascii,
|
331 |
+
arr = (binary + fd.readline().decode()).strip().split()
|
332 |
+
try:
|
333 |
+
arr.remove('[')
|
334 |
+
arr.remove(']') # optionally
|
335 |
+
except ValueError:
|
336 |
+
pass
|
337 |
+
ans = np.array(arr, dtype=float)
|
338 |
+
if fd is not file_or_fd: fd.close() # cleanup
|
339 |
+
return ans
|
340 |
+
|
341 |
+
|
342 |
+
# Writing,
|
343 |
+
def write_vec_flt(file_or_fd, v, key=''):
|
344 |
+
""" write_vec_flt(f, v, key='')
|
345 |
+
Write a binary kaldi vector to filename or stream. Supports 32bit and 64bit floats.
|
346 |
+
Arguments:
|
347 |
+
file_or_fd : filename or opened file descriptor for writing,
|
348 |
+
v : the vector to be stored,
|
349 |
+
key (optional) : used for writing ark-file, the utterance-id gets written before the vector.
|
350 |
+
|
351 |
+
Example of writing single vector:
|
352 |
+
kaldi_io.write_vec_flt(filename, vec)
|
353 |
+
|
354 |
+
Example of writing arkfile:
|
355 |
+
with open(ark_file,'w') as f:
|
356 |
+
for key,vec in dict.iteritems():
|
357 |
+
kaldi_io.write_vec_flt(f, vec, key=key)
|
358 |
+
"""
|
359 |
+
fd = open_or_fd(file_or_fd, mode='wb')
|
360 |
+
if sys.version_info[0] == 3: assert (fd.mode == 'wb')
|
361 |
+
try:
|
362 |
+
if key != '':
|
363 |
+
fd.write(
|
364 |
+
(key +
|
365 |
+
' ').encode("latin1")) # ark-files have keys (utterance-id),
|
366 |
+
fd.write('\0B'.encode()) # we write binary!
|
367 |
+
# Data-type,
|
368 |
+
if v.dtype == 'float32': fd.write('FV '.encode())
|
369 |
+
elif v.dtype == 'float64': fd.write('DV '.encode())
|
370 |
+
else:
|
371 |
+
raise UnsupportedDataType(
|
372 |
+
"'%s', please use 'float32' or 'float64'" % v.dtype)
|
373 |
+
# Dim,
|
374 |
+
fd.write('\04'.encode())
|
375 |
+
fd.write(struct.pack(np.dtype('uint32').char, v.shape[0])) # dim
|
376 |
+
# Data,
|
377 |
+
fd.write(v.tobytes())
|
378 |
+
finally:
|
379 |
+
if fd is not file_or_fd: fd.close()
|
380 |
+
|
381 |
+
|
382 |
+
#################################################
|
383 |
+
# Float matrices (features, transformations, ...),
|
384 |
+
|
385 |
+
|
386 |
+
# Reading,
|
387 |
+
def read_mat_scp(file_or_fd):
|
388 |
+
""" generator(key,mat) = read_mat_scp(file_or_fd)
|
389 |
+
Returns generator of (key,matrix) tuples, read according to kaldi scp.
|
390 |
+
file_or_fd : scp, gzipped scp, pipe or opened file descriptor.
|
391 |
+
|
392 |
+
Iterate the scp:
|
393 |
+
for key,mat in kaldi_io.read_mat_scp(file):
|
394 |
+
...
|
395 |
+
|
396 |
+
Read scp to a 'dictionary':
|
397 |
+
d = { key:mat for key,mat in kaldi_io.read_mat_scp(file) }
|
398 |
+
"""
|
399 |
+
fd = open_or_fd(file_or_fd)
|
400 |
+
try:
|
401 |
+
for line in fd:
|
402 |
+
(key, rxfile) = line.decode().split(' ')
|
403 |
+
mat = read_mat(rxfile)
|
404 |
+
yield key, mat
|
405 |
+
finally:
|
406 |
+
if fd is not file_or_fd: fd.close()
|
407 |
+
|
408 |
+
|
409 |
+
def read_mat_ark(file_or_fd):
|
410 |
+
""" generator(key,mat) = read_mat_ark(file_or_fd)
|
411 |
+
Returns generator of (key,matrix) tuples, read from ark file/stream.
|
412 |
+
file_or_fd : scp, gzipped scp, pipe or opened file descriptor.
|
413 |
+
|
414 |
+
Iterate the ark:
|
415 |
+
for key,mat in kaldi_io.read_mat_ark(file):
|
416 |
+
...
|
417 |
+
|
418 |
+
Read ark to a 'dictionary':
|
419 |
+
d = { key:mat for key,mat in kaldi_io.read_mat_ark(file) }
|
420 |
+
"""
|
421 |
+
fd = open_or_fd(file_or_fd)
|
422 |
+
try:
|
423 |
+
key = read_key(fd)
|
424 |
+
while key:
|
425 |
+
mat = read_mat(fd)
|
426 |
+
yield key, mat
|
427 |
+
key = read_key(fd)
|
428 |
+
finally:
|
429 |
+
if fd is not file_or_fd: fd.close()
|
430 |
+
|
431 |
+
|
432 |
+
def read_mat(file_or_fd):
|
433 |
+
""" [mat] = read_mat(file_or_fd)
|
434 |
+
Reads single kaldi matrix, supports ascii and binary.
|
435 |
+
file_or_fd : file, gzipped file, pipe or opened file descriptor.
|
436 |
+
"""
|
437 |
+
fd = open_or_fd(file_or_fd)
|
438 |
+
try:
|
439 |
+
binary = fd.read(2).decode()
|
440 |
+
if binary == '\0B':
|
441 |
+
mat = _read_mat_binary(fd)
|
442 |
+
else:
|
443 |
+
assert (binary == ' [')
|
444 |
+
mat = _read_mat_ascii(fd)
|
445 |
+
finally:
|
446 |
+
if fd is not file_or_fd: fd.close()
|
447 |
+
return mat
|
448 |
+
|
449 |
+
|
450 |
+
def _read_mat_binary(fd):
|
451 |
+
# Data type
|
452 |
+
header = fd.read(3).decode()
|
453 |
+
# 'CM', 'CM2', 'CM3' are possible values,
|
454 |
+
if header.startswith('CM'): return _read_compressed_mat(fd, header)
|
455 |
+
elif header == 'FM ': sample_size = 4 # floats
|
456 |
+
elif header == 'DM ': sample_size = 8 # doubles
|
457 |
+
else: raise UnknownMatrixHeader("The header contained '%s'" % header)
|
458 |
+
assert (sample_size > 0)
|
459 |
+
# Dimensions
|
460 |
+
s1, rows, s2, cols = np.frombuffer(fd.read(10),
|
461 |
+
dtype='int8,int32,int8,int32',
|
462 |
+
count=1)[0]
|
463 |
+
# Read whole matrix
|
464 |
+
buf = fd.read(rows * cols * sample_size)
|
465 |
+
if sample_size == 4: vec = np.frombuffer(buf, dtype='float32')
|
466 |
+
elif sample_size == 8: vec = np.frombuffer(buf, dtype='float64')
|
467 |
+
else: raise BadSampleSize
|
468 |
+
mat = np.reshape(vec, (rows, cols))
|
469 |
+
return mat
|
470 |
+
|
471 |
+
|
472 |
+
def _read_mat_ascii(fd):
|
473 |
+
rows = []
|
474 |
+
while 1:
|
475 |
+
line = fd.readline().decode()
|
476 |
+
if (len(line) == 0): raise BadInputFormat # eof, should not happen!
|
477 |
+
if len(line.strip()) == 0: continue # skip empty line
|
478 |
+
arr = line.strip().split()
|
479 |
+
if arr[-1] != ']':
|
480 |
+
rows.append(np.array(arr, dtype='float32')) # not last line
|
481 |
+
else:
|
482 |
+
rows.append(np.array(arr[:-1], dtype='float32')) # last line
|
483 |
+
mat = np.vstack(rows)
|
484 |
+
return mat
|
485 |
+
|
486 |
+
|
487 |
+
def _read_compressed_mat(fd, format):
|
488 |
+
""" Read a compressed matrix,
|
489 |
+
see: https://github.com/kaldi-asr/kaldi/blob/master/src/matrix/compressed-matrix.h
|
490 |
+
methods: CompressedMatrix::Read(...), CompressedMatrix::CopyToMat(...),
|
491 |
+
"""
|
492 |
+
assert (format == 'CM ') # The formats CM2, CM3 are not supported...
|
493 |
+
|
494 |
+
# Format of header 'struct',
|
495 |
+
global_header = np.dtype([('minvalue', 'float32'), ('range', 'float32'),
|
496 |
+
('num_rows', 'int32'), ('num_cols', 'int32')
|
497 |
+
]) # member '.format' is not written,
|
498 |
+
per_col_header = np.dtype([('percentile_0', 'uint16'),
|
499 |
+
('percentile_25', 'uint16'),
|
500 |
+
('percentile_75', 'uint16'),
|
501 |
+
('percentile_100', 'uint16')])
|
502 |
+
|
503 |
+
# Mapping for percentiles in col-headers,
|
504 |
+
def uint16_to_float(value, min, range):
|
505 |
+
return np.float32(min + range * 1.52590218966964e-05 * value)
|
506 |
+
|
507 |
+
# Mapping for matrix elements,
|
508 |
+
def uint8_to_float_v2(vec, p0, p25, p75, p100):
|
509 |
+
# Split the vector by masks,
|
510 |
+
mask_0_64 = (vec <= 64)
|
511 |
+
mask_193_255 = (vec > 192)
|
512 |
+
mask_65_192 = (~(mask_0_64 | mask_193_255))
|
513 |
+
# Sanity check (useful but slow...),
|
514 |
+
# assert(len(vec) == np.sum(np.hstack([mask_0_64,mask_65_192,mask_193_255])))
|
515 |
+
# assert(len(vec) == np.sum(np.any([mask_0_64,mask_65_192,mask_193_255], axis=0)))
|
516 |
+
# Build the float vector,
|
517 |
+
ans = np.empty(len(vec), dtype='float32')
|
518 |
+
ans[mask_0_64] = p0 + (p25 - p0) / 64. * vec[mask_0_64]
|
519 |
+
ans[mask_65_192] = p25 + (p75 - p25) / 128. * (vec[mask_65_192] - 64)
|
520 |
+
ans[mask_193_255] = p75 + (p100 - p75) / 63. * (vec[mask_193_255] -
|
521 |
+
192)
|
522 |
+
return ans
|
523 |
+
|
524 |
+
# Read global header,
|
525 |
+
globmin, globrange, rows, cols = np.frombuffer(fd.read(16),
|
526 |
+
dtype=global_header,
|
527 |
+
count=1)[0]
|
528 |
+
|
529 |
+
# The data is structed as [Colheader, ... , Colheader, Data, Data , .... ]
|
530 |
+
# { cols }{ size }
|
531 |
+
col_headers = np.frombuffer(fd.read(cols * 8),
|
532 |
+
dtype=per_col_header,
|
533 |
+
count=cols)
|
534 |
+
data = np.reshape(np.frombuffer(fd.read(cols * rows),
|
535 |
+
dtype='uint8',
|
536 |
+
count=cols * rows),
|
537 |
+
newshape=(cols, rows)) # stored as col-major,
|
538 |
+
|
539 |
+
mat = np.empty((cols, rows), dtype='float32')
|
540 |
+
for i, col_header in enumerate(col_headers):
|
541 |
+
col_header_flt = [
|
542 |
+
uint16_to_float(percentile, globmin, globrange)
|
543 |
+
for percentile in col_header
|
544 |
+
]
|
545 |
+
mat[i] = uint8_to_float_v2(data[i], *col_header_flt)
|
546 |
+
|
547 |
+
return mat.T # transpose! col-major -> row-major,
|
548 |
+
|
549 |
+
|
550 |
+
def write_ark_scp(key, mat, ark_fout, scp_out):
|
551 |
+
mat_offset = write_mat(ark_fout, mat, key)
|
552 |
+
scp_line = '{}\t{}:{}'.format(key, ark_fout.name, mat_offset)
|
553 |
+
scp_out.write(scp_line)
|
554 |
+
scp_out.write('\n')
|
555 |
+
|
556 |
+
|
557 |
+
# Writing,
|
558 |
+
def write_mat(file_or_fd, m, key=''):
|
559 |
+
""" write_mat(f, m, key='')
|
560 |
+
Write a binary kaldi matrix to filename or stream. Supports 32bit and 64bit floats.
|
561 |
+
Arguments:
|
562 |
+
file_or_fd : filename of opened file descriptor for writing,
|
563 |
+
m : the matrix to be stored,
|
564 |
+
key (optional) : used for writing ark-file, the utterance-id gets written before the matrix.
|
565 |
+
|
566 |
+
Example of writing single matrix:
|
567 |
+
kaldi_io.write_mat(filename, mat)
|
568 |
+
|
569 |
+
Example of writing arkfile:
|
570 |
+
with open(ark_file,'w') as f:
|
571 |
+
for key,mat in dict.iteritems():
|
572 |
+
kaldi_io.write_mat(f, mat, key=key)
|
573 |
+
"""
|
574 |
+
mat_offset = 0
|
575 |
+
fd = open_or_fd(file_or_fd, mode='wb')
|
576 |
+
if sys.version_info[0] == 3: assert (fd.mode == 'wb')
|
577 |
+
try:
|
578 |
+
if key != '':
|
579 |
+
fd.write(
|
580 |
+
(key +
|
581 |
+
' ').encode("latin1")) # ark-files have keys (utterance-id),
|
582 |
+
mat_offset = fd.tell()
|
583 |
+
fd.write('\0B'.encode()) # we write binary!
|
584 |
+
# Data-type,
|
585 |
+
if m.dtype == 'float32': fd.write('FM '.encode())
|
586 |
+
elif m.dtype == 'float64': fd.write('DM '.encode())
|
587 |
+
else:
|
588 |
+
raise UnsupportedDataType(
|
589 |
+
"'%s', please use 'float32' or 'float64'" % m.dtype)
|
590 |
+
# Dims,
|
591 |
+
fd.write('\04'.encode())
|
592 |
+
fd.write(struct.pack(np.dtype('uint32').char, m.shape[0])) # rows
|
593 |
+
fd.write('\04'.encode())
|
594 |
+
fd.write(struct.pack(np.dtype('uint32').char, m.shape[1])) # cols
|
595 |
+
# Data,
|
596 |
+
fd.write(m.tobytes())
|
597 |
+
finally:
|
598 |
+
if fd is not file_or_fd: fd.close()
|
599 |
+
return mat_offset
|
600 |
+
|
601 |
+
|
602 |
+
#################################################
|
603 |
+
# 'Posterior' kaldi type (posteriors, confusion network, nnet1 training targets, ...)
|
604 |
+
# Corresponds to: vector<vector<tuple<int,float> > >
|
605 |
+
# - outer vector: time axis
|
606 |
+
# - inner vector: records at the time
|
607 |
+
# - tuple: int = index, float = value
|
608 |
+
#
|
609 |
+
|
610 |
+
|
611 |
+
def read_cnet_ark(file_or_fd):
|
612 |
+
""" Alias of function 'read_post_ark()', 'cnet' = confusion network """
|
613 |
+
return read_post_ark(file_or_fd)
|
614 |
+
|
615 |
+
|
616 |
+
def read_post_ark(file_or_fd):
|
617 |
+
""" generator(key,vec<vec<int,float>>) = read_post_ark(file)
|
618 |
+
Returns generator of (key,posterior) tuples, read from ark file.
|
619 |
+
file_or_fd : ark, gzipped ark, pipe or opened file descriptor.
|
620 |
+
|
621 |
+
Iterate the ark:
|
622 |
+
for key,post in kaldi_io.read_post_ark(file):
|
623 |
+
...
|
624 |
+
|
625 |
+
Read ark to a 'dictionary':
|
626 |
+
d = { key:post for key,post in kaldi_io.read_post_ark(file) }
|
627 |
+
"""
|
628 |
+
fd = open_or_fd(file_or_fd)
|
629 |
+
try:
|
630 |
+
key = read_key(fd)
|
631 |
+
while key:
|
632 |
+
post = read_post(fd)
|
633 |
+
yield key, post
|
634 |
+
key = read_key(fd)
|
635 |
+
finally:
|
636 |
+
if fd is not file_or_fd: fd.close()
|
637 |
+
|
638 |
+
|
639 |
+
def read_post(file_or_fd):
|
640 |
+
""" [post] = read_post(file_or_fd)
|
641 |
+
Reads single kaldi 'Posterior' in binary format.
|
642 |
+
|
643 |
+
The 'Posterior' is C++ type 'vector<vector<tuple<int,float> > >',
|
644 |
+
the outer-vector is usually time axis, inner-vector are the records
|
645 |
+
at given time, and the tuple is composed of an 'index' (integer)
|
646 |
+
and a 'float-value'. The 'float-value' can represent a probability
|
647 |
+
or any other numeric value.
|
648 |
+
|
649 |
+
Returns vector of vectors of tuples.
|
650 |
+
"""
|
651 |
+
fd = open_or_fd(file_or_fd)
|
652 |
+
ans = []
|
653 |
+
binary = fd.read(2).decode()
|
654 |
+
assert (binary == '\0B')
|
655 |
+
# binary flag
|
656 |
+
assert (fd.read(1).decode() == '\4')
|
657 |
+
# int-size
|
658 |
+
outer_vec_size = np.frombuffer(fd.read(4), dtype='int32',
|
659 |
+
count=1)[0] # number of frames (or bins)
|
660 |
+
|
661 |
+
# Loop over 'outer-vector',
|
662 |
+
for i in range(outer_vec_size):
|
663 |
+
assert (fd.read(1).decode() == '\4')
|
664 |
+
# int-size
|
665 |
+
inner_vec_size = np.frombuffer(
|
666 |
+
fd.read(4), dtype='int32',
|
667 |
+
count=1)[0] # number of records for frame (or bin)
|
668 |
+
data = np.frombuffer(fd.read(inner_vec_size * 10),
|
669 |
+
dtype=[('size_idx', 'int8'), ('idx', 'int32'),
|
670 |
+
('size_post', 'int8'),
|
671 |
+
('post', 'float32')],
|
672 |
+
count=inner_vec_size)
|
673 |
+
assert (data[0]['size_idx'] == 4)
|
674 |
+
assert (data[0]['size_post'] == 4)
|
675 |
+
ans.append(data[['idx', 'post']].tolist())
|
676 |
+
|
677 |
+
if fd is not file_or_fd: fd.close()
|
678 |
+
return ans
|
679 |
+
|
680 |
+
|
681 |
+
#################################################
|
682 |
+
# Kaldi Confusion Network bin begin/end times,
|
683 |
+
# (kaldi stores CNs time info separately from the Posterior).
|
684 |
+
#
|
685 |
+
|
686 |
+
|
687 |
+
def read_cntime_ark(file_or_fd):
|
688 |
+
""" generator(key,vec<tuple<float,float>>) = read_cntime_ark(file_or_fd)
|
689 |
+
Returns generator of (key,cntime) tuples, read from ark file.
|
690 |
+
file_or_fd : file, gzipped file, pipe or opened file descriptor.
|
691 |
+
|
692 |
+
Iterate the ark:
|
693 |
+
for key,time in kaldi_io.read_cntime_ark(file):
|
694 |
+
...
|
695 |
+
|
696 |
+
Read ark to a 'dictionary':
|
697 |
+
d = { key:time for key,time in kaldi_io.read_post_ark(file) }
|
698 |
+
"""
|
699 |
+
fd = open_or_fd(file_or_fd)
|
700 |
+
try:
|
701 |
+
key = read_key(fd)
|
702 |
+
while key:
|
703 |
+
cntime = read_cntime(fd)
|
704 |
+
yield key, cntime
|
705 |
+
key = read_key(fd)
|
706 |
+
finally:
|
707 |
+
if fd is not file_or_fd: fd.close()
|
708 |
+
|
709 |
+
|
710 |
+
def read_cntime(file_or_fd):
|
711 |
+
""" [cntime] = read_cntime(file_or_fd)
|
712 |
+
Reads single kaldi 'Confusion Network time info', in binary format:
|
713 |
+
C++ type: vector<tuple<float,float> >.
|
714 |
+
(begin/end times of bins at the confusion network).
|
715 |
+
|
716 |
+
Binary layout is '<num-bins> <beg1> <end1> <beg2> <end2> ...'
|
717 |
+
|
718 |
+
file_or_fd : file, gzipped file, pipe or opened file descriptor.
|
719 |
+
|
720 |
+
Returns vector of tuples.
|
721 |
+
"""
|
722 |
+
fd = open_or_fd(file_or_fd)
|
723 |
+
binary = fd.read(2).decode()
|
724 |
+
assert (binary == '\0B')
|
725 |
+
# assuming it's binary
|
726 |
+
|
727 |
+
assert (fd.read(1).decode() == '\4')
|
728 |
+
# int-size
|
729 |
+
vec_size = np.frombuffer(fd.read(4), dtype='int32',
|
730 |
+
count=1)[0] # number of frames (or bins)
|
731 |
+
|
732 |
+
data = np.frombuffer(fd.read(vec_size * 10),
|
733 |
+
dtype=[('size_beg', 'int8'), ('t_beg', 'float32'),
|
734 |
+
('size_end', 'int8'), ('t_end', 'float32')],
|
735 |
+
count=vec_size)
|
736 |
+
assert (data[0]['size_beg'] == 4)
|
737 |
+
assert (data[0]['size_end'] == 4)
|
738 |
+
ans = data[['t_beg',
|
739 |
+
't_end']].tolist() # Return vector of tuples (t_beg,t_end),
|
740 |
+
|
741 |
+
if fd is not file_or_fd: fd.close()
|
742 |
+
return ans
|
743 |
+
|
744 |
+
|
745 |
+
#################################################
|
746 |
+
# Segments related,
|
747 |
+
#
|
748 |
+
|
749 |
+
|
750 |
+
# Segments as 'Bool vectors' can be handy,
|
751 |
+
# - for 'superposing' the segmentations,
|
752 |
+
# - for frame-selection in Speaker-ID experiments,
|
753 |
+
def read_segments_as_bool_vec(segments_file):
|
754 |
+
""" [ bool_vec ] = read_segments_as_bool_vec(segments_file)
|
755 |
+
using kaldi 'segments' file for 1 wav, format : '<utt> <rec> <t-beg> <t-end>'
|
756 |
+
- t-beg, t-end is in seconds,
|
757 |
+
- assumed 100 frames/second,
|
758 |
+
"""
|
759 |
+
segs = np.loadtxt(segments_file, dtype='object,object,f,f', ndmin=1)
|
760 |
+
# Sanity checks,
|
761 |
+
assert (len(segs) > 0) # empty segmentation is an error,
|
762 |
+
assert (len(np.unique([rec[1] for rec in segs])) == 1
|
763 |
+
) # segments with only 1 wav-file,
|
764 |
+
# Convert time to frame-indexes,
|
765 |
+
start = np.rint([100 * rec[2] for rec in segs]).astype(int)
|
766 |
+
end = np.rint([100 * rec[3] for rec in segs]).astype(int)
|
767 |
+
# Taken from 'read_lab_to_bool_vec', htk.py,
|
768 |
+
frms = np.repeat(
|
769 |
+
np.r_[np.tile([False, True], len(end)), False],
|
770 |
+
np.r_[np.c_[start - np.r_[0, end[:-1]], end - start].flat, 0])
|
771 |
+
assert np.sum(end - start) == np.sum(frms)
|
772 |
+
return frms
|
wenet/dataset/processor.py
ADDED
@@ -0,0 +1,694 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import copy
|
16 |
+
import librosa
|
17 |
+
import logging
|
18 |
+
import json
|
19 |
+
import random
|
20 |
+
import tarfile
|
21 |
+
from subprocess import PIPE, Popen
|
22 |
+
from urllib.parse import urlparse
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import torchaudio
|
26 |
+
import torchaudio.compliance.kaldi as kaldi
|
27 |
+
import torch.nn.functional as F
|
28 |
+
from torch.nn.utils.rnn import pad_sequence
|
29 |
+
from wenet.text.base_tokenizer import BaseTokenizer
|
30 |
+
|
31 |
+
torchaudio.utils.sox_utils.set_buffer_size(16500)
|
32 |
+
|
33 |
+
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
|
34 |
+
|
35 |
+
|
36 |
+
def url_opener(data):
|
37 |
+
""" Give url or local file, return file descriptor
|
38 |
+
Inplace operation.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
data(Iterable[str]): url or local file list
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
Iterable[{src, stream}]
|
45 |
+
"""
|
46 |
+
for sample in data:
|
47 |
+
assert 'src' in sample
|
48 |
+
# TODO(Binbin Zhang): support HTTP
|
49 |
+
url = sample['src']
|
50 |
+
try:
|
51 |
+
pr = urlparse(url)
|
52 |
+
# local file
|
53 |
+
if pr.scheme == '' or pr.scheme == 'file':
|
54 |
+
stream = open(url, 'rb')
|
55 |
+
# network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP
|
56 |
+
else:
|
57 |
+
cmd = f'wget -q -O - {url}'
|
58 |
+
process = Popen(cmd, shell=True, stdout=PIPE)
|
59 |
+
sample.update(process=process)
|
60 |
+
stream = process.stdout
|
61 |
+
sample.update(stream=stream)
|
62 |
+
yield sample
|
63 |
+
except Exception as ex:
|
64 |
+
logging.warning('Failed to open {}'.format(url))
|
65 |
+
|
66 |
+
|
67 |
+
def tar_file_and_group(data):
|
68 |
+
""" Expand a stream of open tar files into a stream of tar file contents.
|
69 |
+
And groups the file with same prefix
|
70 |
+
|
71 |
+
Args:
|
72 |
+
data: Iterable[{src, stream}]
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
Iterable[{key, wav, txt, sample_rate}]
|
76 |
+
"""
|
77 |
+
for sample in data:
|
78 |
+
assert 'stream' in sample
|
79 |
+
stream = None
|
80 |
+
try:
|
81 |
+
stream = tarfile.open(fileobj=sample['stream'], mode="r:*")
|
82 |
+
prev_prefix = None
|
83 |
+
example = {}
|
84 |
+
valid = True
|
85 |
+
for tarinfo in stream:
|
86 |
+
name = tarinfo.name
|
87 |
+
pos = name.rfind('.')
|
88 |
+
assert pos > 0
|
89 |
+
prefix, postfix = name[:pos], name[pos + 1:]
|
90 |
+
if prev_prefix is not None and prefix != prev_prefix:
|
91 |
+
example['key'] = prev_prefix
|
92 |
+
if valid:
|
93 |
+
yield example
|
94 |
+
example = {}
|
95 |
+
valid = True
|
96 |
+
with stream.extractfile(tarinfo) as file_obj:
|
97 |
+
try:
|
98 |
+
if postfix == 'txt':
|
99 |
+
example['txt'] = file_obj.read().decode(
|
100 |
+
'utf8').strip()
|
101 |
+
elif postfix in AUDIO_FORMAT_SETS:
|
102 |
+
waveform, sample_rate = torchaudio.load(file_obj)
|
103 |
+
example['wav'] = waveform
|
104 |
+
example['sample_rate'] = sample_rate
|
105 |
+
else:
|
106 |
+
example[postfix] = file_obj.read()
|
107 |
+
except Exception as ex:
|
108 |
+
valid = False
|
109 |
+
logging.warning('error to parse {}'.format(name))
|
110 |
+
prev_prefix = prefix
|
111 |
+
if prev_prefix is not None:
|
112 |
+
example['key'] = prev_prefix
|
113 |
+
yield example
|
114 |
+
except Exception as ex:
|
115 |
+
logging.warning(
|
116 |
+
'In tar_file_and_group: {} when processing {}'.format(
|
117 |
+
ex, sample['src']))
|
118 |
+
finally:
|
119 |
+
if stream is not None:
|
120 |
+
stream.close()
|
121 |
+
if 'process' in sample:
|
122 |
+
sample['process'].communicate()
|
123 |
+
sample['stream'].close()
|
124 |
+
|
125 |
+
|
126 |
+
def parse_raw(data):
|
127 |
+
""" Parse key/wav/txt from json line
|
128 |
+
|
129 |
+
Args:
|
130 |
+
data: Iterable[str], str is a json line has key/wav/txt
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
Iterable[{key, wav, txt, sample_rate}]
|
134 |
+
"""
|
135 |
+
for sample in data:
|
136 |
+
assert 'src' in sample
|
137 |
+
json_line = sample['src']
|
138 |
+
obj = json.loads(json_line)
|
139 |
+
assert 'key' in obj
|
140 |
+
assert 'wav' in obj
|
141 |
+
assert 'txt' in obj
|
142 |
+
key = obj['key']
|
143 |
+
wav_file = obj['wav']
|
144 |
+
txt = obj['txt']
|
145 |
+
try:
|
146 |
+
if 'start' in obj:
|
147 |
+
assert 'end' in obj
|
148 |
+
sample_rate = torchaudio.info(wav_file).sample_rate
|
149 |
+
start_frame = int(obj['start'] * sample_rate)
|
150 |
+
end_frame = int(obj['end'] * sample_rate)
|
151 |
+
waveform, _ = torchaudio.load(filepath=wav_file,
|
152 |
+
num_frames=end_frame -
|
153 |
+
start_frame,
|
154 |
+
frame_offset=start_frame)
|
155 |
+
else:
|
156 |
+
waveform, sample_rate = torchaudio.load(wav_file)
|
157 |
+
example = copy.deepcopy(obj) # copy and keep all the fields
|
158 |
+
example['wav'] = waveform # overwrite wav
|
159 |
+
example['sample_rate'] = sample_rate
|
160 |
+
yield example
|
161 |
+
except Exception as ex:
|
162 |
+
logging.warning('Failed to read {}'.format(wav_file))
|
163 |
+
|
164 |
+
|
165 |
+
def parse_speaker(data, speaker_table_path):
|
166 |
+
speaker_dict = {}
|
167 |
+
with open(speaker_table_path, 'r', encoding='utf8') as fin:
|
168 |
+
for line in fin:
|
169 |
+
arr = line.strip().split()
|
170 |
+
speaker_dict[arr[0]] = int(arr[1])
|
171 |
+
for sample in data:
|
172 |
+
assert 'speaker' in sample
|
173 |
+
speaker = sample['speaker']
|
174 |
+
sample['speaker'] = speaker_dict.get(speaker, 0)
|
175 |
+
yield sample
|
176 |
+
|
177 |
+
|
178 |
+
def filter(data,
|
179 |
+
max_length=10240,
|
180 |
+
min_length=10,
|
181 |
+
token_max_length=200,
|
182 |
+
token_min_length=1,
|
183 |
+
min_output_input_ratio=0.0005,
|
184 |
+
max_output_input_ratio=1):
|
185 |
+
""" Filter sample according to feature and label length
|
186 |
+
Inplace operation.
|
187 |
+
|
188 |
+
Args::
|
189 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
190 |
+
max_length: drop utterance which is greater than max_length(10ms)
|
191 |
+
min_length: drop utterance which is less than min_length(10ms)
|
192 |
+
token_max_length: drop utterance which is greater than
|
193 |
+
token_max_length, especially when use char unit for
|
194 |
+
english modeling
|
195 |
+
token_min_length: drop utterance which is
|
196 |
+
less than token_max_length
|
197 |
+
min_output_input_ratio: minimal ration of
|
198 |
+
token_length / feats_length(10ms)
|
199 |
+
max_output_input_ratio: maximum ration of
|
200 |
+
token_length / feats_length(10ms)
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
Iterable[{key, wav, label, sample_rate}]
|
204 |
+
"""
|
205 |
+
for sample in data:
|
206 |
+
try:
|
207 |
+
assert 'sample_rate' in sample
|
208 |
+
assert 'wav' in sample
|
209 |
+
assert 'label' in sample
|
210 |
+
except:
|
211 |
+
continue
|
212 |
+
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
213 |
+
num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100
|
214 |
+
if num_frames < min_length:
|
215 |
+
continue
|
216 |
+
if num_frames > max_length:
|
217 |
+
continue
|
218 |
+
if len(sample['label']) < token_min_length:
|
219 |
+
continue
|
220 |
+
if len(sample['label']) > token_max_length:
|
221 |
+
continue
|
222 |
+
if num_frames != 0:
|
223 |
+
if len(sample['label']) / num_frames < min_output_input_ratio:
|
224 |
+
continue
|
225 |
+
if len(sample['label']) / num_frames > max_output_input_ratio:
|
226 |
+
continue
|
227 |
+
yield sample
|
228 |
+
|
229 |
+
|
230 |
+
def resample(data, resample_rate=16000):
|
231 |
+
""" Resample data.
|
232 |
+
Inplace operation.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
236 |
+
resample_rate: target resample rate
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
Iterable[{key, wav, label, sample_rate}]
|
240 |
+
"""
|
241 |
+
for sample in data:
|
242 |
+
assert 'sample_rate' in sample
|
243 |
+
assert 'wav' in sample
|
244 |
+
sample_rate = sample['sample_rate']
|
245 |
+
waveform = sample['wav']
|
246 |
+
if sample_rate != resample_rate:
|
247 |
+
sample['sample_rate'] = resample_rate
|
248 |
+
sample['wav'] = torchaudio.transforms.Resample(
|
249 |
+
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
250 |
+
yield sample
|
251 |
+
|
252 |
+
|
253 |
+
def speed_perturb(data, speeds=None):
|
254 |
+
""" Apply speed perturb to the data.
|
255 |
+
Inplace operation.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
259 |
+
speeds(List[float]): optional speed
|
260 |
+
|
261 |
+
Returns:
|
262 |
+
Iterable[{key, wav, label, sample_rate}]
|
263 |
+
"""
|
264 |
+
if speeds is None:
|
265 |
+
speeds = [0.9, 1.0, 1.1]
|
266 |
+
for sample in data:
|
267 |
+
assert 'sample_rate' in sample
|
268 |
+
assert 'wav' in sample
|
269 |
+
sample_rate = sample['sample_rate']
|
270 |
+
waveform = sample['wav']
|
271 |
+
speed = random.choice(speeds)
|
272 |
+
if speed != 1.0:
|
273 |
+
wav, _ = torchaudio.sox_effects.apply_effects_tensor(
|
274 |
+
waveform, sample_rate,
|
275 |
+
[['speed', str(speed)], ['rate', str(sample_rate)]])
|
276 |
+
sample['wav'] = wav
|
277 |
+
|
278 |
+
yield sample
|
279 |
+
|
280 |
+
|
281 |
+
def compute_fbank(data,
|
282 |
+
num_mel_bins=23,
|
283 |
+
frame_length=25,
|
284 |
+
frame_shift=10,
|
285 |
+
dither=0.0):
|
286 |
+
""" Extract fbank
|
287 |
+
|
288 |
+
Args:
|
289 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
290 |
+
|
291 |
+
Returns:
|
292 |
+
Iterable[{key, feat, label}]
|
293 |
+
"""
|
294 |
+
for sample in data:
|
295 |
+
assert 'sample_rate' in sample
|
296 |
+
assert 'wav' in sample
|
297 |
+
assert 'key' in sample
|
298 |
+
assert 'label' in sample
|
299 |
+
sample_rate = sample['sample_rate']
|
300 |
+
waveform = sample['wav']
|
301 |
+
waveform = waveform * (1 << 15)
|
302 |
+
# Only keep key, feat, label
|
303 |
+
mat = kaldi.fbank(waveform,
|
304 |
+
num_mel_bins=num_mel_bins,
|
305 |
+
frame_length=frame_length,
|
306 |
+
frame_shift=frame_shift,
|
307 |
+
dither=dither,
|
308 |
+
energy_floor=0.0,
|
309 |
+
sample_frequency=sample_rate)
|
310 |
+
sample['feat'] = mat
|
311 |
+
yield sample
|
312 |
+
|
313 |
+
|
314 |
+
def compute_mfcc(data,
|
315 |
+
num_mel_bins=23,
|
316 |
+
frame_length=25,
|
317 |
+
frame_shift=10,
|
318 |
+
dither=0.0,
|
319 |
+
num_ceps=40,
|
320 |
+
high_freq=0.0,
|
321 |
+
low_freq=20.0):
|
322 |
+
""" Extract mfcc
|
323 |
+
|
324 |
+
Args:
|
325 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
326 |
+
|
327 |
+
Returns:
|
328 |
+
Iterable[{key, feat, label}]
|
329 |
+
"""
|
330 |
+
for sample in data:
|
331 |
+
assert 'sample_rate' in sample
|
332 |
+
assert 'wav' in sample
|
333 |
+
assert 'key' in sample
|
334 |
+
assert 'label' in sample
|
335 |
+
sample_rate = sample['sample_rate']
|
336 |
+
waveform = sample['wav']
|
337 |
+
waveform = waveform * (1 << 15)
|
338 |
+
# Only keep key, feat, label
|
339 |
+
mat = kaldi.mfcc(waveform,
|
340 |
+
num_mel_bins=num_mel_bins,
|
341 |
+
frame_length=frame_length,
|
342 |
+
frame_shift=frame_shift,
|
343 |
+
dither=dither,
|
344 |
+
num_ceps=num_ceps,
|
345 |
+
high_freq=high_freq,
|
346 |
+
low_freq=low_freq,
|
347 |
+
sample_frequency=sample_rate)
|
348 |
+
sample['feat'] = mat
|
349 |
+
yield sample
|
350 |
+
|
351 |
+
|
352 |
+
def compute_log_mel_spectrogram(data,
|
353 |
+
n_fft=400,
|
354 |
+
hop_length=160,
|
355 |
+
num_mel_bins=80,
|
356 |
+
padding=0):
|
357 |
+
""" Extract log mel spectrogram, modified from openai-whisper, see:
|
358 |
+
- https://github.com/openai/whisper/blob/main/whisper/audio.py
|
359 |
+
- https://github.com/wenet-e2e/wenet/pull/2141#issuecomment-1811765040
|
360 |
+
|
361 |
+
Args:
|
362 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
363 |
+
|
364 |
+
Returns:
|
365 |
+
Iterable[{key, feat, label}]
|
366 |
+
"""
|
367 |
+
for sample in data:
|
368 |
+
assert 'sample_rate' in sample
|
369 |
+
assert 'wav' in sample
|
370 |
+
assert 'key' in sample
|
371 |
+
assert 'label' in sample
|
372 |
+
sample_rate = sample['sample_rate']
|
373 |
+
waveform = sample['wav'].squeeze(0) # (channel=1, sample) -> (sample,)
|
374 |
+
if padding > 0:
|
375 |
+
waveform = F.pad(waveform, (0, padding))
|
376 |
+
window = torch.hann_window(n_fft)
|
377 |
+
stft = torch.stft(waveform,
|
378 |
+
n_fft,
|
379 |
+
hop_length,
|
380 |
+
window=window,
|
381 |
+
return_complex=True)
|
382 |
+
magnitudes = stft[..., :-1].abs()**2
|
383 |
+
|
384 |
+
filters = torch.from_numpy(
|
385 |
+
librosa.filters.mel(sr=sample_rate,
|
386 |
+
n_fft=n_fft,
|
387 |
+
n_mels=num_mel_bins))
|
388 |
+
mel_spec = filters @ magnitudes
|
389 |
+
|
390 |
+
# NOTE(xcsong): https://github.com/openai/whisper/discussions/269
|
391 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
392 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
393 |
+
log_spec = (log_spec + 4.0) / 4.0
|
394 |
+
sample['feat'] = log_spec.transpose(0, 1)
|
395 |
+
yield sample
|
396 |
+
|
397 |
+
|
398 |
+
def tokenize(data, tokenizer: BaseTokenizer, global_prompt_dict=None):
|
399 |
+
""" Decode text to chars or BPE
|
400 |
+
Inplace operation
|
401 |
+
|
402 |
+
Args:
|
403 |
+
data: Iterable[{key, wav, txt, sample_rate}]
|
404 |
+
|
405 |
+
Returns:
|
406 |
+
Iterable[{key, wav, txt, tokens, label, sample_rate}]
|
407 |
+
"""
|
408 |
+
for sample in data:
|
409 |
+
assert 'txt' in sample
|
410 |
+
if 'task' in sample:
|
411 |
+
task_name = sample['task']
|
412 |
+
if "<AGE>" in task_name:
|
413 |
+
txt = sample['txt'].replace("<YOUTH>", "<ADULT>")
|
414 |
+
else:
|
415 |
+
txt = sample['txt']
|
416 |
+
else:
|
417 |
+
txt = sample['txt']
|
418 |
+
tokens, label = tokenizer.tokenize(txt)
|
419 |
+
sample['tokens'] = tokens
|
420 |
+
sample['label'] = label + [tokenizer.eod_id]
|
421 |
+
if 'task' in sample:
|
422 |
+
task_name = sample['task']
|
423 |
+
random_index = random.randint(0, len(global_prompt_dict[task_name])-1)
|
424 |
+
prompt = global_prompt_dict[task_name][random_index]
|
425 |
+
sample['prompt'] = tokenizer.tokenize(prompt)
|
426 |
+
yield sample
|
427 |
+
|
428 |
+
|
429 |
+
def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80):
|
430 |
+
""" Do spec augmentation
|
431 |
+
Inplace operation
|
432 |
+
|
433 |
+
Args:
|
434 |
+
data: Iterable[{key, feat, label}]
|
435 |
+
num_t_mask: number of time mask to apply
|
436 |
+
num_f_mask: number of freq mask to apply
|
437 |
+
max_t: max width of time mask
|
438 |
+
max_f: max width of freq mask
|
439 |
+
max_w: max width of time warp
|
440 |
+
|
441 |
+
Returns
|
442 |
+
Iterable[{key, feat, label}]
|
443 |
+
"""
|
444 |
+
for sample in data:
|
445 |
+
assert 'feat' in sample
|
446 |
+
x = sample['feat']
|
447 |
+
assert isinstance(x, torch.Tensor)
|
448 |
+
y = x.clone().detach()
|
449 |
+
max_frames = y.size(0)
|
450 |
+
max_freq = y.size(1)
|
451 |
+
# time mask
|
452 |
+
for i in range(num_t_mask):
|
453 |
+
start = random.randint(0, max_frames - 1)
|
454 |
+
length = random.randint(1, max_t)
|
455 |
+
end = min(max_frames, start + length)
|
456 |
+
y[start:end, :] = 0
|
457 |
+
# freq mask
|
458 |
+
for i in range(num_f_mask):
|
459 |
+
start = random.randint(0, max_freq - 1)
|
460 |
+
length = random.randint(1, max_f)
|
461 |
+
end = min(max_freq, start + length)
|
462 |
+
y[:, start:end] = 0
|
463 |
+
sample['feat'] = y
|
464 |
+
yield sample
|
465 |
+
|
466 |
+
|
467 |
+
def spec_sub(data, max_t=20, num_t_sub=3):
|
468 |
+
""" Do spec substitute
|
469 |
+
Inplace operation
|
470 |
+
ref: U2++, section 3.2.3 [https://arxiv.org/abs/2106.05642]
|
471 |
+
|
472 |
+
Args:
|
473 |
+
data: Iterable[{key, feat, label}]
|
474 |
+
max_t: max width of time substitute
|
475 |
+
num_t_sub: number of time substitute to apply
|
476 |
+
|
477 |
+
Returns
|
478 |
+
Iterable[{key, feat, label}]
|
479 |
+
"""
|
480 |
+
for sample in data:
|
481 |
+
assert 'feat' in sample
|
482 |
+
x = sample['feat']
|
483 |
+
assert isinstance(x, torch.Tensor)
|
484 |
+
y = x.clone().detach()
|
485 |
+
max_frames = y.size(0)
|
486 |
+
for i in range(num_t_sub):
|
487 |
+
start = random.randint(0, max_frames - 1)
|
488 |
+
length = random.randint(1, max_t)
|
489 |
+
end = min(max_frames, start + length)
|
490 |
+
# only substitute the earlier time chosen randomly for current time
|
491 |
+
pos = random.randint(0, start)
|
492 |
+
y[start:end, :] = x[start - pos:end - pos, :]
|
493 |
+
sample['feat'] = y
|
494 |
+
yield sample
|
495 |
+
|
496 |
+
|
497 |
+
def spec_trim(data, max_t=20):
|
498 |
+
""" Trim tailing frames. Inplace operation.
|
499 |
+
ref: TrimTail [https://arxiv.org/abs/2211.00522]
|
500 |
+
|
501 |
+
Args:
|
502 |
+
data: Iterable[{key, feat, label}]
|
503 |
+
max_t: max width of length trimming
|
504 |
+
|
505 |
+
Returns
|
506 |
+
Iterable[{key, feat, label}]
|
507 |
+
"""
|
508 |
+
for sample in data:
|
509 |
+
assert 'feat' in sample
|
510 |
+
x = sample['feat']
|
511 |
+
assert isinstance(x, torch.Tensor)
|
512 |
+
max_frames = x.size(0)
|
513 |
+
length = random.randint(1, max_t)
|
514 |
+
if length < max_frames / 2:
|
515 |
+
y = x.clone().detach()[:max_frames - length]
|
516 |
+
sample['feat'] = y
|
517 |
+
yield sample
|
518 |
+
|
519 |
+
|
520 |
+
def shuffle(data, shuffle_size=10000):
|
521 |
+
""" Local shuffle the data
|
522 |
+
|
523 |
+
Args:
|
524 |
+
data: Iterable[{key, feat, label}]
|
525 |
+
shuffle_size: buffer size for shuffle
|
526 |
+
|
527 |
+
Returns:
|
528 |
+
Iterable[{key, feat, label}]
|
529 |
+
"""
|
530 |
+
buf = []
|
531 |
+
for sample in data:
|
532 |
+
buf.append(sample)
|
533 |
+
if len(buf) >= shuffle_size:
|
534 |
+
random.shuffle(buf)
|
535 |
+
for x in buf:
|
536 |
+
yield x
|
537 |
+
buf = []
|
538 |
+
# The sample left over
|
539 |
+
random.shuffle(buf)
|
540 |
+
for x in buf:
|
541 |
+
yield x
|
542 |
+
|
543 |
+
|
544 |
+
def sort(data, sort_size=500):
|
545 |
+
""" Sort the data by feature length.
|
546 |
+
Sort is used after shuffle and before batch, so we can group
|
547 |
+
utts with similar lengths into a batch, and `sort_size` should
|
548 |
+
be less than `shuffle_size`
|
549 |
+
|
550 |
+
Args:
|
551 |
+
data: Iterable[{key, feat, label}]
|
552 |
+
sort_size: buffer size for sort
|
553 |
+
|
554 |
+
Returns:
|
555 |
+
Iterable[{key, feat, label}]
|
556 |
+
"""
|
557 |
+
|
558 |
+
buf = []
|
559 |
+
for sample in data:
|
560 |
+
buf.append(sample)
|
561 |
+
if len(buf) >= sort_size:
|
562 |
+
buf.sort(key=lambda x: x['feat'].size(0))
|
563 |
+
for x in buf:
|
564 |
+
yield x
|
565 |
+
buf = []
|
566 |
+
# The sample left over
|
567 |
+
buf.sort(key=lambda x: x['feat'].size(0))
|
568 |
+
for x in buf:
|
569 |
+
yield x
|
570 |
+
|
571 |
+
|
572 |
+
def static_batch(data, batch_size=16):
|
573 |
+
""" Static batch the data by `batch_size`
|
574 |
+
|
575 |
+
Args:
|
576 |
+
data: Iterable[{key, feat, label}]
|
577 |
+
batch_size: batch size
|
578 |
+
|
579 |
+
Returns:
|
580 |
+
Iterable[List[{key, feat, label}]]
|
581 |
+
"""
|
582 |
+
buf = []
|
583 |
+
for sample in data:
|
584 |
+
buf.append(sample)
|
585 |
+
if len(buf) >= batch_size:
|
586 |
+
yield buf
|
587 |
+
buf = []
|
588 |
+
if len(buf) > 0:
|
589 |
+
yield buf
|
590 |
+
|
591 |
+
|
592 |
+
def dynamic_batch(data, max_frames_in_batch=12000):
|
593 |
+
""" Dynamic batch the data until the total frames in batch
|
594 |
+
reach `max_frames_in_batch`
|
595 |
+
|
596 |
+
Args:
|
597 |
+
data: Iterable[{key, feat, label}]
|
598 |
+
max_frames_in_batch: max_frames in one batch
|
599 |
+
|
600 |
+
Returns:
|
601 |
+
Iterable[List[{key, feat, label}]]
|
602 |
+
"""
|
603 |
+
buf = []
|
604 |
+
longest_frames = 0
|
605 |
+
for sample in data:
|
606 |
+
assert 'feat' in sample
|
607 |
+
assert isinstance(sample['feat'], torch.Tensor)
|
608 |
+
new_sample_frames = sample['feat'].size(0)
|
609 |
+
longest_frames = max(longest_frames, new_sample_frames)
|
610 |
+
frames_after_padding = longest_frames * (len(buf) + 1)
|
611 |
+
if frames_after_padding > max_frames_in_batch:
|
612 |
+
yield buf
|
613 |
+
buf = [sample]
|
614 |
+
longest_frames = new_sample_frames
|
615 |
+
else:
|
616 |
+
buf.append(sample)
|
617 |
+
if len(buf) > 0:
|
618 |
+
yield buf
|
619 |
+
|
620 |
+
|
621 |
+
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000):
|
622 |
+
""" Wrapper for static/dynamic batch
|
623 |
+
"""
|
624 |
+
if batch_type == 'static':
|
625 |
+
return static_batch(data, batch_size)
|
626 |
+
elif batch_type == 'dynamic':
|
627 |
+
return dynamic_batch(data, max_frames_in_batch)
|
628 |
+
else:
|
629 |
+
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
630 |
+
|
631 |
+
|
632 |
+
def padding(data):
|
633 |
+
""" Padding the data into training data
|
634 |
+
|
635 |
+
Args:
|
636 |
+
data: Iterable[List[{key, feat, label}]]
|
637 |
+
|
638 |
+
Returns:
|
639 |
+
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
|
640 |
+
"""
|
641 |
+
for sample in data:
|
642 |
+
assert isinstance(sample, list)
|
643 |
+
feats_length = torch.tensor([x['feat'].size(0) for x in sample],
|
644 |
+
dtype=torch.int32)
|
645 |
+
order = torch.argsort(feats_length, descending=True)
|
646 |
+
feats_lengths = torch.tensor(
|
647 |
+
[sample[i]['feat'].size(0) for i in order], dtype=torch.int32)
|
648 |
+
sorted_feats = [sample[i]['feat'] for i in order]
|
649 |
+
sorted_keys = [sample[i]['key'] for i in order]
|
650 |
+
sorted_labels = [
|
651 |
+
torch.tensor(sample[i]['label'], dtype=torch.int64) for i in order
|
652 |
+
]
|
653 |
+
sorted_wavs = [sample[i]['wav'].squeeze(0) for i in order]
|
654 |
+
label_lengths = torch.tensor([x.size(0) for x in sorted_labels],
|
655 |
+
dtype=torch.int32)
|
656 |
+
wav_lengths = torch.tensor([x.size(0) for x in sorted_wavs],
|
657 |
+
dtype=torch.int32)
|
658 |
+
|
659 |
+
padded_feats = pad_sequence(sorted_feats,
|
660 |
+
batch_first=True,
|
661 |
+
padding_value=0)
|
662 |
+
padding_labels = pad_sequence(sorted_labels,
|
663 |
+
batch_first=True,
|
664 |
+
padding_value=-1)
|
665 |
+
padded_wavs = pad_sequence(sorted_wavs,
|
666 |
+
batch_first=True,
|
667 |
+
padding_value=0)
|
668 |
+
batch = {
|
669 |
+
"keys": sorted_keys,
|
670 |
+
"feats": padded_feats,
|
671 |
+
"target": padding_labels,
|
672 |
+
"feats_lengths": feats_lengths,
|
673 |
+
"target_lengths": label_lengths,
|
674 |
+
"pcm": padded_wavs,
|
675 |
+
"pcm_length": wav_lengths,
|
676 |
+
}
|
677 |
+
if 'speaker' in sample[0]:
|
678 |
+
speaker = torch.tensor([sample[i]['speaker'] for i in order],
|
679 |
+
dtype=torch.int32)
|
680 |
+
batch['speaker'] = speaker
|
681 |
+
if 'prompt' in sample[0]:
|
682 |
+
sorted_prompts = [
|
683 |
+
torch.tensor(sample[i]['prompt'], dtype=torch.int64
|
684 |
+
) for i in order
|
685 |
+
]
|
686 |
+
prompt_lengths = torch.tensor([x.size(0) for x in
|
687 |
+
sorted_prompts], dtype=torch.int32)
|
688 |
+
padding_prompts = pad_sequence(sorted_prompts,
|
689 |
+
batch_first=True,
|
690 |
+
padding_value=-1)
|
691 |
+
batch['prompt'] = padding_prompts
|
692 |
+
batch['prompt_lengths'] = prompt_lengths
|
693 |
+
|
694 |
+
yield batch
|
wenet/dataset/wav_distortion.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Chao Yang)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import sys
|
16 |
+
import random
|
17 |
+
import math
|
18 |
+
|
19 |
+
import torchaudio
|
20 |
+
import torch
|
21 |
+
|
22 |
+
|
23 |
+
def db2amp(db):
|
24 |
+
return pow(10, db / 20)
|
25 |
+
|
26 |
+
|
27 |
+
def amp2db(amp):
|
28 |
+
return 20 * math.log10(amp)
|
29 |
+
|
30 |
+
|
31 |
+
def make_poly_distortion(conf):
|
32 |
+
"""Generate a db-domain ploynomial distortion function
|
33 |
+
|
34 |
+
f(x) = a * x^m * (1-x)^n + x
|
35 |
+
|
36 |
+
Args:
|
37 |
+
conf: a dict {'a': #int, 'm': #int, 'n': #int}
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
The ploynomial function, which could be applied on
|
41 |
+
a float amplitude value
|
42 |
+
"""
|
43 |
+
a = conf['a']
|
44 |
+
m = conf['m']
|
45 |
+
n = conf['n']
|
46 |
+
|
47 |
+
def poly_distortion(x):
|
48 |
+
abs_x = abs(x)
|
49 |
+
if abs_x < 0.000001:
|
50 |
+
x = x
|
51 |
+
else:
|
52 |
+
db_norm = amp2db(abs_x) / 100 + 1
|
53 |
+
if db_norm < 0:
|
54 |
+
db_norm = 0
|
55 |
+
db_norm = a * pow(db_norm, m) * pow((1 - db_norm), n) + db_norm
|
56 |
+
if db_norm > 1:
|
57 |
+
db_norm = 1
|
58 |
+
db = (db_norm - 1) * 100
|
59 |
+
amp = db2amp(db)
|
60 |
+
if amp >= 0.9997:
|
61 |
+
amp = 0.9997
|
62 |
+
if x > 0:
|
63 |
+
x = amp
|
64 |
+
else:
|
65 |
+
x = -amp
|
66 |
+
return x
|
67 |
+
|
68 |
+
return poly_distortion
|
69 |
+
|
70 |
+
|
71 |
+
def make_quad_distortion():
|
72 |
+
return make_poly_distortion({'a': 1, 'm': 1, 'n': 1})
|
73 |
+
|
74 |
+
|
75 |
+
# the amplitude are set to max for all non-zero point
|
76 |
+
def make_max_distortion(conf):
|
77 |
+
"""Generate a max distortion function
|
78 |
+
|
79 |
+
Args:
|
80 |
+
conf: a dict {'max_db': float }
|
81 |
+
'max_db': the maxium value.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
The max function, which could be applied on
|
85 |
+
a float amplitude value
|
86 |
+
"""
|
87 |
+
max_db = conf['max_db']
|
88 |
+
if max_db:
|
89 |
+
max_amp = db2amp(max_db) # < 0.997
|
90 |
+
else:
|
91 |
+
max_amp = 0.997
|
92 |
+
|
93 |
+
def max_distortion(x):
|
94 |
+
if x > 0:
|
95 |
+
x = max_amp
|
96 |
+
elif x < 0:
|
97 |
+
x = -max_amp
|
98 |
+
else:
|
99 |
+
x = 0.0
|
100 |
+
return x
|
101 |
+
|
102 |
+
return max_distortion
|
103 |
+
|
104 |
+
|
105 |
+
def make_amp_mask(db_mask=None):
|
106 |
+
"""Get a amplitude domain mask from db domain mask
|
107 |
+
|
108 |
+
Args:
|
109 |
+
db_mask: Optional. A list of tuple. if None, using default value.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
A list of tuple. The amplitude domain mask
|
113 |
+
"""
|
114 |
+
if db_mask is None:
|
115 |
+
db_mask = [(-110, -95), (-90, -80), (-65, -60), (-50, -30), (-15, 0)]
|
116 |
+
amp_mask = [(db2amp(db[0]), db2amp(db[1])) for db in db_mask]
|
117 |
+
return amp_mask
|
118 |
+
|
119 |
+
|
120 |
+
default_mask = make_amp_mask()
|
121 |
+
|
122 |
+
|
123 |
+
def generate_amp_mask(mask_num):
|
124 |
+
"""Generate amplitude domain mask randomly in [-100db, 0db]
|
125 |
+
|
126 |
+
Args:
|
127 |
+
mask_num: the slot number of the mask
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
A list of tuple. each tuple defines a slot.
|
131 |
+
e.g. [(-100, -80), (-65, -60), (-50, -30), (-15, 0)]
|
132 |
+
for #mask_num = 4
|
133 |
+
"""
|
134 |
+
a = [0] * 2 * mask_num
|
135 |
+
a[0] = 0
|
136 |
+
m = []
|
137 |
+
for i in range(1, 2 * mask_num):
|
138 |
+
a[i] = a[i - 1] + random.uniform(0.5, 1)
|
139 |
+
max_val = a[2 * mask_num - 1]
|
140 |
+
for i in range(0, mask_num):
|
141 |
+
l = ((a[2 * i] - max_val) / max_val) * 100
|
142 |
+
r = ((a[2 * i + 1] - max_val) / max_val) * 100
|
143 |
+
m.append((l, r))
|
144 |
+
return make_amp_mask(m)
|
145 |
+
|
146 |
+
|
147 |
+
def make_fence_distortion(conf):
|
148 |
+
"""Generate a fence distortion function
|
149 |
+
|
150 |
+
In this fence-like shape function, the values in mask slots are
|
151 |
+
set to maxium, while the values not in mask slots are set to 0.
|
152 |
+
Use seperated masks for Positive and negetive amplitude.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
conf: a dict {'mask_number': int,'max_db': float }
|
156 |
+
'mask_number': the slot number in mask.
|
157 |
+
'max_db': the maxium value.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
The fence function, which could be applied on
|
161 |
+
a float amplitude value
|
162 |
+
"""
|
163 |
+
mask_number = conf['mask_number']
|
164 |
+
max_db = conf['max_db']
|
165 |
+
max_amp = db2amp(max_db) # 0.997
|
166 |
+
if mask_number <= 0:
|
167 |
+
positive_mask = default_mask
|
168 |
+
negative_mask = make_amp_mask([(-50, 0)])
|
169 |
+
else:
|
170 |
+
positive_mask = generate_amp_mask(mask_number)
|
171 |
+
negative_mask = generate_amp_mask(mask_number)
|
172 |
+
|
173 |
+
def fence_distortion(x):
|
174 |
+
is_in_mask = False
|
175 |
+
if x > 0:
|
176 |
+
for mask in positive_mask:
|
177 |
+
if x >= mask[0] and x <= mask[1]:
|
178 |
+
is_in_mask = True
|
179 |
+
return max_amp
|
180 |
+
if not is_in_mask:
|
181 |
+
return 0.0
|
182 |
+
elif x < 0:
|
183 |
+
abs_x = abs(x)
|
184 |
+
for mask in negative_mask:
|
185 |
+
if abs_x >= mask[0] and abs_x <= mask[1]:
|
186 |
+
is_in_mask = True
|
187 |
+
return max_amp
|
188 |
+
if not is_in_mask:
|
189 |
+
return 0.0
|
190 |
+
return x
|
191 |
+
|
192 |
+
return fence_distortion
|
193 |
+
|
194 |
+
|
195 |
+
#
|
196 |
+
def make_jag_distortion(conf):
|
197 |
+
"""Generate a jag distortion function
|
198 |
+
|
199 |
+
In this jag-like shape function, the values in mask slots are
|
200 |
+
not changed, while the values not in mask slots are set to 0.
|
201 |
+
Use seperated masks for Positive and negetive amplitude.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
conf: a dict {'mask_number': #int}
|
205 |
+
'mask_number': the slot number in mask.
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
The jag function,which could be applied on
|
209 |
+
a float amplitude value
|
210 |
+
"""
|
211 |
+
mask_number = conf['mask_number']
|
212 |
+
if mask_number <= 0:
|
213 |
+
positive_mask = default_mask
|
214 |
+
negative_mask = make_amp_mask([(-50, 0)])
|
215 |
+
else:
|
216 |
+
positive_mask = generate_amp_mask(mask_number)
|
217 |
+
negative_mask = generate_amp_mask(mask_number)
|
218 |
+
|
219 |
+
def jag_distortion(x):
|
220 |
+
is_in_mask = False
|
221 |
+
if x > 0:
|
222 |
+
for mask in positive_mask:
|
223 |
+
if x >= mask[0] and x <= mask[1]:
|
224 |
+
is_in_mask = True
|
225 |
+
return x
|
226 |
+
if not is_in_mask:
|
227 |
+
return 0.0
|
228 |
+
elif x < 0:
|
229 |
+
abs_x = abs(x)
|
230 |
+
for mask in negative_mask:
|
231 |
+
if abs_x >= mask[0] and abs_x <= mask[1]:
|
232 |
+
is_in_mask = True
|
233 |
+
return x
|
234 |
+
if not is_in_mask:
|
235 |
+
return 0.0
|
236 |
+
return x
|
237 |
+
|
238 |
+
return jag_distortion
|
239 |
+
|
240 |
+
|
241 |
+
# gaining 20db means amp = amp * 10
|
242 |
+
# gaining -20db means amp = amp / 10
|
243 |
+
def make_gain_db(conf):
|
244 |
+
"""Generate a db domain gain function
|
245 |
+
|
246 |
+
Args:
|
247 |
+
conf: a dict {'db': #float}
|
248 |
+
'db': the gaining value
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
The db gain function, which could be applied on
|
252 |
+
a float amplitude value
|
253 |
+
"""
|
254 |
+
db = conf['db']
|
255 |
+
|
256 |
+
def gain_db(x):
|
257 |
+
return min(0.997, x * pow(10, db / 20))
|
258 |
+
|
259 |
+
return gain_db
|
260 |
+
|
261 |
+
|
262 |
+
def distort(x, func, rate=0.8):
|
263 |
+
"""Distort a waveform in sample point level
|
264 |
+
|
265 |
+
Args:
|
266 |
+
x: the origin wavefrom
|
267 |
+
func: the distort function
|
268 |
+
rate: sample point-level distort probability
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
the distorted waveform
|
272 |
+
"""
|
273 |
+
for i in range(0, x.shape[1]):
|
274 |
+
a = random.uniform(0, 1)
|
275 |
+
if a < rate:
|
276 |
+
x[0][i] = func(float(x[0][i]))
|
277 |
+
return x
|
278 |
+
|
279 |
+
|
280 |
+
def distort_chain(x, funcs, rate=0.8):
|
281 |
+
for i in range(0, x.shape[1]):
|
282 |
+
a = random.uniform(0, 1)
|
283 |
+
if a < rate:
|
284 |
+
for func in funcs:
|
285 |
+
x[0][i] = func(float(x[0][i]))
|
286 |
+
return x
|
287 |
+
|
288 |
+
|
289 |
+
# x is numpy
|
290 |
+
def distort_wav_conf(x, distort_type, distort_conf, rate=0.1):
|
291 |
+
if distort_type == 'gain_db':
|
292 |
+
gain_db = make_gain_db(distort_conf)
|
293 |
+
x = distort(x, gain_db)
|
294 |
+
elif distort_type == 'max_distortion':
|
295 |
+
max_distortion = make_max_distortion(distort_conf)
|
296 |
+
x = distort(x, max_distortion, rate=rate)
|
297 |
+
elif distort_type == 'fence_distortion':
|
298 |
+
fence_distortion = make_fence_distortion(distort_conf)
|
299 |
+
x = distort(x, fence_distortion, rate=rate)
|
300 |
+
elif distort_type == 'jag_distortion':
|
301 |
+
jag_distortion = make_jag_distortion(distort_conf)
|
302 |
+
x = distort(x, jag_distortion, rate=rate)
|
303 |
+
elif distort_type == 'poly_distortion':
|
304 |
+
poly_distortion = make_poly_distortion(distort_conf)
|
305 |
+
x = distort(x, poly_distortion, rate=rate)
|
306 |
+
elif distort_type == 'quad_distortion':
|
307 |
+
quad_distortion = make_quad_distortion()
|
308 |
+
x = distort(x, quad_distortion, rate=rate)
|
309 |
+
elif distort_type == 'none_distortion':
|
310 |
+
pass
|
311 |
+
else:
|
312 |
+
print('unsupport type')
|
313 |
+
return x
|
314 |
+
|
315 |
+
|
316 |
+
def distort_wav_conf_and_save(distort_type, distort_conf, rate, wav_in,
|
317 |
+
wav_out):
|
318 |
+
x, sr = torchaudio.load(wav_in)
|
319 |
+
x = x.detach().numpy()
|
320 |
+
out = distort_wav_conf(x, distort_type, distort_conf, rate)
|
321 |
+
torchaudio.save(wav_out, torch.from_numpy(out), sr)
|
322 |
+
|
323 |
+
|
324 |
+
if __name__ == "__main__":
|
325 |
+
distort_type = sys.argv[1]
|
326 |
+
wav_in = sys.argv[2]
|
327 |
+
wav_out = sys.argv[3]
|
328 |
+
conf = None
|
329 |
+
rate = 0.1
|
330 |
+
if distort_type == 'new_jag_distortion':
|
331 |
+
conf = {'mask_number': 4}
|
332 |
+
elif distort_type == 'new_fence_distortion':
|
333 |
+
conf = {'mask_number': 1, 'max_db': -30}
|
334 |
+
elif distort_type == 'poly_distortion':
|
335 |
+
conf = {'a': 4, 'm': 2, "n": 2}
|
336 |
+
distort_wav_conf_and_save(distort_type, conf, rate, wav_in, wav_out)
|
wenet/e_branchformer/encoder.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Yifan Peng (Carnegie Mellon University)
|
2 |
+
# 2023 Voicecomm Inc (Kai Li)
|
3 |
+
# 2023 Lucky Wong
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
17 |
+
"""Encoder definition."""
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from typing import List, Optional, Union
|
21 |
+
from wenet.branchformer.encoder import LayerDropModuleList
|
22 |
+
|
23 |
+
from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer
|
24 |
+
from wenet.branchformer.cgmlp import ConvolutionalGatingMLP
|
25 |
+
from wenet.transformer.encoder import ConformerEncoder
|
26 |
+
from wenet.utils.class_utils import (
|
27 |
+
WENET_ACTIVATION_CLASSES,
|
28 |
+
WENET_ATTENTION_CLASSES,
|
29 |
+
WENET_MLP_CLASSES,
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
class EBranchformerEncoder(ConformerEncoder):
|
34 |
+
"""E-Branchformer encoder module."""
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
input_size: int,
|
39 |
+
output_size: int = 256,
|
40 |
+
attention_heads: int = 4,
|
41 |
+
linear_units: int = 2048,
|
42 |
+
selfattention_layer_type: str = "rel_selfattn",
|
43 |
+
pos_enc_layer_type: str = "rel_pos",
|
44 |
+
activation_type: str = "swish",
|
45 |
+
cgmlp_linear_units: int = 2048,
|
46 |
+
cgmlp_conv_kernel: int = 31,
|
47 |
+
use_linear_after_conv: bool = False,
|
48 |
+
gate_activation: str = "identity",
|
49 |
+
num_blocks: int = 12,
|
50 |
+
dropout_rate: float = 0.1,
|
51 |
+
positional_dropout_rate: float = 0.1,
|
52 |
+
attention_dropout_rate: float = 0.0,
|
53 |
+
input_layer: str = "conv2d",
|
54 |
+
stochastic_depth_rate: Union[float, List[float]] = 0.0,
|
55 |
+
static_chunk_size: int = 0,
|
56 |
+
use_dynamic_chunk: bool = False,
|
57 |
+
global_cmvn: torch.nn.Module = None,
|
58 |
+
use_dynamic_left_chunk: bool = False,
|
59 |
+
causal: bool = False,
|
60 |
+
merge_conv_kernel: int = 3,
|
61 |
+
use_ffn: bool = True,
|
62 |
+
macaron_style: bool = True,
|
63 |
+
query_bias: bool = True,
|
64 |
+
key_bias: bool = True,
|
65 |
+
value_bias: bool = True,
|
66 |
+
conv_bias: bool = True,
|
67 |
+
gradient_checkpointing: bool = False,
|
68 |
+
use_sdpa: bool = False,
|
69 |
+
layer_norm_type: str = 'layer_norm',
|
70 |
+
norm_eps: float = 1e-5,
|
71 |
+
n_kv_head: Optional[int] = None,
|
72 |
+
head_dim: Optional[int] = None,
|
73 |
+
mlp_type: str = 'position_wise_feed_forward',
|
74 |
+
mlp_bias: bool = True,
|
75 |
+
n_expert: int = 8,
|
76 |
+
n_expert_activated: int = 2,
|
77 |
+
):
|
78 |
+
super().__init__(input_size,
|
79 |
+
output_size,
|
80 |
+
attention_heads,
|
81 |
+
linear_units,
|
82 |
+
num_blocks,
|
83 |
+
dropout_rate,
|
84 |
+
positional_dropout_rate,
|
85 |
+
attention_dropout_rate,
|
86 |
+
input_layer,
|
87 |
+
pos_enc_layer_type,
|
88 |
+
True,
|
89 |
+
static_chunk_size,
|
90 |
+
use_dynamic_chunk,
|
91 |
+
global_cmvn,
|
92 |
+
use_dynamic_left_chunk,
|
93 |
+
1,
|
94 |
+
macaron_style,
|
95 |
+
selfattention_layer_type,
|
96 |
+
activation_type,
|
97 |
+
query_bias=query_bias,
|
98 |
+
key_bias=key_bias,
|
99 |
+
value_bias=value_bias,
|
100 |
+
conv_bias=conv_bias,
|
101 |
+
gradient_checkpointing=gradient_checkpointing,
|
102 |
+
use_sdpa=use_sdpa,
|
103 |
+
layer_norm_type=layer_norm_type,
|
104 |
+
norm_eps=norm_eps,
|
105 |
+
n_kv_head=n_kv_head,
|
106 |
+
head_dim=head_dim,
|
107 |
+
mlp_type=mlp_type,
|
108 |
+
mlp_bias=mlp_bias,
|
109 |
+
n_expert=n_expert,
|
110 |
+
n_expert_activated=n_expert_activated)
|
111 |
+
|
112 |
+
encoder_selfattn_layer_args = (
|
113 |
+
attention_heads,
|
114 |
+
output_size,
|
115 |
+
attention_dropout_rate,
|
116 |
+
query_bias,
|
117 |
+
key_bias,
|
118 |
+
value_bias,
|
119 |
+
use_sdpa,
|
120 |
+
n_kv_head,
|
121 |
+
head_dim,
|
122 |
+
)
|
123 |
+
|
124 |
+
cgmlp_layer = ConvolutionalGatingMLP
|
125 |
+
cgmlp_layer_args = (output_size, cgmlp_linear_units, cgmlp_conv_kernel,
|
126 |
+
dropout_rate, use_linear_after_conv,
|
127 |
+
gate_activation, causal)
|
128 |
+
|
129 |
+
# feed-forward module definition
|
130 |
+
mlp_class = WENET_MLP_CLASSES[mlp_type]
|
131 |
+
activation = WENET_ACTIVATION_CLASSES[activation_type]()
|
132 |
+
positionwise_layer_args = (
|
133 |
+
output_size,
|
134 |
+
linear_units,
|
135 |
+
dropout_rate,
|
136 |
+
activation,
|
137 |
+
mlp_bias,
|
138 |
+
n_expert,
|
139 |
+
n_expert_activated,
|
140 |
+
)
|
141 |
+
|
142 |
+
if isinstance(stochastic_depth_rate, float):
|
143 |
+
stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
|
144 |
+
if len(stochastic_depth_rate) != num_blocks:
|
145 |
+
raise ValueError(
|
146 |
+
f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
|
147 |
+
f"should be equal to num_blocks ({num_blocks})")
|
148 |
+
|
149 |
+
self.encoders = LayerDropModuleList(
|
150 |
+
p=stochastic_depth_rate,
|
151 |
+
modules=[
|
152 |
+
EBranchformerEncoderLayer(
|
153 |
+
output_size,
|
154 |
+
WENET_ATTENTION_CLASSES[selfattention_layer_type](
|
155 |
+
*encoder_selfattn_layer_args),
|
156 |
+
cgmlp_layer(*cgmlp_layer_args),
|
157 |
+
mlp_class(*positionwise_layer_args) if use_ffn else None,
|
158 |
+
mlp_class(*positionwise_layer_args)
|
159 |
+
if use_ffn and macaron_style else None,
|
160 |
+
dropout_rate,
|
161 |
+
merge_conv_kernel=merge_conv_kernel,
|
162 |
+
causal=causal,
|
163 |
+
stochastic_depth_rate=stochastic_depth_rate[lnum],
|
164 |
+
) for lnum in range(num_blocks)
|
165 |
+
])
|
wenet/e_branchformer/encoder_layer.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Yifan Peng (Carnegie Mellon University)
|
2 |
+
# 2023 Voicecomm Inc (Kai Li)
|
3 |
+
# 2023 Lucky Wong
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
17 |
+
"""EBranchformerEncoderLayer definition."""
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
from typing import Optional, Tuple
|
22 |
+
|
23 |
+
from wenet.transformer.attention import T_CACHE
|
24 |
+
|
25 |
+
|
26 |
+
class EBranchformerEncoderLayer(torch.nn.Module):
|
27 |
+
"""E-Branchformer encoder layer module.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
size (int): model dimension
|
31 |
+
attn: standard self-attention or efficient attention
|
32 |
+
cgmlp: ConvolutionalGatingMLP
|
33 |
+
feed_forward: feed-forward module, optional
|
34 |
+
feed_forward: macaron-style feed-forward module, optional
|
35 |
+
dropout_rate (float): dropout probability
|
36 |
+
merge_conv_kernel (int): kernel size of the depth-wise conv in merge module
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
size: int,
|
42 |
+
attn: torch.nn.Module,
|
43 |
+
cgmlp: torch.nn.Module,
|
44 |
+
feed_forward: Optional[torch.nn.Module],
|
45 |
+
feed_forward_macaron: Optional[torch.nn.Module],
|
46 |
+
dropout_rate: float,
|
47 |
+
merge_conv_kernel: int = 3,
|
48 |
+
causal: bool = True,
|
49 |
+
stochastic_depth_rate=0.0,
|
50 |
+
):
|
51 |
+
super().__init__()
|
52 |
+
|
53 |
+
self.size = size
|
54 |
+
self.attn = attn
|
55 |
+
self.cgmlp = cgmlp
|
56 |
+
|
57 |
+
self.feed_forward = feed_forward
|
58 |
+
self.feed_forward_macaron = feed_forward_macaron
|
59 |
+
self.ff_scale = 1.0
|
60 |
+
if self.feed_forward is not None:
|
61 |
+
self.norm_ff = nn.LayerNorm(size)
|
62 |
+
if self.feed_forward_macaron is not None:
|
63 |
+
self.ff_scale = 0.5
|
64 |
+
self.norm_ff_macaron = nn.LayerNorm(size)
|
65 |
+
|
66 |
+
self.norm_mha = nn.LayerNorm(size) # for the MHA module
|
67 |
+
self.norm_mlp = nn.LayerNorm(size) # for the MLP module
|
68 |
+
# for the final output of the block
|
69 |
+
self.norm_final = nn.LayerNorm(size)
|
70 |
+
|
71 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
72 |
+
|
73 |
+
if causal:
|
74 |
+
padding = 0
|
75 |
+
self.lorder = merge_conv_kernel - 1
|
76 |
+
else:
|
77 |
+
# kernel_size should be an odd number for none causal convolution
|
78 |
+
assert (merge_conv_kernel - 1) % 2 == 0
|
79 |
+
padding = (merge_conv_kernel - 1) // 2
|
80 |
+
self.lorder = 0
|
81 |
+
self.depthwise_conv_fusion = torch.nn.Conv1d(
|
82 |
+
size + size,
|
83 |
+
size + size,
|
84 |
+
kernel_size=merge_conv_kernel,
|
85 |
+
stride=1,
|
86 |
+
padding=padding,
|
87 |
+
groups=size + size,
|
88 |
+
bias=True,
|
89 |
+
)
|
90 |
+
self.merge_proj = torch.nn.Linear(size + size, size)
|
91 |
+
self.stochastic_depth_rate = stochastic_depth_rate
|
92 |
+
|
93 |
+
def _forward(
|
94 |
+
self,
|
95 |
+
x: torch.Tensor,
|
96 |
+
mask: torch.Tensor,
|
97 |
+
pos_emb: torch.Tensor,
|
98 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
99 |
+
att_cache: T_CACHE = (torch.zeros(
|
100 |
+
(0, 0, 0, 0)), torch.zeros(0, 0, 0, 0)),
|
101 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
102 |
+
stoch_layer_coeff: float = 1.0
|
103 |
+
) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE, torch.Tensor]:
|
104 |
+
|
105 |
+
if self.feed_forward_macaron is not None:
|
106 |
+
residual = x
|
107 |
+
x = self.norm_ff_macaron(x)
|
108 |
+
x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
|
109 |
+
self.feed_forward_macaron(x))
|
110 |
+
|
111 |
+
# Two branches
|
112 |
+
x1 = x
|
113 |
+
x2 = x
|
114 |
+
|
115 |
+
# Branch 1: multi-headed attention module
|
116 |
+
x1 = self.norm_mha(x1)
|
117 |
+
x_att, new_att_cache = self.attn(x1, x1, x1, mask, pos_emb, att_cache)
|
118 |
+
x1 = self.dropout(x_att)
|
119 |
+
|
120 |
+
# Branch 2: convolutional gating mlp
|
121 |
+
# Fake new cnn cache here, and then change it in conv_module
|
122 |
+
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
123 |
+
x2 = self.norm_mlp(x2)
|
124 |
+
x2, new_cnn_cache = self.cgmlp(x2, mask_pad, cnn_cache)
|
125 |
+
x2 = self.dropout(x2)
|
126 |
+
|
127 |
+
# Merge two branches
|
128 |
+
x_concat = torch.cat([x1, x2], dim=-1)
|
129 |
+
x_tmp = x_concat.transpose(1, 2)
|
130 |
+
if self.lorder > 0:
|
131 |
+
x_tmp = nn.functional.pad(x_tmp, (self.lorder, 0), "constant", 0.0)
|
132 |
+
assert x_tmp.size(2) > self.lorder
|
133 |
+
x_tmp = self.depthwise_conv_fusion(x_tmp)
|
134 |
+
x_tmp = x_tmp.transpose(1, 2)
|
135 |
+
x = x + stoch_layer_coeff * self.dropout(
|
136 |
+
self.merge_proj(x_concat + x_tmp))
|
137 |
+
|
138 |
+
if self.feed_forward is not None:
|
139 |
+
# feed forward module
|
140 |
+
residual = x
|
141 |
+
x = self.norm_ff(x)
|
142 |
+
x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
|
143 |
+
self.feed_forward(x))
|
144 |
+
|
145 |
+
x = self.norm_final(x)
|
146 |
+
|
147 |
+
return x, mask, new_att_cache, new_cnn_cache
|
148 |
+
|
149 |
+
def forward(
|
150 |
+
self,
|
151 |
+
x: torch.Tensor,
|
152 |
+
mask: torch.Tensor,
|
153 |
+
pos_emb: torch.Tensor,
|
154 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
155 |
+
att_cache: T_CACHE = (torch.zeros(
|
156 |
+
(0, 0, 0, 0)), torch.zeros(0, 0, 0, 0)),
|
157 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
158 |
+
) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE, torch.Tensor]:
|
159 |
+
"""Compute encoded features.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
x (Union[Tuple, torch.Tensor]): Input tensor (#batch, time, size).
|
163 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time, time).
|
164 |
+
pos_emb (torch.Tensor): positional encoding, must not be None
|
165 |
+
for BranchformerEncoderLayer.
|
166 |
+
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
167 |
+
(#batch, 1,time), (0, 0, 0) means fake mask.
|
168 |
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
169 |
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
170 |
+
cnn_cache (torch.Tensor): Convolution cache in cgmlp layer
|
171 |
+
(#batch=1, size, cache_t2)
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
175 |
+
torch.Tensor: Mask tensor (#batch, time, time.
|
176 |
+
torch.Tensor: att_cache tensor,
|
177 |
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
178 |
+
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
179 |
+
"""
|
180 |
+
|
181 |
+
stoch_layer_coeff = 1.0
|
182 |
+
# with stochastic depth, residual connection `x + f(x)` becomes
|
183 |
+
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
184 |
+
if self.training:
|
185 |
+
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
186 |
+
return self._forward(x, mask, pos_emb, mask_pad, att_cache, cnn_cache,
|
187 |
+
stoch_layer_coeff)
|
wenet/efficient_conformer/__init__.py
ADDED
File without changes
|
wenet/efficient_conformer/attention.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019 Shigeki Karita
|
2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
3 |
+
# 2022 Xingchen Song ([email protected])
|
4 |
+
# 2022 58.com(Wuba) Inc AI Lab.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
"""Multi-Head Attention layer definition."""
|
18 |
+
|
19 |
+
import math
|
20 |
+
from typing import Tuple, Optional
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from torch import nn
|
24 |
+
import torch.nn.functional as F
|
25 |
+
from wenet.transformer.attention import MultiHeadedAttention
|
26 |
+
|
27 |
+
|
28 |
+
class GroupedRelPositionMultiHeadedAttention(MultiHeadedAttention):
|
29 |
+
"""Multi-Head Attention layer with relative position encoding.
|
30 |
+
Paper:
|
31 |
+
https://arxiv.org/abs/1901.02860
|
32 |
+
https://arxiv.org/abs/2109.01163
|
33 |
+
Args:
|
34 |
+
n_head (int): The number of heads.
|
35 |
+
n_feat (int): The number of features.
|
36 |
+
dropout_rate (float): Dropout rate.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, n_head, n_feat, dropout_rate, group_size=3):
|
40 |
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
41 |
+
super().__init__(n_head, n_feat, dropout_rate)
|
42 |
+
# linear transformation for positional encoding
|
43 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
44 |
+
self.group_size = group_size
|
45 |
+
self.d_k = n_feat // n_head # for GroupedAttention
|
46 |
+
self.n_feat = n_feat
|
47 |
+
# these two learnable bias are used in matrix c and matrix d
|
48 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
49 |
+
self.pos_bias_u = nn.Parameter(
|
50 |
+
torch.Tensor(self.h, self.d_k * self.group_size))
|
51 |
+
self.pos_bias_v = nn.Parameter(
|
52 |
+
torch.Tensor(self.h, self.d_k * self.group_size))
|
53 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
54 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
55 |
+
|
56 |
+
def rel_shift(self, x, zero_triu: bool = False):
|
57 |
+
"""Compute relative positinal encoding.
|
58 |
+
Args:
|
59 |
+
x (torch.Tensor): Input tensor (batch, time, size).
|
60 |
+
zero_triu (bool): If true, return the lower triangular part of
|
61 |
+
the matrix.
|
62 |
+
Returns:
|
63 |
+
torch.Tensor: Output tensor.
|
64 |
+
"""
|
65 |
+
|
66 |
+
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
67 |
+
device=x.device,
|
68 |
+
dtype=x.dtype)
|
69 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
70 |
+
|
71 |
+
x_padded = x_padded.view(x.size()[0],
|
72 |
+
x.size()[1],
|
73 |
+
x.size(3) + 1, x.size(2))
|
74 |
+
x = x_padded[:, :, 1:].view_as(x)
|
75 |
+
|
76 |
+
if zero_triu:
|
77 |
+
ones = torch.ones((x.size(2), x.size(3)))
|
78 |
+
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
79 |
+
|
80 |
+
return x
|
81 |
+
|
82 |
+
def pad4group(self, Q, K, V, P, mask, group_size: int = 3):
|
83 |
+
"""
|
84 |
+
q: (#batch, time1, size) -> (#batch, head, time1, size/head)
|
85 |
+
k,v: (#batch, time2, size) -> (#batch, head, time2, size/head)
|
86 |
+
p: (#batch, time2, size)
|
87 |
+
"""
|
88 |
+
# Compute Overflows
|
89 |
+
overflow_Q = Q.size(2) % group_size
|
90 |
+
overflow_KV = K.size(2) % group_size
|
91 |
+
|
92 |
+
# if-else for ONNX export
|
93 |
+
# 0 // 0.00000000000000001 = 0
|
94 |
+
# 1 // 1.00000000000000001 = 1
|
95 |
+
padding_Q = (group_size - overflow_Q) * int(
|
96 |
+
overflow_Q // (overflow_Q + 0.00000000000000001))
|
97 |
+
padding_KV = (group_size - overflow_KV) * int(
|
98 |
+
overflow_KV // (overflow_KV + 0.00000000000000001))
|
99 |
+
|
100 |
+
batch_size, _, seq_len_KV, _ = K.size()
|
101 |
+
|
102 |
+
# Input Padding (B, T, D) -> (B, T + P, D)
|
103 |
+
Q = F.pad(Q, (0, 0, 0, padding_Q), value=0.0)
|
104 |
+
K = F.pad(K, (0, 0, 0, padding_KV), value=0.0)
|
105 |
+
V = F.pad(V, (0, 0, 0, padding_KV), value=0.0)
|
106 |
+
|
107 |
+
if mask is not None and mask.size(2) > 0: # time2 > 0:
|
108 |
+
mask = mask[:, ::group_size, ::group_size]
|
109 |
+
|
110 |
+
Q = Q.transpose(1, 2).contiguous().view(
|
111 |
+
batch_size, -1, self.h, self.d_k * group_size).transpose(1, 2)
|
112 |
+
K = K.transpose(1, 2).contiguous().view(
|
113 |
+
batch_size, -1, self.h, self.d_k * group_size).transpose(1, 2)
|
114 |
+
V = V.transpose(1, 2).contiguous().view(
|
115 |
+
batch_size, -1, self.h, self.d_k * group_size).transpose(1, 2)
|
116 |
+
|
117 |
+
# process pos_emb
|
118 |
+
P_batch_size = P.size(0)
|
119 |
+
overflow_P = P.size(1) % group_size
|
120 |
+
padding_P = group_size - overflow_P if overflow_P else 0
|
121 |
+
P = F.pad(P, (0, 0, 0, padding_P), value=0.0)
|
122 |
+
P = P.view(P_batch_size, -1, self.h,
|
123 |
+
self.d_k * group_size).transpose(1, 2)
|
124 |
+
|
125 |
+
return Q, K, V, P, mask, padding_Q
|
126 |
+
|
127 |
+
def forward_attention(self,
|
128 |
+
value: torch.Tensor,
|
129 |
+
scores: torch.Tensor,
|
130 |
+
mask: torch.Tensor = torch.ones((0, 0, 0),
|
131 |
+
dtype=torch.bool),
|
132 |
+
padding_q: Optional[int] = None) -> torch.Tensor:
|
133 |
+
"""Compute attention context vector.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
value (torch.Tensor): Transformed value, size
|
137 |
+
(#batch, n_head, time2, d_k).
|
138 |
+
scores (torch.Tensor): Attention score, size
|
139 |
+
(#batch, n_head, time1, time2).
|
140 |
+
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
|
141 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
142 |
+
padding_q : for GroupedAttention in efficent conformer
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
146 |
+
weighted by the attention score (#batch, time1, time2).
|
147 |
+
|
148 |
+
"""
|
149 |
+
n_batch = value.size(0)
|
150 |
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
|
151 |
+
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
152 |
+
# 1st chunk to ease the onnx export.]
|
153 |
+
# 2. pytorch training
|
154 |
+
if mask.size(2) > 0: # time2 > 0
|
155 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
156 |
+
# For last chunk, time2 might be larger than scores.size(-1)
|
157 |
+
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
158 |
+
scores = scores.masked_fill(mask, -float('inf'))
|
159 |
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
160 |
+
mask, 0.0) # (batch, head, time1, time2)
|
161 |
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
|
162 |
+
# 1. onnx(16/-1, -1/-1, 16/0)
|
163 |
+
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
164 |
+
else:
|
165 |
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
166 |
+
|
167 |
+
p_attn = self.dropout(attn)
|
168 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
169 |
+
|
170 |
+
# n_feat!=h*d_k may be happened in GroupAttention
|
171 |
+
x = (x.transpose(1, 2).contiguous().view(n_batch, -1, self.n_feat)
|
172 |
+
) # (batch, time1, d_model)
|
173 |
+
if padding_q is not None:
|
174 |
+
# for GroupedAttention in efficent conformer
|
175 |
+
x = x[:, :x.size(1) - padding_q]
|
176 |
+
|
177 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
178 |
+
|
179 |
+
def forward(
|
180 |
+
self,
|
181 |
+
query: torch.Tensor,
|
182 |
+
key: torch.Tensor,
|
183 |
+
value: torch.Tensor,
|
184 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
185 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
186 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
187 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
188 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
189 |
+
Args:
|
190 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
191 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
192 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
193 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
194 |
+
(#batch, time1, time2).
|
195 |
+
pos_emb (torch.Tensor): Positional embedding tensor
|
196 |
+
(#batch, time2, size).
|
197 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
198 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
199 |
+
and `head * d_k == size`
|
200 |
+
Returns:
|
201 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
202 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
203 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
204 |
+
and `head * d_k == size`
|
205 |
+
"""
|
206 |
+
q = self.linear_q(query)
|
207 |
+
k = self.linear_k(key) # (#batch, time2, size)
|
208 |
+
v = self.linear_v(value)
|
209 |
+
p = self.linear_pos(pos_emb) # (#batch, time2, size)
|
210 |
+
|
211 |
+
batch_size, seq_len_KV, _ = k.size() # seq_len_KV = time2
|
212 |
+
|
213 |
+
# (#batch, time2, size) -> (#batch, head, time2, size/head)
|
214 |
+
q = q.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
|
215 |
+
k = k.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
|
216 |
+
v = v.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
|
217 |
+
if cache.size(0) > 0:
|
218 |
+
# use attention cache
|
219 |
+
key_cache, value_cache = torch.split(cache,
|
220 |
+
cache.size(-1) // 2,
|
221 |
+
dim=-1)
|
222 |
+
k = torch.cat([key_cache, k], dim=2)
|
223 |
+
v = torch.cat([value_cache, v], dim=2)
|
224 |
+
new_cache = torch.cat((k, v), dim=-1)
|
225 |
+
|
226 |
+
# May be k and p does not match. eg. time2=18+18/2=27 > mask=36/2=18
|
227 |
+
if mask is not None and mask.size(2) > 0:
|
228 |
+
time2 = mask.size(2)
|
229 |
+
k = k[:, :, -time2:, :]
|
230 |
+
v = v[:, :, -time2:, :]
|
231 |
+
|
232 |
+
# q k v p: (batch, head, time1, d_k)
|
233 |
+
q, k, v, p, mask, padding_q = self.pad4group(q, k, v, p, mask,
|
234 |
+
self.group_size)
|
235 |
+
|
236 |
+
# q_with_bias_u & q_with_bias_v = (batch, head, time1, d_k)
|
237 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
238 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
239 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
240 |
+
|
241 |
+
# compute attention score
|
242 |
+
# first compute matrix a and matrix c
|
243 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
244 |
+
# (batch, head, time1, time2)
|
245 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
246 |
+
|
247 |
+
# compute matrix b and matrix d
|
248 |
+
# (batch, head, time1, time2)
|
249 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
250 |
+
# Remove rel_shift since it is useless in speech recognition,
|
251 |
+
# and it requires special attention for streaming.
|
252 |
+
# matrix_bd = self.rel_shift(matrix_bd)
|
253 |
+
|
254 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
255 |
+
self.d_k * self.group_size) # (batch, head, time1, time2)
|
256 |
+
|
257 |
+
return self.forward_attention(v, scores, mask, padding_q), new_cache
|
wenet/efficient_conformer/convolution.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
2 |
+
# 2022 58.com(Wuba) Inc AI Lab.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""ConvolutionModule definition."""
|
17 |
+
from typing import Tuple
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
|
22 |
+
|
23 |
+
class ConvolutionModule(nn.Module):
|
24 |
+
"""ConvolutionModule in Conformer model."""
|
25 |
+
|
26 |
+
def __init__(self,
|
27 |
+
channels: int,
|
28 |
+
kernel_size: int = 15,
|
29 |
+
activation: nn.Module = nn.ReLU(),
|
30 |
+
norm: str = "batch_norm",
|
31 |
+
causal: bool = False,
|
32 |
+
bias: bool = True,
|
33 |
+
stride: int = 1):
|
34 |
+
"""Construct an ConvolutionModule object.
|
35 |
+
Args:
|
36 |
+
channels (int): The number of channels of conv layers.
|
37 |
+
kernel_size (int): Kernel size of conv layers.
|
38 |
+
causal (int): Whether use causal convolution or not
|
39 |
+
stride (int): Stride Convolution, for efficient Conformer
|
40 |
+
"""
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
self.pointwise_conv1 = nn.Conv1d(
|
44 |
+
channels,
|
45 |
+
2 * channels,
|
46 |
+
kernel_size=1,
|
47 |
+
stride=1,
|
48 |
+
padding=0,
|
49 |
+
bias=bias,
|
50 |
+
)
|
51 |
+
# self.lorder is used to distinguish if it's a causal convolution,
|
52 |
+
# if self.lorder > 0: it's a causal convolution, the input will be
|
53 |
+
# padded with self.lorder frames on the left in forward.
|
54 |
+
# else: it's a symmetrical convolution
|
55 |
+
if causal:
|
56 |
+
padding = 0
|
57 |
+
self.lorder = kernel_size - 1
|
58 |
+
else:
|
59 |
+
# kernel_size should be an odd number for none causal convolution
|
60 |
+
assert (kernel_size - 1) % 2 == 0
|
61 |
+
padding = (kernel_size - 1) // 2
|
62 |
+
self.lorder = 0
|
63 |
+
|
64 |
+
self.depthwise_conv = nn.Conv1d(
|
65 |
+
channels,
|
66 |
+
channels,
|
67 |
+
kernel_size,
|
68 |
+
stride=stride, # for depthwise_conv in StrideConv
|
69 |
+
padding=padding,
|
70 |
+
groups=channels,
|
71 |
+
bias=bias,
|
72 |
+
)
|
73 |
+
|
74 |
+
assert norm in ['batch_norm', 'layer_norm']
|
75 |
+
if norm == "batch_norm":
|
76 |
+
self.use_layer_norm = False
|
77 |
+
self.norm = nn.BatchNorm1d(channels)
|
78 |
+
else:
|
79 |
+
self.use_layer_norm = True
|
80 |
+
self.norm = nn.LayerNorm(channels)
|
81 |
+
|
82 |
+
self.pointwise_conv2 = nn.Conv1d(
|
83 |
+
channels,
|
84 |
+
channels,
|
85 |
+
kernel_size=1,
|
86 |
+
stride=1,
|
87 |
+
padding=0,
|
88 |
+
bias=bias,
|
89 |
+
)
|
90 |
+
self.activation = activation
|
91 |
+
self.stride = stride
|
92 |
+
|
93 |
+
def forward(
|
94 |
+
self,
|
95 |
+
x: torch.Tensor,
|
96 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
97 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0)),
|
98 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
99 |
+
"""Compute convolution module.
|
100 |
+
Args:
|
101 |
+
x (torch.Tensor): Input tensor (#batch, time, channels).
|
102 |
+
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
|
103 |
+
(0, 0, 0) means fake mask.
|
104 |
+
cache (torch.Tensor): left context cache, it is only
|
105 |
+
used in causal convolution (#batch, channels, cache_t),
|
106 |
+
(0, 0, 0) meas fake cache.
|
107 |
+
Returns:
|
108 |
+
torch.Tensor: Output tensor (#batch, time, channels).
|
109 |
+
"""
|
110 |
+
# exchange the temporal dimension and the feature dimension
|
111 |
+
x = x.transpose(1, 2) # (#batch, channels, time)
|
112 |
+
|
113 |
+
# mask batch padding
|
114 |
+
if mask_pad.size(2) > 0: # time > 0
|
115 |
+
x.masked_fill_(~mask_pad, 0.0)
|
116 |
+
|
117 |
+
if self.lorder > 0:
|
118 |
+
if cache.size(2) == 0: # cache_t == 0
|
119 |
+
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
|
120 |
+
else:
|
121 |
+
# When export ONNX,the first cache is not None but all-zero,
|
122 |
+
# cause shape error in residual block,
|
123 |
+
# eg. cache14 + x9 = 23, 23-7+1=17 != 9
|
124 |
+
cache = cache[:, :, -self.lorder:]
|
125 |
+
assert cache.size(0) == x.size(0) # equal batch
|
126 |
+
assert cache.size(1) == x.size(1) # equal channel
|
127 |
+
x = torch.cat((cache, x), dim=2)
|
128 |
+
assert (x.size(2) > self.lorder)
|
129 |
+
new_cache = x[:, :, -self.lorder:]
|
130 |
+
else:
|
131 |
+
# It's better we just return None if no cache is requried,
|
132 |
+
# However, for JIT export, here we just fake one tensor instead of
|
133 |
+
# None.
|
134 |
+
new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
135 |
+
|
136 |
+
# GLU mechanism
|
137 |
+
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
138 |
+
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
139 |
+
|
140 |
+
# 1D Depthwise Conv
|
141 |
+
x = self.depthwise_conv(x)
|
142 |
+
if self.use_layer_norm:
|
143 |
+
x = x.transpose(1, 2)
|
144 |
+
x = self.activation(self.norm(x))
|
145 |
+
if self.use_layer_norm:
|
146 |
+
x = x.transpose(1, 2)
|
147 |
+
x = self.pointwise_conv2(x)
|
148 |
+
# mask batch padding
|
149 |
+
if mask_pad.size(2) > 0: # time > 0
|
150 |
+
if mask_pad.size(2) != x.size(2):
|
151 |
+
mask_pad = mask_pad[:, :, ::self.stride]
|
152 |
+
x.masked_fill_(~mask_pad, 0.0)
|
153 |
+
|
154 |
+
return x.transpose(1, 2), new_cache
|
wenet/efficient_conformer/encoder.py
ADDED
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
2 |
+
# 2022 Xingchen Song ([email protected])
|
3 |
+
# 2022 58.com(Wuba) Inc AI Lab.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
# Modified from EfficientConformer(https://github.com/burchim/EfficientConformer)
|
17 |
+
# Paper(https://arxiv.org/abs/2109.01163)
|
18 |
+
"""Encoder definition."""
|
19 |
+
from typing import Tuple, Optional, List, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import logging
|
23 |
+
import torch.nn.functional as F
|
24 |
+
|
25 |
+
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
26 |
+
from wenet.transformer.encoder_layer import ConformerEncoderLayer
|
27 |
+
|
28 |
+
from wenet.efficient_conformer.convolution import ConvolutionModule
|
29 |
+
from wenet.efficient_conformer.encoder_layer import StrideConformerEncoderLayer
|
30 |
+
|
31 |
+
from wenet.utils.mask import make_pad_mask
|
32 |
+
from wenet.utils.mask import add_optional_chunk_mask
|
33 |
+
from wenet.utils.class_utils import (
|
34 |
+
WENET_ATTENTION_CLASSES,
|
35 |
+
WENET_EMB_CLASSES,
|
36 |
+
WENET_SUBSAMPLE_CLASSES,
|
37 |
+
WENET_ACTIVATION_CLASSES,
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
class EfficientConformerEncoder(torch.nn.Module):
|
42 |
+
"""Conformer encoder module."""
|
43 |
+
|
44 |
+
def __init__(self,
|
45 |
+
input_size: int,
|
46 |
+
output_size: int = 256,
|
47 |
+
attention_heads: int = 4,
|
48 |
+
linear_units: int = 2048,
|
49 |
+
num_blocks: int = 6,
|
50 |
+
dropout_rate: float = 0.1,
|
51 |
+
positional_dropout_rate: float = 0.1,
|
52 |
+
attention_dropout_rate: float = 0.0,
|
53 |
+
input_layer: str = "conv2d",
|
54 |
+
pos_enc_layer_type: str = "rel_pos",
|
55 |
+
normalize_before: bool = True,
|
56 |
+
static_chunk_size: int = 0,
|
57 |
+
use_dynamic_chunk: bool = False,
|
58 |
+
global_cmvn: torch.nn.Module = None,
|
59 |
+
use_dynamic_left_chunk: bool = False,
|
60 |
+
macaron_style: bool = True,
|
61 |
+
activation_type: str = "swish",
|
62 |
+
use_cnn_module: bool = True,
|
63 |
+
cnn_module_kernel: int = 15,
|
64 |
+
causal: bool = False,
|
65 |
+
cnn_module_norm: str = "batch_norm",
|
66 |
+
stride_layer_idx: Optional[Union[int, List[int]]] = 3,
|
67 |
+
stride: Optional[Union[int, List[int]]] = 2,
|
68 |
+
group_layer_idx: Optional[Union[int, List[int],
|
69 |
+
tuple]] = (0, 1, 2, 3),
|
70 |
+
group_size: int = 3,
|
71 |
+
stride_kernel: bool = True,
|
72 |
+
**kwargs):
|
73 |
+
"""Construct Efficient Conformer Encoder
|
74 |
+
|
75 |
+
Args:
|
76 |
+
input_size to use_dynamic_chunk, see in BaseEncoder
|
77 |
+
macaron_style (bool): Whether to use macaron style for
|
78 |
+
positionwise layer.
|
79 |
+
activation_type (str): Encoder activation function type.
|
80 |
+
use_cnn_module (bool): Whether to use convolution module.
|
81 |
+
cnn_module_kernel (int): Kernel size of convolution module.
|
82 |
+
causal (bool): whether to use causal convolution or not.
|
83 |
+
stride_layer_idx (list): layer id with StrideConv, start from 0
|
84 |
+
stride (list): stride size of each StrideConv in efficient conformer
|
85 |
+
group_layer_idx (list): layer id with GroupedAttention, start from 0
|
86 |
+
group_size (int): group size of every GroupedAttention layer
|
87 |
+
stride_kernel (bool): default True. True: recompute cnn kernels with stride.
|
88 |
+
"""
|
89 |
+
super().__init__()
|
90 |
+
self._output_size = output_size
|
91 |
+
|
92 |
+
logging.info(
|
93 |
+
f"input_layer = {input_layer}, "
|
94 |
+
f"subsampling_class = {WENET_SUBSAMPLE_CLASSES[input_layer]}")
|
95 |
+
|
96 |
+
self.global_cmvn = global_cmvn
|
97 |
+
self.embed = WENET_SUBSAMPLE_CLASSES[input_layer](
|
98 |
+
input_size,
|
99 |
+
output_size,
|
100 |
+
dropout_rate,
|
101 |
+
WENET_EMB_CLASSES[pos_enc_layer_type](output_size,
|
102 |
+
positional_dropout_rate),
|
103 |
+
)
|
104 |
+
self.input_layer = input_layer
|
105 |
+
self.normalize_before = normalize_before
|
106 |
+
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
107 |
+
self.static_chunk_size = static_chunk_size
|
108 |
+
self.use_dynamic_chunk = use_dynamic_chunk
|
109 |
+
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
110 |
+
|
111 |
+
activation = WENET_ACTIVATION_CLASSES[activation_type]()
|
112 |
+
self.num_blocks = num_blocks
|
113 |
+
self.attention_heads = attention_heads
|
114 |
+
self.cnn_module_kernel = cnn_module_kernel
|
115 |
+
self.global_chunk_size = 0
|
116 |
+
self.chunk_feature_map = 0
|
117 |
+
|
118 |
+
# efficient conformer configs
|
119 |
+
self.stride_layer_idx = [stride_layer_idx] \
|
120 |
+
if type(stride_layer_idx) == int else stride_layer_idx
|
121 |
+
self.stride = [stride] \
|
122 |
+
if type(stride) == int else stride
|
123 |
+
self.group_layer_idx = [group_layer_idx] \
|
124 |
+
if type(group_layer_idx) == int else group_layer_idx
|
125 |
+
self.grouped_size = group_size # group size of every GroupedAttention layer
|
126 |
+
|
127 |
+
assert len(self.stride) == len(self.stride_layer_idx)
|
128 |
+
self.cnn_module_kernels = [cnn_module_kernel
|
129 |
+
] # kernel size of each StridedConv
|
130 |
+
for i in self.stride:
|
131 |
+
if stride_kernel:
|
132 |
+
self.cnn_module_kernels.append(self.cnn_module_kernels[-1] //
|
133 |
+
i)
|
134 |
+
else:
|
135 |
+
self.cnn_module_kernels.append(self.cnn_module_kernels[-1])
|
136 |
+
|
137 |
+
logging.info(f"stride_layer_idx= {self.stride_layer_idx}, "
|
138 |
+
f"stride = {self.stride}, "
|
139 |
+
f"cnn_module_kernel = {self.cnn_module_kernels}, "
|
140 |
+
f"group_layer_idx = {self.group_layer_idx}, "
|
141 |
+
f"grouped_size = {self.grouped_size}")
|
142 |
+
|
143 |
+
# feed-forward module definition
|
144 |
+
positionwise_layer = PositionwiseFeedForward
|
145 |
+
positionwise_layer_args = (
|
146 |
+
output_size,
|
147 |
+
linear_units,
|
148 |
+
dropout_rate,
|
149 |
+
activation,
|
150 |
+
)
|
151 |
+
# convolution module definition
|
152 |
+
convolution_layer = ConvolutionModule
|
153 |
+
|
154 |
+
# encoder definition
|
155 |
+
index = 0
|
156 |
+
layers = []
|
157 |
+
for i in range(num_blocks):
|
158 |
+
# self-attention module definition
|
159 |
+
if i in self.group_layer_idx:
|
160 |
+
encoder_selfattn_layer = WENET_ATTENTION_CLASSES[
|
161 |
+
"grouped_rel_selfattn"]
|
162 |
+
encoder_selfattn_layer_args = (attention_heads, output_size,
|
163 |
+
attention_dropout_rate,
|
164 |
+
self.grouped_size)
|
165 |
+
else:
|
166 |
+
if pos_enc_layer_type == "no_pos":
|
167 |
+
encoder_selfattn_layer = WENET_ATTENTION_CLASSES[
|
168 |
+
"selfattn"]
|
169 |
+
else:
|
170 |
+
encoder_selfattn_layer = WENET_ATTENTION_CLASSES[
|
171 |
+
"rel_selfattn"]
|
172 |
+
encoder_selfattn_layer_args = (attention_heads, output_size,
|
173 |
+
attention_dropout_rate)
|
174 |
+
|
175 |
+
# conformer module definition
|
176 |
+
if i in self.stride_layer_idx:
|
177 |
+
# conformer block with downsampling
|
178 |
+
convolution_layer_args_stride = (
|
179 |
+
output_size, self.cnn_module_kernels[index], activation,
|
180 |
+
cnn_module_norm, causal, True, self.stride[index])
|
181 |
+
layers.append(
|
182 |
+
StrideConformerEncoderLayer(
|
183 |
+
output_size,
|
184 |
+
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
185 |
+
positionwise_layer(*positionwise_layer_args),
|
186 |
+
positionwise_layer(*positionwise_layer_args)
|
187 |
+
if macaron_style else None,
|
188 |
+
convolution_layer(*convolution_layer_args_stride)
|
189 |
+
if use_cnn_module else None,
|
190 |
+
torch.nn.AvgPool1d(
|
191 |
+
kernel_size=self.stride[index],
|
192 |
+
stride=self.stride[index],
|
193 |
+
padding=0,
|
194 |
+
ceil_mode=True,
|
195 |
+
count_include_pad=False), # pointwise_conv_layer
|
196 |
+
dropout_rate,
|
197 |
+
normalize_before,
|
198 |
+
))
|
199 |
+
index = index + 1
|
200 |
+
else:
|
201 |
+
# conformer block
|
202 |
+
convolution_layer_args_normal = (
|
203 |
+
output_size, self.cnn_module_kernels[index], activation,
|
204 |
+
cnn_module_norm, causal)
|
205 |
+
layers.append(
|
206 |
+
ConformerEncoderLayer(
|
207 |
+
output_size,
|
208 |
+
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
209 |
+
positionwise_layer(*positionwise_layer_args),
|
210 |
+
positionwise_layer(*positionwise_layer_args)
|
211 |
+
if macaron_style else None,
|
212 |
+
convolution_layer(*convolution_layer_args_normal)
|
213 |
+
if use_cnn_module else None,
|
214 |
+
dropout_rate,
|
215 |
+
normalize_before,
|
216 |
+
))
|
217 |
+
|
218 |
+
self.encoders = torch.nn.ModuleList(layers)
|
219 |
+
|
220 |
+
def set_global_chunk_size(self, chunk_size):
|
221 |
+
"""Used in ONNX export.
|
222 |
+
"""
|
223 |
+
logging.info(f"set global chunk size: {chunk_size}, default is 0.")
|
224 |
+
self.global_chunk_size = chunk_size
|
225 |
+
if self.embed.subsampling_rate == 2:
|
226 |
+
self.chunk_feature_map = 2 * self.global_chunk_size + 1
|
227 |
+
elif self.embed.subsampling_rate == 6:
|
228 |
+
self.chunk_feature_map = 6 * self.global_chunk_size + 5
|
229 |
+
elif self.embed.subsampling_rate == 8:
|
230 |
+
self.chunk_feature_map = 8 * self.global_chunk_size + 7
|
231 |
+
else:
|
232 |
+
self.chunk_feature_map = 4 * self.global_chunk_size + 3
|
233 |
+
|
234 |
+
def output_size(self) -> int:
|
235 |
+
return self._output_size
|
236 |
+
|
237 |
+
def calculate_downsampling_factor(self, i: int) -> int:
|
238 |
+
factor = 1
|
239 |
+
for idx, stride_idx in enumerate(self.stride_layer_idx):
|
240 |
+
if i > stride_idx:
|
241 |
+
factor *= self.stride[idx]
|
242 |
+
return factor
|
243 |
+
|
244 |
+
def forward(
|
245 |
+
self,
|
246 |
+
xs: torch.Tensor,
|
247 |
+
xs_lens: torch.Tensor,
|
248 |
+
decoding_chunk_size: int = 0,
|
249 |
+
num_decoding_left_chunks: int = -1,
|
250 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
251 |
+
"""Embed positions in tensor.
|
252 |
+
Args:
|
253 |
+
xs: padded input tensor (B, T, D)
|
254 |
+
xs_lens: input length (B)
|
255 |
+
decoding_chunk_size: decoding chunk size for dynamic chunk
|
256 |
+
0: default for training, use random dynamic chunk.
|
257 |
+
<0: for decoding, use full chunk.
|
258 |
+
>0: for decoding, use fixed chunk size as set.
|
259 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
260 |
+
the chunk size is decoding_chunk_size.
|
261 |
+
>=0: use num_decoding_left_chunks
|
262 |
+
<0: use all left chunks
|
263 |
+
Returns:
|
264 |
+
encoder output tensor xs, and subsampled masks
|
265 |
+
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
266 |
+
masks: torch.Tensor batch padding mask after subsample
|
267 |
+
(B, 1, T' ~= T/subsample_rate)
|
268 |
+
"""
|
269 |
+
T = xs.size(1)
|
270 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
271 |
+
if self.global_cmvn is not None:
|
272 |
+
xs = self.global_cmvn(xs)
|
273 |
+
xs, pos_emb, masks = self.embed(xs, masks)
|
274 |
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
275 |
+
chunk_masks = add_optional_chunk_mask(xs, masks,
|
276 |
+
self.use_dynamic_chunk,
|
277 |
+
self.use_dynamic_left_chunk,
|
278 |
+
decoding_chunk_size,
|
279 |
+
self.static_chunk_size,
|
280 |
+
num_decoding_left_chunks)
|
281 |
+
index = 0 # traverse stride
|
282 |
+
for i, layer in enumerate(self.encoders):
|
283 |
+
# layer return : x, mask, new_att_cache, new_cnn_cache
|
284 |
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
285 |
+
if i in self.stride_layer_idx:
|
286 |
+
masks = masks[:, :, ::self.stride[index]]
|
287 |
+
chunk_masks = chunk_masks[:, ::self.stride[index], ::self.
|
288 |
+
stride[index]]
|
289 |
+
mask_pad = masks
|
290 |
+
pos_emb = pos_emb[:, ::self.stride[index], :]
|
291 |
+
index = index + 1
|
292 |
+
|
293 |
+
if self.normalize_before:
|
294 |
+
xs = self.after_norm(xs)
|
295 |
+
# Here we assume the mask is not changed in encoder layers, so just
|
296 |
+
# return the masks before encoder layers, and the masks will be used
|
297 |
+
# for cross attention with decoder later
|
298 |
+
return xs, masks
|
299 |
+
|
300 |
+
def forward_chunk(
|
301 |
+
self,
|
302 |
+
xs: torch.Tensor,
|
303 |
+
offset: int,
|
304 |
+
required_cache_size: int,
|
305 |
+
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
306 |
+
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
307 |
+
att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
308 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
309 |
+
""" Forward just one chunk
|
310 |
+
|
311 |
+
Args:
|
312 |
+
xs (torch.Tensor): chunk input
|
313 |
+
offset (int): current offset in encoder output time stamp
|
314 |
+
required_cache_size (int): cache size required for next chunk
|
315 |
+
compuation
|
316 |
+
>=0: actual cache size
|
317 |
+
<0: means all history cache is required
|
318 |
+
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
|
319 |
+
transformer/conformer attention, with shape
|
320 |
+
(elayers, head, cache_t1, d_k * 2), where
|
321 |
+
`head * d_k == hidden-dim` and
|
322 |
+
`cache_t1 == chunk_size * num_decoding_left_chunks`.
|
323 |
+
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
|
324 |
+
(elayers, b=1, hidden-dim, cache_t2), where
|
325 |
+
`cache_t2 == cnn.lorder - 1`
|
326 |
+
att_mask : mask matrix of self attention
|
327 |
+
|
328 |
+
Returns:
|
329 |
+
torch.Tensor: output of current input xs
|
330 |
+
torch.Tensor: subsampling cache required for next chunk computation
|
331 |
+
List[torch.Tensor]: encoder layers output cache required for next
|
332 |
+
chunk computation
|
333 |
+
List[torch.Tensor]: conformer cnn cache
|
334 |
+
|
335 |
+
"""
|
336 |
+
assert xs.size(0) == 1
|
337 |
+
|
338 |
+
# using downsampling factor to recover offset
|
339 |
+
offset *= self.calculate_downsampling_factor(self.num_blocks + 1)
|
340 |
+
|
341 |
+
chunk_masks = torch.ones(1,
|
342 |
+
xs.size(1),
|
343 |
+
device=xs.device,
|
344 |
+
dtype=torch.bool)
|
345 |
+
chunk_masks = chunk_masks.unsqueeze(1) # (1, 1, xs-time)
|
346 |
+
|
347 |
+
real_len = 0
|
348 |
+
if self.global_chunk_size > 0:
|
349 |
+
# for ONNX decode simulation, padding xs to chunk_size
|
350 |
+
real_len = xs.size(1)
|
351 |
+
pad_len = self.chunk_feature_map - real_len
|
352 |
+
xs = F.pad(xs, (0, 0, 0, pad_len), value=0.0)
|
353 |
+
chunk_masks = F.pad(chunk_masks, (0, pad_len), value=0.0)
|
354 |
+
|
355 |
+
if self.global_cmvn is not None:
|
356 |
+
xs = self.global_cmvn(xs)
|
357 |
+
|
358 |
+
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
|
359 |
+
xs, pos_emb, chunk_masks = self.embed(xs, chunk_masks, offset)
|
360 |
+
elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
|
361 |
+
chunk_size = xs.size(1)
|
362 |
+
attention_key_size = cache_t1 + chunk_size
|
363 |
+
# NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
|
364 |
+
# shape(pos_emb) = (b=1, chunk_size, emb_size=output_size=hidden-dim)
|
365 |
+
|
366 |
+
if required_cache_size < 0:
|
367 |
+
next_cache_start = 0
|
368 |
+
elif required_cache_size == 0:
|
369 |
+
next_cache_start = attention_key_size
|
370 |
+
else:
|
371 |
+
next_cache_start = max(attention_key_size - required_cache_size, 0)
|
372 |
+
|
373 |
+
r_att_cache = []
|
374 |
+
r_cnn_cache = []
|
375 |
+
mask_pad = torch.ones(1,
|
376 |
+
xs.size(1),
|
377 |
+
device=xs.device,
|
378 |
+
dtype=torch.bool)
|
379 |
+
mask_pad = mask_pad.unsqueeze(1) # batchPad (b=1, 1, time=chunk_size)
|
380 |
+
|
381 |
+
if self.global_chunk_size > 0:
|
382 |
+
# for ONNX decode simulation
|
383 |
+
pos_emb = self.embed.position_encoding(
|
384 |
+
offset=max(offset - cache_t1, 0),
|
385 |
+
size=cache_t1 + self.global_chunk_size)
|
386 |
+
att_mask[:, :, -self.global_chunk_size:] = chunk_masks
|
387 |
+
mask_pad = chunk_masks.to(torch.bool)
|
388 |
+
else:
|
389 |
+
pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
|
390 |
+
size=attention_key_size)
|
391 |
+
|
392 |
+
max_att_len, max_cnn_len = 0, 0 # for repeat_interleave of new_att_cache
|
393 |
+
for i, layer in enumerate(self.encoders):
|
394 |
+
factor = self.calculate_downsampling_factor(i)
|
395 |
+
# NOTE(xcsong): Before layer.forward
|
396 |
+
# shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
|
397 |
+
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
|
398 |
+
# shape(new_att_cache) = [ batch, head, time2, outdim//head * 2 ]
|
399 |
+
att_cache_trunc = 0
|
400 |
+
if xs.size(1) + att_cache.size(2) / factor > pos_emb.size(1):
|
401 |
+
# The time step is not divisible by the downsampling multiple
|
402 |
+
att_cache_trunc = xs.size(1) + \
|
403 |
+
att_cache.size(2) // factor - pos_emb.size(1) + 1
|
404 |
+
xs, _, new_att_cache, new_cnn_cache = layer(
|
405 |
+
xs,
|
406 |
+
att_mask,
|
407 |
+
pos_emb,
|
408 |
+
mask_pad=mask_pad,
|
409 |
+
att_cache=att_cache[i:i +
|
410 |
+
1, :, ::factor, :][:, :,
|
411 |
+
att_cache_trunc:, :],
|
412 |
+
cnn_cache=cnn_cache[i, :, :, :]
|
413 |
+
if cnn_cache.size(0) > 0 else cnn_cache)
|
414 |
+
|
415 |
+
if i in self.stride_layer_idx:
|
416 |
+
# compute time dimension for next block
|
417 |
+
efficient_index = self.stride_layer_idx.index(i)
|
418 |
+
att_mask = att_mask[:, ::self.stride[efficient_index], ::self.
|
419 |
+
stride[efficient_index]]
|
420 |
+
mask_pad = mask_pad[:, ::self.stride[efficient_index], ::self.
|
421 |
+
stride[efficient_index]]
|
422 |
+
pos_emb = pos_emb[:, ::self.stride[efficient_index], :]
|
423 |
+
|
424 |
+
# shape(new_att_cache) = [batch, head, time2, outdim]
|
425 |
+
new_att_cache = new_att_cache[:, :, next_cache_start // factor:, :]
|
426 |
+
# shape(new_cnn_cache) = [1, batch, outdim, cache_t2]
|
427 |
+
new_cnn_cache = new_cnn_cache.unsqueeze(0)
|
428 |
+
|
429 |
+
# use repeat_interleave to new_att_cache
|
430 |
+
new_att_cache = new_att_cache.repeat_interleave(repeats=factor,
|
431 |
+
dim=2)
|
432 |
+
# padding new_cnn_cache to cnn.lorder for casual convolution
|
433 |
+
new_cnn_cache = F.pad(
|
434 |
+
new_cnn_cache,
|
435 |
+
(self.cnn_module_kernel - 1 - new_cnn_cache.size(3), 0))
|
436 |
+
|
437 |
+
if i == 0:
|
438 |
+
# record length for the first block as max length
|
439 |
+
max_att_len = new_att_cache.size(2)
|
440 |
+
max_cnn_len = new_cnn_cache.size(3)
|
441 |
+
|
442 |
+
# update real shape of att_cache and cnn_cache
|
443 |
+
r_att_cache.append(new_att_cache[:, :, -max_att_len:, :])
|
444 |
+
r_cnn_cache.append(new_cnn_cache[:, :, :, -max_cnn_len:])
|
445 |
+
|
446 |
+
if self.normalize_before:
|
447 |
+
xs = self.after_norm(xs)
|
448 |
+
|
449 |
+
# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
|
450 |
+
# ? may be larger than cache_t1, it depends on required_cache_size
|
451 |
+
r_att_cache = torch.cat(r_att_cache, dim=0)
|
452 |
+
# NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
|
453 |
+
r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
|
454 |
+
|
455 |
+
if self.global_chunk_size > 0 and real_len:
|
456 |
+
chunk_real_len = real_len // self.embed.subsampling_rate // \
|
457 |
+
self.calculate_downsampling_factor(self.num_blocks + 1)
|
458 |
+
# Keeping 1 more timestep can mitigate information leakage
|
459 |
+
# from the encoder caused by the padding
|
460 |
+
xs = xs[:, :chunk_real_len + 1, :]
|
461 |
+
|
462 |
+
return xs, r_att_cache, r_cnn_cache
|
463 |
+
|
464 |
+
def forward_chunk_by_chunk(
|
465 |
+
self,
|
466 |
+
xs: torch.Tensor,
|
467 |
+
decoding_chunk_size: int,
|
468 |
+
num_decoding_left_chunks: int = -1,
|
469 |
+
use_onnx=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
470 |
+
""" Forward input chunk by chunk with chunk_size like a streaming
|
471 |
+
fashion
|
472 |
+
|
473 |
+
Here we should pay special attention to computation cache in the
|
474 |
+
streaming style forward chunk by chunk. Three things should be taken
|
475 |
+
into account for computation in the current network:
|
476 |
+
1. transformer/conformer encoder layers output cache
|
477 |
+
2. convolution in conformer
|
478 |
+
3. convolution in subsampling
|
479 |
+
|
480 |
+
However, we don't implement subsampling cache for:
|
481 |
+
1. We can control subsampling module to output the right result by
|
482 |
+
overlapping input instead of cache left context, even though it
|
483 |
+
wastes some computation, but subsampling only takes a very
|
484 |
+
small fraction of computation in the whole model.
|
485 |
+
2. Typically, there are several covolution layers with subsampling
|
486 |
+
in subsampling module, it is tricky and complicated to do cache
|
487 |
+
with different convolution layers with different subsampling
|
488 |
+
rate.
|
489 |
+
3. Currently, nn.Sequential is used to stack all the convolution
|
490 |
+
layers in subsampling, we need to rewrite it to make it work
|
491 |
+
with cache, which is not prefered.
|
492 |
+
Args:
|
493 |
+
xs (torch.Tensor): (1, max_len, dim)
|
494 |
+
decoding_chunk_size (int): decoding chunk size
|
495 |
+
num_decoding_left_chunks (int):
|
496 |
+
use_onnx (bool): True for simulating ONNX model inference.
|
497 |
+
"""
|
498 |
+
assert decoding_chunk_size > 0
|
499 |
+
# The model is trained by static or dynamic chunk
|
500 |
+
assert self.static_chunk_size > 0 or self.use_dynamic_chunk
|
501 |
+
subsampling = self.embed.subsampling_rate
|
502 |
+
context = self.embed.right_context + 1 # Add current frame
|
503 |
+
stride = subsampling * decoding_chunk_size
|
504 |
+
decoding_window = (decoding_chunk_size - 1) * subsampling + context
|
505 |
+
num_frames = xs.size(1)
|
506 |
+
|
507 |
+
outputs = []
|
508 |
+
offset = 0
|
509 |
+
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
|
510 |
+
if use_onnx:
|
511 |
+
logging.info("Simulating for ONNX runtime ...")
|
512 |
+
att_cache: torch.Tensor = torch.zeros(
|
513 |
+
(self.num_blocks, self.attention_heads, required_cache_size,
|
514 |
+
self.output_size() // self.attention_heads * 2),
|
515 |
+
device=xs.device)
|
516 |
+
cnn_cache: torch.Tensor = torch.zeros(
|
517 |
+
(self.num_blocks, 1, self.output_size(),
|
518 |
+
self.cnn_module_kernel - 1),
|
519 |
+
device=xs.device)
|
520 |
+
self.set_global_chunk_size(chunk_size=decoding_chunk_size)
|
521 |
+
else:
|
522 |
+
logging.info("Simulating for JIT runtime ...")
|
523 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0),
|
524 |
+
device=xs.device)
|
525 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0),
|
526 |
+
device=xs.device)
|
527 |
+
|
528 |
+
# Feed forward overlap input step by step
|
529 |
+
for cur in range(0, num_frames - context + 1, stride):
|
530 |
+
end = min(cur + decoding_window, num_frames)
|
531 |
+
logging.info(f"-->> frame chunk msg: cur={cur}, "
|
532 |
+
f"end={end}, num_frames={end-cur}, "
|
533 |
+
f"decoding_window={decoding_window}")
|
534 |
+
if use_onnx:
|
535 |
+
att_mask: torch.Tensor = torch.ones(
|
536 |
+
(1, 1, required_cache_size + decoding_chunk_size),
|
537 |
+
dtype=torch.bool,
|
538 |
+
device=xs.device)
|
539 |
+
if cur == 0:
|
540 |
+
att_mask[:, :, :required_cache_size] = 0
|
541 |
+
else:
|
542 |
+
att_mask: torch.Tensor = torch.ones((0, 0, 0),
|
543 |
+
dtype=torch.bool,
|
544 |
+
device=xs.device)
|
545 |
+
|
546 |
+
chunk_xs = xs[:, cur:end, :]
|
547 |
+
(y, att_cache, cnn_cache) = \
|
548 |
+
self.forward_chunk(
|
549 |
+
chunk_xs, offset, required_cache_size,
|
550 |
+
att_cache, cnn_cache, att_mask)
|
551 |
+
outputs.append(y)
|
552 |
+
offset += y.size(1)
|
553 |
+
|
554 |
+
ys = torch.cat(outputs, 1)
|
555 |
+
masks = torch.ones(1,
|
556 |
+
1,
|
557 |
+
ys.size(1),
|
558 |
+
device=ys.device,
|
559 |
+
dtype=torch.bool)
|
560 |
+
return ys, masks
|
wenet/efficient_conformer/encoder_layer.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
2 |
+
# 2022 Xingchen Song ([email protected])
|
3 |
+
# 2022 58.com(Wuba) Inc AI Lab.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
17 |
+
"""Encoder self-attention layer definition."""
|
18 |
+
|
19 |
+
from typing import Optional, Tuple
|
20 |
+
import torch
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
|
24 |
+
class StrideConformerEncoderLayer(nn.Module):
|
25 |
+
"""Encoder layer module.
|
26 |
+
Args:
|
27 |
+
size (int): Input dimension.
|
28 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
29 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
30 |
+
instance can be used as the argument.
|
31 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
32 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
33 |
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
|
34 |
+
instance.
|
35 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
36 |
+
conv_module (torch.nn.Module): Convolution module instance.
|
37 |
+
`ConvlutionModule` instance can be used as the argument.
|
38 |
+
dropout_rate (float): Dropout rate.
|
39 |
+
normalize_before (bool):
|
40 |
+
True: use layer_norm before each sub-block.
|
41 |
+
False: use layer_norm after each sub-block.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self,
|
45 |
+
size: int,
|
46 |
+
self_attn: torch.nn.Module,
|
47 |
+
feed_forward: Optional[nn.Module] = None,
|
48 |
+
feed_forward_macaron: Optional[nn.Module] = None,
|
49 |
+
conv_module: Optional[nn.Module] = None,
|
50 |
+
pointwise_conv_layer: Optional[nn.Module] = None,
|
51 |
+
dropout_rate: float = 0.1,
|
52 |
+
normalize_before: bool = True):
|
53 |
+
"""Construct an EncoderLayer object."""
|
54 |
+
super().__init__()
|
55 |
+
self.self_attn = self_attn
|
56 |
+
self.feed_forward = feed_forward
|
57 |
+
self.feed_forward_macaron = feed_forward_macaron
|
58 |
+
self.conv_module = conv_module
|
59 |
+
self.pointwise_conv_layer = pointwise_conv_layer
|
60 |
+
self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
|
61 |
+
self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
|
62 |
+
if feed_forward_macaron is not None:
|
63 |
+
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
|
64 |
+
self.ff_scale = 0.5
|
65 |
+
else:
|
66 |
+
self.ff_scale = 1.0
|
67 |
+
if self.conv_module is not None:
|
68 |
+
self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
|
69 |
+
self.norm_final = nn.LayerNorm(
|
70 |
+
size, eps=1e-5) # for the final output of the block
|
71 |
+
self.dropout = nn.Dropout(dropout_rate)
|
72 |
+
self.size = size
|
73 |
+
self.normalize_before = normalize_before
|
74 |
+
self.concat_linear = nn.Linear(size + size, size)
|
75 |
+
|
76 |
+
def forward(
|
77 |
+
self,
|
78 |
+
x: torch.Tensor,
|
79 |
+
mask: torch.Tensor,
|
80 |
+
pos_emb: torch.Tensor,
|
81 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
82 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
83 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
84 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
85 |
+
"""Compute encoded features.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
x (torch.Tensor): (#batch, time, size)
|
89 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
90 |
+
(0, 0, 0) means fake mask.
|
91 |
+
pos_emb (torch.Tensor): positional encoding, must not be None
|
92 |
+
for ConformerEncoderLayer.
|
93 |
+
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
94 |
+
(#batch, 1,time), (0, 0, 0) means fake mask.
|
95 |
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
96 |
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
97 |
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
98 |
+
(#batch=1, size, cache_t2)
|
99 |
+
Returns:
|
100 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
101 |
+
torch.Tensor: Mask tensor (#batch, time, time).
|
102 |
+
torch.Tensor: att_cache tensor,
|
103 |
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
104 |
+
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
105 |
+
"""
|
106 |
+
|
107 |
+
# whether to use macaron style
|
108 |
+
if self.feed_forward_macaron is not None:
|
109 |
+
residual = x
|
110 |
+
if self.normalize_before:
|
111 |
+
x = self.norm_ff_macaron(x)
|
112 |
+
x = residual + self.ff_scale * self.dropout(
|
113 |
+
self.feed_forward_macaron(x))
|
114 |
+
if not self.normalize_before:
|
115 |
+
x = self.norm_ff_macaron(x)
|
116 |
+
|
117 |
+
# multi-headed self-attention module
|
118 |
+
residual = x
|
119 |
+
if self.normalize_before:
|
120 |
+
x = self.norm_mha(x)
|
121 |
+
|
122 |
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
|
123 |
+
att_cache)
|
124 |
+
|
125 |
+
x = residual + self.dropout(x_att)
|
126 |
+
if not self.normalize_before:
|
127 |
+
x = self.norm_mha(x)
|
128 |
+
|
129 |
+
# convolution module
|
130 |
+
# Fake new cnn cache here, and then change it in conv_module
|
131 |
+
new_cnn_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device)
|
132 |
+
if self.conv_module is not None:
|
133 |
+
residual = x
|
134 |
+
if self.normalize_before:
|
135 |
+
x = self.norm_conv(x)
|
136 |
+
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
|
137 |
+
|
138 |
+
# add pointwise_conv for efficient conformer
|
139 |
+
# pointwise_conv_layer does not change shape
|
140 |
+
if self.pointwise_conv_layer is not None:
|
141 |
+
residual = residual.transpose(1, 2)
|
142 |
+
residual = self.pointwise_conv_layer(residual)
|
143 |
+
residual = residual.transpose(1, 2)
|
144 |
+
assert residual.size(0) == x.size(0)
|
145 |
+
assert residual.size(1) == x.size(1)
|
146 |
+
assert residual.size(2) == x.size(2)
|
147 |
+
|
148 |
+
x = residual + self.dropout(x)
|
149 |
+
|
150 |
+
if not self.normalize_before:
|
151 |
+
x = self.norm_conv(x)
|
152 |
+
|
153 |
+
# feed forward module
|
154 |
+
residual = x
|
155 |
+
if self.normalize_before:
|
156 |
+
x = self.norm_ff(x)
|
157 |
+
|
158 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
159 |
+
if not self.normalize_before:
|
160 |
+
x = self.norm_ff(x)
|
161 |
+
|
162 |
+
if self.conv_module is not None:
|
163 |
+
x = self.norm_final(x)
|
164 |
+
|
165 |
+
return x, mask, new_att_cache, new_cnn_cache
|
wenet/efficient_conformer/subsampling.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
2 |
+
# 2022 58.com(Wuba) Inc AI Lab.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""Subsampling layer definition."""
|
17 |
+
|
18 |
+
from typing import Tuple, Union
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from wenet.transformer.subsampling import BaseSubsampling
|
22 |
+
|
23 |
+
|
24 |
+
class Conv2dSubsampling2(BaseSubsampling):
|
25 |
+
"""Convolutional 2D subsampling (to 1/4 length).
|
26 |
+
|
27 |
+
Args:
|
28 |
+
idim (int): Input dimension.
|
29 |
+
odim (int): Output dimension.
|
30 |
+
dropout_rate (float): Dropout rate.
|
31 |
+
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
35 |
+
pos_enc_class: torch.nn.Module):
|
36 |
+
"""Construct an Conv2dSubsampling4 object."""
|
37 |
+
super().__init__()
|
38 |
+
self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, odim, 3, 2),
|
39 |
+
torch.nn.ReLU())
|
40 |
+
self.out = torch.nn.Sequential(
|
41 |
+
torch.nn.Linear(odim * ((idim - 1) // 2), odim))
|
42 |
+
self.pos_enc = pos_enc_class
|
43 |
+
# The right context for every conv layer is computed by:
|
44 |
+
# (kernel_size - 1) * frame_rate_of_this_layer
|
45 |
+
self.subsampling_rate = 2
|
46 |
+
# 2 = (3 - 1) * 1
|
47 |
+
self.right_context = 2
|
48 |
+
|
49 |
+
def forward(
|
50 |
+
self,
|
51 |
+
x: torch.Tensor,
|
52 |
+
x_mask: torch.Tensor,
|
53 |
+
offset: Union[int, torch.Tensor] = 0
|
54 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
55 |
+
"""Subsample x.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
59 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
63 |
+
where time' = time // 2.
|
64 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
65 |
+
where time' = time // 2.
|
66 |
+
torch.Tensor: positional encoding
|
67 |
+
|
68 |
+
"""
|
69 |
+
x = x.unsqueeze(1) # (b, c=1, t, f)
|
70 |
+
x = self.conv(x)
|
71 |
+
b, c, t, f = x.size()
|
72 |
+
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
73 |
+
x, pos_emb = self.pos_enc(x, offset)
|
74 |
+
return x, pos_emb, x_mask[:, :, :-2:2]
|
wenet/finetune/lora/__init__.py
ADDED
File without changes
|
wenet/finetune/lora/config.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
init_batch_size: 2
|
2 |
+
init_iters: 8
|
3 |
+
init_config:
|
4 |
+
mode: "gradient" # option: "simple", "svd", "gradient"
|
5 |
+
lora_A: "unit" # option: "gaussian", "kaiming", "fan_out_kaiming", "xavier", "zeros", "unit", "orthogonal"
|
6 |
+
lora_A_std: 0.01 # only needed when lora_A is "gaussian"
|
7 |
+
lora_B: "unit" # option: "gaussian", "kaiming", "fan_out_kaiming", "xavier", "zeros", "unit", "orthogonal"
|
8 |
+
lora_B_std: 0.01 # only needed when lora_B is "gaussian"
|
9 |
+
scale: "stable" # option: "default", "stable", "unit", "normalized", "gd", "weightS"
|
10 |
+
stable_gamma: 2 # only needed when scale is "stable"
|
11 |
+
direction: "ArB2r" # option: "ArBr", "A2rBr", "ArB2r"(only needed when mode is "gradient")
|
12 |
+
dtype: "fp32" # option: "bf16", "fp32"
|
13 |
+
norm_clip: false # norm clipping
|
wenet/finetune/lora/layers.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 microsoft
|
2 |
+
# 2023 Alan ([email protected])
|
3 |
+
# -----------------------------------------------------------------------------
|
4 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for
|
5 |
+
# license information.
|
6 |
+
# -----------------------------------------------------------------------------
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
import math
|
13 |
+
from typing import List
|
14 |
+
|
15 |
+
|
16 |
+
class LoRALayer():
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
r: int,
|
21 |
+
lora_alpha: int,
|
22 |
+
lora_dropout: float,
|
23 |
+
merge_weights: bool,
|
24 |
+
):
|
25 |
+
self.r = r
|
26 |
+
self.lora_alpha = lora_alpha
|
27 |
+
# Optional dropout
|
28 |
+
if lora_dropout > 0.:
|
29 |
+
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
30 |
+
else:
|
31 |
+
self.lora_dropout = self.identity
|
32 |
+
# Mark the weight as unmerged
|
33 |
+
self.merged = False
|
34 |
+
self.merge_weights = merge_weights
|
35 |
+
|
36 |
+
def identity(self, x):
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
class Embedding(nn.Embedding, LoRALayer):
|
41 |
+
# LoRA implemented in a dense layer
|
42 |
+
def __init__(self,
|
43 |
+
num_embeddings: int,
|
44 |
+
embedding_dim: int,
|
45 |
+
r: int = 0,
|
46 |
+
lora_alpha: int = 1,
|
47 |
+
merge_weights: bool = True,
|
48 |
+
**kwargs):
|
49 |
+
nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
|
50 |
+
LoRALayer.__init__(self,
|
51 |
+
r=r,
|
52 |
+
lora_alpha=lora_alpha,
|
53 |
+
lora_dropout=0,
|
54 |
+
merge_weights=merge_weights)
|
55 |
+
# Actual trainable parameters
|
56 |
+
if r > 0:
|
57 |
+
self.lora_A = nn.Parameter(
|
58 |
+
self.weight.new_zeros((r, num_embeddings)))
|
59 |
+
self.lora_B = nn.Parameter(
|
60 |
+
self.weight.new_zeros((embedding_dim, r)))
|
61 |
+
self.scaling = self.lora_alpha / self.r
|
62 |
+
# Freezing the pre-trained weight matrix
|
63 |
+
self.weight.requires_grad = False
|
64 |
+
self.reset_parameters()
|
65 |
+
|
66 |
+
def reset_parameters(self):
|
67 |
+
nn.Embedding.reset_parameters(self)
|
68 |
+
if hasattr(self, 'lora_A'):
|
69 |
+
# initialize A the same way as the default for nn.Linear and B to zero
|
70 |
+
nn.init.zeros_(self.lora_A)
|
71 |
+
nn.init.normal_(self.lora_B)
|
72 |
+
|
73 |
+
def train(self, mode: bool = True):
|
74 |
+
nn.Embedding.train(self, mode)
|
75 |
+
if mode:
|
76 |
+
if self.merge_weights and self.merged:
|
77 |
+
# Make sure that the weights are not merged
|
78 |
+
if self.r > 0:
|
79 |
+
temp = (self.lora_B @ self.lora_A).transpose(0, 1)
|
80 |
+
self.weight.data -= temp * self.scaling
|
81 |
+
self.merged = False
|
82 |
+
else:
|
83 |
+
if self.merge_weights and not self.merged:
|
84 |
+
# Merge the weights and mark it
|
85 |
+
if self.r > 0:
|
86 |
+
temp = (self.lora_B @ self.lora_A).transpose(0, 1)
|
87 |
+
self.weight.data += temp * self.scaling
|
88 |
+
self.merged = True
|
89 |
+
|
90 |
+
def forward(self, x: torch.Tensor):
|
91 |
+
if self.r > 0 and not self.merged:
|
92 |
+
result = nn.Embedding.forward(self, x)
|
93 |
+
after_A = F.embedding(x, self.lora_A.transpose(0, 1),
|
94 |
+
self.padding_idx, self.max_norm,
|
95 |
+
self.norm_type, self.scale_grad_by_freq,
|
96 |
+
self.sparse)
|
97 |
+
result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
|
98 |
+
return result
|
99 |
+
else:
|
100 |
+
return nn.Embedding.forward(self, x)
|
101 |
+
|
102 |
+
|
103 |
+
class Linear(nn.Linear, LoRALayer):
|
104 |
+
# LoRA implemented in a dense layer
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
in_features: int,
|
108 |
+
out_features: int,
|
109 |
+
r: int = 0,
|
110 |
+
lora_alpha: int = 1,
|
111 |
+
lora_dropout: float = 0.,
|
112 |
+
fan_in_fan_out: bool = False,
|
113 |
+
# Set this to True if the layer to replace stores weight like (fan_in,
|
114 |
+
# fan_out)
|
115 |
+
merge_weights: bool = True,
|
116 |
+
**kwargs):
|
117 |
+
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
118 |
+
LoRALayer.__init__(self,
|
119 |
+
r=r,
|
120 |
+
lora_alpha=lora_alpha,
|
121 |
+
lora_dropout=lora_dropout,
|
122 |
+
merge_weights=merge_weights)
|
123 |
+
|
124 |
+
self.fan_in_fan_out = fan_in_fan_out
|
125 |
+
# Actual trainable parameters
|
126 |
+
if r > 0:
|
127 |
+
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
|
128 |
+
self.lora_B = nn.Parameter(self.weight.new_zeros(
|
129 |
+
(out_features, r)))
|
130 |
+
self.scaling = self.lora_alpha / self.r
|
131 |
+
# Freezing the pre-trained weight matrix
|
132 |
+
self.weight.requires_grad = False
|
133 |
+
self.reset_parameters()
|
134 |
+
if fan_in_fan_out:
|
135 |
+
self.weight.data = self.weight.data.transpose(0, 1)
|
136 |
+
|
137 |
+
def reset_parameters(self):
|
138 |
+
nn.Linear.reset_parameters(self)
|
139 |
+
if hasattr(self, 'lora_A'):
|
140 |
+
# initialize A the same way as the default for nn.Linear and B to zero
|
141 |
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
142 |
+
nn.init.zeros_(self.lora_B)
|
143 |
+
|
144 |
+
def T(self, w):
|
145 |
+
return w.transpose(0, 1) if self.fan_in_fan_out else w
|
146 |
+
|
147 |
+
def train(self, mode: bool = True):
|
148 |
+
nn.Linear.train(self, mode)
|
149 |
+
if mode:
|
150 |
+
if self.merge_weights and self.merged:
|
151 |
+
# Make sure that the weights are not merged
|
152 |
+
if self.r > 0:
|
153 |
+
temp = self.T(self.lora_B @ self.lora_A)
|
154 |
+
self.weight.data -= temp * self.scaling
|
155 |
+
self.merged = False
|
156 |
+
else:
|
157 |
+
if self.merge_weights and not self.merged:
|
158 |
+
# Merge the weights and mark it
|
159 |
+
if self.r > 0:
|
160 |
+
temp = self.T(self.lora_B @ self.lora_A)
|
161 |
+
self.weight.data += temp * self.scaling
|
162 |
+
self.merged = True
|
163 |
+
|
164 |
+
def forward(self, x: torch.Tensor):
|
165 |
+
if self.r > 0 and not self.merged:
|
166 |
+
result = F.linear(x, self.T(self.weight), bias=self.bias)
|
167 |
+
result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1)
|
168 |
+
@ self.lora_B.transpose(0, 1)) * self.scaling
|
169 |
+
return result
|
170 |
+
else:
|
171 |
+
return F.linear(x, self.T(self.weight), bias=self.bias)
|
172 |
+
|
173 |
+
|
174 |
+
class MergedLinear(nn.Linear, LoRALayer):
|
175 |
+
# LoRA implemented in a dense layer
|
176 |
+
def __init__(self,
|
177 |
+
in_features: int,
|
178 |
+
out_features: int,
|
179 |
+
r: int = 0,
|
180 |
+
lora_alpha: int = 1,
|
181 |
+
lora_dropout: float = 0.,
|
182 |
+
enable_lora: List[bool] = None,
|
183 |
+
fan_in_fan_out: bool = False,
|
184 |
+
merge_weights: bool = True,
|
185 |
+
**kwargs):
|
186 |
+
if enable_lora is None:
|
187 |
+
enable_lora = [False]
|
188 |
+
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
189 |
+
LoRALayer.__init__(self,
|
190 |
+
r=r,
|
191 |
+
lora_alpha=lora_alpha,
|
192 |
+
lora_dropout=lora_dropout,
|
193 |
+
merge_weights=merge_weights)
|
194 |
+
assert out_features % len(enable_lora) == 0, \
|
195 |
+
'The length of enable_lora must divide out_features'
|
196 |
+
self.enable_lora = enable_lora
|
197 |
+
self.fan_in_fan_out = fan_in_fan_out
|
198 |
+
# Actual trainable parameters
|
199 |
+
if r > 0 and any(enable_lora):
|
200 |
+
self.lora_A = nn.Parameter(
|
201 |
+
self.weight.new_zeros((r * sum(enable_lora), in_features)))
|
202 |
+
self.lora_B = nn.Parameter(
|
203 |
+
self.weight.new_zeros(
|
204 |
+
(out_features // len(enable_lora) * sum(enable_lora), r)))
|
205 |
+
# weights for Conv1D with groups=sum(enable_lora)
|
206 |
+
self.scaling = self.lora_alpha / self.r
|
207 |
+
# Freezing the pre-trained weight matrix
|
208 |
+
self.weight.requires_grad = False
|
209 |
+
# Compute the indices
|
210 |
+
self.lora_ind = self.weight.new_zeros(
|
211 |
+
(out_features, ), dtype=torch.bool).view(len(enable_lora), -1)
|
212 |
+
self.lora_ind[enable_lora, :] = True
|
213 |
+
self.lora_ind = self.lora_ind.view(-1)
|
214 |
+
self.reset_parameters()
|
215 |
+
if fan_in_fan_out:
|
216 |
+
self.weight.data = self.weight.data.transpose(0, 1)
|
217 |
+
|
218 |
+
def reset_parameters(self):
|
219 |
+
nn.Linear.reset_parameters(self)
|
220 |
+
if hasattr(self, 'lora_A'):
|
221 |
+
# initialize A the same way as the default for nn.Linear and B to zero
|
222 |
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
223 |
+
nn.init.zeros_(self.lora_B)
|
224 |
+
|
225 |
+
def zero_pad(self, x):
|
226 |
+
result = x.new_zeros((len(self.lora_ind), *x.size()[1:]))
|
227 |
+
result[self.lora_ind] = x
|
228 |
+
return result
|
229 |
+
|
230 |
+
def T(self, w):
|
231 |
+
return w.transpose(0, 1) if self.fan_in_fan_out else w
|
232 |
+
|
233 |
+
def merge_AB(self):
|
234 |
+
delta_w = F.conv1d(self.lora_A.unsqueeze(0),
|
235 |
+
self.lora_B.unsqueeze(-1),
|
236 |
+
groups=sum(self.enable_lora)).squeeze(0)
|
237 |
+
return self.T(delta_w)
|
238 |
+
|
239 |
+
def train(self, mode: bool = True):
|
240 |
+
nn.Linear.train(self, mode)
|
241 |
+
if mode:
|
242 |
+
if self.merge_weights and self.merged:
|
243 |
+
# Make sure that the weights are not merged
|
244 |
+
if self.r > 0 and any(self.enable_lora):
|
245 |
+
self.weight.data -= self.merge_AB() * self.scaling
|
246 |
+
self.merged = False
|
247 |
+
else:
|
248 |
+
if self.merge_weights and not self.merged:
|
249 |
+
# Merge the weights and mark it
|
250 |
+
if self.r > 0 and any(self.enable_lora):
|
251 |
+
self.weight.data += self.merge_AB() * self.scaling
|
252 |
+
self.merged = True
|
253 |
+
|
254 |
+
def forward(self, x: torch.Tensor):
|
255 |
+
if self.merged:
|
256 |
+
return F.linear(x, self.T(self.weight), bias=self.bias)
|
257 |
+
else:
|
258 |
+
result = F.linear(x, self.T(self.weight), bias=self.bias)
|
259 |
+
if self.r > 0:
|
260 |
+
temp = self.T(self.merge_AB().T)
|
261 |
+
result += self.lora_dropout(x) @ temp * self.scaling
|
262 |
+
return result
|
263 |
+
|
264 |
+
|
265 |
+
class ConvLoRA(nn.Module, LoRALayer):
|
266 |
+
|
267 |
+
def __init__(self,
|
268 |
+
conv_module,
|
269 |
+
in_channels,
|
270 |
+
out_channels,
|
271 |
+
kernel_size,
|
272 |
+
r=0,
|
273 |
+
lora_alpha=1,
|
274 |
+
lora_dropout=0.,
|
275 |
+
merge_weights=True,
|
276 |
+
**kwargs):
|
277 |
+
super(ConvLoRA, self).__init__()
|
278 |
+
self.conv = conv_module(in_channels, out_channels, kernel_size,
|
279 |
+
**kwargs)
|
280 |
+
LoRALayer.__init__(self,
|
281 |
+
r=r,
|
282 |
+
lora_alpha=lora_alpha,
|
283 |
+
lora_dropout=lora_dropout,
|
284 |
+
merge_weights=merge_weights)
|
285 |
+
assert isinstance(kernel_size, int)
|
286 |
+
# Actual trainable parameters
|
287 |
+
if r > 0:
|
288 |
+
self.lora_A = nn.Parameter(
|
289 |
+
self.conv.weight.new_zeros(
|
290 |
+
(r * kernel_size, in_channels * kernel_size)))
|
291 |
+
self.lora_B = nn.Parameter(
|
292 |
+
self.conv.weight.new_zeros(
|
293 |
+
(out_channels // self.conv.groups * kernel_size,
|
294 |
+
r * kernel_size)))
|
295 |
+
self.scaling = self.lora_alpha / self.r
|
296 |
+
# Freezing the pre-trained weight matrix
|
297 |
+
self.conv.weight.requires_grad = False
|
298 |
+
self.reset_parameters()
|
299 |
+
self.merged = False
|
300 |
+
|
301 |
+
def reset_parameters(self):
|
302 |
+
self.conv.reset_parameters()
|
303 |
+
if hasattr(self, 'lora_A'):
|
304 |
+
# initialize A the same way as the default for nn.Linear and B to zero
|
305 |
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
306 |
+
nn.init.zeros_(self.lora_B)
|
307 |
+
|
308 |
+
def train(self, mode=True):
|
309 |
+
super(ConvLoRA, self).train(mode)
|
310 |
+
if mode:
|
311 |
+
if self.merge_weights and self.merged:
|
312 |
+
if self.r > 0:
|
313 |
+
# Make sure that the weights are not merged
|
314 |
+
self.conv.weight.data -= (self.lora_B @ self.lora_A).view(
|
315 |
+
self.conv.weight.shape) * self.scaling
|
316 |
+
self.merged = False
|
317 |
+
else:
|
318 |
+
if self.merge_weights and not self.merged:
|
319 |
+
if self.r > 0:
|
320 |
+
# Merge the weights and mark it
|
321 |
+
self.conv.weight.data += (self.lora_B @ self.lora_A).view(
|
322 |
+
self.conv.weight.shape) * self.scaling
|
323 |
+
self.merged = True
|
324 |
+
|
325 |
+
def forward(self, x):
|
326 |
+
if self.r > 0 and not self.merged:
|
327 |
+
return self.conv._conv_forward(
|
328 |
+
x, self.conv.weight +
|
329 |
+
(self.lora_B @ self.lora_A).view(self.conv.weight.shape) *
|
330 |
+
self.scaling, self.conv.bias)
|
331 |
+
return self.conv(x)
|
332 |
+
|
333 |
+
|
334 |
+
class Conv2d(ConvLoRA):
|
335 |
+
|
336 |
+
def __init__(self, *args, **kwargs):
|
337 |
+
super(Conv2d, self).__init__(nn.Conv2d, *args, **kwargs)
|
338 |
+
|
339 |
+
|
340 |
+
class Conv1d(ConvLoRA):
|
341 |
+
|
342 |
+
def __init__(self, *args, **kwargs):
|
343 |
+
super(Conv1d, self).__init__(nn.Conv1d, *args, **kwargs)
|
344 |
+
|
345 |
+
|
346 |
+
# Can Extend to other ones like this
|
347 |
+
class Conv3d(ConvLoRA):
|
348 |
+
|
349 |
+
def __init__(self, *args, **kwargs):
|
350 |
+
super(Conv3d, self).__init__(nn.Conv3d, *args, **kwargs)
|
wenet/finetune/lora/utils.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 microsoft
|
2 |
+
# 2023 Alan ([email protected])
|
3 |
+
# -----------------------------------------------------------------------------
|
4 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for
|
5 |
+
# license information.
|
6 |
+
# -----------------------------------------------------------------------------
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from typing import Dict, List
|
13 |
+
|
14 |
+
import wenet.finetune.lora.layers as lora
|
15 |
+
|
16 |
+
|
17 |
+
def get_nested_attr(module, attr_path):
|
18 |
+
attrs = attr_path.split('.')
|
19 |
+
for attr in attrs:
|
20 |
+
if hasattr(module, attr):
|
21 |
+
module = getattr(module, attr)
|
22 |
+
else:
|
23 |
+
return None
|
24 |
+
return module
|
25 |
+
|
26 |
+
|
27 |
+
def inject_lora(module, lora_config):
|
28 |
+
lora_rank = lora_config["lora_rank"]
|
29 |
+
lora_alpha = lora_config["lora_alpha"]
|
30 |
+
lora_dropout = lora_config["lora_dropout"]
|
31 |
+
for lora_attr in lora_config["lora_list"]:
|
32 |
+
if hasattr(module, lora_attr):
|
33 |
+
submodule = getattr(module, lora_attr)
|
34 |
+
n_feat = submodule.in_features
|
35 |
+
lora_linear = lora.Linear(n_feat, n_feat, r=lora_rank,
|
36 |
+
lora_alpha=lora_alpha,
|
37 |
+
lora_dropout=lora_dropout)
|
38 |
+
setattr(module, lora_attr, lora_linear)
|
39 |
+
|
40 |
+
|
41 |
+
def inject_lora_to_model(model, lora_config):
|
42 |
+
lora_modules = []
|
43 |
+
for module in lora_config["lora_modules"]:
|
44 |
+
submodule = get_nested_attr(model, module)
|
45 |
+
for layer in submodule:
|
46 |
+
lora_modules.append(layer)
|
47 |
+
|
48 |
+
updated_lora_modules = []
|
49 |
+
for i in range(len(lora_modules)):
|
50 |
+
for attn_attr in lora_config["lora_attn_attr"]:
|
51 |
+
if hasattr(lora_modules[i], attn_attr):
|
52 |
+
updated_lora_modules.append(getattr(lora_modules[i], attn_attr))
|
53 |
+
|
54 |
+
for lora_module in updated_lora_modules:
|
55 |
+
inject_lora(lora_module, lora_config)
|
56 |
+
|
57 |
+
|
58 |
+
def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
|
59 |
+
logging.info('freezing all params except lora module.')
|
60 |
+
for n, p in model.named_parameters():
|
61 |
+
if 'lora_' not in n:
|
62 |
+
p.requires_grad = False
|
63 |
+
if bias == 'none':
|
64 |
+
return
|
65 |
+
elif bias == 'all':
|
66 |
+
for n, p in model.named_parameters():
|
67 |
+
if 'bias' in n:
|
68 |
+
p.requires_grad = True
|
69 |
+
elif bias == 'lora_only':
|
70 |
+
for m in model.modules():
|
71 |
+
if isinstance(m, lora.LoRALayer) and \
|
72 |
+
hasattr(m, 'bias') and \
|
73 |
+
m.bias is not None:
|
74 |
+
m.bias.requires_grad = True
|
75 |
+
else:
|
76 |
+
raise NotImplementedError
|
77 |
+
|
78 |
+
|
79 |
+
def lora_state_dict(model: nn.Module,
|
80 |
+
bias: str = 'none') -> Dict[str, torch.Tensor]:
|
81 |
+
my_state_dict = model.state_dict()
|
82 |
+
if bias == 'none':
|
83 |
+
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
|
84 |
+
elif bias == 'all':
|
85 |
+
return {
|
86 |
+
k: my_state_dict[k]
|
87 |
+
for k in my_state_dict if 'lora_' in k or 'bias' in k
|
88 |
+
}
|
89 |
+
elif bias == 'lora_only':
|
90 |
+
to_return = {}
|
91 |
+
for k in my_state_dict:
|
92 |
+
if 'lora_' in k:
|
93 |
+
to_return[k] = my_state_dict[k]
|
94 |
+
bias_name = k.split('lora_')[0] + 'bias'
|
95 |
+
if bias_name in my_state_dict:
|
96 |
+
to_return[bias_name] = my_state_dict[bias_name]
|
97 |
+
return to_return
|
98 |
+
else:
|
99 |
+
raise NotImplementedError
|
100 |
+
|
101 |
+
|
102 |
+
def get_record_gradient_hook(model, record_dict):
|
103 |
+
def record_gradient_hook(grad):
|
104 |
+
for n, p in model.named_parameters():
|
105 |
+
if p.requires_grad and p.grad is not None:
|
106 |
+
if n not in record_dict:
|
107 |
+
record_dict[n] = p.grad.cpu()
|
108 |
+
else:
|
109 |
+
record_dict[n] += p.grad.cpu()
|
110 |
+
p.grad = None
|
111 |
+
return grad
|
112 |
+
|
113 |
+
return record_gradient_hook
|
114 |
+
|
115 |
+
|
116 |
+
def estimate_gradient(
|
117 |
+
model, dataloader, max_iters: int = 8,
|
118 |
+
device: torch.device = torch.device("cpu")
|
119 |
+
) -> Dict[str, List[torch.Tensor]]:
|
120 |
+
r"""
|
121 |
+
Estimate the gradient of the model on the given dataset
|
122 |
+
"""
|
123 |
+
logging.info("Estimating gradient layer by layer, time needed")
|
124 |
+
model.train()
|
125 |
+
named_grads = {}
|
126 |
+
hooks = []
|
127 |
+
requires_grad_states = {}
|
128 |
+
for name, param in model.named_parameters():
|
129 |
+
requires_grad_states[name] = param.requires_grad
|
130 |
+
param.requires_grad = True
|
131 |
+
hook = param.register_hook(get_record_gradient_hook(model, named_grads))
|
132 |
+
hooks.append(hook)
|
133 |
+
num = 0
|
134 |
+
for _, batch_dict in enumerate(dataloader):
|
135 |
+
num += 1
|
136 |
+
if max_iters is not None and num >= max_iters:
|
137 |
+
break
|
138 |
+
outputs = model(batch_dict, device)
|
139 |
+
outputs['loss'].backward()
|
140 |
+
get_record_gradient_hook(model, named_grads)(None) # get gradient of last layer
|
141 |
+
# make sure the gradient is cleared
|
142 |
+
for n, p in model.named_parameters():
|
143 |
+
if p.grad is not None:
|
144 |
+
p.grad = None
|
145 |
+
for n, _ in named_grads.items():
|
146 |
+
named_grads[n] /= num
|
147 |
+
for hook in hooks:
|
148 |
+
hook.remove()
|
149 |
+
# recover original requires_grad states
|
150 |
+
for name, param in model.named_parameters():
|
151 |
+
param.requires_grad = requires_grad_states[name]
|
152 |
+
torch.cuda.empty_cache()
|
153 |
+
return named_grads
|
154 |
+
|
155 |
+
|
156 |
+
@torch.no_grad()
|
157 |
+
def reinit_lora_modules(name, module, init_config, **kwargs):
|
158 |
+
r"""Refer to https://github.com/Outsider565/LoRA-GA/blob/
|
159 |
+
c185846309ea9012d0bcd46ebd30347dda1c592c/run_exp.py#L67
|
160 |
+
Reinitialize the lora model with the given configuration.
|
161 |
+
"""
|
162 |
+
import math
|
163 |
+
lora_r = min(module.lora_A.shape)
|
164 |
+
a_dim = max(module.lora_A.shape)
|
165 |
+
b_dim = max(module.lora_B.shape)
|
166 |
+
if init_config.mode == "simple":
|
167 |
+
match init_config.lora_A:
|
168 |
+
case "gaussian":
|
169 |
+
torch.nn.init.normal_(
|
170 |
+
module.lora_A, mean=0.0,
|
171 |
+
std=init_config.lora_A_std
|
172 |
+
)
|
173 |
+
case "kaiming":
|
174 |
+
# https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124
|
175 |
+
torch.nn.init.kaiming_uniform_(module.lora_A,
|
176 |
+
a=math.sqrt(5))
|
177 |
+
case "fan_out_kaiming":
|
178 |
+
torch.nn.init.kaiming_normal_(
|
179 |
+
module.lora_A, mode="fan_out"
|
180 |
+
)
|
181 |
+
case "xavier":
|
182 |
+
torch.nn.init.xavier_normal_(module.lora_A)
|
183 |
+
case "zeros":
|
184 |
+
torch.nn.init.zeros_(module.lora_A)
|
185 |
+
case "unit":
|
186 |
+
torch.nn.init.normal_(
|
187 |
+
module.lora_A, mean=0.0,
|
188 |
+
std=1.0 / (a_dim**0.5)
|
189 |
+
)
|
190 |
+
case "orthogonal":
|
191 |
+
torch.nn.init.orthogonal_(module.lora_A)
|
192 |
+
case _:
|
193 |
+
raise ValueError(
|
194 |
+
f"Unknown lora_A initialization: {init_config.lora_A}"
|
195 |
+
)
|
196 |
+
match init_config.lora_B:
|
197 |
+
case "gaussian":
|
198 |
+
torch.nn.init.normal_(
|
199 |
+
module.lora_B, mean=0.0,
|
200 |
+
std=init_config.lora_B_std
|
201 |
+
)
|
202 |
+
case "kaiming":
|
203 |
+
torch.nn.init.kaiming_normal_(module.lora_B)
|
204 |
+
case "fan_out_kaiming":
|
205 |
+
torch.nn.init.kaiming_normal_(
|
206 |
+
module.lora_B, mode="fan_out"
|
207 |
+
)
|
208 |
+
case "xavier":
|
209 |
+
torch.nn.init.xavier_normal_(module.lora_B)
|
210 |
+
case "zeros":
|
211 |
+
torch.nn.init.zeros_(module.lora_B)
|
212 |
+
case "unit":
|
213 |
+
torch.nn.init.normal_(
|
214 |
+
module.lora_B, mean=0.0,
|
215 |
+
std=1.0 / (b_dim**0.5)
|
216 |
+
)
|
217 |
+
case "orthogonal":
|
218 |
+
torch.nn.init.orthogonal_(module.lora_B)
|
219 |
+
case _:
|
220 |
+
raise ValueError(
|
221 |
+
f"Unknown lora_B initialization: {init_config.lora_B}"
|
222 |
+
)
|
223 |
+
if getattr(init_config, 'scale', '') == "stable":
|
224 |
+
gamma = init_config.stable_gamma
|
225 |
+
m, n = module.weight.shape
|
226 |
+
module.lora_B.data *= (m**0.25) / gamma**0.5
|
227 |
+
module.lora_A.data *= (n**0.25) / gamma**0.5
|
228 |
+
elif init_config.mode == "svd":
|
229 |
+
U, S, V = torch.svd_lowrank(module.weight.float(), q=4 * lora_r,
|
230 |
+
niter=4)
|
231 |
+
V = V.T
|
232 |
+
m, n = module.weight.shape
|
233 |
+
if init_config.scale == "default":
|
234 |
+
S = S / module.scaling
|
235 |
+
module.lora_B = torch.nn.Parameter(
|
236 |
+
(U[:, :lora_r] * torch.sqrt(S[:lora_r])).contiguous()
|
237 |
+
)
|
238 |
+
module.lora_A = torch.nn.Parameter(
|
239 |
+
(V[:lora_r, :].T * torch.sqrt(S[:lora_r])).T.contiguous()
|
240 |
+
)
|
241 |
+
elif init_config.scale == "stable":
|
242 |
+
gamma = init_config.stable_gamma
|
243 |
+
module.lora_B = torch.nn.Parameter(
|
244 |
+
(U[:, :lora_r] * (m**0.25) / gamma**0.5).contiguous()
|
245 |
+
)
|
246 |
+
module.lora_A = torch.nn.Parameter(
|
247 |
+
(V[:lora_r, :] * (n**0.25) / gamma**0.5).contiguous()
|
248 |
+
)
|
249 |
+
elif init_config.scale == "unit":
|
250 |
+
module.lora_B = torch.nn.Parameter((U[:, :lora_r]).contiguous())
|
251 |
+
module.lora_A = torch.nn.Parameter((V[:lora_r, :]).contiguous())
|
252 |
+
elif init_config.scale == "normalized":
|
253 |
+
S_sum = S[:lora_r].sum()
|
254 |
+
module.lora_B = torch.nn.Parameter(
|
255 |
+
(U[:, :lora_r] * torch.sqrt(S[:lora_r])
|
256 |
+
/ torch.sqrt(S_sum) * lora_r**0.5).contiguous()
|
257 |
+
)
|
258 |
+
module.lora_A = torch.nn.Parameter(
|
259 |
+
(V[:lora_r, :].T * torch.sqrt(S[:lora_r])
|
260 |
+
/ torch.sqrt(S_sum) * lora_r**0.5).T.contiguous()
|
261 |
+
)
|
262 |
+
elif init_config.mode == "gradient":
|
263 |
+
named_grad = kwargs["named_grads"]
|
264 |
+
grad_name = name + ".weight"
|
265 |
+
grads = named_grad[grad_name]
|
266 |
+
U, S, V = torch.svd_lowrank(grads.cuda().float(), q=4 * lora_r, niter=4)
|
267 |
+
V = V.T
|
268 |
+
# set direction
|
269 |
+
if init_config.direction == "ArBr":
|
270 |
+
B = U[:, 0 : 2 * lora_r : 2]
|
271 |
+
A = V[1 : 2 * lora_r : 2, :]
|
272 |
+
elif init_config.direction == "A2rBr":
|
273 |
+
B = U[:, :lora_r]
|
274 |
+
A = V[lora_r : 2 * lora_r, :]
|
275 |
+
elif init_config.direction == "ArB2r":
|
276 |
+
B = U[:, lora_r : 2 * lora_r]
|
277 |
+
A = V[:lora_r, :]
|
278 |
+
scaling_factor = module.scaling
|
279 |
+
if init_config.scale == "gd":
|
280 |
+
A = A / scaling_factor
|
281 |
+
B = B / scaling_factor
|
282 |
+
elif init_config.scale == "unit":
|
283 |
+
# Because A,B is orthogonal, do not need to scale
|
284 |
+
pass
|
285 |
+
elif init_config.scale == "stable":
|
286 |
+
m, n = grads.shape
|
287 |
+
# m: feature_out, n: feature_in
|
288 |
+
# the scale of output is only related to the feature_out
|
289 |
+
gamma = init_config.stable_gamma
|
290 |
+
B = B * m**0.25 / gamma**0.5
|
291 |
+
A = A * m**0.25 / gamma**0.5
|
292 |
+
elif init_config.scale == "weightS":
|
293 |
+
_, S, _ = torch.svd_lowrank(module.weight.float(), q=4 * lora_r,
|
294 |
+
niter=4)
|
295 |
+
S = S / module.scaling
|
296 |
+
avg_s = torch.sqrt(S[:lora_r]).mean().to(A.device)
|
297 |
+
B = B * avg_s
|
298 |
+
A = A * avg_s
|
299 |
+
module.lora_B = torch.nn.Parameter(B.contiguous().cuda())
|
300 |
+
module.lora_A = torch.nn.Parameter(A.contiguous().cuda())
|
301 |
+
|
302 |
+
with torch.no_grad():
|
303 |
+
# consider dtype not in init_config
|
304 |
+
if not hasattr(init_config, "dtype"):
|
305 |
+
pass
|
306 |
+
elif init_config.dtype == "bf16":
|
307 |
+
module.lora_A.data = module.lora_A.data.to(torch.bfloat16)
|
308 |
+
module.lora_B.data = module.lora_B.data.to(torch.bfloat16)
|
309 |
+
elif init_config.dtype == "fp32":
|
310 |
+
module.lora_A.data = module.lora_A.data.to(torch.float32)
|
311 |
+
module.lora_B.data = module.lora_B.data.to(torch.float32)
|
312 |
+
# If lora_A@lora_B is not zero,
|
313 |
+
# then we need to subtract lora_A@lora_B from the original weight matrix
|
314 |
+
offset = (
|
315 |
+
module.lora_B @ module.lora_A
|
316 |
+
).to(module.weight.data.device)
|
317 |
+
scaling_factor = module.scaling
|
318 |
+
offset *= scaling_factor
|
319 |
+
if hasattr(init_config, "norm_clip") and init_config.norm_clip:
|
320 |
+
# for numerical stability,
|
321 |
+
# offset's largest value must be less then weight's largest value
|
322 |
+
ratio = torch.max(torch.abs(module.weight.data)) / torch.max(
|
323 |
+
torch.abs(offset)
|
324 |
+
)
|
325 |
+
if ratio < 1:
|
326 |
+
offset *= ratio
|
327 |
+
module.lora_A.data *= ratio**0.5
|
328 |
+
module.lora_B.data *= ratio**0.5
|
329 |
+
logging.warning(f"Clipping offset by {ratio}")
|
330 |
+
try:
|
331 |
+
module.weight.data -= offset
|
332 |
+
except Exception as e:
|
333 |
+
logging.warning(f"{e}")
|
334 |
+
breakpoint()
|
wenet/k2/__init__.py
ADDED
File without changes
|
wenet/k2/model.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Binbin Zhang ([email protected])
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Dict, List, Tuple
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch.nn.utils.rnn import pad_sequence
|
19 |
+
|
20 |
+
from wenet.transformer.asr_model import ASRModel
|
21 |
+
from wenet.transformer.ctc import CTC
|
22 |
+
from wenet.transformer.decoder import TransformerDecoder
|
23 |
+
from wenet.transformer.encoder import TransformerEncoder
|
24 |
+
from wenet.utils.common import (IGNORE_ID, add_sos_eos, reverse_pad_list)
|
25 |
+
|
26 |
+
|
27 |
+
class K2Model(ASRModel):
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
vocab_size: int,
|
32 |
+
encoder: TransformerEncoder,
|
33 |
+
decoder: TransformerDecoder,
|
34 |
+
ctc: CTC,
|
35 |
+
ctc_weight: float = 0.5,
|
36 |
+
ignore_id: int = IGNORE_ID,
|
37 |
+
reverse_weight: float = 0.0,
|
38 |
+
lsm_weight: float = 0.0,
|
39 |
+
length_normalized_loss: bool = False,
|
40 |
+
lfmmi_dir: str = '',
|
41 |
+
special_tokens: dict = None,
|
42 |
+
device: torch.device = torch.device("cuda"),
|
43 |
+
):
|
44 |
+
super().__init__(vocab_size,
|
45 |
+
encoder,
|
46 |
+
decoder,
|
47 |
+
ctc,
|
48 |
+
ctc_weight,
|
49 |
+
ignore_id,
|
50 |
+
reverse_weight,
|
51 |
+
lsm_weight,
|
52 |
+
length_normalized_loss,
|
53 |
+
special_tokens=special_tokens)
|
54 |
+
self.lfmmi_dir = lfmmi_dir
|
55 |
+
self.device = device
|
56 |
+
if self.lfmmi_dir != '':
|
57 |
+
self.load_lfmmi_resource()
|
58 |
+
|
59 |
+
@torch.jit.unused
|
60 |
+
def _forward_ctc(
|
61 |
+
self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor,
|
62 |
+
text: torch.Tensor,
|
63 |
+
text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
64 |
+
loss_ctc, ctc_probs = self._calc_lfmmi_loss(encoder_out, encoder_mask,
|
65 |
+
text)
|
66 |
+
return loss_ctc, ctc_probs
|
67 |
+
|
68 |
+
@torch.jit.unused
|
69 |
+
def load_lfmmi_resource(self):
|
70 |
+
try:
|
71 |
+
import icefall
|
72 |
+
except ImportError:
|
73 |
+
print('Error: Failed to import icefall')
|
74 |
+
with open('{}/tokens.txt'.format(self.lfmmi_dir), 'r') as fin:
|
75 |
+
for line in fin:
|
76 |
+
arr = line.strip().split()
|
77 |
+
if arr[0] == '<sos/eos>':
|
78 |
+
self.sos_eos_id = int(arr[1])
|
79 |
+
device = torch.device(self.device)
|
80 |
+
self.graph_compiler = icefall.mmi_graph_compiler.MmiTrainingGraphCompiler(
|
81 |
+
self.lfmmi_dir,
|
82 |
+
device=device,
|
83 |
+
oov="<UNK>",
|
84 |
+
sos_id=self.sos_eos_id,
|
85 |
+
eos_id=self.sos_eos_id,
|
86 |
+
)
|
87 |
+
self.lfmmi = icefall.mmi.LFMMILoss(
|
88 |
+
graph_compiler=self.graph_compiler,
|
89 |
+
den_scale=1,
|
90 |
+
use_pruned_intersect=False,
|
91 |
+
)
|
92 |
+
self.word_table = {}
|
93 |
+
with open('{}/words.txt'.format(self.lfmmi_dir), 'r') as fin:
|
94 |
+
for line in fin:
|
95 |
+
arr = line.strip().split()
|
96 |
+
assert len(arr) == 2
|
97 |
+
self.word_table[int(arr[1])] = arr[0]
|
98 |
+
|
99 |
+
@torch.jit.unused
|
100 |
+
def _calc_lfmmi_loss(self, encoder_out, encoder_mask, text):
|
101 |
+
try:
|
102 |
+
import k2
|
103 |
+
except ImportError:
|
104 |
+
print('Error: Failed to import k2')
|
105 |
+
ctc_probs = self.ctc.log_softmax(encoder_out)
|
106 |
+
supervision_segments = torch.stack((
|
107 |
+
torch.arange(len(encoder_mask)),
|
108 |
+
torch.zeros(len(encoder_mask)),
|
109 |
+
encoder_mask.squeeze(dim=1).sum(dim=1).to('cpu'),
|
110 |
+
), 1).to(torch.int32)
|
111 |
+
dense_fsa_vec = k2.DenseFsaVec(
|
112 |
+
ctc_probs,
|
113 |
+
supervision_segments,
|
114 |
+
allow_truncate=3,
|
115 |
+
)
|
116 |
+
text = [
|
117 |
+
' '.join([self.word_table[j.item()] for j in i if j != -1])
|
118 |
+
for i in text
|
119 |
+
]
|
120 |
+
loss = self.lfmmi(dense_fsa_vec=dense_fsa_vec, texts=text) / len(text)
|
121 |
+
return loss, ctc_probs
|
122 |
+
|
123 |
+
def load_hlg_resource_if_necessary(self, hlg, word):
|
124 |
+
try:
|
125 |
+
import k2
|
126 |
+
except ImportError:
|
127 |
+
print('Error: Failed to import k2')
|
128 |
+
if not hasattr(self, 'hlg'):
|
129 |
+
device = torch.device(self.device)
|
130 |
+
self.hlg = k2.Fsa.from_dict(torch.load(hlg, map_location=device))
|
131 |
+
if not hasattr(self.hlg, "lm_scores"):
|
132 |
+
self.hlg.lm_scores = self.hlg.scores.clone()
|
133 |
+
if not hasattr(self, 'word_table'):
|
134 |
+
self.word_table = {}
|
135 |
+
with open(word, 'r') as fin:
|
136 |
+
for line in fin:
|
137 |
+
arr = line.strip().split()
|
138 |
+
assert len(arr) == 2
|
139 |
+
self.word_table[int(arr[1])] = arr[0]
|
140 |
+
|
141 |
+
@torch.no_grad()
|
142 |
+
def hlg_onebest(
|
143 |
+
self,
|
144 |
+
speech: torch.Tensor,
|
145 |
+
speech_lengths: torch.Tensor,
|
146 |
+
decoding_chunk_size: int = -1,
|
147 |
+
num_decoding_left_chunks: int = -1,
|
148 |
+
simulate_streaming: bool = False,
|
149 |
+
hlg: str = '',
|
150 |
+
word: str = '',
|
151 |
+
symbol_table: Dict[str, int] = None,
|
152 |
+
) -> List[int]:
|
153 |
+
try:
|
154 |
+
import icefall
|
155 |
+
except ImportError:
|
156 |
+
print('Error: Failed to import icefall')
|
157 |
+
self.load_hlg_resource_if_necessary(hlg, word)
|
158 |
+
encoder_out, encoder_mask = self._forward_encoder(
|
159 |
+
speech, speech_lengths, decoding_chunk_size,
|
160 |
+
num_decoding_left_chunks,
|
161 |
+
simulate_streaming) # (B, maxlen, encoder_dim)
|
162 |
+
ctc_probs = self.ctc.log_softmax(
|
163 |
+
encoder_out) # (1, maxlen, vocab_size)
|
164 |
+
supervision_segments = torch.stack(
|
165 |
+
(torch.arange(len(encoder_mask)), torch.zeros(len(encoder_mask)),
|
166 |
+
encoder_mask.squeeze(dim=1).sum(dim=1).cpu()),
|
167 |
+
1,
|
168 |
+
).to(torch.int32)
|
169 |
+
lattice = icefall.decode.get_lattice(
|
170 |
+
nnet_output=ctc_probs,
|
171 |
+
decoding_graph=self.hlg,
|
172 |
+
supervision_segments=supervision_segments,
|
173 |
+
search_beam=20,
|
174 |
+
output_beam=7,
|
175 |
+
min_active_states=30,
|
176 |
+
max_active_states=10000,
|
177 |
+
subsampling_factor=4)
|
178 |
+
best_path = icefall.decode.one_best_decoding(lattice=lattice,
|
179 |
+
use_double_scores=True)
|
180 |
+
hyps = icefall.utils.get_texts(best_path)
|
181 |
+
hyps = [[symbol_table[k] for j in i for k in self.word_table[j]]
|
182 |
+
for i in hyps]
|
183 |
+
return hyps
|
184 |
+
|
185 |
+
@torch.no_grad()
|
186 |
+
def hlg_rescore(
|
187 |
+
self,
|
188 |
+
speech: torch.Tensor,
|
189 |
+
speech_lengths: torch.Tensor,
|
190 |
+
decoding_chunk_size: int = -1,
|
191 |
+
num_decoding_left_chunks: int = -1,
|
192 |
+
simulate_streaming: bool = False,
|
193 |
+
lm_scale: float = 0,
|
194 |
+
decoder_scale: float = 0,
|
195 |
+
r_decoder_scale: float = 0,
|
196 |
+
hlg: str = '',
|
197 |
+
word: str = '',
|
198 |
+
symbol_table: Dict[str, int] = None,
|
199 |
+
) -> List[int]:
|
200 |
+
try:
|
201 |
+
import k2
|
202 |
+
import icefall
|
203 |
+
except ImportError:
|
204 |
+
print('Error: Failed to import k2 & icefall')
|
205 |
+
self.load_hlg_resource_if_necessary(hlg, word)
|
206 |
+
device = speech.device
|
207 |
+
encoder_out, encoder_mask = self._forward_encoder(
|
208 |
+
speech, speech_lengths, decoding_chunk_size,
|
209 |
+
num_decoding_left_chunks,
|
210 |
+
simulate_streaming) # (B, maxlen, encoder_dim)
|
211 |
+
ctc_probs = self.ctc.log_softmax(
|
212 |
+
encoder_out) # (1, maxlen, vocab_size)
|
213 |
+
supervision_segments = torch.stack(
|
214 |
+
(torch.arange(len(encoder_mask)), torch.zeros(len(encoder_mask)),
|
215 |
+
encoder_mask.squeeze(dim=1).sum(dim=1).cpu()),
|
216 |
+
1,
|
217 |
+
).to(torch.int32)
|
218 |
+
lattice = icefall.decode.get_lattice(
|
219 |
+
nnet_output=ctc_probs,
|
220 |
+
decoding_graph=self.hlg,
|
221 |
+
supervision_segments=supervision_segments,
|
222 |
+
search_beam=20,
|
223 |
+
output_beam=7,
|
224 |
+
min_active_states=30,
|
225 |
+
max_active_states=10000,
|
226 |
+
subsampling_factor=4)
|
227 |
+
nbest = icefall.decode.Nbest.from_lattice(
|
228 |
+
lattice=lattice,
|
229 |
+
num_paths=100,
|
230 |
+
use_double_scores=True,
|
231 |
+
nbest_scale=0.5,
|
232 |
+
)
|
233 |
+
nbest = nbest.intersect(lattice)
|
234 |
+
assert hasattr(nbest.fsa, "lm_scores")
|
235 |
+
assert hasattr(nbest.fsa, "tokens")
|
236 |
+
assert isinstance(nbest.fsa.tokens, torch.Tensor)
|
237 |
+
|
238 |
+
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
|
239 |
+
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
|
240 |
+
tokens = tokens.remove_values_leq(0)
|
241 |
+
hyps = tokens.tolist()
|
242 |
+
|
243 |
+
# cal attention_score
|
244 |
+
hyps_pad = pad_sequence([
|
245 |
+
torch.tensor(hyp, device=device, dtype=torch.long) for hyp in hyps
|
246 |
+
], True, self.ignore_id) # (beam_size, max_hyps_len)
|
247 |
+
ori_hyps_pad = hyps_pad
|
248 |
+
hyps_lens = torch.tensor([len(hyp) for hyp in hyps],
|
249 |
+
device=device,
|
250 |
+
dtype=torch.long) # (beam_size,)
|
251 |
+
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
|
252 |
+
hyps_lens = hyps_lens + 1 # Add <sos> at begining
|
253 |
+
encoder_out_repeat = []
|
254 |
+
tot_scores = nbest.tot_scores()
|
255 |
+
repeats = [tot_scores[i].shape[0] for i in range(tot_scores.dim0)]
|
256 |
+
for i in range(len(encoder_out)):
|
257 |
+
encoder_out_repeat.append(encoder_out[i:i + 1].repeat(
|
258 |
+
repeats[i], 1, 1))
|
259 |
+
encoder_out = torch.concat(encoder_out_repeat, dim=0)
|
260 |
+
encoder_mask = torch.ones(encoder_out.size(0),
|
261 |
+
1,
|
262 |
+
encoder_out.size(1),
|
263 |
+
dtype=torch.bool,
|
264 |
+
device=device)
|
265 |
+
# used for right to left decoder
|
266 |
+
r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id)
|
267 |
+
r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos,
|
268 |
+
self.ignore_id)
|
269 |
+
reverse_weight = 0.5
|
270 |
+
decoder_out, r_decoder_out, _ = self.decoder(
|
271 |
+
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
|
272 |
+
reverse_weight) # (beam_size, max_hyps_len, vocab_size)
|
273 |
+
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
|
274 |
+
decoder_out = decoder_out
|
275 |
+
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
|
276 |
+
# conventional transformer decoder.
|
277 |
+
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
|
278 |
+
r_decoder_out = r_decoder_out
|
279 |
+
|
280 |
+
decoder_scores = torch.tensor([
|
281 |
+
sum([decoder_out[i, j, hyps[i][j]] for j in range(len(hyps[i]))])
|
282 |
+
for i in range(len(hyps))
|
283 |
+
],
|
284 |
+
device=device) # noqa
|
285 |
+
r_decoder_scores = []
|
286 |
+
for i in range(len(hyps)):
|
287 |
+
score = 0
|
288 |
+
for j in range(len(hyps[i])):
|
289 |
+
score += r_decoder_out[i, len(hyps[i]) - j - 1, hyps[i][j]]
|
290 |
+
score += r_decoder_out[i, len(hyps[i]), self.eos]
|
291 |
+
r_decoder_scores.append(score)
|
292 |
+
r_decoder_scores = torch.tensor(r_decoder_scores, device=device)
|
293 |
+
|
294 |
+
am_scores = nbest.compute_am_scores()
|
295 |
+
ngram_lm_scores = nbest.compute_lm_scores()
|
296 |
+
tot_scores = am_scores.values + lm_scale * ngram_lm_scores.values + \
|
297 |
+
decoder_scale * decoder_scores + r_decoder_scale * r_decoder_scores
|
298 |
+
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
299 |
+
max_indexes = ragged_tot_scores.argmax()
|
300 |
+
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
301 |
+
hyps = icefall.utils.get_texts(best_path)
|
302 |
+
hyps = [[symbol_table[k] for j in i for k in self.word_table[j]]
|
303 |
+
for i in hyps]
|
304 |
+
return hyps
|
wenet/llm_asr/__init__.py
ADDED
File without changes
|