Xuan2060320350 Plachta commited on
Commit
cd6614b
·
0 Parent(s):

Duplicate from Plachta/VALL-E-X

Browse files

Co-authored-by: ElderFrog <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Songting
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: VALL E X
3
+ emoji: 🎙
4
+ colorFrom: green
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.39.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: Plachta/VALL-E-X
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import data, models, modules, utils
app.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import time
6
+ import tempfile
7
+ import platform
8
+ if platform.system().lower() == 'windows':
9
+ temp = pathlib.PosixPath
10
+ pathlib.PosixPath = pathlib.WindowsPath
11
+ elif platform.system().lower() == 'linux':
12
+ temp = pathlib.WindowsPath
13
+ pathlib.WindowsPath = pathlib.PosixPath
14
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
15
+
16
+ import langid
17
+ langid.set_languages(['en', 'zh', 'ja'])
18
+
19
+ import torch
20
+ import torchaudio
21
+ import random
22
+
23
+ import numpy as np
24
+
25
+ from data.tokenizer import (
26
+ AudioTokenizer,
27
+ tokenize_audio,
28
+ )
29
+ from data.collation import get_text_token_collater
30
+ from models.vallex import VALLE
31
+ from utils.g2p import PhonemeBpeTokenizer
32
+ from descriptions import *
33
+ from macros import *
34
+
35
+ import gradio as gr
36
+ import whisper
37
+ import multiprocessing
38
+
39
+ thread_count = multiprocessing.cpu_count()
40
+
41
+ print("Use",thread_count,"cpu cores for computing")
42
+
43
+ torch.set_num_threads(thread_count)
44
+ torch.set_num_interop_threads(thread_count)
45
+ torch._C._jit_set_profiling_executor(False)
46
+ torch._C._jit_set_profiling_mode(False)
47
+ torch._C._set_graph_executor_optimize(False)
48
+
49
+ text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json")
50
+ text_collater = get_text_token_collater()
51
+
52
+ device = torch.device("cpu")
53
+ if torch.cuda.is_available():
54
+ device = torch.device("cuda", 0)
55
+
56
+ # VALL-E-X model
57
+ model = VALLE(
58
+ N_DIM,
59
+ NUM_HEAD,
60
+ NUM_LAYERS,
61
+ norm_first=True,
62
+ add_prenet=False,
63
+ prefix_mode=PREFIX_MODE,
64
+ share_embedding=True,
65
+ nar_scale_factor=1.0,
66
+ prepend_bos=True,
67
+ num_quantizers=NUM_QUANTIZERS,
68
+ )
69
+ checkpoint = torch.load("./epoch-10.pt", map_location='cpu')
70
+ missing_keys, unexpected_keys = model.load_state_dict(
71
+ checkpoint["model"], strict=True
72
+ )
73
+ assert not missing_keys
74
+ model.eval()
75
+
76
+ # Encodec model
77
+ audio_tokenizer = AudioTokenizer(device)
78
+
79
+ # ASR
80
+ whisper_model = whisper.load_model("medium").cpu()
81
+
82
+ # Voice Presets
83
+ preset_list = os.walk("./presets/").__next__()[2]
84
+ preset_list = [preset[:-4] for preset in preset_list if preset.endswith(".npz")]
85
+
86
+ def clear_prompts():
87
+ try:
88
+ path = tempfile.gettempdir()
89
+ for eachfile in os.listdir(path):
90
+ filename = os.path.join(path, eachfile)
91
+ if os.path.isfile(filename) and filename.endswith(".npz"):
92
+ lastmodifytime = os.stat(filename).st_mtime
93
+ endfiletime = time.time() - 60
94
+ if endfiletime > lastmodifytime:
95
+ os.remove(filename)
96
+ except:
97
+ return
98
+
99
+ def transcribe_one(model, audio_path):
100
+ # load audio and pad/trim it to fit 30 seconds
101
+ audio = whisper.load_audio(audio_path)
102
+ audio = whisper.pad_or_trim(audio)
103
+
104
+ # make log-Mel spectrogram and move to the same device as the model
105
+ mel = whisper.log_mel_spectrogram(audio).to(model.device)
106
+
107
+ # detect the spoken language
108
+ _, probs = model.detect_language(mel)
109
+ print(f"Detected language: {max(probs, key=probs.get)}")
110
+ lang = max(probs, key=probs.get)
111
+ # decode the audio
112
+ options = whisper.DecodingOptions(temperature=1.0, best_of=5, fp16=False if device == torch.device("cpu") else True, sample_len=150)
113
+ result = whisper.decode(model, mel, options)
114
+
115
+ # print the recognized text
116
+ print(result.text)
117
+
118
+ text_pr = result.text
119
+ if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
120
+ text_pr += "."
121
+ return lang, text_pr
122
+
123
+ def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
124
+ global model, text_collater, text_tokenizer, audio_tokenizer
125
+ clear_prompts()
126
+ audio_prompt = uploaded_audio if uploaded_audio is not None else recorded_audio
127
+ sr, wav_pr = audio_prompt
128
+ if len(wav_pr) / sr > 15:
129
+ return "Rejected, Audio too long (should be less than 15 seconds)", None
130
+ if not isinstance(wav_pr, torch.FloatTensor):
131
+ wav_pr = torch.FloatTensor(wav_pr)
132
+ if wav_pr.abs().max() > 1:
133
+ wav_pr /= wav_pr.abs().max()
134
+ if wav_pr.size(-1) == 2:
135
+ wav_pr = wav_pr[:, 0]
136
+ if wav_pr.ndim == 1:
137
+ wav_pr = wav_pr.unsqueeze(0)
138
+ assert wav_pr.ndim and wav_pr.size(0) == 1
139
+
140
+ if transcript_content == "":
141
+ text_pr, lang_pr = make_prompt(name, wav_pr, sr, save=False)
142
+ else:
143
+ lang_pr = langid.classify(str(transcript_content))[0]
144
+ lang_token = lang2token[lang_pr]
145
+ text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
146
+ # tokenize audio
147
+ encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
148
+ audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
149
+
150
+ # tokenize text
151
+ phonemes, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
152
+ text_tokens, enroll_x_lens = text_collater(
153
+ [
154
+ phonemes
155
+ ]
156
+ )
157
+
158
+ message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n"
159
+
160
+ # save as npz file
161
+ np.savez(os.path.join(tempfile.gettempdir(), f"{name}.npz"),
162
+ audio_tokens=audio_tokens, text_tokens=text_tokens, lang_code=lang2code[lang_pr])
163
+ return message, os.path.join(tempfile.gettempdir(), f"{name}.npz")
164
+
165
+
166
+ def make_prompt(name, wav, sr, save=True):
167
+ global whisper_model
168
+ whisper_model.to(device)
169
+ if not isinstance(wav, torch.FloatTensor):
170
+ wav = torch.tensor(wav)
171
+ if wav.abs().max() > 1:
172
+ wav /= wav.abs().max()
173
+ if wav.size(-1) == 2:
174
+ wav = wav.mean(-1, keepdim=False)
175
+ if wav.ndim == 1:
176
+ wav = wav.unsqueeze(0)
177
+ assert wav.ndim and wav.size(0) == 1
178
+ torchaudio.save(f"./prompts/{name}.wav", wav, sr)
179
+ lang, text = transcribe_one(whisper_model, f"./prompts/{name}.wav")
180
+ lang_token = lang2token[lang]
181
+ text = lang_token + text + lang_token
182
+ with open(f"./prompts/{name}.txt", 'w') as f:
183
+ f.write(text)
184
+ if not save:
185
+ os.remove(f"./prompts/{name}.wav")
186
+ os.remove(f"./prompts/{name}.txt")
187
+
188
+ whisper_model.cpu()
189
+ torch.cuda.empty_cache()
190
+ return text, lang
191
+
192
+ @torch.no_grad()
193
+ def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, transcript_content):
194
+ if len(text) > 150:
195
+ return "Rejected, Text too long (should be less than 150 characters)", None
196
+ global model, text_collater, text_tokenizer, audio_tokenizer
197
+ model.to(device)
198
+ audio_prompt = audio_prompt if audio_prompt is not None else record_audio_prompt
199
+ sr, wav_pr = audio_prompt
200
+ if len(wav_pr) / sr > 15:
201
+ return "Rejected, Audio too long (should be less than 15 seconds)", None
202
+ if not isinstance(wav_pr, torch.FloatTensor):
203
+ wav_pr = torch.FloatTensor(wav_pr)
204
+ if wav_pr.abs().max() > 1:
205
+ wav_pr /= wav_pr.abs().max()
206
+ if wav_pr.size(-1) == 2:
207
+ wav_pr = wav_pr[:, 0]
208
+ if wav_pr.ndim == 1:
209
+ wav_pr = wav_pr.unsqueeze(0)
210
+ assert wav_pr.ndim and wav_pr.size(0) == 1
211
+
212
+ if transcript_content == "":
213
+ text_pr, lang_pr = make_prompt('dummy', wav_pr, sr, save=False)
214
+ else:
215
+ lang_pr = langid.classify(str(transcript_content))[0]
216
+ lang_token = lang2token[lang_pr]
217
+ text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
218
+
219
+ if language == 'auto-detect':
220
+ lang_token = lang2token[langid.classify(text)[0]]
221
+ else:
222
+ lang_token = langdropdown2token[language]
223
+ lang = token2lang[lang_token]
224
+ text = lang_token + text + lang_token
225
+
226
+ # onload model
227
+ model.to(device)
228
+
229
+ # tokenize audio
230
+ encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
231
+ audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
232
+
233
+ # tokenize text
234
+ logging.info(f"synthesize text: {text}")
235
+ phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
236
+ text_tokens, text_tokens_lens = text_collater(
237
+ [
238
+ phone_tokens
239
+ ]
240
+ )
241
+
242
+ enroll_x_lens = None
243
+ if text_pr:
244
+ text_prompts, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
245
+ text_prompts, enroll_x_lens = text_collater(
246
+ [
247
+ text_prompts
248
+ ]
249
+ )
250
+ text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
251
+ text_tokens_lens += enroll_x_lens
252
+ lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
253
+ encoded_frames = model.inference(
254
+ text_tokens.to(device),
255
+ text_tokens_lens.to(device),
256
+ audio_prompts,
257
+ enroll_x_lens=enroll_x_lens,
258
+ top_k=-100,
259
+ temperature=1,
260
+ prompt_language=lang_pr,
261
+ text_language=langs if accent == "no-accent" else lang,
262
+ )
263
+ samples = audio_tokenizer.decode(
264
+ [(encoded_frames.transpose(2, 1), None)]
265
+ )
266
+
267
+ # offload model
268
+ model.to('cpu')
269
+ torch.cuda.empty_cache()
270
+
271
+ message = f"text prompt: {text_pr}\nsythesized text: {text}"
272
+ return message, (24000, samples[0][0].cpu().numpy())
273
+
274
+ @torch.no_grad()
275
+ def infer_from_prompt(text, language, accent, preset_prompt, prompt_file):
276
+ if len(text) > 150:
277
+ return "Rejected, Text too long (should be less than 150 characters)", None
278
+ clear_prompts()
279
+ model.to(device)
280
+ # text to synthesize
281
+ if language == 'auto-detect':
282
+ lang_token = lang2token[langid.classify(text)[0]]
283
+ else:
284
+ lang_token = langdropdown2token[language]
285
+ lang = token2lang[lang_token]
286
+ text = lang_token + text + lang_token
287
+
288
+ # load prompt
289
+ if prompt_file is not None:
290
+ prompt_data = np.load(prompt_file.name)
291
+ else:
292
+ prompt_data = np.load(os.path.join("./presets/", f"{preset_prompt}.npz"))
293
+ audio_prompts = prompt_data['audio_tokens']
294
+ text_prompts = prompt_data['text_tokens']
295
+ lang_pr = prompt_data['lang_code']
296
+ lang_pr = code2lang[int(lang_pr)]
297
+
298
+ # numpy to tensor
299
+ audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
300
+ text_prompts = torch.tensor(text_prompts).type(torch.int32)
301
+
302
+ enroll_x_lens = text_prompts.shape[-1]
303
+ logging.info(f"synthesize text: {text}")
304
+ phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
305
+ text_tokens, text_tokens_lens = text_collater(
306
+ [
307
+ phone_tokens
308
+ ]
309
+ )
310
+ text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
311
+ text_tokens_lens += enroll_x_lens
312
+ # accent control
313
+ lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
314
+ encoded_frames = model.inference(
315
+ text_tokens.to(device),
316
+ text_tokens_lens.to(device),
317
+ audio_prompts,
318
+ enroll_x_lens=enroll_x_lens,
319
+ top_k=-100,
320
+ temperature=1,
321
+ prompt_language=lang_pr,
322
+ text_language=langs if accent == "no-accent" else lang,
323
+ )
324
+ samples = audio_tokenizer.decode(
325
+ [(encoded_frames.transpose(2, 1), None)]
326
+ )
327
+ model.to('cpu')
328
+ torch.cuda.empty_cache()
329
+
330
+ message = f"sythesized text: {text}"
331
+ return message, (24000, samples[0][0].cpu().numpy())
332
+
333
+
334
+ from utils.sentence_cutter import split_text_into_sentences
335
+ @torch.no_grad()
336
+ def infer_long_text(text, preset_prompt, prompt=None, language='auto', accent='no-accent'):
337
+ """
338
+ For long audio generation, two modes are available.
339
+ fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence.
340
+ sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance.
341
+ """
342
+ if len(text) > 1000:
343
+ return "Rejected, Text too long (should be less than 1000 characters)", None
344
+ mode = 'fixed-prompt'
345
+ global model, audio_tokenizer, text_tokenizer, text_collater
346
+ model.to(device)
347
+ if (prompt is None or prompt == "") and preset_prompt == "":
348
+ mode = 'sliding-window' # If no prompt is given, use sliding-window mode
349
+ sentences = split_text_into_sentences(text)
350
+ # detect language
351
+ if language == "auto-detect":
352
+ language = langid.classify(text)[0]
353
+ else:
354
+ language = token2lang[langdropdown2token[language]]
355
+
356
+ # if initial prompt is given, encode it
357
+ if prompt is not None and prompt != "":
358
+ # load prompt
359
+ prompt_data = np.load(prompt.name)
360
+ audio_prompts = prompt_data['audio_tokens']
361
+ text_prompts = prompt_data['text_tokens']
362
+ lang_pr = prompt_data['lang_code']
363
+ lang_pr = code2lang[int(lang_pr)]
364
+
365
+ # numpy to tensor
366
+ audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
367
+ text_prompts = torch.tensor(text_prompts).type(torch.int32)
368
+ elif preset_prompt is not None and preset_prompt != "":
369
+ prompt_data = np.load(os.path.join("./presets/", f"{preset_prompt}.npz"))
370
+ audio_prompts = prompt_data['audio_tokens']
371
+ text_prompts = prompt_data['text_tokens']
372
+ lang_pr = prompt_data['lang_code']
373
+ lang_pr = code2lang[int(lang_pr)]
374
+
375
+ # numpy to tensor
376
+ audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
377
+ text_prompts = torch.tensor(text_prompts).type(torch.int32)
378
+ else:
379
+ audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
380
+ text_prompts = torch.zeros([1, 0]).type(torch.int32)
381
+ lang_pr = language if language != 'mix' else 'en'
382
+ if mode == 'fixed-prompt':
383
+ complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
384
+ for text in sentences:
385
+ text = text.replace("\n", "").strip(" ")
386
+ if text == "":
387
+ continue
388
+ lang_token = lang2token[language]
389
+ lang = token2lang[lang_token]
390
+ text = lang_token + text + lang_token
391
+
392
+ enroll_x_lens = text_prompts.shape[-1]
393
+ logging.info(f"synthesize text: {text}")
394
+ phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
395
+ text_tokens, text_tokens_lens = text_collater(
396
+ [
397
+ phone_tokens
398
+ ]
399
+ )
400
+ text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
401
+ text_tokens_lens += enroll_x_lens
402
+ # accent control
403
+ lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
404
+ encoded_frames = model.inference(
405
+ text_tokens.to(device),
406
+ text_tokens_lens.to(device),
407
+ audio_prompts,
408
+ enroll_x_lens=enroll_x_lens,
409
+ top_k=-100,
410
+ temperature=1,
411
+ prompt_language=lang_pr,
412
+ text_language=langs if accent == "no-accent" else lang,
413
+ )
414
+ complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
415
+ samples = audio_tokenizer.decode(
416
+ [(complete_tokens, None)]
417
+ )
418
+ model.to('cpu')
419
+ message = f"Cut into {len(sentences)} sentences"
420
+ return message, (24000, samples[0][0].cpu().numpy())
421
+ elif mode == "sliding-window":
422
+ complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
423
+ original_audio_prompts = audio_prompts
424
+ original_text_prompts = text_prompts
425
+ for text in sentences:
426
+ text = text.replace("\n", "").strip(" ")
427
+ if text == "":
428
+ continue
429
+ lang_token = lang2token[language]
430
+ lang = token2lang[lang_token]
431
+ text = lang_token + text + lang_token
432
+
433
+ enroll_x_lens = text_prompts.shape[-1]
434
+ logging.info(f"synthesize text: {text}")
435
+ phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
436
+ text_tokens, text_tokens_lens = text_collater(
437
+ [
438
+ phone_tokens
439
+ ]
440
+ )
441
+ text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
442
+ text_tokens_lens += enroll_x_lens
443
+ # accent control
444
+ lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
445
+ encoded_frames = model.inference(
446
+ text_tokens.to(device),
447
+ text_tokens_lens.to(device),
448
+ audio_prompts,
449
+ enroll_x_lens=enroll_x_lens,
450
+ top_k=-100,
451
+ temperature=1,
452
+ prompt_language=lang_pr,
453
+ text_language=langs if accent == "no-accent" else lang,
454
+ )
455
+ complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
456
+ if torch.rand(1) < 1.0:
457
+ audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:]
458
+ text_prompts = text_tokens[:, enroll_x_lens:]
459
+ else:
460
+ audio_prompts = original_audio_prompts
461
+ text_prompts = original_text_prompts
462
+ samples = audio_tokenizer.decode(
463
+ [(complete_tokens, None)]
464
+ )
465
+ model.to('cpu')
466
+ message = f"Cut into {len(sentences)} sentences"
467
+ return message, (24000, samples[0][0].cpu().numpy())
468
+ else:
469
+ raise ValueError(f"No such mode {mode}")
470
+
471
+
472
+ def main():
473
+ app = gr.Blocks()
474
+ with app:
475
+ gr.Markdown(top_md)
476
+ with gr.Tab("Infer from audio"):
477
+ gr.Markdown(infer_from_audio_md)
478
+ with gr.Row():
479
+ with gr.Column():
480
+
481
+ textbox = gr.TextArea(label="Text",
482
+ placeholder="Type your sentence here",
483
+ value="Welcome back, Master. What can I do for you today?", elem_id=f"tts-input")
484
+ language_dropdown = gr.Dropdown(choices=['auto-detect', 'English', '中文', '日本語'], value='English', label='auto-detect')
485
+ accent_dropdown = gr.Dropdown(choices=['no-accent', 'English', '中文', '日本語'], value='no-accent', label='accent')
486
+ textbox_transcript = gr.TextArea(label="Transcript",
487
+ placeholder="Write transcript here. (leave empty to use whisper)",
488
+ value="", elem_id=f"prompt-name")
489
+ upload_audio_prompt = gr.Audio(label='uploaded audio prompt', source='upload', interactive=True)
490
+ record_audio_prompt = gr.Audio(label='recorded audio prompt', source='microphone', interactive=True)
491
+ with gr.Column():
492
+ text_output = gr.Textbox(label="Message")
493
+ audio_output = gr.Audio(label="Output Audio", elem_id="tts-audio")
494
+ btn = gr.Button("Generate!")
495
+ btn.click(infer_from_audio,
496
+ inputs=[textbox, language_dropdown, accent_dropdown, upload_audio_prompt, record_audio_prompt, textbox_transcript],
497
+ outputs=[text_output, audio_output])
498
+ textbox_mp = gr.TextArea(label="Prompt name",
499
+ placeholder="Name your prompt here",
500
+ value="prompt_1", elem_id=f"prompt-name")
501
+ btn_mp = gr.Button("Make prompt!")
502
+ prompt_output = gr.File(interactive=False)
503
+ btn_mp.click(make_npz_prompt,
504
+ inputs=[textbox_mp, upload_audio_prompt, record_audio_prompt, textbox_transcript],
505
+ outputs=[text_output, prompt_output])
506
+ with gr.Tab("Make prompt"):
507
+ gr.Markdown(make_prompt_md)
508
+ with gr.Row():
509
+ with gr.Column():
510
+ textbox2 = gr.TextArea(label="Prompt name",
511
+ placeholder="Name your prompt here",
512
+ value="prompt_1", elem_id=f"prompt-name")
513
+ # 添加选择语言和输入台本的地方
514
+ textbox_transcript2 = gr.TextArea(label="Transcript",
515
+ placeholder="Write transcript here. (leave empty to use whisper)",
516
+ value="", elem_id=f"prompt-name")
517
+ upload_audio_prompt_2 = gr.Audio(label='uploaded audio prompt', source='upload', interactive=True)
518
+ record_audio_prompt_2 = gr.Audio(label='recorded audio prompt', source='microphone', interactive=True)
519
+ with gr.Column():
520
+ text_output_2 = gr.Textbox(label="Message")
521
+ prompt_output_2 = gr.File(interactive=False)
522
+ btn_2 = gr.Button("Make!")
523
+ btn_2.click(make_npz_prompt,
524
+ inputs=[textbox2, upload_audio_prompt_2, record_audio_prompt_2, textbox_transcript2],
525
+ outputs=[text_output_2, prompt_output_2])
526
+ with gr.Tab("Infer from prompt"):
527
+ gr.Markdown(infer_from_prompt_md)
528
+ with gr.Row():
529
+ with gr.Column():
530
+ textbox_3 = gr.TextArea(label="Text",
531
+ placeholder="Type your sentence here",
532
+ value="Welcome back, Master. What can I do for you today?", elem_id=f"tts-input")
533
+ language_dropdown_3 = gr.Dropdown(choices=['auto-detect', 'English', '中文', '日本語', 'Mix'], value='auto-detect',
534
+ label='language')
535
+ accent_dropdown_3 = gr.Dropdown(choices=['no-accent', 'English', '中文', '日本語'], value='no-accent',
536
+ label='accent')
537
+ preset_dropdown_3 = gr.Dropdown(choices=preset_list, value=None, label='Voice preset')
538
+ prompt_file = gr.File(file_count='single', file_types=['.npz'], interactive=True)
539
+ with gr.Column():
540
+ text_output_3 = gr.Textbox(label="Message")
541
+ audio_output_3 = gr.Audio(label="Output Audio", elem_id="tts-audio")
542
+ btn_3 = gr.Button("Generate!")
543
+ btn_3.click(infer_from_prompt,
544
+ inputs=[textbox_3, language_dropdown_3, accent_dropdown_3, preset_dropdown_3, prompt_file],
545
+ outputs=[text_output_3, audio_output_3])
546
+ with gr.Tab("Infer long text"):
547
+ gr.Markdown("This is a long text generation demo. You can use this to generate long audio. ")
548
+ with gr.Row():
549
+ with gr.Column():
550
+ textbox_4 = gr.TextArea(label="Text",
551
+ placeholder="Type your sentence here",
552
+ value=long_text_example, elem_id=f"tts-input")
553
+ language_dropdown_4 = gr.Dropdown(choices=['auto-detect', 'English', '中文', '日本語'], value='auto-detect',
554
+ label='language')
555
+ accent_dropdown_4 = gr.Dropdown(choices=['no-accent', 'English', '中文', '日本語'], value='no-accent',
556
+ label='accent')
557
+ preset_dropdown_4 = gr.Dropdown(choices=preset_list, value=None, label='Voice preset')
558
+ prompt_file_4 = gr.File(file_count='single', file_types=['.npz'], interactive=True)
559
+ with gr.Column():
560
+ text_output_4 = gr.TextArea(label="Message")
561
+ audio_output_4 = gr.Audio(label="Output Audio", elem_id="tts-audio")
562
+ btn_4 = gr.Button("Generate!")
563
+ btn_4.click(infer_long_text,
564
+ inputs=[textbox_4, preset_dropdown_4, prompt_file_4, language_dropdown_4, accent_dropdown_4],
565
+ outputs=[text_output_4, audio_output_4])
566
+
567
+ app.launch()
568
+
569
+ if __name__ == "__main__":
570
+ formatter = (
571
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
572
+ )
573
+ logging.basicConfig(format=formatter, level=logging.INFO)
574
+ main()
data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .collation import *
data/collation.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import List, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ class TextTokenCollater:
9
+ """Collate list of text tokens
10
+
11
+ Map sentences to integers. Sentences are padded to equal length.
12
+ Beginning and end-of-sequence symbols can be added.
13
+
14
+ Example:
15
+ >>> token_collater = TextTokenCollater(text_tokens)
16
+ >>> tokens_batch, tokens_lens = token_collater(text)
17
+
18
+ Returns:
19
+ tokens_batch: IntTensor of shape (B, L)
20
+ B: batch dimension, number of input sentences
21
+ L: length of the longest sentence
22
+ tokens_lens: IntTensor of shape (B,)
23
+ Length of each sentence after adding <eos> and <bos>
24
+ but before padding.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ text_tokens: List[str],
30
+ add_eos: bool = True,
31
+ add_bos: bool = True,
32
+ pad_symbol: str = "<pad>",
33
+ bos_symbol: str = "<bos>",
34
+ eos_symbol: str = "<eos>",
35
+ ):
36
+ self.pad_symbol = pad_symbol
37
+
38
+ self.add_eos = add_eos
39
+ self.add_bos = add_bos
40
+
41
+ self.bos_symbol = bos_symbol
42
+ self.eos_symbol = eos_symbol
43
+
44
+ unique_tokens = (
45
+ [pad_symbol]
46
+ + ([bos_symbol] if add_bos else [])
47
+ + ([eos_symbol] if add_eos else [])
48
+ + sorted(text_tokens)
49
+ )
50
+
51
+ self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
52
+ self.idx2token = [token for token in unique_tokens]
53
+
54
+ def index(
55
+ self, tokens_list: List[str]
56
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
57
+ seqs, seq_lens = [], []
58
+ for tokens in tokens_list:
59
+ assert (
60
+ all([True if s in self.token2idx else False for s in tokens])
61
+ is True
62
+ )
63
+ seq = (
64
+ ([self.bos_symbol] if self.add_bos else [])
65
+ + list(tokens)
66
+ + ([self.eos_symbol] if self.add_eos else [])
67
+ )
68
+ seqs.append(seq)
69
+ seq_lens.append(len(seq))
70
+
71
+ max_len = max(seq_lens)
72
+ for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
73
+ seq.extend([self.pad_symbol] * (max_len - seq_len))
74
+
75
+ tokens = torch.from_numpy(
76
+ np.array(
77
+ [[self.token2idx[token] for token in seq] for seq in seqs],
78
+ dtype=np.int64,
79
+ )
80
+ )
81
+ tokens_lens = torch.IntTensor(seq_lens)
82
+
83
+ return tokens, tokens_lens
84
+
85
+ def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
86
+ tokens_seqs = [[p for p in text] for text in texts]
87
+ max_len = len(max(tokens_seqs, key=len))
88
+
89
+ seqs = [
90
+ ([self.bos_symbol] if self.add_bos else [])
91
+ + list(seq)
92
+ + ([self.eos_symbol] if self.add_eos else [])
93
+ + [self.pad_symbol] * (max_len - len(seq))
94
+ for seq in tokens_seqs
95
+ ]
96
+
97
+ tokens_batch = torch.from_numpy(
98
+ np.array(
99
+ [seq for seq in seqs],
100
+ dtype=np.int64,
101
+ )
102
+ )
103
+
104
+ tokens_lens = torch.IntTensor(
105
+ [
106
+ len(seq) + int(self.add_eos) + int(self.add_bos)
107
+ for seq in tokens_seqs
108
+ ]
109
+ )
110
+
111
+ return tokens_batch, tokens_lens
112
+
113
+
114
+ def get_text_token_collater() -> TextTokenCollater:
115
+ collater = TextTokenCollater(
116
+ ['0'], add_bos=False, add_eos=False
117
+ )
118
+ return collater
data/tokenizer.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2023 (authors: Feiteng Li)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ from dataclasses import asdict, dataclass
18
+ from typing import Any, Dict, List, Optional, Pattern, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torchaudio
23
+ from encodec import EncodecModel
24
+ from encodec.utils import convert_audio
25
+
26
+ def remove_encodec_weight_norm(model):
27
+ from encodec.modules import SConv1d
28
+ from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
29
+ from torch.nn.utils import remove_weight_norm
30
+
31
+ encoder = model.encoder.model
32
+ for key in encoder._modules:
33
+ if isinstance(encoder._modules[key], SEANetResnetBlock):
34
+ remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
35
+ block_modules = encoder._modules[key].block._modules
36
+ for skey in block_modules:
37
+ if isinstance(block_modules[skey], SConv1d):
38
+ remove_weight_norm(block_modules[skey].conv.conv)
39
+ elif isinstance(encoder._modules[key], SConv1d):
40
+ remove_weight_norm(encoder._modules[key].conv.conv)
41
+
42
+ decoder = model.decoder.model
43
+ for key in decoder._modules:
44
+ if isinstance(decoder._modules[key], SEANetResnetBlock):
45
+ remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
46
+ block_modules = decoder._modules[key].block._modules
47
+ for skey in block_modules:
48
+ if isinstance(block_modules[skey], SConv1d):
49
+ remove_weight_norm(block_modules[skey].conv.conv)
50
+ elif isinstance(decoder._modules[key], SConvTranspose1d):
51
+ remove_weight_norm(decoder._modules[key].convtr.convtr)
52
+ elif isinstance(decoder._modules[key], SConv1d):
53
+ remove_weight_norm(decoder._modules[key].conv.conv)
54
+
55
+
56
+ class AudioTokenizer:
57
+ """EnCodec audio."""
58
+
59
+ def __init__(
60
+ self,
61
+ device: Any = None,
62
+ ) -> None:
63
+ # Instantiate a pretrained EnCodec model
64
+ model = EncodecModel.encodec_model_24khz()
65
+ model.set_target_bandwidth(6.0)
66
+ remove_encodec_weight_norm(model)
67
+
68
+ if not device:
69
+ device = torch.device("cpu")
70
+ if torch.cuda.is_available():
71
+ device = torch.device("cuda:0")
72
+
73
+ self._device = device
74
+
75
+ self.codec = model.to(device)
76
+ self.sample_rate = model.sample_rate
77
+ self.channels = model.channels
78
+
79
+ @property
80
+ def device(self):
81
+ return self._device
82
+
83
+ def encode(self, wav: torch.Tensor) -> torch.Tensor:
84
+ return self.codec.encode(wav.to(self.device))
85
+
86
+ def decode(self, frames: torch.Tensor) -> torch.Tensor:
87
+ return self.codec.decode(frames)
88
+
89
+
90
+ def tokenize_audio(tokenizer: AudioTokenizer, audio):
91
+ # Load and pre-process the audio waveform
92
+ if isinstance(audio, str):
93
+ wav, sr = torchaudio.load(audio)
94
+ else:
95
+ wav, sr = audio
96
+ wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
97
+ wav = wav.unsqueeze(0)
98
+
99
+ # Extract discrete codes from EnCodec
100
+ with torch.no_grad():
101
+ encoded_frames = tokenizer.encode(wav)
102
+ return encoded_frames
103
+
104
+
105
+ if __name__ == "__main__":
106
+ model = EncodecModel.encodec_model_24khz()
107
+ model.set_target_bandwidth(6.0)
108
+
109
+ samples = torch.from_numpy(np.random.random([4, 1, 1600])).type(
110
+ torch.float32
111
+ )
112
+ codes_raw = model.encode(samples)
113
+
114
+ remove_encodec_weight_norm(model)
115
+ codes_norm = model.encode(samples)
116
+
117
+ assert torch.allclose(codes_raw[0][0], codes_norm[0][0])
descriptions.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ top_md = """
2
+ # VALL-E X
3
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1yyD_sz531QntLKowMHo-XxorsFBCfKul?usp=sharing)
4
+ VALL-E X can synthesize high-quality personalized speech with only a 3-second enrolled recording of
5
+ an unseen speaker as an acoustic prompt, even in another language for a monolingual speaker.<br>
6
+ This implementation supports zero-shot, mono-lingual/cross-lingual text-to-speech functionality of three languages (English, Chinese, Japanese)<br>
7
+ See this [demo](https://plachtaa.github.io/) page for more details.
8
+ """
9
+
10
+ infer_from_audio_md = """
11
+ Upload a speech of 3~10 seconds as the audio prompt and type in the text you'd like to synthesize.<br>
12
+ The model will synthesize speech of given text with the same voice of your audio prompt.<br>
13
+ The model also tends to preserve the emotion & acoustic environment of your given speech.<br>
14
+ For faster inference, please use **"Make prompt"** to get a `.npz` file as the encoded audio prompt, and use it by **"Infer from prompt"**
15
+ """
16
+
17
+ make_prompt_md = """
18
+ Upload a speech of 3~10 seconds as the audio prompt.<br>
19
+ Get a `.npz` file as the encoded audio prompt. Use it by **"Infer with prompt"**
20
+ """
21
+
22
+ infer_from_prompt_md = """
23
+ Faster than **"Infer from audio"**.<br>
24
+ You need to **"Make prompt"** first, and upload the encoded prompt (a `.npz` file)
25
+ """
26
+
27
+ long_text_example = "Just a few years ago, there were no legions of deep learning scientists developing intelligent products and services at major companies and startups. When we entered the field, machine learning did not command headlines in daily newspapers. Our parents had no idea what machine learning was, let alone why we might prefer it to a career in medicine or law. Machine learning was a blue skies academic discipline whose industrial significance was limited to a narrow set of real-world applications, including speech recognition and computer vision. Moreover, many of these applications required so much domain knowledge that they were often regarded as entirely separate areas for which machine learning was one small component. At that time, neural networks—the predecessors of the deep learning methods that we focus on in this book—were generally regarded as outmoded."
epoch-10.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5fcd05ee0c9c84a16a7b44495c46262177e66d5d454c20ca5f1da9832dbd5ac
3
+ size 1482302113
images/vallex_framework.jpg ADDED
macros.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NUM_LAYERS = 12
2
+ NUM_HEAD = 16
3
+ N_DIM = 1024
4
+ PREFIX_MODE = 1
5
+ NUM_QUANTIZERS = 8
6
+ SAMPLE_RATE = 24000
7
+
8
+ lang2token = {
9
+ 'zh': "[ZH]",
10
+ 'ja': "[JA]",
11
+ "en": "[EN]",
12
+ 'mix': "",
13
+ }
14
+
15
+ lang2code = {
16
+ 'zh': 0,
17
+ 'ja': 1,
18
+ "en": 2,
19
+ }
20
+
21
+ token2lang = {
22
+ '[ZH]': "zh",
23
+ '[JA]': "ja",
24
+ "[EN]": "en",
25
+ "": "mix"
26
+ }
27
+
28
+ code2lang = {
29
+ 0: 'zh',
30
+ 1: 'ja',
31
+ 2: "en",
32
+ }
33
+
34
+ langdropdown2token = {
35
+ 'English': "[EN]",
36
+ '中文': "[ZH]",
37
+ '日本語': "[JA]",
38
+ 'Mix': "",
39
+ }
models/__init__.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch.nn as nn
4
+ # from icefall.utils import AttributeDict, str2bool
5
+
6
+ from .macros import (
7
+ NUM_AUDIO_TOKENS,
8
+ NUM_MEL_BINS,
9
+ NUM_SPEAKER_CLASSES,
10
+ NUM_TEXT_TOKENS,
11
+ SPEAKER_EMBEDDING_DIM,
12
+ )
13
+ from .vallex import VALLE, VALLF
14
+
15
+
16
+ def add_model_arguments(parser: argparse.ArgumentParser):
17
+ parser.add_argument(
18
+ "--model-name",
19
+ type=str,
20
+ default="VALL-E",
21
+ help="VALL-E, VALL-F, Transformer.",
22
+ )
23
+ parser.add_argument(
24
+ "--decoder-dim",
25
+ type=int,
26
+ default=1024,
27
+ help="Embedding dimension in the decoder model.",
28
+ )
29
+ parser.add_argument(
30
+ "--nhead",
31
+ type=int,
32
+ default=16,
33
+ help="Number of attention heads in the Decoder layers.",
34
+ )
35
+ parser.add_argument(
36
+ "--num-decoder-layers",
37
+ type=int,
38
+ default=12,
39
+ help="Number of Decoder layers.",
40
+ )
41
+ parser.add_argument(
42
+ "--scale-factor",
43
+ type=float,
44
+ default=1.0,
45
+ help="Model scale factor which will be assigned different meanings in different models.",
46
+ )
47
+ parser.add_argument(
48
+ "--norm-first",
49
+ type=bool,
50
+ default=True,
51
+ help="Pre or Post Normalization.",
52
+ )
53
+ parser.add_argument(
54
+ "--add-prenet",
55
+ type=bool,
56
+ default=False,
57
+ help="Whether add PreNet after Inputs.",
58
+ )
59
+
60
+ # VALL-E & F
61
+ parser.add_argument(
62
+ "--prefix-mode",
63
+ type=int,
64
+ default=1,
65
+ help="The mode for how to prefix VALL-E NAR Decoder, "
66
+ "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.",
67
+ )
68
+ parser.add_argument(
69
+ "--share-embedding",
70
+ type=bool,
71
+ default=True,
72
+ help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.",
73
+ )
74
+ parser.add_argument(
75
+ "--prepend-bos",
76
+ type=bool,
77
+ default=False,
78
+ help="Whether prepend <BOS> to the acoustic tokens -> AR Decoder inputs.",
79
+ )
80
+ parser.add_argument(
81
+ "--num-quantizers",
82
+ type=int,
83
+ default=8,
84
+ help="Number of Audio/Semantic quantization layers.",
85
+ )
86
+
87
+ # Transformer
88
+ parser.add_argument(
89
+ "--scaling-xformers",
90
+ type=bool,
91
+ default=False,
92
+ help="Apply Reworked Conformer scaling on Transformers.",
93
+ )
94
+
95
+
96
+ def get_model(params) -> nn.Module:
97
+ if params.model_name.lower() in ["vall-f", "vallf"]:
98
+ model = VALLF(
99
+ params.decoder_dim,
100
+ params.nhead,
101
+ params.num_decoder_layers,
102
+ norm_first=params.norm_first,
103
+ add_prenet=params.add_prenet,
104
+ prefix_mode=params.prefix_mode,
105
+ share_embedding=params.share_embedding,
106
+ nar_scale_factor=params.scale_factor,
107
+ prepend_bos=params.prepend_bos,
108
+ num_quantizers=params.num_quantizers,
109
+ )
110
+ elif params.model_name.lower() in ["vall-e", "valle"]:
111
+ model = VALLE(
112
+ params.decoder_dim,
113
+ params.nhead,
114
+ params.num_decoder_layers,
115
+ norm_first=params.norm_first,
116
+ add_prenet=params.add_prenet,
117
+ prefix_mode=params.prefix_mode,
118
+ share_embedding=params.share_embedding,
119
+ nar_scale_factor=params.scale_factor,
120
+ prepend_bos=params.prepend_bos,
121
+ num_quantizers=params.num_quantizers,
122
+ )
123
+ else:
124
+ raise ValueError("No such model")
125
+
126
+ return model
models/macros.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Text
2
+ NUM_TEXT_TOKENS = 2048
3
+
4
+ # Audio
5
+ NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins
6
+ NUM_MEL_BINS = 100 # BigVGAN bigvgan_24khz_100band
7
+
8
+
9
+ # Speaker
10
+ NUM_SPEAKER_CLASSES = 4096
11
+ SPEAKER_EMBEDDING_DIM = 64
models/vallex.py ADDED
@@ -0,0 +1,830 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 (authors: Feiteng Li)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+ from typing import Dict, Iterator, List, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ # from icefall.utils import make_pad_mask
23
+ # from torchmetrics.classification import MulticlassAccuracy
24
+
25
+ from modules.embedding import SinePositionalEmbedding, TokenEmbedding
26
+ from modules.transformer import (
27
+ AdaptiveLayerNorm,
28
+ LayerNorm,
29
+ TransformerDecoderLayer,
30
+ TransformerEncoder,
31
+ TransformerEncoderLayer,
32
+ )
33
+
34
+ from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS
35
+
36
+
37
+ class Transpose(nn.Identity):
38
+ """(N, T, D) -> (N, D, T)"""
39
+
40
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
41
+ return input.transpose(1, 2)
42
+
43
+
44
+ # NOTE: There are two ways to implement the model
45
+ # 1) [VALL-F] standard TransformerDecoder, use x as memory
46
+ # 2) [VALL-E] modified TransformerDecoder like GPT-x(e.g. causal TransformerEncoder),
47
+ # use x as the prefix of decoder inputs
48
+ class VALLF(nn.Module):
49
+ """It implements https://arxiv.org/abs/2301.02111
50
+ "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ d_model: int,
56
+ nhead: int,
57
+ num_layers: int,
58
+ norm_first: bool = True,
59
+ add_prenet: bool = False,
60
+ decoder_cls: Union[
61
+ nn.TransformerDecoder, nn.TransformerEncoder
62
+ ] = nn.TransformerDecoder,
63
+ decoder_layer_cls: Union[
64
+ TransformerDecoderLayer, TransformerEncoderLayer
65
+ ] = TransformerDecoderLayer,
66
+ prefix_mode: int = 0,
67
+ share_embedding: bool = True,
68
+ nar_scale_factor: float = 1.0,
69
+ prepend_bos: bool = True,
70
+ num_quantizers: int = 8,
71
+ ):
72
+ """
73
+ Args:
74
+ d_model:
75
+ The number of expected features in the input (required).
76
+ nhead:
77
+ The number of heads in the multiheadattention models (required).
78
+ num_layers:
79
+ The number of sub-decoder-layers in the decoder (required).
80
+ """
81
+ super().__init__()
82
+ nar_d_model = int(d_model * nar_scale_factor)
83
+
84
+ self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
85
+ self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS)
86
+
87
+ # ID NUM_AUDIO_TOKENS -> PAD
88
+ # ID NUM_AUDIO_TOKENS + 1 -> BOS
89
+ self.ar_audio_prepend_bos = prepend_bos
90
+ self.ar_audio_embedding = TokenEmbedding(
91
+ d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos)
92
+ )
93
+
94
+ # PreNet
95
+ if add_prenet:
96
+ self.ar_text_prenet = nn.Sequential(
97
+ Transpose(),
98
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
99
+ nn.BatchNorm1d(d_model),
100
+ nn.ReLU(),
101
+ nn.Dropout(0.5),
102
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
103
+ nn.BatchNorm1d(d_model),
104
+ nn.ReLU(),
105
+ nn.Dropout(0.5),
106
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
107
+ nn.BatchNorm1d(d_model),
108
+ nn.ReLU(),
109
+ nn.Dropout(0.5),
110
+ Transpose(),
111
+ nn.Linear(d_model, d_model),
112
+ )
113
+
114
+ self.ar_audio_prenet = nn.Sequential(
115
+ nn.Linear(d_model, 256),
116
+ nn.ReLU(),
117
+ nn.Dropout(0.25),
118
+ nn.Linear(256, 256),
119
+ nn.ReLU(),
120
+ nn.Dropout(0.25),
121
+ nn.Linear(256, d_model),
122
+ )
123
+ else:
124
+ self.ar_text_prenet = nn.Identity()
125
+ self.ar_audio_prenet = nn.Identity()
126
+
127
+ self.ar_text_position = SinePositionalEmbedding(
128
+ d_model,
129
+ dropout=0.1,
130
+ scale=False,
131
+ alpha=True,
132
+ )
133
+ self.ar_audio_position = SinePositionalEmbedding(
134
+ d_model,
135
+ dropout=0.1,
136
+ scale=False,
137
+ alpha=True,
138
+ )
139
+
140
+ self.ar_decoder = decoder_cls(
141
+ decoder_layer_cls(
142
+ d_model,
143
+ nhead,
144
+ dim_feedforward=d_model * 4,
145
+ dropout=0.1,
146
+ batch_first=True,
147
+ norm_first=norm_first,
148
+ ),
149
+ num_layers=num_layers,
150
+ norm=LayerNorm(d_model) if norm_first else None,
151
+ )
152
+ self.ar_predict_layer = nn.Linear(
153
+ d_model, NUM_AUDIO_TOKENS + 1, bias=False
154
+ )
155
+
156
+ self.rng = random.Random(0)
157
+ self.num_heads = nhead
158
+ self.prefix_mode = prefix_mode
159
+ self.num_quantizers = num_quantizers
160
+
161
+ assert num_quantizers >= 1
162
+ if num_quantizers > 1:
163
+ self.nar_audio_embeddings = nn.ModuleList(
164
+ [TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)]
165
+ + [
166
+ TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS)
167
+ for i in range(num_quantizers - 1)
168
+ ]
169
+ ) # W_a
170
+
171
+ # PreNet
172
+ if add_prenet:
173
+ self.nar_text_prenet = nn.Sequential(
174
+ Transpose(),
175
+ nn.Conv1d(
176
+ nar_d_model, nar_d_model, kernel_size=5, padding="same"
177
+ ),
178
+ nn.BatchNorm1d(nar_d_model),
179
+ nn.ReLU(),
180
+ nn.Dropout(0.5),
181
+ nn.Conv1d(
182
+ nar_d_model, nar_d_model, kernel_size=5, padding="same"
183
+ ),
184
+ nn.BatchNorm1d(nar_d_model),
185
+ nn.ReLU(),
186
+ nn.Dropout(0.5),
187
+ nn.Conv1d(
188
+ nar_d_model, nar_d_model, kernel_size=5, padding="same"
189
+ ),
190
+ nn.BatchNorm1d(nar_d_model),
191
+ nn.ReLU(),
192
+ nn.Dropout(0.5),
193
+ Transpose(),
194
+ nn.Linear(nar_d_model, nar_d_model),
195
+ )
196
+ self.nar_audio_prenet = nn.Sequential(
197
+ nn.Linear(nar_d_model, 256),
198
+ nn.ReLU(),
199
+ nn.Dropout(0.25),
200
+ nn.Linear(256, 256),
201
+ nn.ReLU(),
202
+ nn.Dropout(0.25),
203
+ nn.Linear(256, nar_d_model),
204
+ )
205
+ else:
206
+ self.nar_text_prenet = nn.Identity()
207
+ self.nar_audio_prenet = nn.Identity()
208
+
209
+ self.nar_text_position = SinePositionalEmbedding(
210
+ nar_d_model,
211
+ dropout=0.0,
212
+ scale=False,
213
+ alpha=False,
214
+ )
215
+ self.nar_audio_position = SinePositionalEmbedding(
216
+ nar_d_model,
217
+ dropout=0.1,
218
+ scale=False,
219
+ alpha=False,
220
+ )
221
+
222
+ self.nar_decoder = decoder_cls(
223
+ decoder_layer_cls(
224
+ nar_d_model,
225
+ int(nhead * nar_scale_factor),
226
+ dim_feedforward=nar_d_model * 4,
227
+ dropout=0.1,
228
+ batch_first=True,
229
+ norm_first=norm_first,
230
+ adaptive_layer_norm=True,
231
+ ),
232
+ num_layers=int(num_layers * nar_scale_factor),
233
+ norm=AdaptiveLayerNorm(
234
+ nar_d_model, norm=nn.LayerNorm(nar_d_model)
235
+ )
236
+ if norm_first
237
+ else None,
238
+ )
239
+ self.nar_predict_layers = nn.ModuleList(
240
+ [
241
+ nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False)
242
+ for i in range(num_quantizers - 1)
243
+ ]
244
+ )
245
+ self.nar_stage_embeddings = nn.ModuleList(
246
+ [
247
+ TokenEmbedding(nar_d_model, 1)
248
+ for i in range(num_quantizers - 1)
249
+ ]
250
+ )
251
+
252
+ if share_embedding:
253
+ # We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa
254
+ # NOTE(Feiteng): In the experiment, this undermines accuracy
255
+ # self.ar_predict_layer.weight = self.ar_audio_embedding.weight
256
+
257
+ # We also share the parameters of the acoustic embedding layer and the output prediction layer,
258
+ # which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer.
259
+ for j in range(0, num_quantizers - 2):
260
+ self.nar_predict_layers[
261
+ j
262
+ ].weight = self.nar_audio_embeddings[j + 2].weight
263
+
264
+ def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
265
+ assert stage > 0
266
+ if stage == 1:
267
+ for name, param in self.named_parameters():
268
+ if name.startswith("ar_"):
269
+ print(f" AR parameter: {name}")
270
+ yield param
271
+
272
+ if stage == 2:
273
+ for name, param in self.named_parameters():
274
+ if name.startswith("nar_"):
275
+ print(f"NAR parameter: {name}")
276
+ yield param
277
+
278
+ def stage_named_parameters(
279
+ self, stage: int = 1
280
+ ) -> Iterator[Tuple[str, nn.Parameter]]:
281
+ assert stage > 0
282
+ if stage == 1:
283
+ for pair in self.named_parameters():
284
+ if pair[0].startswith("ar_"):
285
+ yield pair
286
+
287
+ if stage == 2:
288
+ for pair in self.named_parameters():
289
+ if pair[0].startswith("nar_"):
290
+ yield pair
291
+
292
+ def pad_y_eos(self, y, y_mask_int, eos_id):
293
+ targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
294
+ y_mask_int, (0, 1), value=1
295
+ )
296
+ # inputs, targets
297
+ if self.ar_audio_prepend_bos:
298
+ return (
299
+ F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1),
300
+ targets,
301
+ )
302
+
303
+ return targets[:, :-1], targets[:, 1:]
304
+
305
+ def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes, prefix_mode):
306
+ # 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds
307
+ # from the same utterance.
308
+ # We implement this differently.
309
+ if prefix_mode == 0:
310
+ # no prefix
311
+ prefix_len = 0
312
+ y_emb = self.nar_audio_embeddings[0](y)
313
+ for j in range(1, nar_stage):
314
+ # Formula (4) (5)
315
+ y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
316
+ elif prefix_mode == 1:
317
+ # prefix at begining
318
+ int_low = (0.25 * y_lens.min()).type(torch.int64).item()
319
+ prefix_len = torch.randint(0, int_low * 2, size=()).item()
320
+ prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames
321
+
322
+ y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
323
+ y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
324
+ for j in range(1, self.num_quantizers):
325
+ y_prompts += self.nar_audio_embeddings[j](
326
+ codes[:, :prefix_len, j]
327
+ )
328
+ if j < nar_stage:
329
+ y_emb += self.nar_audio_embeddings[j](
330
+ codes[:, prefix_len:, j]
331
+ )
332
+ y_emb = torch.concat([y_prompts, y_emb], axis=1)
333
+ elif prefix_mode in [2, 4]:
334
+ if prefix_mode == 2:
335
+ # random prefix
336
+ prefix_len = min(225, int(0.25 * y_lens.min().item()))
337
+
338
+ y_prompts_codes = []
339
+ for b in range(codes.shape[0]):
340
+ start = self.rng.randint(0, y_lens[b].item() - prefix_len)
341
+ y_prompts_codes.append(
342
+ torch.clone(codes[b, start : start + prefix_len])
343
+ )
344
+ codes[
345
+ b, start : start + prefix_len, nar_stage
346
+ ] = NUM_AUDIO_TOKENS
347
+ y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
348
+ else:
349
+ prefix_len = y_prompts_codes.shape[1]
350
+
351
+ y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
352
+ y_emb = self.nar_audio_embeddings[0](y)
353
+ for j in range(1, self.num_quantizers):
354
+ y_prompts += self.nar_audio_embeddings[j](
355
+ y_prompts_codes[..., j]
356
+ )
357
+ if j < nar_stage:
358
+ y_emb += self.nar_audio_embeddings[j](codes[..., j])
359
+ y_emb = torch.concat([y_prompts, y_emb], axis=1)
360
+ else:
361
+ raise ValueError
362
+
363
+ return y_emb, prefix_len
364
+
365
+ def forward(
366
+ self,
367
+ x: torch.Tensor,
368
+ x_lens: torch.Tensor,
369
+ y: Union[torch.Tensor],
370
+ y_lens: Union[torch.Tensor],
371
+ reduction: str = "sum",
372
+ train_stage: int = 0,
373
+ **kwargs,
374
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
375
+ raise NotImplementedError
376
+
377
+ def inference(
378
+ self,
379
+ x: torch.Tensor,
380
+ x_lens: torch.Tensor,
381
+ y: torch.Tensor,
382
+ enroll_x_lens: Union[torch.Tensor, None] = None,
383
+ top_k: int = -100,
384
+ temperature: float = 1.0,
385
+ ) -> torch.Tensor:
386
+ raise NotImplementedError
387
+
388
+ def visualize(
389
+ self,
390
+ predicts: Tuple[torch.Tensor],
391
+ batch: Dict[str, Union[List, torch.Tensor]],
392
+ output_dir: str,
393
+ limit: int = 4,
394
+ ) -> None:
395
+ raise NotImplementedError
396
+
397
+
398
+ class VALLE(VALLF):
399
+ """It implements https://arxiv.org/abs/2301.02111
400
+ "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
401
+ """
402
+
403
+ def __init__(
404
+ self,
405
+ d_model: int,
406
+ nhead: int,
407
+ num_layers: int,
408
+ norm_first: bool = True,
409
+ add_prenet: bool = False,
410
+ prefix_mode: int = 0,
411
+ share_embedding: bool = True,
412
+ nar_scale_factor: float = 1.0,
413
+ **kwargs,
414
+ ):
415
+ """
416
+ Args:
417
+ d_model:
418
+ The number of expected features in the input (required).
419
+ nhead:
420
+ The number of heads in the multiheadattention models (required).
421
+ num_layers:
422
+ The number of sub-decoder-layers in the decoder (required).
423
+ """
424
+ super(VALLE, self).__init__(
425
+ d_model,
426
+ nhead,
427
+ num_layers,
428
+ norm_first=norm_first,
429
+ add_prenet=add_prenet,
430
+ decoder_cls=TransformerEncoder,
431
+ decoder_layer_cls=TransformerEncoderLayer,
432
+ prefix_mode=prefix_mode,
433
+ share_embedding=share_embedding,
434
+ nar_scale_factor=nar_scale_factor,
435
+ **kwargs,
436
+ )
437
+ self.language_ID = {
438
+ 'en': 0,
439
+ 'zh': 1,
440
+ 'ja': 2,
441
+ }
442
+ self.ar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
443
+ self.nar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
444
+
445
+ def forward(
446
+ self,
447
+ x: torch.Tensor,
448
+ x_lens: torch.Tensor,
449
+ y: Union[torch.Tensor],
450
+ y_lens: Union[torch.Tensor],
451
+ reduction: str = "sum",
452
+ train_stage: int = 0,
453
+ **kwargs,
454
+ ):
455
+ raise NotImplementedError
456
+ def inference(
457
+ self,
458
+ x: torch.Tensor,
459
+ x_lens: torch.Tensor,
460
+ y: torch.Tensor,
461
+ enroll_x_lens: torch.Tensor,
462
+ top_k: int = -100,
463
+ temperature: float = 1.0,
464
+ prompt_language: str = None,
465
+ text_language: str = None,
466
+ ) -> torch.Tensor:
467
+ """
468
+ Args:
469
+ x:
470
+ A 2-D tensor of shape (1, S).
471
+ x_lens:
472
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
473
+ before padding.
474
+ y:
475
+ A 3-D tensor of shape (1, T, 8).
476
+ top_k: (`optional`) int
477
+ The number of highest probability tokens to keep for top-k-filtering. Default to -100.
478
+ temperature: (`optional`) float
479
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
480
+ Returns:
481
+ Return the predicted audio code matrix.
482
+ """
483
+ assert x.ndim == 2, x.shape
484
+ assert x_lens.ndim == 1, x_lens.shape
485
+ assert y.ndim == 3, y.shape
486
+ assert y.shape[0] == 1, y.shape
487
+
488
+ assert torch.all(x_lens > 0)
489
+
490
+ # NOTE: x has been padded in TextTokenCollater
491
+ text = x
492
+ x = self.ar_text_embedding(text)
493
+ # Add language embedding
494
+ prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
495
+ if isinstance(text_language, str):
496
+ text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
497
+ elif isinstance(text_language, List):
498
+ text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
499
+ x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id)
500
+ x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id)
501
+ x = self.ar_text_prenet(x)
502
+ x = self.ar_text_position(x)
503
+
504
+ text_len = x_lens.max()
505
+ prompts = y
506
+ prefix_len = y.shape[1]
507
+
508
+ # AR Decoder
509
+ # TODO: Managing decoder steps avoid repetitive computation
510
+ y = prompts[..., 0]
511
+ if self.ar_audio_prepend_bos:
512
+ y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1)
513
+
514
+ x_len = x_lens.max()
515
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
516
+
517
+ kv_cache = None
518
+ use_kv_caching = True
519
+ while True:
520
+ y_emb = self.ar_audio_embedding(y)
521
+ y_emb = self.ar_audio_prenet(y_emb)
522
+ y_pos = self.ar_audio_position(y_emb)
523
+ xy_pos = torch.concat([x, y_pos], dim=1)
524
+
525
+ y_len = y.shape[1]
526
+ x_attn_mask_pad = F.pad(
527
+ x_attn_mask,
528
+ (0, y_len),
529
+ value=True,
530
+ )
531
+ y_attn_mask = F.pad(
532
+ torch.triu(
533
+ torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1
534
+ ),
535
+ (x_len, 0),
536
+ value=False,
537
+ )
538
+ xy_attn_mask = torch.concat(
539
+ [x_attn_mask_pad, y_attn_mask], dim=0
540
+ ).to(y.device)
541
+
542
+
543
+ if use_kv_caching and kv_cache is not None:
544
+ xy_pos = xy_pos[:, [-1]]
545
+ else:
546
+ pass
547
+
548
+ xy_dec, kv_cache = self.ar_decoder.infer(
549
+ xy_pos,
550
+ mask=xy_attn_mask,
551
+ past_kv=kv_cache,
552
+ use_cache=use_kv_caching,
553
+ )
554
+ # xy_dec, _ = self.ar_decoder(
555
+ # (xy_pos, None),
556
+ # mask=xy_attn_mask,
557
+ # )
558
+
559
+ logits = self.ar_predict_layer(xy_dec[:, -1])
560
+ samples = topk_sampling(
561
+ logits, top_k=top_k, top_p=1, temperature=temperature
562
+ )
563
+
564
+ if (
565
+ torch.argmax(logits, dim=-1)[0] == NUM_AUDIO_TOKENS
566
+ or samples[0, 0] == NUM_AUDIO_TOKENS
567
+ or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16
568
+ ):
569
+ if prompts.shape[1] == y.shape[1]:
570
+ raise SyntaxError(
571
+ "well trained model shouldn't reach here."
572
+ )
573
+
574
+ print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
575
+ break
576
+
577
+ y = torch.concat([y, samples], dim=1)
578
+
579
+ codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
580
+ if self.num_quantizers == 1:
581
+ return torch.stack(codes, dim=-1)
582
+
583
+ # Non-AR Decoders
584
+ y_emb = self.nar_audio_embeddings[0](
585
+ y[:, int(self.ar_audio_prepend_bos) :]
586
+ )
587
+
588
+ if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
589
+ enrolled_len = enroll_x_lens.max().item()
590
+ # SOS + Synthesis Text + EOS
591
+ text = torch.concat(
592
+ [
593
+ text[:, :1],
594
+ text[:, enrolled_len - 1 :],
595
+ ],
596
+ dim=1,
597
+ )
598
+ text_len = text_len - (enrolled_len - 2)
599
+ assert text.shape[0] == 1
600
+
601
+ x = self.nar_text_embedding(text)
602
+ # Add language embedding
603
+ prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
604
+ if isinstance(text_language, str):
605
+ text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
606
+ elif isinstance(text_language, List):
607
+ text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
608
+ x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id)
609
+ x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id)
610
+ x = self.nar_text_prenet(x)
611
+ x = self.nar_text_position(x)
612
+
613
+ if self.prefix_mode == 0:
614
+ for i, (predict_layer, embedding_layer) in enumerate(
615
+ zip(
616
+ self.nar_predict_layers,
617
+ self.nar_audio_embeddings[1:],
618
+ )
619
+ ):
620
+ y_pos = self.nar_audio_prenet(y_emb)
621
+ y_pos = self.nar_audio_position(y_pos)
622
+ xy_pos = torch.concat([x, y_pos], dim=1)
623
+
624
+ xy_dec, _ = self.nar_decoder(
625
+ (xy_pos, self.nar_stage_embeddings[i].weight)
626
+ )
627
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
628
+
629
+ samples = torch.argmax(logits, dim=-1)
630
+ codes.append(samples)
631
+
632
+ if i < self.num_quantizers - 2:
633
+ y_emb[:, :prefix_len] += embedding_layer(
634
+ prompts[..., i + 1]
635
+ )
636
+ y_emb[:, prefix_len:] += embedding_layer(samples)
637
+ else:
638
+ for j in range(1, self.num_quantizers):
639
+ y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
640
+ prompts[..., j]
641
+ )
642
+
643
+ for i, (predict_layer, embedding_layer) in enumerate(
644
+ zip(
645
+ self.nar_predict_layers,
646
+ self.nar_audio_embeddings[1:],
647
+ )
648
+ ):
649
+ y_pos = self.nar_audio_prenet(y_emb)
650
+ y_pos = self.nar_audio_position(y_pos)
651
+ xy_pos = torch.concat([x, y_pos], dim=1)
652
+
653
+ xy_dec, _ = self.nar_decoder(
654
+ (xy_pos, self.nar_stage_embeddings[i].weight)
655
+ )
656
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
657
+
658
+ samples = torch.argmax(logits, dim=-1)
659
+ codes.append(samples)
660
+
661
+ if i < self.num_quantizers - 2:
662
+ y_emb[:, prefix_len:] += embedding_layer(samples)
663
+
664
+ assert len(codes) == self.num_quantizers
665
+ return torch.stack(codes, dim=-1)
666
+
667
+ def continual(
668
+ self,
669
+ x: torch.Tensor,
670
+ x_lens: torch.Tensor,
671
+ y: torch.Tensor,
672
+ ) -> torch.Tensor:
673
+ """
674
+ Args:
675
+ x:
676
+ A 2-D tensor of shape (1, S).
677
+ x_lens:
678
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
679
+ before padding.
680
+ y:
681
+ A 3-D tensor of shape (1, T, 8).
682
+ Returns:
683
+ Return the predicted audio code matrix.
684
+ """
685
+ assert x.ndim == 2, x.shape
686
+ assert x_lens.ndim == 1, x_lens.shape
687
+ assert y.ndim == 3, y.shape
688
+ assert y.shape[0] == 1, y.shape
689
+
690
+ assert torch.all(x_lens > 0)
691
+ assert self.num_quantizers == 8
692
+
693
+ # NOTE: x has been padded in TextTokenCollater
694
+ text = x
695
+ x = self.ar_text_embedding(text)
696
+ x = self.ar_text_prenet(x)
697
+ x = self.ar_text_position(x)
698
+
699
+ text_len = x_lens.max()
700
+
701
+ prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)
702
+
703
+ # AR Decoder
704
+ prompts = y[:, :prefix_len]
705
+
706
+ codes = [y[:, prefix_len:, 0]]
707
+ # Non-AR Decoders
708
+ x = self.nar_text_embedding(text)
709
+ x = self.nar_text_prenet(x)
710
+ x = self.nar_text_position(x)
711
+
712
+ y_emb = self.nar_audio_embeddings[0](y[..., 0])
713
+
714
+ if self.prefix_mode == 0:
715
+ for i, (predict_layer, embedding_layer) in enumerate(
716
+ zip(
717
+ self.nar_predict_layers,
718
+ self.nar_audio_embeddings[1:],
719
+ )
720
+ ):
721
+ y_pos = self.nar_audio_position(y_emb)
722
+ y_pos = self.nar_audio_prenet(y_pos)
723
+ xy_pos = torch.concat([x, y_pos], dim=1)
724
+
725
+ xy_dec, _ = self.nar_decoder(
726
+ (xy_pos, self.nar_stage_embeddings[i].weight)
727
+ )
728
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
729
+
730
+ samples = torch.argmax(logits, dim=-1)
731
+ codes.append(samples)
732
+
733
+ if i < 6:
734
+ y_emb[:, :prefix_len] += embedding_layer(
735
+ prompts[..., i + 1]
736
+ )
737
+ y_emb[:, prefix_len:] += embedding_layer(samples)
738
+ else:
739
+ for j in range(1, 8):
740
+ y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
741
+ prompts[..., j]
742
+ )
743
+
744
+ for i, (predict_layer, embedding_layer) in enumerate(
745
+ zip(
746
+ self.nar_predict_layers,
747
+ self.nar_audio_embeddings[1:],
748
+ )
749
+ ):
750
+ y_pos = self.nar_audio_prenet(y_emb)
751
+ y_pos = self.nar_audio_position(y_pos)
752
+ xy_pos = torch.concat([x, y_pos], dim=1)
753
+
754
+ xy_dec, _ = self.nar_decoder(
755
+ (xy_pos, self.nar_stage_embeddings[i].weight)
756
+ )
757
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
758
+
759
+ samples = torch.argmax(logits, dim=-1)
760
+ codes.append(samples)
761
+
762
+ if i < 6:
763
+ y_emb[:, prefix_len:] += embedding_layer(samples)
764
+
765
+ assert len(codes) == 8
766
+ return torch.stack(codes, dim=-1)
767
+
768
+
769
+ # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
770
+ def top_k_top_p_filtering(
771
+ logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
772
+ ):
773
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
774
+ Args:
775
+ logits: logits distribution shape (batch size, vocabulary size)
776
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
777
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
778
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
779
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
780
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
781
+ """
782
+ if top_k > 0:
783
+ top_k = min(
784
+ max(top_k, min_tokens_to_keep), logits.size(-1)
785
+ ) # Safety check
786
+ # Remove all tokens with a probability less than the last token of the top-k
787
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
788
+ logits[indices_to_remove] = filter_value
789
+
790
+ if top_p < 1.0:
791
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
792
+ cumulative_probs = torch.cumsum(
793
+ F.softmax(sorted_logits, dim=-1), dim=-1
794
+ )
795
+
796
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
797
+ sorted_indices_to_remove = cumulative_probs > top_p
798
+ if min_tokens_to_keep > 1:
799
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
800
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
801
+ # Shift the indices to the right to keep also the first token above the threshold
802
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
803
+ ..., :-1
804
+ ].clone()
805
+ sorted_indices_to_remove[..., 0] = 0
806
+
807
+ # scatter sorted tensors to original indexing
808
+ indices_to_remove = sorted_indices_to_remove.scatter(
809
+ 1, sorted_indices, sorted_indices_to_remove
810
+ )
811
+ logits[indices_to_remove] = filter_value
812
+ return logits
813
+
814
+
815
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
816
+ # temperature: (`optional`) float
817
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
818
+ # top_k: (`optional`) int
819
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
820
+ # top_p: (`optional`) float
821
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
822
+
823
+ # Temperature (higher temperature => more likely to sample low probability tokens)
824
+ if temperature != 1.0:
825
+ logits = logits / temperature
826
+ # Top-p/top-k filtering
827
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
828
+ # Sample
829
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
830
+ return token
modules/__init__.py ADDED
File without changes
modules/activation.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, List
2
+ import math
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Linear, Module
7
+ from torch.nn import functional as F
8
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
9
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
10
+ from torch.nn.parameter import Parameter
11
+
12
+ def _in_projection_packed(
13
+ q: Tensor,
14
+ k: Tensor,
15
+ v: Tensor,
16
+ w: Tensor,
17
+ b: Optional[Tensor] = None,
18
+ ) -> List[Tensor]:
19
+ r"""
20
+ Performs the in-projection step of the attention operation, using packed weights.
21
+ Output is a triple containing projection tensors for query, key and value.
22
+
23
+ Args:
24
+ q, k, v: query, key and value tensors to be projected. For self-attention,
25
+ these are typically the same tensor; for encoder-decoder attention,
26
+ k and v are typically the same tensor. (We take advantage of these
27
+ identities for performance if they are present.) Regardless, q, k and v
28
+ must share a common embedding dimension; otherwise their shapes may vary.
29
+ w: projection weights for q, k and v, packed into a single tensor. Weights
30
+ are packed along dimension 0, in q, k, v order.
31
+ b: optional projection biases for q, k and v, packed into a single tensor
32
+ in q, k, v order.
33
+
34
+ Shape:
35
+ Inputs:
36
+ - q: :math:`(..., E)` where E is the embedding dimension
37
+ - k: :math:`(..., E)` where E is the embedding dimension
38
+ - v: :math:`(..., E)` where E is the embedding dimension
39
+ - w: :math:`(E * 3, E)` where E is the embedding dimension
40
+ - b: :math:`E * 3` where E is the embedding dimension
41
+
42
+ Output:
43
+ - in output list :math:`[q', k', v']`, each output tensor will have the
44
+ same shape as the corresponding input tensor.
45
+ """
46
+ E = q.size(-1)
47
+ if k is v:
48
+ if q is k:
49
+ # self-attention
50
+ return F.linear(q, w, b).chunk(3, dim=-1)
51
+ else:
52
+ # encoder-decoder attention
53
+ w_q, w_kv = w.split([E, E * 2])
54
+ if b is None:
55
+ b_q = b_kv = None
56
+ else:
57
+ b_q, b_kv = b.split([E, E * 2])
58
+ return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
59
+ else:
60
+ w_q, w_k, w_v = w.chunk(3)
61
+ if b is None:
62
+ b_q = b_k = b_v = None
63
+ else:
64
+ b_q, b_k, b_v = b.chunk(3)
65
+ return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
66
+
67
+ def _scaled_dot_product_attention(
68
+ q: Tensor,
69
+ k: Tensor,
70
+ v: Tensor,
71
+ attn_mask: Optional[Tensor] = None,
72
+ dropout_p: float = 0.0,
73
+ ) -> Tuple[Tensor, Tensor]:
74
+ r"""
75
+ Computes scaled dot product attention on query, key and value tensors, using
76
+ an optional attention mask if passed, and applying dropout if a probability
77
+ greater than 0.0 is specified.
78
+ Returns a tensor pair containing attended values and attention weights.
79
+
80
+ Args:
81
+ q, k, v: query, key and value tensors. See Shape section for shape details.
82
+ attn_mask: optional tensor containing mask values to be added to calculated
83
+ attention. May be 2D or 3D; see Shape section for details.
84
+ dropout_p: dropout probability. If greater than 0.0, dropout is applied.
85
+
86
+ Shape:
87
+ - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
88
+ and E is embedding dimension.
89
+ - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
90
+ and E is embedding dimension.
91
+ - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
92
+ and E is embedding dimension.
93
+ - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
94
+ shape :math:`(Nt, Ns)`.
95
+
96
+ - Output: attention values have shape :math:`(B, Nt, E)`; attention weights
97
+ have shape :math:`(B, Nt, Ns)`
98
+ """
99
+ B, Nt, E = q.shape
100
+ q = q / math.sqrt(E)
101
+ # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
102
+ if attn_mask is not None:
103
+ attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
104
+ else:
105
+ attn = torch.bmm(q, k.transpose(-2, -1))
106
+
107
+ attn = F.softmax(attn, dim=-1)
108
+ if dropout_p > 0.0:
109
+ attn = F.dropout(attn, p=dropout_p)
110
+ # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
111
+ output = torch.bmm(attn, v)
112
+ return output, attn
113
+
114
+ def multi_head_attention_forward(
115
+ x,
116
+ ipw,
117
+ ipb,
118
+ opw,
119
+ opb,
120
+ n_head,
121
+ attn_mask,
122
+ past_kv=None,
123
+ use_cache=False,
124
+ ):
125
+ # x = x.transpose(1, 0)
126
+ # tgt_len, bsz, embed_dim = x.shape
127
+ # head_dim = embed_dim // n_head
128
+ # q, k, v = _in_projection_packed(x, x, x, ipw, ipb)
129
+ # q = q.contiguous().view(tgt_len, bsz * n_head, head_dim).transpose(0, 1)
130
+ # k = k.contiguous().view(k.shape[0], bsz * n_head, head_dim).transpose(0, 1)
131
+ # v = v.contiguous().view(v.shape[0], bsz * n_head, head_dim).transpose(0, 1)
132
+
133
+ # new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
134
+ # new_attn_mask.masked_fill_(attn_mask, float("-inf"))
135
+ # attn_mask = new_attn_mask
136
+ #
137
+ # attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, 0.0)
138
+ # attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
139
+ # attn_output = torch._C._nn.linear(attn_output, opw, opb)
140
+ # attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
141
+
142
+ B, T, C = x.size()
143
+
144
+ q, k, v = torch._C._nn.linear(x, ipw, ipb).chunk(3, dim=-1)
145
+ k = k.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
146
+ q = q.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
147
+ v = v.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
148
+ if past_kv is not None:
149
+ past_key = past_kv[0]
150
+ past_value = past_kv[1]
151
+ k = torch.cat((past_key, k), dim=-2)
152
+ v = torch.cat((past_value, v), dim=-2)
153
+
154
+ FULL_T = k.shape[-2]
155
+
156
+ if use_cache is True:
157
+ present = (k, v)
158
+ else:
159
+ present = None
160
+
161
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
162
+ att = att.masked_fill(attn_mask[FULL_T - T:FULL_T, :FULL_T], float('-inf'))
163
+ att = F.softmax(att, dim=-1)
164
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
165
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
166
+ y = torch._C._nn.linear(y, opw, opb)
167
+ return (y, present)
168
+
169
+
170
+ class MultiheadAttention(Module):
171
+ r"""Allows the model to jointly attend to information
172
+ from different representation subspaces as described in the paper:
173
+ `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
174
+
175
+ Multi-Head Attention is defined as:
176
+
177
+ .. math::
178
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
179
+
180
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
181
+
182
+ ``forward()`` will use a special optimized implementation if all of the following
183
+ conditions are met:
184
+
185
+ - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
186
+ restriction will be loosened in the future.)
187
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
188
+ - training is disabled (using ``.eval()``)
189
+ - dropout is 0
190
+ - ``add_bias_kv`` is ``False``
191
+ - ``add_zero_attn`` is ``False``
192
+ - ``batch_first`` is ``True`` and the input is batched
193
+ - ``kdim`` and ``vdim`` are equal to ``embed_dim``
194
+ - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
195
+ - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
196
+ nor ``attn_mask`` is passed
197
+
198
+ If the optimized implementation is in use, a
199
+ `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
200
+ ``query``/``key``/``value`` to represent padding more efficiently than using a
201
+ padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
202
+ will be returned, and an additional speedup proportional to the fraction of the input
203
+ that is padding can be expected.
204
+
205
+ Args:
206
+ embed_dim: Total dimension of the model.
207
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
208
+ across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
209
+ dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
210
+ bias: If specified, adds bias to input / output projection layers. Default: ``True``.
211
+ add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
212
+ add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
213
+ Default: ``False``.
214
+ kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
215
+ vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
216
+ batch_first: If ``True``, then the input and output tensors are provided
217
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
218
+
219
+ Examples::
220
+
221
+ >>> # xdoctest: +SKIP
222
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
223
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
224
+
225
+ """
226
+ __constants__ = ["batch_first"]
227
+ bias_k: Optional[torch.Tensor]
228
+ bias_v: Optional[torch.Tensor]
229
+
230
+ def __init__(
231
+ self,
232
+ embed_dim,
233
+ num_heads,
234
+ dropout=0.0,
235
+ bias=True,
236
+ add_bias_kv=False,
237
+ add_zero_attn=False,
238
+ kdim=None,
239
+ vdim=None,
240
+ batch_first=False,
241
+ linear1_cls=Linear,
242
+ linear2_cls=Linear,
243
+ device=None,
244
+ dtype=None,
245
+ ) -> None:
246
+ factory_kwargs = {"device": device, "dtype": dtype}
247
+ super(MultiheadAttention, self).__init__()
248
+ self.embed_dim = embed_dim
249
+ self.kdim = kdim if kdim is not None else embed_dim
250
+ self.vdim = vdim if vdim is not None else embed_dim
251
+ self._qkv_same_embed_dim = (
252
+ self.kdim == embed_dim and self.vdim == embed_dim
253
+ )
254
+
255
+ self.num_heads = num_heads
256
+ self.dropout = dropout
257
+ self.batch_first = batch_first
258
+ self.head_dim = embed_dim // num_heads
259
+ assert (
260
+ self.head_dim * num_heads == self.embed_dim
261
+ ), "embed_dim must be divisible by num_heads"
262
+
263
+ if add_bias_kv:
264
+ self.bias_k = Parameter(
265
+ torch.empty((1, 1, embed_dim), **factory_kwargs)
266
+ )
267
+ self.bias_v = Parameter(
268
+ torch.empty((1, 1, embed_dim), **factory_kwargs)
269
+ )
270
+ else:
271
+ self.bias_k = self.bias_v = None
272
+
273
+ if linear1_cls == Linear:
274
+ if not self._qkv_same_embed_dim:
275
+ self.q_proj_weight = Parameter(
276
+ torch.empty((embed_dim, embed_dim), **factory_kwargs)
277
+ )
278
+ self.k_proj_weight = Parameter(
279
+ torch.empty((embed_dim, self.kdim), **factory_kwargs)
280
+ )
281
+ self.v_proj_weight = Parameter(
282
+ torch.empty((embed_dim, self.vdim), **factory_kwargs)
283
+ )
284
+ self.register_parameter("in_proj_weight", None)
285
+ else:
286
+ self.in_proj_weight = Parameter(
287
+ torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
288
+ )
289
+ self.register_parameter("q_proj_weight", None)
290
+ self.register_parameter("k_proj_weight", None)
291
+ self.register_parameter("v_proj_weight", None)
292
+
293
+ if bias:
294
+ self.in_proj_bias = Parameter(
295
+ torch.empty(3 * embed_dim, **factory_kwargs)
296
+ )
297
+ else:
298
+ self.register_parameter("in_proj_bias", None)
299
+ self.out_proj = NonDynamicallyQuantizableLinear(
300
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
301
+ )
302
+
303
+ self._reset_parameters()
304
+ else:
305
+ if not self._qkv_same_embed_dim:
306
+ raise NotImplementedError
307
+ else:
308
+ self.in_proj_linear = linear1_cls(
309
+ embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
310
+ )
311
+ self.in_proj_weight = self.in_proj_linear.weight
312
+
313
+ self.register_parameter("q_proj_weight", None)
314
+ self.register_parameter("k_proj_weight", None)
315
+ self.register_parameter("v_proj_weight", None)
316
+
317
+ if bias:
318
+ self.in_proj_bias = self.in_proj_linear.bias
319
+ else:
320
+ self.register_parameter("in_proj_bias", None)
321
+
322
+ self.out_proj = linear2_cls(
323
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
324
+ )
325
+
326
+ if self.bias_k is not None:
327
+ xavier_normal_(self.bias_k)
328
+ if self.bias_v is not None:
329
+ xavier_normal_(self.bias_v)
330
+
331
+ self.add_zero_attn = add_zero_attn
332
+
333
+ def _reset_parameters(self):
334
+ if self._qkv_same_embed_dim:
335
+ xavier_uniform_(self.in_proj_weight)
336
+ else:
337
+ xavier_uniform_(self.q_proj_weight)
338
+ xavier_uniform_(self.k_proj_weight)
339
+ xavier_uniform_(self.v_proj_weight)
340
+
341
+ if self.in_proj_bias is not None:
342
+ constant_(self.in_proj_bias, 0.0)
343
+ constant_(self.out_proj.bias, 0.0)
344
+
345
+ if self.bias_k is not None:
346
+ xavier_normal_(self.bias_k)
347
+ if self.bias_v is not None:
348
+ xavier_normal_(self.bias_v)
349
+
350
+ def __setstate__(self, state):
351
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
352
+ if "_qkv_same_embed_dim" not in state:
353
+ state["_qkv_same_embed_dim"] = True
354
+
355
+ super(MultiheadAttention, self).__setstate__(state)
356
+
357
+ def forward(
358
+ self,
359
+ query: Tensor,
360
+ key: Tensor,
361
+ value: Tensor,
362
+ key_padding_mask: Optional[Tensor] = None,
363
+ need_weights: bool = True,
364
+ attn_mask: Optional[Tensor] = None,
365
+ average_attn_weights: bool = True,
366
+ ) -> Tuple[Tensor, Optional[Tensor]]:
367
+ r"""
368
+ Args:
369
+ query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
370
+ or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
371
+ :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
372
+ Queries are compared against key-value pairs to produce the output.
373
+ See "Attention Is All You Need" for more details.
374
+ key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
375
+ or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
376
+ :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
377
+ See "Attention Is All You Need" for more details.
378
+ value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
379
+ ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
380
+ sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
381
+ See "Attention Is All You Need" for more details.
382
+ key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
383
+ to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
384
+ Binary and byte masks are supported.
385
+ For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
386
+ the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
387
+ need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
388
+ Default: ``True``.
389
+ attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
390
+ :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
391
+ :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
392
+ broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
393
+ Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
394
+ corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
395
+ corresponding position is not allowed to attend. For a float mask, the mask values will be added to
396
+ the attention weight.
397
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
398
+ heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
399
+ effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
400
+
401
+ Outputs:
402
+ - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
403
+ :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
404
+ where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
405
+ embedding dimension ``embed_dim``.
406
+ - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
407
+ returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
408
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
409
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
410
+ head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
411
+
412
+ .. note::
413
+ `batch_first` argument is ignored for unbatched inputs.
414
+ """
415
+ is_batched = query.dim() == 3
416
+ if key_padding_mask is not None:
417
+ _kpm_dtype = key_padding_mask.dtype
418
+ if _kpm_dtype != torch.bool and not torch.is_floating_point(
419
+ key_padding_mask
420
+ ):
421
+ raise AssertionError(
422
+ "only bool and floating types of key_padding_mask are supported"
423
+ )
424
+ why_not_fast_path = ""
425
+ if not is_batched:
426
+ why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
427
+ elif query is not key or key is not value:
428
+ # When lifting this restriction, don't forget to either
429
+ # enforce that the dtypes all match or test cases where
430
+ # they don't!
431
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
432
+ elif (
433
+ self.in_proj_bias is not None
434
+ and query.dtype != self.in_proj_bias.dtype
435
+ ):
436
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
437
+ elif (
438
+ self.in_proj_weight is not None
439
+ and query.dtype != self.in_proj_weight.dtype
440
+ ):
441
+ # this case will fail anyway, but at least they'll get a useful error message.
442
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
443
+ elif self.training:
444
+ why_not_fast_path = "training is enabled"
445
+ elif not self.batch_first:
446
+ why_not_fast_path = "batch_first was not True"
447
+ elif self.bias_k is not None:
448
+ why_not_fast_path = "self.bias_k was not None"
449
+ elif self.bias_v is not None:
450
+ why_not_fast_path = "self.bias_v was not None"
451
+ elif self.dropout:
452
+ why_not_fast_path = f"dropout was {self.dropout}, required zero"
453
+ elif self.add_zero_attn:
454
+ why_not_fast_path = "add_zero_attn was enabled"
455
+ elif not self._qkv_same_embed_dim:
456
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
457
+ elif attn_mask is not None:
458
+ why_not_fast_path = "attn_mask was not None"
459
+ elif query.is_nested and key_padding_mask is not None:
460
+ why_not_fast_path = (
461
+ "key_padding_mask is not supported with NestedTensor input"
462
+ )
463
+ elif self.num_heads % 2 == 1:
464
+ why_not_fast_path = "num_heads is odd"
465
+ elif torch.is_autocast_enabled():
466
+ why_not_fast_path = "autocast is enabled"
467
+
468
+ if not why_not_fast_path:
469
+ tensor_args = (
470
+ query,
471
+ key,
472
+ value,
473
+ self.in_proj_weight,
474
+ self.in_proj_bias,
475
+ self.out_proj.weight,
476
+ self.out_proj.bias,
477
+ )
478
+ # We have to use list comprehensions below because TorchScript does not support
479
+ # generator expressions.
480
+ if torch.overrides.has_torch_function(tensor_args):
481
+ why_not_fast_path = "some Tensor argument has_torch_function"
482
+ elif not all(
483
+ [
484
+ (x is None or x.is_cuda or "cpu" in str(x.device))
485
+ for x in tensor_args
486
+ ]
487
+ ):
488
+ why_not_fast_path = (
489
+ "some Tensor argument is neither CUDA nor CPU"
490
+ )
491
+ elif torch.is_grad_enabled() and any(
492
+ [x is not None and x.requires_grad for x in tensor_args]
493
+ ):
494
+ why_not_fast_path = (
495
+ "grad is enabled and at least one of query or the "
496
+ "input/output projection weights or biases requires_grad"
497
+ )
498
+ if not why_not_fast_path:
499
+ return torch._native_multi_head_attention(
500
+ query,
501
+ key,
502
+ value,
503
+ self.embed_dim,
504
+ self.num_heads,
505
+ self.in_proj_weight,
506
+ self.in_proj_bias,
507
+ self.out_proj.weight,
508
+ self.out_proj.bias,
509
+ key_padding_mask
510
+ if key_padding_mask is not None
511
+ else attn_mask,
512
+ need_weights,
513
+ average_attn_weights,
514
+ 1
515
+ if key_padding_mask is not None
516
+ else 0
517
+ if attn_mask is not None
518
+ else None,
519
+ )
520
+
521
+ any_nested = query.is_nested or key.is_nested or value.is_nested
522
+ assert not any_nested, (
523
+ "MultiheadAttention does not support NestedTensor outside of its fast path. "
524
+ + f"The fast path was not hit because {why_not_fast_path}"
525
+ )
526
+
527
+ if self.batch_first and is_batched:
528
+ # make sure that the transpose op does not affect the "is" property
529
+ if key is value:
530
+ if query is key:
531
+ query = key = value = query.transpose(1, 0)
532
+ else:
533
+ query, key = [x.transpose(1, 0) for x in (query, key)]
534
+ value = key
535
+ else:
536
+ query, key, value = [
537
+ x.transpose(1, 0) for x in (query, key, value)
538
+ ]
539
+
540
+ if not self._qkv_same_embed_dim:
541
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
542
+ query,
543
+ key,
544
+ value,
545
+ self.embed_dim,
546
+ self.num_heads,
547
+ self.in_proj_weight,
548
+ self.in_proj_bias,
549
+ self.bias_k,
550
+ self.bias_v,
551
+ self.add_zero_attn,
552
+ self.dropout,
553
+ self.out_proj.weight,
554
+ self.out_proj.bias,
555
+ training=self.training,
556
+ key_padding_mask=key_padding_mask,
557
+ need_weights=need_weights,
558
+ attn_mask=attn_mask,
559
+ use_separate_proj_weight=True,
560
+ q_proj_weight=self.q_proj_weight,
561
+ k_proj_weight=self.k_proj_weight,
562
+ v_proj_weight=self.v_proj_weight,
563
+ average_attn_weights=average_attn_weights,
564
+ )
565
+ else:
566
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
567
+ query,
568
+ key,
569
+ value,
570
+ self.embed_dim,
571
+ self.num_heads,
572
+ self.in_proj_weight,
573
+ self.in_proj_bias,
574
+ self.bias_k,
575
+ self.bias_v,
576
+ self.add_zero_attn,
577
+ self.dropout,
578
+ self.out_proj.weight,
579
+ self.out_proj.bias,
580
+ training=self.training,
581
+ key_padding_mask=key_padding_mask,
582
+ need_weights=need_weights,
583
+ attn_mask=attn_mask,
584
+ average_attn_weights=average_attn_weights,
585
+ )
586
+ if self.batch_first and is_batched:
587
+ return attn_output.transpose(1, 0), attn_output_weights
588
+ else:
589
+ return attn_output, attn_output_weights
590
+
591
+ def infer(self,
592
+ x: Tensor,
593
+ key_padding_mask: Optional[Tensor] = None,
594
+ need_weights: bool = True,
595
+ attn_mask: Optional[Tensor] = None,
596
+ average_attn_weights: bool = True,
597
+ past_kv = None,
598
+ use_cache = False
599
+ ):
600
+ # x = x.transpose(1, 0)
601
+ y, kv = multi_head_attention_forward(
602
+ x=x,
603
+ ipw=self.in_proj_weight,
604
+ ipb=self.in_proj_bias,
605
+ opw=self.out_proj.weight,
606
+ opb=self.out_proj.bias,
607
+ n_head=self.num_heads,
608
+ attn_mask=attn_mask,
609
+ past_kv=past_kv,
610
+ use_cache=use_cache,
611
+ )
612
+ return (y, kv)
modules/embedding.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 (authors: Feiteng Li)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+
21
+ class TokenEmbedding(nn.Module):
22
+ def __init__(
23
+ self,
24
+ dim_model: int,
25
+ vocab_size: int,
26
+ dropout: float = 0.0,
27
+ ):
28
+ super().__init__()
29
+
30
+ self.vocab_size = vocab_size
31
+ self.dim_model = dim_model
32
+
33
+ self.dropout = torch.nn.Dropout(p=dropout)
34
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
35
+
36
+ @property
37
+ def weight(self) -> torch.Tensor:
38
+ return self.word_embeddings.weight
39
+
40
+ def embedding(self, index: int) -> torch.Tensor:
41
+ return self.word_embeddings.weight[index : index + 1]
42
+
43
+ def forward(self, x: torch.Tensor):
44
+ X = self.word_embeddings(x)
45
+ X = self.dropout(X)
46
+
47
+ return X
48
+
49
+
50
+ class SinePositionalEmbedding(nn.Module):
51
+ def __init__(
52
+ self,
53
+ dim_model: int,
54
+ dropout: float = 0.0,
55
+ scale: bool = False,
56
+ alpha: bool = False,
57
+ ):
58
+ super().__init__()
59
+ self.dim_model = dim_model
60
+ self.x_scale = math.sqrt(dim_model) if scale else 1.0
61
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
62
+ self.dropout = torch.nn.Dropout(p=dropout)
63
+
64
+ self.reverse = False
65
+ self.pe = None
66
+ self.extend_pe(torch.tensor(0.0).expand(1, 4000))
67
+
68
+ def extend_pe(self, x):
69
+ """Reset the positional encodings."""
70
+ if self.pe is not None:
71
+ if self.pe.size(1) >= x.size(1):
72
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
73
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
74
+ return
75
+ pe = torch.zeros(x.size(1), self.dim_model)
76
+ if self.reverse:
77
+ position = torch.arange(
78
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
79
+ ).unsqueeze(1)
80
+ else:
81
+ position = torch.arange(
82
+ 0, x.size(1), dtype=torch.float32
83
+ ).unsqueeze(1)
84
+ div_term = torch.exp(
85
+ torch.arange(0, self.dim_model, 2, dtype=torch.float32)
86
+ * -(math.log(10000.0) / self.dim_model)
87
+ )
88
+ pe[:, 0::2] = torch.sin(position * div_term)
89
+ pe[:, 1::2] = torch.cos(position * div_term)
90
+ pe = pe.unsqueeze(0)
91
+ self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
92
+
93
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
94
+ self.extend_pe(x)
95
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
96
+ output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
97
+ return self.dropout(output)
modules/scaling.py ADDED
@@ -0,0 +1,1401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ import collections
19
+ import logging
20
+ import random
21
+ import math
22
+ from functools import reduce
23
+ from itertools import repeat
24
+ from typing import Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ from torch import Tensor
30
+ from torch.nn import Embedding as ScaledEmbedding
31
+
32
+ from utils import Transpose
33
+
34
+
35
+ class ActivationBalancerFunction(torch.autograd.Function):
36
+ @staticmethod
37
+ def forward(
38
+ ctx,
39
+ x: Tensor,
40
+ scale_factor: Tensor,
41
+ sign_factor: Optional[Tensor],
42
+ channel_dim: int,
43
+ ) -> Tensor:
44
+ if channel_dim < 0:
45
+ channel_dim += x.ndim
46
+ ctx.channel_dim = channel_dim
47
+ xgt0 = x > 0
48
+ if sign_factor is None:
49
+ ctx.save_for_backward(xgt0, scale_factor)
50
+ else:
51
+ ctx.save_for_backward(xgt0, scale_factor, sign_factor)
52
+ return x
53
+
54
+ @staticmethod
55
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
56
+ if len(ctx.saved_tensors) == 3:
57
+ xgt0, scale_factor, sign_factor = ctx.saved_tensors
58
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
59
+ scale_factor = scale_factor.unsqueeze(-1)
60
+ sign_factor = sign_factor.unsqueeze(-1)
61
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
62
+ else:
63
+ xgt0, scale_factor = ctx.saved_tensors
64
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
65
+ scale_factor = scale_factor.unsqueeze(-1)
66
+ factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
67
+ neg_delta_grad = x_grad.abs() * factor
68
+ return (
69
+ x_grad - neg_delta_grad,
70
+ None,
71
+ None,
72
+ None,
73
+ )
74
+
75
+
76
+ def _compute_scale_factor(
77
+ x: Tensor,
78
+ channel_dim: int,
79
+ min_abs: float,
80
+ max_abs: float,
81
+ gain_factor: float,
82
+ max_factor: float,
83
+ ) -> Tensor:
84
+ if channel_dim < 0:
85
+ channel_dim += x.ndim
86
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
87
+ x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
88
+
89
+ if min_abs == 0.0:
90
+ below_threshold = 0.0
91
+ else:
92
+ # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
93
+ # x_abs)_mean , min_abs.
94
+ below_threshold = (
95
+ (min_abs - x_abs_mean) * (gain_factor / min_abs)
96
+ ).clamp(min=0, max=max_factor)
97
+
98
+ above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
99
+ min=0, max=max_factor
100
+ )
101
+
102
+ return below_threshold - above_threshold
103
+
104
+
105
+ def _compute_sign_factor(
106
+ x: Tensor,
107
+ channel_dim: int,
108
+ min_positive: float,
109
+ max_positive: float,
110
+ gain_factor: float,
111
+ max_factor: float,
112
+ ) -> Tensor:
113
+ if channel_dim < 0:
114
+ channel_dim += x.ndim
115
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
116
+ proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
117
+ if min_positive == 0.0:
118
+ factor1 = 0.0
119
+ else:
120
+ # 0 if proportion_positive >= min_positive, else can be
121
+ # as large as max_factor.
122
+ factor1 = (
123
+ (min_positive - proportion_positive) * (gain_factor / min_positive)
124
+ ).clamp_(min=0, max=max_factor)
125
+
126
+ if max_positive == 1.0:
127
+ factor2 = 0.0
128
+ else:
129
+ # 0 if self.proportion_positive <= max_positive, else can be
130
+ # as large as -max_factor.
131
+ factor2 = (
132
+ (proportion_positive - max_positive)
133
+ * (gain_factor / (1.0 - max_positive))
134
+ ).clamp_(min=0, max=max_factor)
135
+ sign_factor = factor1 - factor2
136
+ # require min_positive != 0 or max_positive != 1:
137
+ assert not isinstance(sign_factor, float)
138
+ return sign_factor
139
+
140
+
141
+ class ActivationScaleBalancerFunction(torch.autograd.Function):
142
+ """
143
+ This object is used in class ActivationBalancer when the user specified
144
+ min_positive=0, max_positive=1, so there are no constraints on the signs
145
+ of the activations and only the absolute value has a constraint.
146
+ """
147
+
148
+ @staticmethod
149
+ def forward(
150
+ ctx,
151
+ x: Tensor,
152
+ sign_factor: Tensor,
153
+ scale_factor: Tensor,
154
+ channel_dim: int,
155
+ ) -> Tensor:
156
+ if channel_dim < 0:
157
+ channel_dim += x.ndim
158
+ ctx.channel_dim = channel_dim
159
+ xgt0 = x > 0
160
+ ctx.save_for_backward(xgt0, sign_factor, scale_factor)
161
+ return x
162
+
163
+ @staticmethod
164
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
165
+ xgt0, sign_factor, scale_factor = ctx.saved_tensors
166
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
167
+ sign_factor = sign_factor.unsqueeze(-1)
168
+ scale_factor = scale_factor.unsqueeze(-1)
169
+
170
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
171
+ neg_delta_grad = x_grad.abs() * factor
172
+ return (
173
+ x_grad - neg_delta_grad,
174
+ None,
175
+ None,
176
+ None,
177
+ )
178
+
179
+
180
+ class RandomClampFunction(torch.autograd.Function):
181
+ @staticmethod
182
+ def forward(
183
+ ctx,
184
+ x: Tensor,
185
+ min: Optional[float],
186
+ max: Optional[float],
187
+ prob: float,
188
+ reflect: float,
189
+ ) -> Tensor:
190
+ x_clamped = torch.clamp(x, min=min, max=max)
191
+ mask = torch.rand_like(x) < prob
192
+ ans = torch.where(mask, x_clamped, x)
193
+ if x.requires_grad:
194
+ ctx.save_for_backward(ans == x)
195
+ ctx.reflect = reflect
196
+ if reflect != 0.0:
197
+ ans = ans * (1.0 + reflect) - (x * reflect)
198
+ return ans
199
+
200
+ @staticmethod
201
+ def backward(
202
+ ctx, ans_grad: Tensor
203
+ ) -> Tuple[Tensor, None, None, None, None]:
204
+ (is_same,) = ctx.saved_tensors
205
+ x_grad = ans_grad * is_same.to(ans_grad.dtype)
206
+ reflect = ctx.reflect
207
+ if reflect != 0.0:
208
+ x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
209
+ return x_grad, None, None, None, None
210
+
211
+
212
+ def random_clamp(
213
+ x: Tensor,
214
+ min: Optional[float] = None,
215
+ max: Optional[float] = None,
216
+ prob: float = 0.5,
217
+ reflect: float = 0.0,
218
+ ):
219
+ return RandomClampFunction.apply(x, min, max, prob, reflect)
220
+
221
+
222
+ def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
223
+ """
224
+ A randomized way of casting a floating point value to half precision.
225
+ """
226
+ if x.dtype == torch.float16:
227
+ return x
228
+ x_abs = x.abs()
229
+ is_too_small = x_abs < min_abs
230
+ # for elements where is_too_small is true, random_val will contain +-min_abs with
231
+ # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
232
+ # for those elements].
233
+ random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
234
+ return torch.where(is_too_small, random_val, x).to(torch.float16)
235
+
236
+
237
+ class RandomGradFunction(torch.autograd.Function):
238
+ """
239
+ Does nothing in forward pass; in backward pass, gets rid of very small grads using
240
+ randomized approach that preserves expectations (intended to reduce roundoff).
241
+ """
242
+
243
+ @staticmethod
244
+ def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
245
+ ctx.min_abs = min_abs
246
+ return x
247
+
248
+ @staticmethod
249
+ def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
250
+ if ans_grad.dtype == torch.float16:
251
+ return (
252
+ random_cast_to_half(
253
+ ans_grad.to(torch.float32), min_abs=ctx.min_abs
254
+ ),
255
+ None,
256
+ )
257
+ else:
258
+ return ans_grad, None
259
+
260
+
261
+ class RandomGrad(torch.nn.Module):
262
+ """
263
+ Gets rid of very small gradients using an expectation-preserving method, intended to increase
264
+ accuracy of training when using amp (automatic mixed precision)
265
+ """
266
+
267
+ def __init__(self, min_abs: float = 5.0e-06):
268
+ super(RandomGrad, self).__init__()
269
+ self.min_abs = min_abs
270
+
271
+ def forward(self, x: Tensor):
272
+ if (
273
+ torch.jit.is_scripting()
274
+ or not self.training
275
+ or torch.jit.is_tracing()
276
+ ):
277
+ return x
278
+ else:
279
+ return RandomGradFunction.apply(x, self.min_abs)
280
+
281
+
282
+ class SoftmaxFunction(torch.autograd.Function):
283
+ """
284
+ Tries to handle half-precision derivatives in a randomized way that should
285
+ be more accurate for training than the default behavior.
286
+ """
287
+
288
+ @staticmethod
289
+ def forward(ctx, x: Tensor, dim: int):
290
+ ans = x.softmax(dim=dim)
291
+ # if x dtype is float16, x.softmax() returns a float32 because
292
+ # (presumably) that op does not support float16, and autocast
293
+ # is enabled.
294
+ if torch.is_autocast_enabled():
295
+ ans = ans.to(torch.float16)
296
+ ctx.save_for_backward(ans)
297
+ ctx.x_dtype = x.dtype
298
+ ctx.dim = dim
299
+ return ans
300
+
301
+ @staticmethod
302
+ def backward(ctx, ans_grad: Tensor):
303
+ (ans,) = ctx.saved_tensors
304
+ with torch.cuda.amp.autocast(enabled=False):
305
+ ans_grad = ans_grad.to(torch.float32)
306
+ ans = ans.to(torch.float32)
307
+ x_grad = ans_grad * ans
308
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
309
+ return x_grad, None
310
+
311
+
312
+ def softmax(x: Tensor, dim: int):
313
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
314
+ return x.softmax(dim)
315
+
316
+ return SoftmaxFunction.apply(x, dim)
317
+
318
+
319
+ class MaxEigLimiterFunction(torch.autograd.Function):
320
+ @staticmethod
321
+ def forward(
322
+ ctx,
323
+ x: Tensor,
324
+ coeffs: Tensor,
325
+ direction: Tensor,
326
+ channel_dim: int,
327
+ grad_scale: float,
328
+ ) -> Tensor:
329
+ ctx.channel_dim = channel_dim
330
+ ctx.grad_scale = grad_scale
331
+ ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
332
+ return x
333
+
334
+ @staticmethod
335
+ def backward(ctx, x_grad, *args):
336
+ with torch.enable_grad():
337
+ (x_orig, coeffs, new_direction) = ctx.saved_tensors
338
+ x_orig.requires_grad = True
339
+ num_channels = x_orig.shape[ctx.channel_dim]
340
+ x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
341
+ new_direction.requires_grad = False
342
+ x = x - x.mean(dim=0)
343
+ x_var = (x ** 2).mean()
344
+ x_residual = x - coeffs * new_direction
345
+ x_residual_var = (x_residual ** 2).mean()
346
+ # `variance_proportion` is the proportion of the variance accounted for
347
+ # by the top eigen-direction. This is to be minimized.
348
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
349
+ variance_proportion.backward()
350
+ x_orig_grad = x_orig.grad
351
+ x_extra_grad = (
352
+ x_orig.grad
353
+ * ctx.grad_scale
354
+ * x_grad.norm()
355
+ / (x_orig_grad.norm() + 1.0e-20)
356
+ )
357
+ return x_grad + x_extra_grad.detach(), None, None, None, None
358
+
359
+
360
+ class BasicNorm(torch.nn.Module):
361
+ """
362
+ This is intended to be a simpler, and hopefully cheaper, replacement for
363
+ LayerNorm. The observation this is based on, is that Transformer-type
364
+ networks, especially with pre-norm, sometimes seem to set one of the
365
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
366
+ the LayerNorm because the output magnitude is then not strongly dependent
367
+ on the other (useful) features. Presumably the weight and bias of the
368
+ LayerNorm are required to allow it to do this.
369
+
370
+ So the idea is to introduce this large constant value as an explicit
371
+ parameter, that takes the role of the "eps" in LayerNorm, so the network
372
+ doesn't have to do this trick. We make the "eps" learnable.
373
+
374
+ Args:
375
+ num_channels: the number of channels, e.g. 512.
376
+ channel_dim: the axis/dimension corresponding to the channel,
377
+ interprted as an offset from the input's ndim if negative.
378
+ shis is NOT the num_channels; it should typically be one of
379
+ {-2, -1, 0, 1, 2, 3}.
380
+ eps: the initial "epsilon" that we add as ballast in:
381
+ scale = ((input_vec**2).mean() + epsilon)**-0.5
382
+ Note: our epsilon is actually large, but we keep the name
383
+ to indicate the connection with conventional LayerNorm.
384
+ learn_eps: if true, we learn epsilon; if false, we keep it
385
+ at the initial value.
386
+ eps_min: float
387
+ eps_max: float
388
+ """
389
+
390
+ def __init__(
391
+ self,
392
+ num_channels: int,
393
+ channel_dim: int = -1, # CAUTION: see documentation.
394
+ eps: float = 0.25,
395
+ learn_eps: bool = True,
396
+ eps_min: float = -3.0,
397
+ eps_max: float = 3.0,
398
+ ) -> None:
399
+ super(BasicNorm, self).__init__()
400
+ self.num_channels = num_channels
401
+ self.channel_dim = channel_dim
402
+ if learn_eps:
403
+ self.eps = nn.Parameter(torch.tensor(eps).log().detach())
404
+ else:
405
+ self.register_buffer("eps", torch.tensor(eps).log().detach())
406
+ self.eps_min = eps_min
407
+ self.eps_max = eps_max
408
+
409
+ def forward(self, x: Tensor) -> Tensor:
410
+ assert x.shape[self.channel_dim] == self.num_channels
411
+ eps = self.eps
412
+ if self.training and random.random() < 0.25:
413
+ # with probability 0.25, in training mode, clamp eps between the min
414
+ # and max; this will encourage it to learn parameters within the
415
+ # allowed range by making parameters that are outside the allowed
416
+ # range noisy.
417
+
418
+ # gradients to allow the parameter to get back into the allowed
419
+ # region if it happens to exit it.
420
+ eps = eps.clamp(min=self.eps_min, max=self.eps_max)
421
+ scales = (
422
+ torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
423
+ ) ** -0.5
424
+ return x * scales
425
+
426
+
427
+ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
428
+ """
429
+ Behaves like a constructor of a modified version of nn.Linear
430
+ that gives an easy way to set the default initial parameter scale.
431
+
432
+ Args:
433
+ Accepts the standard args and kwargs that nn.Linear accepts
434
+ e.g. in_features, out_features, bias=False.
435
+
436
+ initial_scale: you can override this if you want to increase
437
+ or decrease the initial magnitude of the module's output
438
+ (affects the initialization of weight_scale and bias_scale).
439
+ Another option, if you want to do something like this, is
440
+ to re-initialize the parameters.
441
+ """
442
+ ans = nn.Linear(*args, **kwargs)
443
+ with torch.no_grad():
444
+ ans.weight[:] *= initial_scale
445
+ if ans.bias is not None:
446
+ torch.nn.init.uniform_(
447
+ ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
448
+ )
449
+ return ans
450
+
451
+
452
+ def ScaledConv1d(
453
+ *args,
454
+ initial_scale: float = 1.0,
455
+ kernel_size: int = 3,
456
+ padding: str = "same",
457
+ **kwargs,
458
+ ) -> nn.Conv1d:
459
+ """
460
+ Behaves like a constructor of a modified version of nn.Conv1d
461
+ that gives an easy way to set the default initial parameter scale.
462
+
463
+ Args:
464
+ Accepts the standard args and kwargs that nn.Linear accepts
465
+ e.g. in_features, out_features, bias=False.
466
+
467
+ initial_scale: you can override this if you want to increase
468
+ or decrease the initial magnitude of the module's output
469
+ (affects the initialization of weight_scale and bias_scale).
470
+ Another option, if you want to do something like this, is
471
+ to re-initialize the parameters.
472
+ """
473
+ ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
474
+ with torch.no_grad():
475
+ ans.weight[:] *= initial_scale
476
+ if ans.bias is not None:
477
+ torch.nn.init.uniform_(
478
+ ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
479
+ )
480
+ return ans
481
+
482
+
483
+ def TransposeScaledConv1d(
484
+ *args,
485
+ initial_scale: float = 1.0,
486
+ kernel_size: int = 3,
487
+ padding: str = "same",
488
+ **kwargs,
489
+ ) -> nn.Sequential:
490
+ """
491
+ Transpose -> ScaledConv1d
492
+ """
493
+ return nn.Sequential(
494
+ Transpose(),
495
+ ScaledConv1d(
496
+ *args,
497
+ initial_scale=initial_scale,
498
+ kernel_size=kernel_size,
499
+ padding=padding,
500
+ **kwargs,
501
+ ),
502
+ )
503
+
504
+
505
+ def ScaledConv1dTranspose(
506
+ *args,
507
+ initial_scale: float = 1.0,
508
+ kernel_size: int = 3,
509
+ padding: str = "same",
510
+ **kwargs,
511
+ ) -> nn.Sequential:
512
+ """
513
+ Transpose -> ScaledConv1d
514
+ """
515
+ return nn.Sequential(
516
+ ScaledConv1d(
517
+ *args,
518
+ initial_scale=initial_scale,
519
+ kernel_size=kernel_size,
520
+ padding=padding,
521
+ **kwargs,
522
+ ),
523
+ Transpose(),
524
+ )
525
+
526
+
527
+ def TransposeConv1d(
528
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
529
+ ) -> nn.Sequential:
530
+ """
531
+ Transpose -> Conv1d
532
+ """
533
+ return nn.Sequential(
534
+ Transpose(),
535
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
536
+ )
537
+
538
+
539
+ def Conv1dTranspose(
540
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
541
+ ) -> nn.Sequential:
542
+ """
543
+ ScaledConv1d -> Transpose
544
+ """
545
+ return nn.Sequential(
546
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
547
+ Transpose(),
548
+ )
549
+
550
+
551
+ class SRLinear(nn.Linear):
552
+ """https://arxiv.org/abs/2303.06296
553
+ Stabilizing Transformer Training by Preventing Attention Entropy Collapse
554
+ """
555
+
556
+ def __init__(self, in_features, out_features, bias=True, **kwargs):
557
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
558
+ self.register_buffer(
559
+ "u", nn.functional.normalize(torch.randn(in_features), dim=0)
560
+ )
561
+ with torch.no_grad():
562
+ sigma = self.get_sigma()
563
+ self.register_buffer("spectral_norm", sigma)
564
+ self.sigma = nn.Parameter(torch.ones(1))
565
+
566
+ def get_sigma(self):
567
+ with torch.no_grad():
568
+ u = self.u
569
+ v = self.weight.mv(u)
570
+ v = nn.functional.normalize(v, dim=0)
571
+ u = self.weight.T.mv(v)
572
+ u = nn.functional.normalize(u, dim=0)
573
+ self.u.data.copy_(u)
574
+ return torch.einsum("c,cd,d->", v, self.weight, u)
575
+
576
+ def get_weight(self):
577
+ sigma = self.get_sigma()
578
+ if self.training:
579
+ self.spectral_norm.data.copy_(sigma)
580
+ weight = (self.sigma / sigma) * self.weight
581
+ return weight
582
+
583
+ def forward(self, x):
584
+ return nn.functional.linear(x, self.get_weight(), self.bias)
585
+
586
+
587
+ class SRConv1d(SRLinear):
588
+ def __init__(
589
+ self,
590
+ in_features,
591
+ out_features,
592
+ kernel_size,
593
+ stride: int = 1,
594
+ padding: str = "same",
595
+ bias: bool = True,
596
+ **kwargs,
597
+ ):
598
+ in_features = in_features * kernel_size
599
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
600
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
601
+ self.kernel_size = kernel_size
602
+ self.stride = stride
603
+ self.padding = padding
604
+
605
+ def forward(self, x):
606
+ in_features = self.in_features // self.kernel_size
607
+ weight = self.get_weight().view(
608
+ self.out_features, in_features, self.kernel_size
609
+ )
610
+ return nn.functional.conv1d(
611
+ x, weight, bias=self.bias, stride=self.stride, padding=self.padding
612
+ )
613
+
614
+
615
+ def TransposeSRConv1d(
616
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
617
+ ) -> nn.Sequential:
618
+ """
619
+ Transpose -> SRConv1d
620
+ """
621
+ return nn.Sequential(
622
+ Transpose(),
623
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
624
+ )
625
+
626
+
627
+ def SRConv1dTranspose(
628
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
629
+ ) -> nn.Sequential:
630
+ """
631
+ SRConv1d -> Transpose
632
+ """
633
+ return nn.Sequential(
634
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
635
+ Transpose(),
636
+ )
637
+
638
+
639
+ class ActivationBalancer(torch.nn.Module):
640
+ """
641
+ Modifies the backpropped derivatives of a function to try to encourage, for
642
+ each channel, that it is positive at least a proportion `threshold` of the
643
+ time. It does this by multiplying negative derivative values by up to
644
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
645
+ interpolated from 1 at the threshold to those extremal values when none
646
+ of the inputs are positive.
647
+
648
+ Args:
649
+ num_channels: the number of channels
650
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
651
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
652
+ min_positive: the minimum, per channel, of the proportion of the time
653
+ that (x > 0), below which we start to modify the derivatives.
654
+ max_positive: the maximum, per channel, of the proportion of the time
655
+ that (x > 0), above which we start to modify the derivatives.
656
+ max_factor: the maximum factor by which we modify the derivatives for
657
+ either the sign constraint or the magnitude constraint;
658
+ e.g. with max_factor=0.02, the the derivatives would be multiplied by
659
+ values in the range [0.98..1.02].
660
+ sign_gain_factor: determines the 'gain' with which we increase the
661
+ change in gradient once the constraints on min_positive and max_positive
662
+ are violated.
663
+ scale_gain_factor: determines the 'gain' with which we increase the
664
+ change in gradient once the constraints on min_abs and max_abs
665
+ are violated.
666
+ min_abs: the minimum average-absolute-value difference from the mean
667
+ value per channel, which we allow, before we start to modify
668
+ the derivatives to prevent this.
669
+ max_abs: the maximum average-absolute-value difference from the mean
670
+ value per channel, which we allow, before we start to modify
671
+ the derivatives to prevent this.
672
+ min_prob: determines the minimum probability with which we modify the
673
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
674
+ on each forward(). This is done randomly to prevent all layers
675
+ from doing it at the same time. Early in training we may use
676
+ higher probabilities than this; it will decay to this value.
677
+ """
678
+
679
+ def __init__(
680
+ self,
681
+ num_channels: int,
682
+ channel_dim: int,
683
+ min_positive: float = 0.05,
684
+ max_positive: float = 0.95,
685
+ max_factor: float = 0.04,
686
+ sign_gain_factor: float = 0.01,
687
+ scale_gain_factor: float = 0.02,
688
+ min_abs: float = 0.2,
689
+ max_abs: float = 100.0,
690
+ min_prob: float = 0.1,
691
+ ):
692
+ super(ActivationBalancer, self).__init__()
693
+ self.num_channels = num_channels
694
+ self.channel_dim = channel_dim
695
+ self.min_positive = min_positive
696
+ self.max_positive = max_positive
697
+ self.max_factor = max_factor
698
+ self.min_abs = min_abs
699
+ self.max_abs = max_abs
700
+ self.min_prob = min_prob
701
+ self.sign_gain_factor = sign_gain_factor
702
+ self.scale_gain_factor = scale_gain_factor
703
+
704
+ # count measures how many times the forward() function has been called.
705
+ # We occasionally sync this to a tensor called `count`, that exists to
706
+ # make sure it is synced to disk when we load and save the model.
707
+ self.cpu_count = 0
708
+ self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
709
+
710
+ def forward(self, x: Tensor) -> Tensor:
711
+ if (
712
+ torch.jit.is_scripting()
713
+ or not x.requires_grad
714
+ or torch.jit.is_tracing()
715
+ ):
716
+ return _no_op(x)
717
+
718
+ count = self.cpu_count
719
+ self.cpu_count += 1
720
+
721
+ if random.random() < 0.01:
722
+ # Occasionally sync self.cpu_count with self.count.
723
+ # count affects the decay of 'prob'. don't do this on every iter,
724
+ # because syncing with the GPU is slow.
725
+ self.cpu_count = max(self.cpu_count, self.count.item())
726
+ self.count.fill_(self.cpu_count)
727
+
728
+ # the prob of doing some work exponentially decreases from 0.5 till it hits
729
+ # a floor at min_prob (==0.1, by default)
730
+ prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
731
+
732
+ if random.random() < prob:
733
+ sign_gain_factor = 0.5
734
+ if self.min_positive != 0.0 or self.max_positive != 1.0:
735
+ sign_factor = _compute_sign_factor(
736
+ x,
737
+ self.channel_dim,
738
+ self.min_positive,
739
+ self.max_positive,
740
+ gain_factor=self.sign_gain_factor / prob,
741
+ max_factor=self.max_factor,
742
+ )
743
+ else:
744
+ sign_factor = None
745
+
746
+ scale_factor = _compute_scale_factor(
747
+ x.detach(),
748
+ self.channel_dim,
749
+ min_abs=self.min_abs,
750
+ max_abs=self.max_abs,
751
+ gain_factor=self.scale_gain_factor / prob,
752
+ max_factor=self.max_factor,
753
+ )
754
+ return ActivationBalancerFunction.apply(
755
+ x,
756
+ scale_factor,
757
+ sign_factor,
758
+ self.channel_dim,
759
+ )
760
+ else:
761
+ return _no_op(x)
762
+
763
+
764
+ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
765
+ """
766
+ Returns x unmodified, but in backprop will put a penalty for the excess of
767
+ the absolute values of elements of x over the limit "limit". E.g. if
768
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
769
+
770
+ Caution: the value of this penalty will be affected by grad scaling used
771
+ in automatic mixed precision training. For this reasons we use this,
772
+ it shouldn't really matter, or may even be helpful; we just use this
773
+ to disallow really implausible values of scores to be given to softmax.
774
+ """
775
+ x_sign = x.sign()
776
+ over_limit = (x.abs() - limit) > 0
777
+ # The following is a memory efficient way to penalize the absolute values of
778
+ # x that's over the limit. (The memory efficiency comes when you think
779
+ # about which items torch needs to cache for the autograd, and which ones it
780
+ # can throw away). The numerical value of aux_loss as computed here will
781
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
782
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
783
+ # limit).relu().
784
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
785
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
786
+ # sum() due to how with_loss() works.
787
+ x = with_loss(x, aux_loss)
788
+ # you must use x for something, or this will be ineffective.
789
+ return x
790
+
791
+
792
+ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
793
+ if x.ndim == 2:
794
+ return x.diag()
795
+ else:
796
+ (batch, dim, dim) = x.shape
797
+ x = x.reshape(batch, dim * dim)
798
+ x = x[:, :: dim + 1]
799
+ assert x.shape == (batch, dim)
800
+ return x
801
+
802
+
803
+ def _whitening_metric(x: Tensor, num_groups: int):
804
+ """
805
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
806
+ of the centered feature covariance are the same within each group's covariance matrix
807
+ and also between groups.
808
+ Args:
809
+ x: a Tensor of shape (*, num_channels)
810
+ num_groups: the number of groups of channels, a number >=1 that divides num_channels
811
+ Returns:
812
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
813
+ greater than 1.0 otherwise.
814
+ """
815
+ assert x.dtype != torch.float16
816
+ x = x.reshape(-1, x.shape[-1])
817
+ (num_frames, num_channels) = x.shape
818
+ assert num_channels % num_groups == 0
819
+ channels_per_group = num_channels // num_groups
820
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
821
+ # x now has shape (num_groups, num_frames, channels_per_group)
822
+ # subtract the mean so we use the centered, not uncentered, covariance.
823
+ # My experience has been that when we "mess with the gradients" like this,
824
+ # it's better not do anything that tries to move the mean around, because
825
+ # that can easily cause instability.
826
+ x = x - x.mean(dim=1, keepdim=True)
827
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
828
+ x_covar = torch.matmul(x.transpose(1, 2), x)
829
+ x_covar_mean_diag = _diag(x_covar).mean()
830
+ # the following expression is what we'd get if we took the matrix product
831
+ # of each covariance and measured the mean of its trace, i.e.
832
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
833
+ x_covarsq_mean_diag = (x_covar ** 2).sum() / (
834
+ num_groups * channels_per_group
835
+ )
836
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
837
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20)
838
+ return metric
839
+
840
+
841
+ class WhiteningPenaltyFunction(torch.autograd.Function):
842
+ @staticmethod
843
+ def forward(
844
+ ctx,
845
+ x: Tensor,
846
+ num_groups: int,
847
+ whitening_limit: float,
848
+ grad_scale: float,
849
+ ) -> Tensor:
850
+ ctx.save_for_backward(x)
851
+ ctx.num_groups = num_groups
852
+ ctx.whitening_limit = whitening_limit
853
+ ctx.grad_scale = grad_scale
854
+ return x
855
+
856
+ @staticmethod
857
+ def backward(ctx, x_grad: Tensor):
858
+ (x_orig,) = ctx.saved_tensors
859
+ with torch.enable_grad():
860
+ with torch.cuda.amp.autocast(enabled=False):
861
+ x_detached = x_orig.to(torch.float32).detach()
862
+ x_detached.requires_grad = True
863
+
864
+ metric = _whitening_metric(x_detached, ctx.num_groups)
865
+
866
+ if random.random() < 0.005 or __name__ == "__main__":
867
+ logging.info(
868
+ f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
869
+ f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
870
+ )
871
+
872
+ (metric - ctx.whitening_limit).relu().backward()
873
+ penalty_grad = x_detached.grad
874
+ scale = ctx.grad_scale * (
875
+ x_grad.to(torch.float32).norm()
876
+ / (penalty_grad.norm() + 1.0e-20)
877
+ )
878
+ penalty_grad = penalty_grad * scale
879
+ return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
880
+
881
+
882
+ class Whiten(nn.Module):
883
+ def __init__(
884
+ self,
885
+ num_groups: int,
886
+ whitening_limit: float,
887
+ prob: Union[float, Tuple[float, float]],
888
+ grad_scale: float,
889
+ ):
890
+ """
891
+ Args:
892
+ num_groups: the number of groups to divide the channel dim into before
893
+ whitening. We will attempt to make the feature covariance
894
+ within each group, after mean subtraction, as "white" as possible,
895
+ while having the same trace across all groups.
896
+ whitening_limit: a value greater than 1.0, that dictates how much
897
+ freedom we have to violate the constraints. 1.0 would mean perfectly
898
+ white, with exactly the same trace across groups; larger values
899
+ give more freedom. E.g. 2.0.
900
+ prob: the probability with which we apply the gradient modification
901
+ (also affects the grad scale). May be supplied as a float,
902
+ or as a pair (min_prob, max_prob)
903
+
904
+ grad_scale: determines the scale on the gradient term from this object,
905
+ relative to the rest of the gradient on the attention weights.
906
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
907
+ """
908
+ super(Whiten, self).__init__()
909
+ assert num_groups >= 1
910
+ assert whitening_limit >= 1
911
+ assert grad_scale >= 0
912
+ self.num_groups = num_groups
913
+ self.whitening_limit = whitening_limit
914
+ if isinstance(prob, float):
915
+ assert 0 < prob <= 1
916
+ self.prob = prob
917
+ else:
918
+ (self.min_prob, self.max_prob) = prob
919
+ assert 0 < self.min_prob < self.max_prob <= 1
920
+ self.prob = self.max_prob
921
+
922
+ self.grad_scale = grad_scale
923
+
924
+ def forward(self, x: Tensor) -> Tensor:
925
+ """
926
+ In the forward pass, this function just returns the input unmodified.
927
+ In the backward pass, it will modify the gradients to ensure that the
928
+ distribution in each group has close to (lambda times I) as the covariance
929
+ after mean subtraction, with the same lambda across groups.
930
+ For whitening_limit > 1, there will be more freedom to violate this
931
+ constraint.
932
+
933
+ Args:
934
+ x: the input of shape (*, num_channels)
935
+
936
+ Returns:
937
+ x, unmodified. You should make sure
938
+ you use the returned value, or the graph will be freed
939
+ and nothing will happen in backprop.
940
+ """
941
+ if (
942
+ not x.requires_grad
943
+ or random.random() > self.prob
944
+ or self.grad_scale == 0
945
+ ):
946
+ return _no_op(x)
947
+ else:
948
+ if hasattr(self, "min_prob") and random.random() < 0.25:
949
+ # occasionally switch between min_prob and max_prob, based on whether
950
+ # we are above or below the threshold.
951
+ if (
952
+ _whitening_metric(x.to(torch.float32), self.num_groups)
953
+ > self.whitening_limit
954
+ ):
955
+ # there would be a change to the grad.
956
+ self.prob = self.max_prob
957
+ else:
958
+ self.prob = self.min_prob
959
+
960
+ return WhiteningPenaltyFunction.apply(
961
+ x, self.num_groups, self.whitening_limit, self.grad_scale
962
+ )
963
+
964
+
965
+ class WithLoss(torch.autograd.Function):
966
+ @staticmethod
967
+ def forward(ctx, x: Tensor, y: Tensor):
968
+ ctx.y_shape = y.shape
969
+ return x
970
+
971
+ @staticmethod
972
+ def backward(ctx, ans_grad: Tensor):
973
+ return ans_grad, torch.ones(
974
+ ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
975
+ )
976
+
977
+
978
+ def with_loss(x, y):
979
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
980
+ return x
981
+ # returns x but adds y.sum() to the loss function.
982
+ return WithLoss.apply(x, y)
983
+
984
+
985
+ def _no_op(x: Tensor) -> Tensor:
986
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
987
+ return x
988
+ else:
989
+ # a no-op function that will have a node in the autograd graph,
990
+ # to avoid certain bugs relating to backward hooks
991
+ return x.chunk(1, dim=-1)[0]
992
+
993
+
994
+ class Identity(torch.nn.Module):
995
+ def __init__(self):
996
+ super(Identity, self).__init__()
997
+
998
+ def forward(self, x):
999
+ return _no_op(x)
1000
+
1001
+
1002
+ class MaxEig(torch.nn.Module):
1003
+ """
1004
+ Modifies the backpropped derivatives of a function to try to discourage
1005
+ that any given direction in activation space accounts for more than
1006
+ a specified proportion of the covariance (e.g. 0.2).
1007
+
1008
+
1009
+ Args:
1010
+ num_channels: the number of channels
1011
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
1012
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
1013
+ max_var_per_eig: the maximum proportion of the variance of the
1014
+ features/channels, after mean subtraction, that can come from
1015
+ any given eigenvalue.
1016
+ min_prob: the minimum probability with which we apply this during any invocation
1017
+ of forward(), assuming last time we applied the constraint it was
1018
+ not active; supplied for speed.
1019
+ scale: determines the scale with which we modify the gradients, relative
1020
+ to the existing / unmodified gradients
1021
+ """
1022
+
1023
+ def __init__(
1024
+ self,
1025
+ num_channels: int,
1026
+ channel_dim: int,
1027
+ max_var_per_eig: float = 0.2,
1028
+ min_prob: float = 0.01,
1029
+ scale: float = 0.01,
1030
+ ):
1031
+ super(MaxEig, self).__init__()
1032
+ self.num_channels = num_channels
1033
+ self.channel_dim = channel_dim
1034
+ self.scale = scale
1035
+ assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
1036
+ self.max_var_per_eig = max_var_per_eig
1037
+
1038
+ # we figure out the dominant direction using the power method: starting with
1039
+ # a random vector, keep multiplying by the covariance and renormalizing.
1040
+ with torch.no_grad():
1041
+ # arbitrary.. would use randn() but want to leave the rest of the model's
1042
+ # random parameters unchanged for comparison
1043
+ direction = torch.arange(num_channels).to(torch.float)
1044
+ direction = direction / direction.norm()
1045
+ self.register_buffer("max_eig_direction", direction)
1046
+
1047
+ self.min_prob = min_prob
1048
+ # cur_prob is the current probability we'll use to apply the ActivationBalancer.
1049
+ # We'll regress this towards prob, each tiem we try to apply it and it is not
1050
+ # active.
1051
+ self.cur_prob = 1.0
1052
+
1053
+ def forward(self, x: Tensor) -> Tensor:
1054
+ if (
1055
+ torch.jit.is_scripting()
1056
+ or self.max_var_per_eig <= 0
1057
+ or random.random() > self.cur_prob
1058
+ or torch.jit.is_tracing()
1059
+ ):
1060
+ return _no_op(x)
1061
+
1062
+ with torch.cuda.amp.autocast(enabled=False):
1063
+ eps = 1.0e-20
1064
+ orig_x = x
1065
+ x = x.to(torch.float32)
1066
+ with torch.no_grad():
1067
+ x = x.transpose(self.channel_dim, -1).reshape(
1068
+ -1, self.num_channels
1069
+ )
1070
+ x = x - x.mean(dim=0)
1071
+ new_direction, coeffs = self._find_direction_coeffs(
1072
+ x, self.max_eig_direction
1073
+ )
1074
+ x_var = (x ** 2).mean()
1075
+ x_residual = x - coeffs * new_direction
1076
+ x_residual_var = (x_residual ** 2).mean()
1077
+
1078
+ # `variance_proportion` is the proportion of the variance accounted for
1079
+ # by the top eigen-direction.
1080
+ variance_proportion = (x_var - x_residual_var) / (
1081
+ x_var + 1.0e-20
1082
+ )
1083
+
1084
+ # ensure new direction is nonzero even if x == 0, by including `direction`.
1085
+ self._set_direction(
1086
+ 0.1 * self.max_eig_direction + new_direction
1087
+ )
1088
+
1089
+ if random.random() < 0.01 or __name__ == "__main__":
1090
+ logging.info(
1091
+ f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
1092
+ )
1093
+
1094
+ if variance_proportion >= self.max_var_per_eig:
1095
+ # The constraint is active. Note, we should quite rarely
1096
+ # reach here, only near the beginning of training if we are
1097
+ # starting to diverge, should this constraint be active.
1098
+ cur_prob = self.cur_prob
1099
+ self.cur_prob = (
1100
+ 1.0 # next time, do the update with probability 1.0.
1101
+ )
1102
+ return MaxEigLimiterFunction.apply(
1103
+ orig_x, coeffs, new_direction, self.channel_dim, self.scale
1104
+ )
1105
+ else:
1106
+ # let self.cur_prob exponentially approach self.min_prob, as
1107
+ # long as the constraint is inactive.
1108
+ self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
1109
+ return orig_x
1110
+
1111
+ def _set_direction(self, direction: Tensor):
1112
+ """
1113
+ Sets self.max_eig_direction to a normalized version of `direction`
1114
+ """
1115
+ direction = direction.detach()
1116
+ direction = direction / direction.norm()
1117
+ direction_sum = direction.sum().item()
1118
+ if direction_sum - direction_sum == 0: # no inf/nan
1119
+ self.max_eig_direction[:] = direction
1120
+ else:
1121
+ logging.info(
1122
+ f"Warning: sum of direction in MaxEig is {direction_sum}, "
1123
+ "num_channels={self.num_channels}, channel_dim={self.channel_dim}"
1124
+ )
1125
+
1126
+ def _find_direction_coeffs(
1127
+ self, x: Tensor, prev_direction: Tensor
1128
+ ) -> Tuple[Tensor, Tensor, Tensor]:
1129
+ """
1130
+ Figure out (an approximation to) the proportion of the variance of a set of
1131
+ feature vectors that can be attributed to the top eigen-direction.
1132
+ Args:
1133
+ x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
1134
+ prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
1135
+ of the top eigen-direction, or a random direction if this is the first
1136
+ iteration. Does not have to be normalized, but should be nonzero.
1137
+
1138
+ Returns: (cur_direction, coeffs), where:
1139
+ cur_direction: a Tensor of shape (num_channels,) that is the current
1140
+ estimate of the top eigen-direction.
1141
+ coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
1142
+ approximately minimizes, (x - coeffs * cur_direction).norm()
1143
+ """
1144
+ (num_frames, num_channels) = x.shape
1145
+ assert num_channels > 1 and num_frames > 1
1146
+ assert prev_direction.shape == (num_channels,)
1147
+ # `coeffs` are the coefficients of `prev_direction` in x.
1148
+ # actually represent the coeffs up to a constant positive factor.
1149
+ coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
1150
+ cur_direction = (x * coeffs).sum(dim=0) / (
1151
+ (coeffs ** 2).sum() + 1.0e-20
1152
+ )
1153
+ return cur_direction, coeffs
1154
+
1155
+
1156
+ class DoubleSwishFunction(torch.autograd.Function):
1157
+ """
1158
+ double_swish(x) = x * torch.sigmoid(x-1)
1159
+ This is a definition, originally motivated by its close numerical
1160
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
1161
+
1162
+ Memory-efficient derivative computation:
1163
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
1164
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
1165
+ Now, s'(x) = s(x) * (1-s(x)).
1166
+ double_swish'(x) = x * s'(x) + s(x).
1167
+ = x * s(x) * (1-s(x)) + s(x).
1168
+ = double_swish(x) * (1-s(x)) + s(x)
1169
+ ... so we just need to remember s(x) but not x itself.
1170
+ """
1171
+
1172
+ @staticmethod
1173
+ def forward(ctx, x: Tensor) -> Tensor:
1174
+ requires_grad = x.requires_grad
1175
+ x_dtype = x.dtype
1176
+ if x.dtype == torch.float16:
1177
+ x = x.to(torch.float32)
1178
+
1179
+ s = torch.sigmoid(x - 1.0)
1180
+ y = x * s
1181
+
1182
+ if requires_grad:
1183
+ deriv = y * (1 - s) + s
1184
+ # notes on derivative of x * sigmoid(x - 1):
1185
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
1186
+ # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
1187
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
1188
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
1189
+ # floors), should be expectation-preserving.
1190
+ floor = -0.043637
1191
+ ceil = 1.2
1192
+ d_scaled = (deriv - floor) * (
1193
+ 255.0 / (ceil - floor)
1194
+ ) + torch.rand_like(deriv)
1195
+ if __name__ == "__main__":
1196
+ # for self-testing only.
1197
+ assert d_scaled.min() >= 0.0
1198
+ assert d_scaled.max() < 256.0
1199
+ d_int = d_scaled.to(torch.uint8)
1200
+ ctx.save_for_backward(d_int)
1201
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1202
+ y = y.to(torch.float16)
1203
+ return y
1204
+
1205
+ @staticmethod
1206
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1207
+ (d,) = ctx.saved_tensors
1208
+ # the same constants as used in forward pass.
1209
+ floor = -0.043637
1210
+ ceil = 1.2
1211
+ d = d * ((ceil - floor) / 255.0) + floor
1212
+ return y_grad * d
1213
+
1214
+
1215
+ class DoubleSwish(torch.nn.Module):
1216
+ def forward(self, x: Tensor) -> Tensor:
1217
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
1218
+ that we approximate closely with x * sigmoid(x-1).
1219
+ """
1220
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1221
+ return x * torch.sigmoid(x - 1.0)
1222
+ return DoubleSwishFunction.apply(x)
1223
+
1224
+
1225
+ def BalancedDoubleSwish(
1226
+ d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
1227
+ ) -> nn.Sequential:
1228
+ """
1229
+ ActivationBalancer -> DoubleSwish
1230
+ """
1231
+ balancer = ActivationBalancer(
1232
+ d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
1233
+ )
1234
+ return nn.Sequential(
1235
+ balancer,
1236
+ DoubleSwish(),
1237
+ )
1238
+
1239
+
1240
+ def _test_max_eig():
1241
+ for proportion in [0.1, 0.5, 10.0]:
1242
+ logging.info(f"proportion = {proportion}")
1243
+ x = torch.randn(100, 128)
1244
+ direction = torch.randn(128)
1245
+ coeffs = torch.randn(100, 1)
1246
+ x += proportion * direction * coeffs
1247
+
1248
+ x.requires_grad = True
1249
+
1250
+ num_channels = 128
1251
+ m = MaxEig(
1252
+ num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
1253
+ ) # grad_scale
1254
+
1255
+ for _ in range(4):
1256
+ y = m(x)
1257
+
1258
+ y_grad = torch.randn_like(x)
1259
+ y.backward(gradient=y_grad)
1260
+
1261
+ if proportion < 0.2:
1262
+ assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
1263
+ elif proportion > 1.0:
1264
+ assert not torch.allclose(x.grad, y_grad)
1265
+
1266
+
1267
+ def _test_whiten():
1268
+ for proportion in [0.1, 0.5, 10.0]:
1269
+ logging.info(f"_test_whiten(): proportion = {proportion}")
1270
+ x = torch.randn(100, 128)
1271
+ direction = torch.randn(128)
1272
+ coeffs = torch.randn(100, 1)
1273
+ x += proportion * direction * coeffs
1274
+
1275
+ x.requires_grad = True
1276
+
1277
+ num_channels = 128
1278
+ m = Whiten(
1279
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
1280
+ ) # grad_scale
1281
+
1282
+ for _ in range(4):
1283
+ y = m(x)
1284
+
1285
+ y_grad = torch.randn_like(x)
1286
+ y.backward(gradient=y_grad)
1287
+
1288
+ if proportion < 0.2:
1289
+ assert torch.allclose(x.grad, y_grad)
1290
+ elif proportion > 1.0:
1291
+ assert not torch.allclose(x.grad, y_grad)
1292
+
1293
+
1294
+ def _test_activation_balancer_sign():
1295
+ probs = torch.arange(0, 1, 0.01)
1296
+ N = 1000
1297
+ x = 1.0 * (
1298
+ (2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0
1299
+ )
1300
+ x = x.detach()
1301
+ x.requires_grad = True
1302
+ m = ActivationBalancer(
1303
+ probs.numel(),
1304
+ channel_dim=0,
1305
+ min_positive=0.05,
1306
+ max_positive=0.95,
1307
+ max_factor=0.2,
1308
+ min_abs=0.0,
1309
+ )
1310
+
1311
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
1312
+
1313
+ y = m(x)
1314
+ y.backward(gradient=y_grad)
1315
+ print("_test_activation_balancer_sign: x = ", x)
1316
+ print("_test_activation_balancer_sign: y grad = ", y_grad)
1317
+ print("_test_activation_balancer_sign: x grad = ", x.grad)
1318
+
1319
+
1320
+ def _test_activation_balancer_magnitude():
1321
+ magnitudes = torch.arange(0, 1, 0.01)
1322
+ N = 1000
1323
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
1324
+ -1
1325
+ )
1326
+ x = x.detach()
1327
+ x.requires_grad = True
1328
+ m = ActivationBalancer(
1329
+ magnitudes.numel(),
1330
+ channel_dim=0,
1331
+ min_positive=0.0,
1332
+ max_positive=1.0,
1333
+ max_factor=0.2,
1334
+ min_abs=0.2,
1335
+ max_abs=0.8,
1336
+ min_prob=1.0,
1337
+ )
1338
+
1339
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
1340
+
1341
+ y = m(x)
1342
+ y.backward(gradient=y_grad)
1343
+ print("_test_activation_balancer_magnitude: x = ", x)
1344
+ print("_test_activation_balancer_magnitude: y grad = ", y_grad)
1345
+ print("_test_activation_balancer_magnitude: x grad = ", x.grad)
1346
+
1347
+
1348
+ def _test_basic_norm():
1349
+ num_channels = 128
1350
+ m = BasicNorm(num_channels=num_channels, channel_dim=1)
1351
+
1352
+ x = torch.randn(500, num_channels)
1353
+
1354
+ y = m(x)
1355
+
1356
+ assert y.shape == x.shape
1357
+ x_rms = (x ** 2).mean().sqrt()
1358
+ y_rms = (y ** 2).mean().sqrt()
1359
+ print("x rms = ", x_rms)
1360
+ print("y rms = ", y_rms)
1361
+ assert y_rms < x_rms
1362
+ assert y_rms > 0.5 * x_rms
1363
+
1364
+
1365
+ def _test_double_swish_deriv():
1366
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1367
+ x.requires_grad = True
1368
+ m = DoubleSwish()
1369
+
1370
+ tol = (1.2 - (-0.043637)) / 255.0
1371
+ torch.autograd.gradcheck(m, x, atol=tol)
1372
+
1373
+ # for self-test.
1374
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1375
+ x.requires_grad = True
1376
+ y = m(x)
1377
+
1378
+
1379
+ def _test_softmax():
1380
+ a = torch.randn(2, 10, dtype=torch.float64)
1381
+ b = a.clone()
1382
+ a.requires_grad = True
1383
+ b.requires_grad = True
1384
+ a.softmax(dim=1)[:, 0].sum().backward()
1385
+ print("a grad = ", a.grad)
1386
+ softmax(b, dim=1)[:, 0].sum().backward()
1387
+ print("b grad = ", b.grad)
1388
+ assert torch.allclose(a.grad, b.grad)
1389
+
1390
+
1391
+ if __name__ == "__main__":
1392
+ logging.getLogger().setLevel(logging.INFO)
1393
+ torch.set_num_threads(1)
1394
+ torch.set_num_interop_threads(1)
1395
+ _test_softmax()
1396
+ _test_whiten()
1397
+ _test_max_eig()
1398
+ _test_activation_balancer_sign()
1399
+ _test_activation_balancer_magnitude()
1400
+ _test_basic_norm()
1401
+ _test_double_swish_deriv()
modules/transformer.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import numbers
3
+ from functools import partial
4
+ from typing import Any, Callable, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from torch import Tensor, nn
8
+ from torch.nn import functional as F
9
+
10
+ from .activation import MultiheadAttention
11
+ from .scaling import ActivationBalancer, BalancedDoubleSwish
12
+ from .scaling import BasicNorm as _BasicNorm
13
+
14
+ _shape_t = Union[int, List[int], torch.Size]
15
+
16
+
17
+ class LayerNorm(nn.Module):
18
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
19
+ normalized_shape: Tuple[int, ...]
20
+ eps: float
21
+ elementwise_affine: bool
22
+
23
+ def __init__(
24
+ self,
25
+ normalized_shape: _shape_t,
26
+ eps: float = 1e-5,
27
+ elementwise_affine: bool = True,
28
+ device=None,
29
+ dtype=None,
30
+ ) -> None:
31
+ factory_kwargs = {"device": device, "dtype": dtype}
32
+ super(LayerNorm, self).__init__()
33
+ if isinstance(normalized_shape, numbers.Integral):
34
+ # mypy error: incompatible types in assignment
35
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
36
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
37
+ self.eps = eps
38
+ self.elementwise_affine = elementwise_affine
39
+ if self.elementwise_affine:
40
+ self.weight = nn.Parameter(
41
+ torch.empty(self.normalized_shape, **factory_kwargs)
42
+ )
43
+ self.bias = nn.Parameter(
44
+ torch.empty(self.normalized_shape, **factory_kwargs)
45
+ )
46
+ else:
47
+ self.register_parameter("weight", None)
48
+ self.register_parameter("bias", None)
49
+
50
+ self.reset_parameters()
51
+
52
+ def reset_parameters(self) -> None:
53
+ if self.elementwise_affine:
54
+ nn.init.ones_(self.weight)
55
+ nn.init.zeros_(self.bias)
56
+
57
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
58
+ if isinstance(input, tuple):
59
+ input, embedding = input
60
+ return (
61
+ F.layer_norm(
62
+ input,
63
+ self.normalized_shape,
64
+ self.weight,
65
+ self.bias,
66
+ self.eps,
67
+ ),
68
+ embedding,
69
+ )
70
+
71
+ assert embedding is None
72
+ return F.layer_norm(
73
+ input, self.normalized_shape, self.weight, self.bias, self.eps
74
+ )
75
+
76
+ def extra_repr(self) -> str:
77
+ return (
78
+ "{normalized_shape}, eps={eps}, "
79
+ "elementwise_affine={elementwise_affine}".format(**self.__dict__)
80
+ )
81
+
82
+
83
+ class AdaptiveLayerNorm(nn.Module):
84
+ r"""Adaptive Layer Normalization"""
85
+
86
+ def __init__(self, d_model, norm) -> None:
87
+ super(AdaptiveLayerNorm, self).__init__()
88
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
89
+ self.norm = norm
90
+ self.d_model = d_model
91
+ self.eps = self.norm.eps
92
+
93
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
94
+ if isinstance(input, tuple):
95
+ input, embedding = input
96
+ weight, bias = torch.split(
97
+ self.project_layer(embedding),
98
+ split_size_or_sections=self.d_model,
99
+ dim=-1,
100
+ )
101
+ return (weight * self.norm(input) + bias, embedding)
102
+
103
+ weight, bias = torch.split(
104
+ self.project_layer(embedding),
105
+ split_size_or_sections=self.d_model,
106
+ dim=-1,
107
+ )
108
+ return weight * self.norm(input) + bias
109
+
110
+
111
+ class BasicNorm(_BasicNorm):
112
+ def __init__(
113
+ self,
114
+ d_model: int,
115
+ eps: float = 1e-5,
116
+ device=None,
117
+ dtype=None,
118
+ ):
119
+ super(BasicNorm, self).__init__(d_model, eps=eps)
120
+
121
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
122
+ if isinstance(input, tuple):
123
+ input, embedding = input
124
+ return (
125
+ super(BasicNorm, self).forward(input),
126
+ embedding,
127
+ )
128
+
129
+ assert embedding is None
130
+ return super(BasicNorm, self).forward(input)
131
+
132
+
133
+ class BalancedBasicNorm(nn.Module):
134
+ def __init__(
135
+ self,
136
+ d_model: int,
137
+ eps: float = 1e-5,
138
+ device=None,
139
+ dtype=None,
140
+ ):
141
+ super(BalancedBasicNorm, self).__init__()
142
+ self.balancer = ActivationBalancer(
143
+ d_model,
144
+ channel_dim=-1,
145
+ min_positive=0.45,
146
+ max_positive=0.55,
147
+ max_abs=6.0,
148
+ )
149
+ self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
150
+
151
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
152
+ if isinstance(input, tuple):
153
+ input, embedding = input
154
+ return self.norm((self.balancer(input), embedding))
155
+
156
+ assert embedding is None
157
+ return self.norm(self.balancer(input))
158
+
159
+
160
+ class IdentityNorm(nn.Module):
161
+ def __init__(
162
+ self,
163
+ d_model: int,
164
+ eps: float = 1e-5,
165
+ device=None,
166
+ dtype=None,
167
+ ) -> None:
168
+ super(IdentityNorm, self).__init__()
169
+
170
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
171
+ if isinstance(input, tuple):
172
+ return input
173
+
174
+ assert embedding is None
175
+ return input
176
+
177
+
178
+ class TransformerEncoderLayer(nn.Module):
179
+ __constants__ = ["batch_first", "norm_first"]
180
+
181
+ def __init__(
182
+ self,
183
+ d_model: int,
184
+ nhead: int,
185
+ dim_feedforward: int = 2048,
186
+ dropout: float = 0.1,
187
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
188
+ batch_first: bool = False,
189
+ norm_first: bool = False,
190
+ device=None,
191
+ dtype=None,
192
+ linear1_self_attention_cls: nn.Module = nn.Linear,
193
+ linear2_self_attention_cls: nn.Module = nn.Linear,
194
+ linear1_feedforward_cls: nn.Module = nn.Linear,
195
+ linear2_feedforward_cls: nn.Module = nn.Linear,
196
+ layer_norm_cls: nn.Module = LayerNorm,
197
+ layer_norm_eps: float = 1e-5,
198
+ adaptive_layer_norm=False,
199
+ ) -> None:
200
+ factory_kwargs = {"device": device, "dtype": dtype}
201
+ super(TransformerEncoderLayer, self).__init__()
202
+ self.self_attn = MultiheadAttention(
203
+ d_model,
204
+ nhead,
205
+ dropout=dropout,
206
+ batch_first=batch_first,
207
+ linear1_cls=linear1_self_attention_cls,
208
+ linear2_cls=linear2_self_attention_cls,
209
+ **factory_kwargs,
210
+ )
211
+
212
+ # Implementation of Feedforward model
213
+ self.linear1 = linear1_feedforward_cls(
214
+ d_model, dim_feedforward, **factory_kwargs
215
+ )
216
+ self.dropout = nn.Dropout(dropout)
217
+ self.linear2 = linear2_feedforward_cls(
218
+ dim_feedforward, d_model, **factory_kwargs
219
+ )
220
+
221
+ self.norm_first = norm_first
222
+ self.dropout1 = nn.Dropout(dropout)
223
+ self.dropout2 = nn.Dropout(dropout)
224
+
225
+ # Legacy string support for activation function.
226
+ if isinstance(activation, str):
227
+ activation = _get_activation_fn(activation)
228
+ elif isinstance(activation, partial):
229
+ activation = activation(d_model)
230
+ elif activation == BalancedDoubleSwish:
231
+ activation = BalancedDoubleSwish(d_model)
232
+
233
+ # # We can't test self.activation in forward() in TorchScript,
234
+ # # so stash some information about it instead.
235
+ # if activation is F.relu or isinstance(activation, torch.nn.ReLU):
236
+ # self.activation_relu_or_gelu = 1
237
+ # elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
238
+ # self.activation_relu_or_gelu = 2
239
+ # else:
240
+ # self.activation_relu_or_gelu = 0
241
+ self.activation = activation
242
+
243
+ norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
244
+ if layer_norm_cls == IdentityNorm:
245
+ norm2 = BalancedBasicNorm(
246
+ d_model, eps=layer_norm_eps, **factory_kwargs
247
+ )
248
+ else:
249
+ norm2 = layer_norm_cls(
250
+ d_model, eps=layer_norm_eps, **factory_kwargs
251
+ )
252
+
253
+ if adaptive_layer_norm:
254
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
255
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
256
+ else:
257
+ self.norm1 = norm1
258
+ self.norm2 = norm2
259
+
260
+ def __setstate__(self, state):
261
+ super(TransformerEncoderLayer, self).__setstate__(state)
262
+ if not hasattr(self, "activation"):
263
+ self.activation = F.relu
264
+
265
+ def forward(
266
+ self,
267
+ src: Tensor,
268
+ src_mask: Optional[Tensor] = None,
269
+ src_key_padding_mask: Optional[Tensor] = None,
270
+ ) -> Tensor:
271
+ r"""Pass the input through the encoder layer.
272
+
273
+ Args:
274
+ src: the sequence to the encoder layer (required).
275
+ src_mask: the mask for the src sequence (optional).
276
+ src_key_padding_mask: the mask for the src keys per batch (optional).
277
+
278
+ Shape:
279
+ see the docs in Transformer class.
280
+ """
281
+ x, stage_embedding = src, None
282
+ is_src_tuple = False
283
+ if isinstance(src, tuple):
284
+ x, stage_embedding = src
285
+ is_src_tuple = True
286
+
287
+ if src_key_padding_mask is not None:
288
+ _skpm_dtype = src_key_padding_mask.dtype
289
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(
290
+ src_key_padding_mask
291
+ ):
292
+ raise AssertionError(
293
+ "only bool and floating types of key_padding_mask are supported"
294
+ )
295
+
296
+ if self.norm_first:
297
+ x = x + self._sa_block(
298
+ self.norm1(x, stage_embedding),
299
+ src_mask,
300
+ src_key_padding_mask,
301
+ )
302
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
303
+ else:
304
+ x = self.norm1(
305
+ x + self._sa_block(x, src_mask, src_key_padding_mask),
306
+ stage_embedding,
307
+ )
308
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
309
+
310
+ if is_src_tuple:
311
+ return (x, stage_embedding)
312
+ return x
313
+
314
+ def infer(
315
+ self,
316
+ src: Tensor,
317
+ src_mask: Optional[Tensor] = None,
318
+ src_key_padding_mask: Optional[Tensor] = None,
319
+ past_kv: Optional[Tensor] = None,
320
+ use_cache: bool = False,
321
+ ):
322
+ x, stage_embedding = src, None
323
+ is_src_tuple = False
324
+ if isinstance(src, tuple):
325
+ x, stage_embedding = src
326
+ is_src_tuple = True
327
+
328
+ if src_key_padding_mask is not None:
329
+ _skpm_dtype = src_key_padding_mask.dtype
330
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(
331
+ src_key_padding_mask
332
+ ):
333
+ raise AssertionError(
334
+ "only bool and floating types of key_padding_mask are supported"
335
+ )
336
+
337
+ if self.norm_first:
338
+ x_attn_out, kv = self.self_attn.infer(
339
+ self.norm1(x, stage_embedding),
340
+ attn_mask=src_mask,
341
+ key_padding_mask=src_key_padding_mask,
342
+ need_weights=False,
343
+ past_kv=past_kv,
344
+ use_cache=use_cache,
345
+ )
346
+ x = x + x_attn_out
347
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
348
+
349
+ if is_src_tuple:
350
+ return (x, stage_embedding)
351
+ return (x, kv)
352
+
353
+ # self-attention block
354
+ def _sa_block(
355
+ self,
356
+ x: Tensor,
357
+ attn_mask: Optional[Tensor],
358
+ key_padding_mask: Optional[Tensor],
359
+ ) -> Tensor:
360
+ x = self.self_attn(
361
+ x,
362
+ x,
363
+ x,
364
+ attn_mask=attn_mask,
365
+ key_padding_mask=key_padding_mask,
366
+ need_weights=False,
367
+ )[0]
368
+ return self.dropout1(x)
369
+
370
+ # feed forward block
371
+ def _ff_block(self, x: Tensor) -> Tensor:
372
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
373
+ return self.dropout2(x)
374
+
375
+
376
+ class TransformerEncoder(nn.Module):
377
+ r"""TransformerEncoder is a stack of N encoder layers. Users can build the
378
+ BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
379
+
380
+ Args:
381
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
382
+ num_layers: the number of sub-encoder-layers in the encoder (required).
383
+ norm: the layer normalization component (optional).
384
+ enable_nested_tensor: if True, input will automatically convert to nested tensor
385
+ (and convert back on output). This will improve the overall performance of
386
+ TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
387
+
388
+ Examples::
389
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
390
+ >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
391
+ >>> src = torch.rand(10, 32, 512)
392
+ >>> out = transformer_encoder(src)
393
+ """
394
+ __constants__ = ["norm"]
395
+
396
+ def __init__(self, encoder_layer, num_layers, norm=None):
397
+ super(TransformerEncoder, self).__init__()
398
+ self.layers = _get_clones(encoder_layer, num_layers)
399
+ self.num_layers = num_layers
400
+ self.norm = norm
401
+
402
+ def forward(
403
+ self,
404
+ src: Tensor,
405
+ mask: Optional[Tensor] = None,
406
+ src_key_padding_mask: Optional[Tensor] = None,
407
+ return_layer_states: bool = False,
408
+ ) -> Tensor:
409
+ r"""Pass the input through the encoder layers in turn.
410
+
411
+ Args:
412
+ src: the sequence to the encoder (required).
413
+ mask: the mask for the src sequence (optional).
414
+ src_key_padding_mask: the mask for the src keys per batch (optional).
415
+ return_layer_states: return layers' state (optional).
416
+
417
+ Shape:
418
+ see the docs in Transformer class.
419
+ """
420
+ if return_layer_states:
421
+ layer_states = [] # layers' output
422
+ output = src
423
+ for mod in self.layers:
424
+ output = mod(
425
+ output,
426
+ src_mask=mask,
427
+ src_key_padding_mask=src_key_padding_mask,
428
+ )
429
+ layer_states.append(output[0])
430
+
431
+ if self.norm is not None:
432
+ output = self.norm(output)
433
+
434
+ return layer_states, output
435
+
436
+ output = src
437
+ for mod in self.layers:
438
+ output = mod(
439
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
440
+ )
441
+
442
+ if self.norm is not None:
443
+ output = self.norm(output)
444
+
445
+ return output
446
+
447
+ def infer(
448
+ self,
449
+ src: Tensor,
450
+ mask: Optional[Tensor] = None,
451
+ src_key_padding_mask: Optional[Tensor] = None,
452
+ return_layer_states: bool = False,
453
+ past_kv: Optional[Tensor] = None,
454
+ use_cache: bool = False,
455
+ ):
456
+ if past_kv is None:
457
+ past_length = 0
458
+ past_kv = tuple([None] * self.num_layers)
459
+ else:
460
+ past_length = past_kv[0][0].size(-2)
461
+ new_kv = () if use_cache else None
462
+ output = src
463
+ for mod, past_layer_kv in zip(self.layers, past_kv):
464
+ output, kv = mod.infer(
465
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past_kv=past_layer_kv, use_cache=use_cache
466
+ )
467
+ if use_cache:
468
+ new_kv = new_kv + (kv,)
469
+
470
+ if self.norm is not None:
471
+ output = self.norm(output)
472
+
473
+ return output, new_kv
474
+
475
+
476
+ class TransformerDecoderLayer(nn.Module):
477
+ __constants__ = ["batch_first", "norm_first"]
478
+
479
+ def __init__(
480
+ self,
481
+ d_model: int,
482
+ nhead: int,
483
+ dim_feedforward: int = 2048,
484
+ dropout: float = 0.1,
485
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
486
+ linear1_self_attention_cls: nn.Module = nn.Linear,
487
+ linear2_self_attention_cls: nn.Module = nn.Linear,
488
+ linear1_feedforward_cls: nn.Module = nn.Linear,
489
+ linear2_feedforward_cls: nn.Module = nn.Linear,
490
+ batch_first: bool = False,
491
+ norm_first: bool = False,
492
+ device=None,
493
+ dtype=None,
494
+ layer_norm_cls: nn.Module = LayerNorm,
495
+ layer_norm_eps: float = 1e-5,
496
+ adaptive_layer_norm=False,
497
+ ) -> None:
498
+ factory_kwargs = {"device": device, "dtype": dtype}
499
+ super(TransformerDecoderLayer, self).__init__()
500
+ self.self_attn = MultiheadAttention(
501
+ d_model,
502
+ nhead,
503
+ dropout=dropout,
504
+ batch_first=batch_first,
505
+ linear1_cls=linear1_self_attention_cls,
506
+ linear2_cls=linear2_self_attention_cls,
507
+ **factory_kwargs,
508
+ )
509
+ self.multihead_attn = MultiheadAttention(
510
+ d_model,
511
+ nhead,
512
+ dropout=dropout,
513
+ batch_first=batch_first,
514
+ linear1_cls=linear1_self_attention_cls,
515
+ linear2_cls=linear2_self_attention_cls,
516
+ **factory_kwargs,
517
+ )
518
+ # Implementation of Feedforward model
519
+ self.linear1 = linear1_feedforward_cls(
520
+ d_model, dim_feedforward, **factory_kwargs
521
+ )
522
+ self.dropout = nn.Dropout(dropout)
523
+ self.linear2 = linear2_feedforward_cls(
524
+ dim_feedforward, d_model, **factory_kwargs
525
+ )
526
+
527
+ self.norm_first = norm_first
528
+ self.dropout1 = nn.Dropout(dropout)
529
+ self.dropout2 = nn.Dropout(dropout)
530
+ self.dropout3 = nn.Dropout(dropout)
531
+
532
+ # Legacy string support for activation function.
533
+ if isinstance(activation, str):
534
+ self.activation = _get_activation_fn(activation)
535
+ elif isinstance(activation, partial):
536
+ self.activation = activation(d_model)
537
+ elif activation == BalancedDoubleSwish:
538
+ self.activation = BalancedDoubleSwish(d_model)
539
+ else:
540
+ self.activation = activation
541
+
542
+ if adaptive_layer_norm:
543
+ norm1 = layer_norm_cls(
544
+ d_model, eps=layer_norm_eps, **factory_kwargs
545
+ )
546
+ norm2 = layer_norm_cls(
547
+ d_model, eps=layer_norm_eps, **factory_kwargs
548
+ )
549
+ norm3 = layer_norm_cls(
550
+ d_model, eps=layer_norm_eps, **factory_kwargs
551
+ )
552
+
553
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
554
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
555
+ self.norm3 = AdaptiveLayerNorm(d_model, norm3)
556
+ else:
557
+ self.norm1 = layer_norm_cls(
558
+ d_model, eps=layer_norm_eps, **factory_kwargs
559
+ )
560
+ self.norm2 = layer_norm_cls(
561
+ d_model, eps=layer_norm_eps, **factory_kwargs
562
+ )
563
+ if layer_norm_cls == IdentityNorm:
564
+ self.norm3 = BalancedBasicNorm(
565
+ d_model, eps=layer_norm_eps, **factory_kwargs
566
+ )
567
+ else:
568
+ self.norm3 = layer_norm_cls(
569
+ d_model, eps=layer_norm_eps, **factory_kwargs
570
+ )
571
+
572
+ def forward(
573
+ self,
574
+ tgt: Tensor,
575
+ memory: Tensor,
576
+ tgt_mask: Optional[Tensor] = None,
577
+ memory_mask: Optional[Tensor] = None,
578
+ tgt_key_padding_mask: Optional[Tensor] = None,
579
+ memory_key_padding_mask: Optional[Tensor] = None,
580
+ ) -> Tensor:
581
+ r"""Pass the inputs (and mask) through the decoder layer.
582
+
583
+ Args:
584
+ tgt: the sequence to the decoder layer (required).
585
+ memory: the sequence from the last layer of the encoder (required).
586
+ tgt_mask: the mask for the tgt sequence (optional).
587
+ memory_mask: the mask for the memory sequence (optional).
588
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
589
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
590
+
591
+ Shape:
592
+ see the docs in Transformer class.
593
+ """
594
+ tgt_is_tuple = False
595
+ if isinstance(tgt, tuple):
596
+ x, stage_embedding = tgt
597
+ tgt_is_tuple = True
598
+ else:
599
+ x, stage_embedding = tgt, None
600
+
601
+ if self.norm_first:
602
+ x = x + self._sa_block(
603
+ self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
604
+ )
605
+ x = x + self._mha_block(
606
+ self.norm2(x, stage_embedding),
607
+ memory,
608
+ memory_mask,
609
+ memory_key_padding_mask,
610
+ )
611
+ x = x + self._ff_block(self.norm3(x, stage_embedding))
612
+ else:
613
+ x = self.norm1(
614
+ x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
615
+ stage_embedding,
616
+ )
617
+ x = self.norm2(
618
+ x
619
+ + self._mha_block(
620
+ x, memory, memory_mask, memory_key_padding_mask
621
+ ),
622
+ stage_embedding,
623
+ )
624
+ x = self.norm3(x + self._ff_block(x), stage_embedding)
625
+
626
+ if tgt_is_tuple:
627
+ return (x, stage_embedding)
628
+ return x
629
+
630
+ # self-attention block
631
+ def _sa_block(
632
+ self,
633
+ x: Tensor,
634
+ attn_mask: Optional[Tensor],
635
+ key_padding_mask: Optional[Tensor],
636
+ ) -> Tensor:
637
+ x = self.self_attn(
638
+ x,
639
+ x,
640
+ x,
641
+ attn_mask=attn_mask,
642
+ key_padding_mask=key_padding_mask,
643
+ need_weights=False,
644
+ )[0]
645
+ return self.dropout1(x)
646
+
647
+ # multihead attention block
648
+ def _mha_block(
649
+ self,
650
+ x: Tensor,
651
+ mem: Tensor,
652
+ attn_mask: Optional[Tensor],
653
+ key_padding_mask: Optional[Tensor],
654
+ ) -> Tensor:
655
+ x = self.multihead_attn(
656
+ x,
657
+ mem,
658
+ mem,
659
+ attn_mask=attn_mask,
660
+ key_padding_mask=key_padding_mask,
661
+ need_weights=False,
662
+ )[0]
663
+ return self.dropout2(x)
664
+
665
+ # feed forward block
666
+ def _ff_block(self, x: Tensor) -> Tensor:
667
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
668
+ return self.dropout3(x)
669
+
670
+
671
+ def _get_clones(module, N):
672
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
673
+
674
+
675
+ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
676
+ if activation == "relu":
677
+ return F.relu
678
+ elif activation == "gelu":
679
+ return F.gelu
680
+
681
+ raise RuntimeError(
682
+ "activation should be relu/gelu, not {}".format(activation)
683
+ )
presets/acou_1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:470ce66fc24a2d14e162343381f7d93ef0a3af51edf5fd37240c21f492b4e769
3
+ size 15650
presets/acou_2.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec1c5328751cadeed5356d4264759799ad96d33ea8dd4f8a3d0a80dd8ddb0e74
3
+ size 15426
presets/acou_3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03f241b094a32b3f542e74374183c6d15e8b70ae73ceeafb11bfd4ee6b8b4a3a
3
+ size 15410
presets/acou_4.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52b96f32863f13f84cf7ac4a27d2bc95cea70c350a037f4d1890b20b8da9501e
3
+ size 15506
presets/alan.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28838c3f0b2f9f315b34e9b940f30641306f0cadc5c527857cd1cc408547ed1c
3
+ size 50002
presets/amused.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df3e882f3a62805b9aaf300d81822cd4eddeafee480503b7b78e32be2085fb11
3
+ size 20882
presets/anger.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:959cec6dc0b30219db0d70cdd165fe00bbdc098165cf9d67ccdd1ecf7a5da5be
3
+ size 22090
presets/babara.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8106b2a98c3f70587f23ab46ed5bf73b1c9a770481c3620ab140bd3256010376
3
+ size 11526
presets/bronya_1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02eaada2c3d58866c813887ed9f871587ef5a7e976abc23382ce46a17b208001
3
+ size 18106
presets/cafe.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d78d96f5829da8f69c327ff25958da5b451305fdc9c308f7e67f13cf8d640fea
3
+ size 22442
presets/dingzhen.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d19167c65eefef5e42dfaa1919ff5149ca0a93cb052396a47d1f42f9865f5f8
3
+ size 18154
presets/dingzhen_1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d19167c65eefef5e42dfaa1919ff5149ca0a93cb052396a47d1f42f9865f5f8
3
+ size 18154
presets/disgust.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4443f0a395072700f2ec6101dbf2ad9d28968aa3e5809e384ea131832f894d7f
3
+ size 39386
presets/emo_amused.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38be2ea16dc79beae68b6c885d99d4dad516acbd88ed5ed6991dd97301f2f30b
3
+ size 15378
presets/emo_anger.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3261c3bdd5b7b4be9783d9293ee3d871be9d9d791f2b3a8bf62a1a0ee0ed93e6
3
+ size 15434
presets/emo_neutral.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2188c4154692316ed7c0edee3aa3dd8678be36f355ee2b8c8a3a6412c3673ba9
3
+ size 15578
presets/emo_sleepy.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a53255890beaf4ed339e1967f0837fdb87c34c9f7e18bf77cd4b08eba176963
3
+ size 15370
presets/emotion_sleepiness.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0f866a278a10c7b6b494fb62589a9d8fef778ccf272df3b0d5510f45b243b5c
3
+ size 33218
presets/en2zh_tts_1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d4de4ed055448ea54f7b40091afae565197f960d954279035ac537ea5a01bc4
3
+ size 44354
presets/en2zh_tts_2.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dcc066ea104daa27d1552fe76574d09359d56fa892241581cc19e931a696eca9
3
+ size 24178
presets/en2zh_tts_3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7468944e6d0ed7f2da033e8037be07dbafc76bd1ed7c0f5996d85ff45aacda11
3
+ size 21410
presets/en2zh_tts_4.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fd8d0914e74769114310e9504d68d6b7b0c6aacd46763478cbfd4f9631ad54a
3
+ size 43826
presets/esta.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f944e135d901a00e74e7affe6757334e9a2679c10ad7ae4bcb5b33569d77eba
3
+ size 40250
presets/fuxuan_2.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17b90388d179ae309e1f577c28c3f10d9bed73c6ccbffdd829c00568eb3941e6
3
+ size 50330
presets/librispeech_1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:415b244e43b45291fd651d71f15bb7a31c244e2054988c436f6bbc04465c6099
3
+ size 15650
presets/librispeech_2.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd74e77370248b025321b9dbae25b1572f13f98da63255e384d382d2b0c78227
3
+ size 15418
presets/librispeech_3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1eceb3f4cc0f3a8856b5e3b5f1ca28c428d75305b1452da1ecf4013bc358ccaa
3
+ size 15634
presets/librispeech_4.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3939dde39f5e65bc01f5eba9acb7b8329465aaca3c38edf1b240aa714e687960
3
+ size 15594
presets/neutral.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8a63993526ffdc788a711b512d07a8b1c816151a1edb63913d0bfb48c2ea380
3
+ size 21050
presets/paimon_1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:452d5e0cd3a060db521bd65a16af818a6177f357801402aa5581eceb2c24039a
3
+ size 13762