andybi7676 commited on
Commit
e6b6af6
·
verified ·
1 Parent(s): 23b16e7

end-2-end reborn model for mls-french unsupervised phoneme recognition (iter2-stage1)

Browse files
Files changed (4) hide show
  1. config.json +92 -0
  2. configuration_reborn.py +105 -0
  3. modeling_reborn.py +381 -0
  4. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RebornUASRModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_reborn.RebornUASRConfig",
7
+ "AutoModel": "modeling_reborn.RebornUASRModel"
8
+ },
9
+ "discriminator_act_after_linear": false,
10
+ "discriminator_causal": true,
11
+ "discriminator_depth": 1,
12
+ "discriminator_dilation": 1,
13
+ "discriminator_dim": 256,
14
+ "discriminator_dropout": 0.0,
15
+ "discriminator_input_dim": 55,
16
+ "discriminator_kernel": 3,
17
+ "discriminator_linear_emb": false,
18
+ "discriminator_max_pool": false,
19
+ "discriminator_spectral_norm": false,
20
+ "discriminator_weight_norm": false,
21
+ "generator_bias": false,
22
+ "generator_bn_apply": false,
23
+ "generator_bn_init_weight": 30.0,
24
+ "generator_dilation": 1,
25
+ "generator_dropout": 0.0,
26
+ "generator_input_dim": 512,
27
+ "generator_kernel": 4,
28
+ "generator_output_dim": 55,
29
+ "generator_stride": 1,
30
+ "model_type": "reborn_uasr",
31
+ "phones": [
32
+ "\u0281",
33
+ "a",
34
+ "i",
35
+ "t",
36
+ "s",
37
+ "l",
38
+ "\u025b",
39
+ "e",
40
+ "k",
41
+ "d",
42
+ "n",
43
+ "m",
44
+ "o",
45
+ "\u0251\u0303",
46
+ "p",
47
+ "j",
48
+ "b",
49
+ "y",
50
+ "\u0254",
51
+ "v",
52
+ "\u0259",
53
+ "z",
54
+ "f",
55
+ "\u0254\u0303",
56
+ "\u0261",
57
+ "u",
58
+ "\u0292",
59
+ "w",
60
+ "\u025b\u0303",
61
+ "\u0283",
62
+ "\u00f8",
63
+ "\u0153",
64
+ "\u026a",
65
+ "\u0279",
66
+ "a\u02d0",
67
+ "i\u02d0",
68
+ "\u03b8",
69
+ "\u0272",
70
+ "e\u026a",
71
+ "\u0252",
72
+ "\u0259\u028a",
73
+ "\u0153\u0303",
74
+ "u\u02d0",
75
+ "\u0254\u02d0",
76
+ "a\u026a",
77
+ "h",
78
+ "\u014b",
79
+ "\u0251\u02d0",
80
+ "\u028c",
81
+ "\u025c\u02d0",
82
+ "<SIL>"
83
+ ],
84
+ "segmenter_dropout": 0.1,
85
+ "segmenter_hidden_dim": 512,
86
+ "segmenter_input_dim": 512,
87
+ "segmenter_kernel_size": 7,
88
+ "segmenter_type": "cnn",
89
+ "special_token_nums": 4,
90
+ "torch_dtype": "float32",
91
+ "transformers_version": "4.24.0"
92
+ }
configuration_reborn.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from transformers import PretrainedConfig
4
+
5
+ class RebornUASRConfig(PretrainedConfig):
6
+ '''
7
+ We can use this class to define the configuration of the reborn model.
8
+ The reborn UASR is composed of a segmenter, a discriminator, and a generator.
9
+ We only include the required configurations for the discriminator and the generator from fairseq's wav2vec-U model configuration.
10
+ '''
11
+ model_type = "reborn_uasr"
12
+
13
+ def __init__(self,
14
+ segmenter_type: str = "cnn",
15
+ segmenter_input_dim: int = 512,
16
+ segmenter_hidden_dim: int = 512,
17
+ segmenter_dropout: float = 0.1,
18
+ segmenter_kernel_size: int = 7,
19
+
20
+ discriminator_input_dim: int = 512,
21
+ discriminator_kernel: int = 3,
22
+ discriminator_dilation: int = 1,
23
+ discriminator_dim: int = 256,
24
+ discriminator_causal: bool = True,
25
+ discriminator_linear_emb: bool = False,
26
+ discriminator_depth: int = 1,
27
+ discriminator_max_pool: bool = False,
28
+ discriminator_act_after_linear: bool = False,
29
+ discriminator_dropout: float = 0.0,
30
+ discriminator_spectral_norm: bool = False,
31
+ discriminator_weight_norm: bool = False,
32
+
33
+ generator_input_dim: int = 512,
34
+ generator_output_dim: int = 40,
35
+ generator_kernel: int = 4,
36
+ generator_dilation: int = 1,
37
+ generator_stride: int = 1,
38
+ generator_bias: bool = False,
39
+ generator_dropout: float = 0.0,
40
+ generator_bn_apply: bool = False,
41
+ generator_bn_init_weight: float = 30.0,
42
+
43
+ phones: list = [],
44
+ dict_fpath: str = "",
45
+ special_token_nums: int = 4, # [<s>, <pad>, </s>, <unk>]
46
+ **kwargs
47
+ ):
48
+ super().__init__(**kwargs)
49
+ # read in all the configurations
50
+ self.segmenter_type = segmenter_type
51
+ self.segmenter_input_dim = segmenter_input_dim
52
+ self.segmenter_hidden_dim = segmenter_hidden_dim
53
+ self.segmenter_dropout = segmenter_dropout
54
+ self.segmenter_kernel_size = segmenter_kernel_size
55
+
56
+ self.discriminator_input_dim = discriminator_input_dim
57
+ self.discriminator_kernel = discriminator_kernel
58
+ self.discriminator_dilation = discriminator_dilation
59
+ self.discriminator_dim = discriminator_dim
60
+ self.discriminator_causal = discriminator_causal
61
+ self.discriminator_linear_emb = discriminator_linear_emb
62
+ self.discriminator_depth = discriminator_depth
63
+ self.discriminator_max_pool = discriminator_max_pool
64
+ self.discriminator_act_after_linear = discriminator_act_after_linear
65
+ self.discriminator_dropout = discriminator_dropout
66
+ self.discriminator_spectral_norm = discriminator_spectral_norm
67
+ self.discriminator_weight_norm = discriminator_weight_norm
68
+
69
+ self.generator_input_dim = generator_input_dim
70
+ self.generator_output_dim = generator_output_dim
71
+ self.generator_kernel = generator_kernel
72
+ self.generator_dilation = generator_dilation
73
+ self.generator_stride = generator_stride
74
+ self.generator_bias = generator_bias
75
+ self.generator_dropout = generator_dropout
76
+ self.generator_bn_apply = generator_bn_apply
77
+ self.generator_bn_init_weight = generator_bn_init_weight
78
+
79
+ self.special_token_nums = special_token_nums
80
+ if os.path.isfile(dict_fpath):
81
+ self.phones = self.read_phns_dict_from_fpath(dict_fpath)
82
+ else:
83
+ self.phones = phones
84
+ if len(self.phones) > 0:
85
+ self.generator_output_dim = len(self.phones) + self.special_token_nums
86
+ self.discriminator_input_dim = self.generator_output_dim
87
+
88
+ def read_phns_dict_from_fpath(self, fpath: str):
89
+ phns = []
90
+ with open(fpath, "r", encoding="utf-8") as f:
91
+ for l in f:
92
+ phn = l.strip().split('\t')[0].split(' ')[0]
93
+ phns.append(phn)
94
+ return phns
95
+
96
+ def main():
97
+ config = RebornUASRConfig(dict_fpath="/home/andybi7676/Desktop/uasr-rl/data/fr_mls/text/prep/phones/dict.phn.txt")
98
+ print(config)
99
+ output_fpath = "./reborn_uasr_configs/config_mls-fr.json"
100
+ with open(output_fpath, 'w', encoding='utf-8') as fw:
101
+ config_json_string = json.dumps(config.to_dict(), indent=2, sort_keys=True) + "\n"
102
+ fw.write(config_json_string)
103
+
104
+ if __name__ == "__main__":
105
+ main()
modeling_reborn.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ from .configuration_reborn import RebornUASRConfig
5
+ from typing import Optional, Tuple, Union, List
6
+
7
+ class RebornSegmenter(nn.Module):
8
+ def __init__(self, config):
9
+ super().__init__()
10
+ self.config = config
11
+ self.conv1 = nn.Conv1d(config.segmenter_input_dim, config.segmenter_hidden_dim, config.segmenter_kernel_size, padding=config.segmenter_kernel_size//2)
12
+ self.conv2 = nn.Conv1d(config.segmenter_hidden_dim, config.segmenter_hidden_dim, 3, padding=1)
13
+ self.conv3 = nn.Conv1d(config.segmenter_hidden_dim, 2, 1)
14
+ self.dropout = nn.Dropout(config.segmenter_dropout)
15
+ self.relu = nn.ReLU()
16
+
17
+ def forward(self, x):
18
+ """
19
+ Input:
20
+ x: (B, T, C)
21
+ padding_mask: (B, T) # 0: not padding; 1: padding
22
+ Output:
23
+ boundary: (B, T, 2) # 0: not boundary; 1: boundary
24
+ """
25
+ x = x.transpose(1, 2)
26
+ x = self.dropout(self.relu(self.conv1(x)))
27
+ x = self.dropout(self.relu(self.conv2(x)))
28
+ x = self.conv3(x)
29
+ x = x.transpose(1, 2)
30
+ return x
31
+
32
+ def boundary_predict(self, x, padding_mask, deterministic=False):
33
+ """
34
+ Input:
35
+ x: (B, T, C)
36
+ padding_mask: (B, T)
37
+ Output:
38
+ boundary: (B, T) # 0: not boundary; 1: boundary
39
+ boundary_logits: (B, T, 2) # 0: not boundary; 1: boundary
40
+ """
41
+ boundary_logits = self.forward(x)
42
+ if deterministic:
43
+ boundary = boundary_logits.argmax(-1)
44
+ boundary[padding_mask] = -1
45
+ else:
46
+ boundary = torch.distributions.Categorical(logits=boundary_logits).sample()
47
+ boundary[padding_mask] = -1
48
+ return boundary, boundary_logits
49
+
50
+ def pre_segment(self, logits, padding_mask, return_boundary=False, deterministic=True):
51
+ """
52
+ Input:
53
+ logits: (B, T, C)
54
+ padding_mask: (B, T)
55
+ Output:
56
+ new_logits: (B, T', C)
57
+ new_padding_mask: (B, T')
58
+ """
59
+
60
+ bsz, tsz, csz = logits.size()
61
+
62
+ boundary, boundary_logits = self.boundary_predict(logits, padding_mask, deterministic=deterministic)
63
+
64
+ # max boundary number
65
+ # print("boundary", boundary)
66
+ # print(torch.sum(boundary==1, dim=1))
67
+ new_tsz = int(torch.max(torch.sum(boundary==1, dim=1)).item())+1 # add <bos>
68
+ new_logits = logits.new_zeros(bsz, new_tsz, csz)
69
+ new_pad = padding_mask.new_zeros(bsz, new_tsz)
70
+
71
+ for b in range(bsz):
72
+ # merge consecutive segments when meeting a boundary (mean_pool_join)
73
+ new_idx = 0
74
+ count = 0
75
+ for t in range(tsz):
76
+ if padding_mask[b, t] == 1:
77
+ break
78
+ if boundary[b, t] == 1:
79
+ new_logits[b, new_idx] /= count
80
+ new_idx += 1
81
+ count = 0
82
+ new_logits[b, new_idx] += logits[b, t]
83
+ count += 1
84
+ if count > 0:
85
+ # last segment
86
+ new_logits[b, new_idx] /= count
87
+ new_idx += 1
88
+ count = 0
89
+ if new_idx < new_tsz:
90
+ pad = new_tsz - new_idx
91
+ new_logits[b, -pad:] = 0
92
+ new_pad[b, -pad:] = True
93
+
94
+ if return_boundary:
95
+ return new_logits, new_pad, boundary, boundary_logits
96
+ return new_logits, new_pad
97
+
98
+ class RebornGenerator(nn.Module):
99
+ def __init__(self, config):
100
+ super().__init__()
101
+
102
+ self.config = config
103
+ self.output_dim = config.generator_output_dim
104
+ self.stride = config.generator_stride
105
+ self.dropout = nn.Dropout(config.generator_dropout)
106
+ cnn_input_dim = config.generator_input_dim
107
+ cnn_output_dim = config.generator_output_dim
108
+
109
+ padding = config.generator_kernel // 2
110
+ self.proj = nn.Sequential(
111
+ nn.Conv1d(
112
+ cnn_input_dim,
113
+ cnn_output_dim,
114
+ kernel_size=config.generator_kernel,
115
+ stride=config.generator_stride,
116
+ dilation=config.generator_dilation,
117
+ padding=padding,
118
+ bias=config.generator_bias,
119
+ ),
120
+ )
121
+
122
+ def forward(self, dense_x, tokens, dense_padding_mask):
123
+ dense_x = self.dropout(dense_x)
124
+ # (B, T, C) -> (B, C, T)
125
+ dense_x = dense_x.transpose(-2, -1)
126
+
127
+ dense_x = self.proj(dense_x)
128
+ # (B, C, T) -> (B, T, C)
129
+ dense_x = dense_x.transpose(-2, -1)
130
+ if self.stride > 1:
131
+ dense_padding_mask = dense_padding_mask[:, :: self.stride]
132
+
133
+ if dense_padding_mask.size(1) != dense_x.size(1):
134
+ new_padding = dense_padding_mask.new_zeros(dense_x.shape[:-1])
135
+ diff = new_padding.size(1) - dense_padding_mask.size(1)
136
+ assert (
137
+ diff > 0
138
+ ), f"{new_padding.shape}, {dense_padding_mask.shape}, {dense_x.shape}, {diff}"
139
+ if diff > 0:
140
+ new_padding[:, diff:] = dense_padding_mask
141
+ else:
142
+ assert diff < 0
143
+ new_padding = dense_padding_mask[:, :diff]
144
+
145
+ dense_padding_mask = new_padding
146
+
147
+ result = {}
148
+
149
+ token_x = None
150
+ if tokens is not None:
151
+ token_x = dense_x.new_zeros(tokens.numel(), self.output_dim)
152
+ token_x.scatter_(1, tokens.view(-1, 1).long(), 1)
153
+ token_x = token_x.view(tokens.shape + (self.output_dim,))
154
+
155
+ result["dense_x"] = dense_x
156
+ result["token_x"] = token_x
157
+ result["dense_padding_mask"] = dense_padding_mask
158
+
159
+ return result
160
+
161
+ def get_item(tensor):
162
+ # tpu-comment: making this a no-op for xla devices.
163
+ if torch.is_tensor(tensor) and tensor.device.type == "xla":
164
+ return tensor.detach()
165
+ if hasattr(tensor, "item"):
166
+ return tensor.item()
167
+ if hasattr(tensor, "__getitem__"):
168
+ return tensor[0]
169
+ return tensor
170
+
171
+ def post_process(sentence: str, symbol: str):
172
+ if symbol == "sentencepiece":
173
+ sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
174
+ elif symbol == "wordpiece":
175
+ sentence = sentence.replace(" ", "").replace("_", " ").strip()
176
+ elif symbol == "letter":
177
+ sentence = sentence.replace(" ", "").replace("|", " ").strip()
178
+ elif symbol == "silence":
179
+ import re
180
+ sentence = sentence.replace("<SIL>", "")
181
+ sentence = re.sub(' +', ' ', sentence).strip()
182
+ elif symbol == "_EOW":
183
+ sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
184
+ elif symbol in {"subword_nmt", "@@ ", "@@"}:
185
+ if symbol == "subword_nmt":
186
+ symbol = "@@ "
187
+ sentence = (sentence + " ").replace(symbol, "").rstrip()
188
+ elif symbol == "none":
189
+ pass
190
+ elif symbol is not None:
191
+ raise NotImplementedError(f"Unknown post_process option: {symbol}")
192
+ return sentence
193
+
194
+ class SimpleTokenizer(object):
195
+ def __init__(self,
196
+ phones: List[str],
197
+ bos="<s>",
198
+ pad="<pad>",
199
+ eos="</s>",
200
+ unk="<unk>",
201
+ extra_special_symbols=None,
202
+ ) -> None:
203
+ self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
204
+ self.symbols = []
205
+ self.count = []
206
+ self.indices = {}
207
+ self.bos_index = self.add_symbol(bos)
208
+ self.pad_index = self.add_symbol(pad)
209
+ self.eos_index = self.add_symbol(eos)
210
+ self.unk_index = self.add_symbol(unk)
211
+ if extra_special_symbols:
212
+ for s in extra_special_symbols:
213
+ self.add_symbol(s)
214
+ self.nspecial = len(self.symbols)
215
+ for phone in phones:
216
+ self.add_symbol(phone)
217
+ self.postprocess_code = "silence"
218
+
219
+ def add_symbol(self, word, n=1, overwrite=False):
220
+ """Adds a word to the dictionary"""
221
+ if word in self.indices and not overwrite:
222
+ idx = self.indices[word]
223
+ self.count[idx] = self.count[idx] + n
224
+ return idx
225
+ else:
226
+ idx = len(self.symbols)
227
+ self.indices[word] = idx
228
+ self.symbols.append(word)
229
+ self.count.append(n)
230
+ return idx
231
+
232
+ def __eq__(self, other):
233
+ return self.indices == other.indices
234
+
235
+ def __getitem__(self, idx):
236
+ if idx < len(self.symbols):
237
+ return self.symbols[idx]
238
+ return self.unk_word
239
+
240
+ def get_count(self, idx):
241
+ return self.count[idx]
242
+
243
+ def __len__(self):
244
+ """Returns the number of symbols in the dictionary"""
245
+ return len(self.symbols)
246
+
247
+ def __contains__(self, sym):
248
+ return sym in self.indices
249
+
250
+ def index(self, sym):
251
+ """Returns the index of the specified symbol"""
252
+ assert isinstance(sym, str)
253
+ if sym in self.indices:
254
+ return self.indices[sym]
255
+ return self.unk_index
256
+
257
+ def string(
258
+ self,
259
+ tensor,
260
+ bpe_symbol=None,
261
+ escape_unk=False,
262
+ extra_symbols_to_ignore=None,
263
+ unk_string=None,
264
+ include_eos=False,
265
+ separator=" ",
266
+ ):
267
+ """Helper for converting a tensor of token indices to a string.
268
+
269
+ Can optionally remove BPE symbols or escape <unk> words.
270
+ """
271
+ if torch.is_tensor(tensor) and tensor.dim() == 2:
272
+ return "\n".join(
273
+ self.string(
274
+ t,
275
+ bpe_symbol,
276
+ escape_unk,
277
+ extra_symbols_to_ignore,
278
+ include_eos=include_eos,
279
+ )
280
+ for t in tensor
281
+ )
282
+
283
+ extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
284
+ if not include_eos:
285
+ extra_symbols_to_ignore.add(self.eos())
286
+
287
+ def token_string(i):
288
+ if i == self.unk():
289
+ if unk_string is not None:
290
+ return unk_string
291
+ else:
292
+ return self.unk_string(escape_unk)
293
+ else:
294
+ return self[i]
295
+
296
+ if hasattr(self, "bos_index"):
297
+ extra_symbols_to_ignore.add(self.bos())
298
+
299
+ sent = separator.join(
300
+ token_string(i)
301
+ for i in tensor
302
+ if get_item(i) not in extra_symbols_to_ignore
303
+ )
304
+
305
+ return post_process(sent, bpe_symbol)
306
+
307
+ def unk_string(self, escape=False):
308
+ """Return unknown string, optionally escaped as: <<unk>>"""
309
+ if escape:
310
+ return "<{}>".format(self.unk_word)
311
+ else:
312
+ return self.unk_word
313
+
314
+ def bos(self):
315
+ """Helper to get index of beginning-of-sentence symbol"""
316
+ return self.bos_index
317
+
318
+ def pad(self):
319
+ """Helper to get index of pad symbol"""
320
+ return self.pad_index
321
+
322
+ def eos(self):
323
+ """Helper to get index of end-of-sentence symbol"""
324
+ return self.eos_index
325
+
326
+ def unk(self):
327
+ """Helper to get index of unk symbol"""
328
+ return self.unk_index
329
+
330
+
331
+ class RebornUASRModel(PreTrainedModel):
332
+ config_class = RebornUASRConfig
333
+
334
+ def __init__(self, config):
335
+ super().__init__(config)
336
+ self.pca = nn.Linear(1024, 512)
337
+ self.segmenter = RebornSegmenter(config)
338
+ self.generator = RebornGenerator(config)
339
+ self.tokenizer = None
340
+ if len(config.phones) > 0:
341
+ self.tokenizer = SimpleTokenizer(config.phones)
342
+
343
+ def forward(
344
+ self,
345
+ x: Optional[torch.Tensor], # (B, T, C)
346
+ padding_mask: Optional[torch.Tensor], # (B, T)
347
+ ):
348
+ x_reduced = self.pca(x)
349
+ x_segmented, segmented_padding_mask = self.segmenter.pre_segment(x_reduced, padding_mask, deterministic=True)
350
+ x_generated = self.generator(x_segmented, None, segmented_padding_mask)
351
+
352
+ return {
353
+ 'x_reduced': x_reduced,
354
+ 'x_segmented': x_segmented,
355
+ 'x_generated': x_generated
356
+ }
357
+
358
+ def generate(self, x, padding_mask, merge_consecutive=True, remove_silence=True):
359
+ res = self.forward(x, padding_mask)
360
+ y_raw_logits = res['x_generated']['dense_x']
361
+ y_raw_padding = res['x_generated']['dense_padding_mask']
362
+ y_raw_logits[y_raw_padding][..., self.tokenizer.pad_index] = float('inf')
363
+ preds = y_raw_logits.argmax(-1)
364
+ hyps = []
365
+ postprocess_code = "silence" if remove_silence else "none"
366
+ for pred in preds:
367
+ if merge_consecutive:
368
+ # merge consecutive predictions
369
+ pred = torch.unique_consecutive(pred)
370
+ hyp = self.tokenizer.string(pred, bpe_symbol=postprocess_code)
371
+ hyps.append(hyp)
372
+ return hyps
373
+
374
+ def main():
375
+ model_config = RebornUASRConfig.from_pretrained("/home/andybi7676/Desktop/uasr-rl/reborn_uasr/config.json")
376
+ print(model_config)
377
+ model = RebornUASRModel(model_config)
378
+ print(model.tokenizer.indices)
379
+
380
+ if __name__ == "__main__":
381
+ main()
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3a6b7ef17c2c5d34c5163fc7f9ba0fd5a4ea51c848766afdbf0f1546a443837
3
+ size 13046733