Spaces:
Running
on
T4
Running
on
T4
tonic
commited on
Commit
β’
33d9042
1
Parent(s):
89d01e6
Laion WhisperSpeech Demo
Browse files- README.md +3 -3
- app.py +61 -0
- requirements.txt +3 -1
- whisperspeech/__init__.py +1 -0
- whisperspeech/_modidx.py +615 -0
- whisperspeech/a2wav.py +45 -0
- whisperspeech/extract_acoustic.py +56 -0
- whisperspeech/fetch_models.py +17 -0
- whisperspeech/languages.py +131 -0
- whisperspeech/modules.py +331 -0
- whisperspeech/pipeline.py +93 -0
- whisperspeech/prepare_s2a_dataset.py +112 -0
- whisperspeech/prepare_t2s_dataset.py +111 -0
- whisperspeech/s2a_delar_mup_wds.py +688 -0
- whisperspeech/s2a_delar_mup_wds_mlang.py +564 -0
- whisperspeech/t2s_up_wds.py +442 -0
- whisperspeech/t2s_up_wds_mlang_enclm.py +519 -0
- whisperspeech/train.py +271 -0
- whisperspeech/train_multi.py +263 -0
- whisperspeech/utils.py +159 -0
- whisperspeech/vad.py +71 -0
- whisperspeech/vq_stoks.py +493 -0
- whisperspeech/wer_metrics.py +77 -0
- whisperspeech/wh_transcribe.py +146 -0
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: pink
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.15.0
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: mit
|
11 |
---
|
12 |
|
|
|
1 |
---
|
2 |
+
title: WhisperSpeech
|
3 |
+
emoji: π¬οΈπ¬π
|
4 |
colorFrom: pink
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.15.0
|
8 |
app_file: app.py
|
9 |
+
pinned: True
|
10 |
license: mit
|
11 |
---
|
12 |
|
app.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
from whisperspeech.pipeline import Pipeline
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from whisperspeech.languages import LANGUAGES
|
8 |
+
from whisperspeech.pipeline import Pipeline
|
9 |
+
import tempfil
|
10 |
+
|
11 |
+
title = """#ππ»ββοΈ Welcome toπTonic'sπ¬οΈπ¬πWhisperSpeech
|
12 |
+
You can use this ZeroGPU Space to test out the current model [π¬οΈπ¬πcollabora/whisperspeech](https://huggingface.co/collabora/whisperspeech). π¬οΈπ¬πcollabora/whisperspeech is An Open Source text-to-speech system built by inverting Whisper. Previously known as spear-tts-pytorch. It's like Stable Diffusion but for speech β both powerful and easily customizable.
|
13 |
+
You can also use π¬οΈπ¬πWhisperSpeech by cloning this space. π§¬π¬π Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic/laion-whisper?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></h3>
|
14 |
+
Join us : πTeamTonicπ is always making cool demos! Join our active builder'sπ οΈcommunity π» [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/GWpVpekp) On π€Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On πGithub: [Polytonic](https://github.com/tonic-ai) & contribute to π [Poly](https://github.com/tonic-ai/poly) π€Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant π€
|
15 |
+
"""
|
16 |
+
|
17 |
+
@spaces.GPU
|
18 |
+
|
19 |
+
def whisper_speech_demo(text, lang, speaker_audio=None, mix_lang=None, mix_text=None):
|
20 |
+
pipe = Pipeline(s2a_ref='collabora/whisperspeech:s2a-q4-tiny-en+pl.model')
|
21 |
+
|
22 |
+
# Use uploaded speaker audio if provided
|
23 |
+
speaker_url = None
|
24 |
+
if speaker_audio is not None:
|
25 |
+
speaker_url = speaker_audio.name
|
26 |
+
|
27 |
+
if mix_lang and mix_text:
|
28 |
+
mixed_langs = lang.split(',') + mix_lang.split(',')
|
29 |
+
mixed_texts = [text] + mix_text.split(',')
|
30 |
+
stoks = pipe.t2s.generate(mixed_texts, lang=mixed_langs)
|
31 |
+
audio_data = pipe.generate(stoks, speaker_url, lang=mixed_langs[0])
|
32 |
+
else:
|
33 |
+
audio_data = pipe.generate(text, speaker_url, lang)
|
34 |
+
|
35 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
36 |
+
tmp_file_name = tmp_file.name
|
37 |
+
with open(tmp_file_name, 'wb') as file:
|
38 |
+
file.write(audio_data)
|
39 |
+
|
40 |
+
return tmp_file_name
|
41 |
+
|
42 |
+
with gr.Blocks() as demo:
|
43 |
+
gr.Markdown(title)
|
44 |
+
with gr.Row():
|
45 |
+
text_input = gr.Textbox(label="Enter text")
|
46 |
+
lang_input = gr.Dropdown(choices=list(LANGUAGES.keys()), label="Language")
|
47 |
+
speaker_input = gr.File(label="Upload Speaker Audio (optional)", type="file", accepts=["audio/*"])
|
48 |
+
with gr.Row():
|
49 |
+
mix_lang_input = gr.Textbox(label="Mixed Languages (optional, comma-separated)", placeholder="e.g., en,pl")
|
50 |
+
mix_text_input = gr.Textbox(label="Mixed Texts (optional, for mixed languages)", placeholder="e.g., Hello, CzeΕΔ")
|
51 |
+
with gr.Row():
|
52 |
+
submit_button = gr.Button("Generate Speech")
|
53 |
+
output_audio = gr.Audio(label="Generated Speech")
|
54 |
+
|
55 |
+
submit_button.click(
|
56 |
+
whisper_speech_demo,
|
57 |
+
inputs=[text_input, lang_input, speaker_input, mix_lang_input, mix_text_input],
|
58 |
+
outputs=output_audio
|
59 |
+
)
|
60 |
+
|
61 |
+
demo.launch()
|
requirements.txt
CHANGED
@@ -1 +1,3 @@
|
|
1 |
-
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
accelerate
|
whisperspeech/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = "0.5.6"
|
whisperspeech/_modidx.py
ADDED
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Autogenerated by nbdev
|
2 |
+
|
3 |
+
d = { 'settings': { 'branch': 'master',
|
4 |
+
'doc_baseurl': '/WhisperSpeech',
|
5 |
+
'doc_host': 'https://collabora.github.io',
|
6 |
+
'git_url': 'https://github.com/collabora/WhisperSpeech',
|
7 |
+
'lib_path': 'whisperspeech'},
|
8 |
+
'syms': { 'whisperspeech.a2wav': { 'whisperspeech.a2wav.Vocoder': ('6. quality-boosting vocoder.html#vocoder', 'whisperspeech/a2wav.py'),
|
9 |
+
'whisperspeech.a2wav.Vocoder.__init__': ( '6. quality-boosting vocoder.html#vocoder.__init__',
|
10 |
+
'whisperspeech/a2wav.py'),
|
11 |
+
'whisperspeech.a2wav.Vocoder.decode': ( '6. quality-boosting vocoder.html#vocoder.decode',
|
12 |
+
'whisperspeech/a2wav.py'),
|
13 |
+
'whisperspeech.a2wav.Vocoder.decode_to_file': ( '6. quality-boosting '
|
14 |
+
'vocoder.html#vocoder.decode_to_file',
|
15 |
+
'whisperspeech/a2wav.py'),
|
16 |
+
'whisperspeech.a2wav.Vocoder.decode_to_notebook': ( '6. quality-boosting '
|
17 |
+
'vocoder.html#vocoder.decode_to_notebook',
|
18 |
+
'whisperspeech/a2wav.py')},
|
19 |
+
'whisperspeech.extract_acoustic': { 'whisperspeech.extract_acoustic.extract_Atoks': ( '1. acoustic token '
|
20 |
+
'extraction.html#extract_atoks',
|
21 |
+
'whisperspeech/extract_acoustic.py'),
|
22 |
+
'whisperspeech.extract_acoustic.extract_acoustic': ( '1. acoustic token '
|
23 |
+
'extraction.html#extract_acoustic',
|
24 |
+
'whisperspeech/extract_acoustic.py'),
|
25 |
+
'whisperspeech.extract_acoustic.load': ( '1. acoustic token extraction.html#load',
|
26 |
+
'whisperspeech/extract_acoustic.py'),
|
27 |
+
'whisperspeech.extract_acoustic.load_model': ( '1. acoustic token '
|
28 |
+
'extraction.html#load_model',
|
29 |
+
'whisperspeech/extract_acoustic.py')},
|
30 |
+
'whisperspeech.extract_semb': { 'whisperspeech.extract_semb.encode_semantic': ( '2c. whisper semantic embedding '
|
31 |
+
'extraction.html#encode_semantic',
|
32 |
+
'whisperspeech/extract_semb.py'),
|
33 |
+
'whisperspeech.extract_semb.extract_semantic': ( '2c. whisper semantic embedding '
|
34 |
+
'extraction.html#extract_semantic',
|
35 |
+
'whisperspeech/extract_semb.py'),
|
36 |
+
'whisperspeech.extract_semb.load_model': ( '2c. whisper semantic embedding '
|
37 |
+
'extraction.html#load_model',
|
38 |
+
'whisperspeech/extract_semb.py')},
|
39 |
+
'whisperspeech.fetch_models': { 'whisperspeech.fetch_models.main': ( '0. download models.html#main',
|
40 |
+
'whisperspeech/fetch_models.py')},
|
41 |
+
'whisperspeech.modules': { 'whisperspeech.modules.Decoder': ('a. neural modules.html#decoder', 'whisperspeech/modules.py'),
|
42 |
+
'whisperspeech.modules.Decoder.__init__': ( 'a. neural modules.html#decoder.__init__',
|
43 |
+
'whisperspeech/modules.py'),
|
44 |
+
'whisperspeech.modules.Decoder.forward': ( 'a. neural modules.html#decoder.forward',
|
45 |
+
'whisperspeech/modules.py'),
|
46 |
+
'whisperspeech.modules.Encoder': ('a. neural modules.html#encoder', 'whisperspeech/modules.py'),
|
47 |
+
'whisperspeech.modules.Encoder.__init__': ( 'a. neural modules.html#encoder.__init__',
|
48 |
+
'whisperspeech/modules.py'),
|
49 |
+
'whisperspeech.modules.Encoder.forward': ( 'a. neural modules.html#encoder.forward',
|
50 |
+
'whisperspeech/modules.py'),
|
51 |
+
'whisperspeech.modules.LayerNorm': ('a. neural modules.html#layernorm', 'whisperspeech/modules.py'),
|
52 |
+
'whisperspeech.modules.LayerNorm.forward': ( 'a. neural modules.html#layernorm.forward',
|
53 |
+
'whisperspeech/modules.py'),
|
54 |
+
'whisperspeech.modules.LinearHead': ( 'a. neural modules.html#linearhead',
|
55 |
+
'whisperspeech/modules.py'),
|
56 |
+
'whisperspeech.modules.MultiHeadAttention': ( 'a. neural modules.html#multiheadattention',
|
57 |
+
'whisperspeech/modules.py'),
|
58 |
+
'whisperspeech.modules.MultiHeadAttention.__init__': ( 'a. neural '
|
59 |
+
'modules.html#multiheadattention.__init__',
|
60 |
+
'whisperspeech/modules.py'),
|
61 |
+
'whisperspeech.modules.MultiHeadAttention.forward': ( 'a. neural '
|
62 |
+
'modules.html#multiheadattention.forward',
|
63 |
+
'whisperspeech/modules.py'),
|
64 |
+
'whisperspeech.modules.MultiHeadAttention.qkv_attention_pth20': ( 'a. neural '
|
65 |
+
'modules.html#multiheadattention.qkv_attention_pth20',
|
66 |
+
'whisperspeech/modules.py'),
|
67 |
+
'whisperspeech.modules.MultiHeadAttention.qkv_attention_vanilla': ( 'a. neural '
|
68 |
+
'modules.html#multiheadattention.qkv_attention_vanilla',
|
69 |
+
'whisperspeech/modules.py'),
|
70 |
+
'whisperspeech.modules.MultiHeadAttention.qkv_attention_xformers': ( 'a. neural '
|
71 |
+
'modules.html#multiheadattention.qkv_attention_xformers',
|
72 |
+
'whisperspeech/modules.py'),
|
73 |
+
'whisperspeech.modules.QueryHead': ('a. neural modules.html#queryhead', 'whisperspeech/modules.py'),
|
74 |
+
'whisperspeech.modules.ResidualAttentionBlock': ( 'a. neural modules.html#residualattentionblock',
|
75 |
+
'whisperspeech/modules.py'),
|
76 |
+
'whisperspeech.modules.ResidualAttentionBlock.__init__': ( 'a. neural '
|
77 |
+
'modules.html#residualattentionblock.__init__',
|
78 |
+
'whisperspeech/modules.py'),
|
79 |
+
'whisperspeech.modules.ResidualAttentionBlock.forward': ( 'a. neural '
|
80 |
+
'modules.html#residualattentionblock.forward',
|
81 |
+
'whisperspeech/modules.py'),
|
82 |
+
'whisperspeech.modules.Rotary': ('a. neural modules.html#rotary', 'whisperspeech/modules.py'),
|
83 |
+
'whisperspeech.modules.Rotary.__init__': ( 'a. neural modules.html#rotary.__init__',
|
84 |
+
'whisperspeech/modules.py'),
|
85 |
+
'whisperspeech.modules.Rotary.forward': ( 'a. neural modules.html#rotary.forward',
|
86 |
+
'whisperspeech/modules.py'),
|
87 |
+
'whisperspeech.modules.SumDecoder': ( 'a. neural modules.html#sumdecoder',
|
88 |
+
'whisperspeech/modules.py'),
|
89 |
+
'whisperspeech.modules.SumDecoder.__init__': ( 'a. neural modules.html#sumdecoder.__init__',
|
90 |
+
'whisperspeech/modules.py'),
|
91 |
+
'whisperspeech.modules.SumDecoder.forward': ( 'a. neural modules.html#sumdecoder.forward',
|
92 |
+
'whisperspeech/modules.py'),
|
93 |
+
'whisperspeech.modules.apply_rotary_pos_emb': ( 'a. neural modules.html#apply_rotary_pos_emb',
|
94 |
+
'whisperspeech/modules.py'),
|
95 |
+
'whisperspeech.modules.init_transformer': ( 'a. neural modules.html#init_transformer',
|
96 |
+
'whisperspeech/modules.py'),
|
97 |
+
'whisperspeech.modules.rotate_half': ( 'a. neural modules.html#rotate_half',
|
98 |
+
'whisperspeech/modules.py'),
|
99 |
+
'whisperspeech.modules.sinusoids': ('a. neural modules.html#sinusoids', 'whisperspeech/modules.py')},
|
100 |
+
'whisperspeech.pipeline': { 'whisperspeech.pipeline.Pipeline': ('7. pipeline.html#pipeline', 'whisperspeech/pipeline.py'),
|
101 |
+
'whisperspeech.pipeline.Pipeline.__init__': ( '7. pipeline.html#pipeline.__init__',
|
102 |
+
'whisperspeech/pipeline.py'),
|
103 |
+
'whisperspeech.pipeline.Pipeline.generate': ( '7. pipeline.html#pipeline.generate',
|
104 |
+
'whisperspeech/pipeline.py'),
|
105 |
+
'whisperspeech.pipeline.Pipeline.generate_atoks': ( '7. pipeline.html#pipeline.generate_atoks',
|
106 |
+
'whisperspeech/pipeline.py'),
|
107 |
+
'whisperspeech.pipeline.Pipeline.generate_to_file': ( '7. pipeline.html#pipeline.generate_to_file',
|
108 |
+
'whisperspeech/pipeline.py'),
|
109 |
+
'whisperspeech.pipeline.Pipeline.generate_to_notebook': ( '7. '
|
110 |
+
'pipeline.html#pipeline.generate_to_notebook',
|
111 |
+
'whisperspeech/pipeline.py')},
|
112 |
+
'whisperspeech.prepare_s2a_dataset': { 'whisperspeech.prepare_s2a_dataset.flac_to_s2a_name': ( '4a. s2a dataset '
|
113 |
+
'preparation.html#flac_to_s2a_name',
|
114 |
+
'whisperspeech/prepare_s2a_dataset.py'),
|
115 |
+
'whisperspeech.prepare_s2a_dataset.prepare_s2a': ( '4a. s2a dataset '
|
116 |
+
'preparation.html#prepare_s2a',
|
117 |
+
'whisperspeech/prepare_s2a_dataset.py'),
|
118 |
+
'whisperspeech.prepare_s2a_dataset.resampler': ( '4a. s2a dataset '
|
119 |
+
'preparation.html#resampler',
|
120 |
+
'whisperspeech/prepare_s2a_dataset.py')},
|
121 |
+
'whisperspeech.prepare_t2s_dataset': { 'whisperspeech.prepare_t2s_dataset.Transcriber': ( '5a. t2s dataset '
|
122 |
+
'preparation.html#transcriber',
|
123 |
+
'whisperspeech/prepare_t2s_dataset.py'),
|
124 |
+
'whisperspeech.prepare_t2s_dataset.Transcriber.__init__': ( '5a. t2s dataset '
|
125 |
+
'preparation.html#transcriber.__init__',
|
126 |
+
'whisperspeech/prepare_t2s_dataset.py'),
|
127 |
+
'whisperspeech.prepare_t2s_dataset.Transcriber.transcribe': ( '5a. t2s dataset '
|
128 |
+
'preparation.html#transcriber.transcribe',
|
129 |
+
'whisperspeech/prepare_t2s_dataset.py'),
|
130 |
+
'whisperspeech.prepare_t2s_dataset.flac_to_t2s_name': ( '5a. t2s dataset '
|
131 |
+
'preparation.html#flac_to_t2s_name',
|
132 |
+
'whisperspeech/prepare_t2s_dataset.py'),
|
133 |
+
'whisperspeech.prepare_t2s_dataset.prepare_t2s': ( '5a. t2s dataset '
|
134 |
+
'preparation.html#prepare_t2s',
|
135 |
+
'whisperspeech/prepare_t2s_dataset.py')},
|
136 |
+
'whisperspeech.s2a_delar_mup_wds': { 'whisperspeech.s2a_delar_mup_wds.CMLMVisual': ( '4b. semantic to acoustic token '
|
137 |
+
'modeling.html#cmlmvisual',
|
138 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
139 |
+
'whisperspeech.s2a_delar_mup_wds.CMLMVisual.__init__': ( '4b. semantic to acoustic token '
|
140 |
+
'modeling.html#cmlmvisual.__init__',
|
141 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
142 |
+
'whisperspeech.s2a_delar_mup_wds.CMLMVisual.add_data': ( '4b. semantic to acoustic token '
|
143 |
+
'modeling.html#cmlmvisual.add_data',
|
144 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
145 |
+
'whisperspeech.s2a_delar_mup_wds.CMLMVisual.add_table_row': ( '4b. semantic to acoustic '
|
146 |
+
'token '
|
147 |
+
'modeling.html#cmlmvisual.add_table_row',
|
148 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
149 |
+
'whisperspeech.s2a_delar_mup_wds.CMLMVisual.hide': ( '4b. semantic to acoustic token '
|
150 |
+
'modeling.html#cmlmvisual.hide',
|
151 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
152 |
+
'whisperspeech.s2a_delar_mup_wds.CMLMVisual.on_iter': ( '4b. semantic to acoustic token '
|
153 |
+
'modeling.html#cmlmvisual.on_iter',
|
154 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
155 |
+
'whisperspeech.s2a_delar_mup_wds.CMLMVisual.plot': ( '4b. semantic to acoustic token '
|
156 |
+
'modeling.html#cmlmvisual.plot',
|
157 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
158 |
+
'whisperspeech.s2a_delar_mup_wds.CMLMVisual.show': ( '4b. semantic to acoustic token '
|
159 |
+
'modeling.html#cmlmvisual.show',
|
160 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
161 |
+
'whisperspeech.s2a_delar_mup_wds.DelSumDecoder': ( '4b. semantic to acoustic token '
|
162 |
+
'modeling.html#delsumdecoder',
|
163 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
164 |
+
'whisperspeech.s2a_delar_mup_wds.DelSumDecoder.__init__': ( '4b. semantic to acoustic '
|
165 |
+
'token '
|
166 |
+
'modeling.html#delsumdecoder.__init__',
|
167 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
168 |
+
'whisperspeech.s2a_delar_mup_wds.DelSumDecoder.forward': ( '4b. semantic to acoustic '
|
169 |
+
'token '
|
170 |
+
'modeling.html#delsumdecoder.forward',
|
171 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
172 |
+
'whisperspeech.s2a_delar_mup_wds.EmbeddingProjector': ( '4b. semantic to acoustic token '
|
173 |
+
'modeling.html#embeddingprojector',
|
174 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
175 |
+
'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention': ( '4b. semantic to acoustic token '
|
176 |
+
'modeling.html#multiheadattention',
|
177 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
178 |
+
'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention.__init__': ( '4b. semantic to '
|
179 |
+
'acoustic token '
|
180 |
+
'modeling.html#multiheadattention.__init__',
|
181 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
182 |
+
'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention.forward': ( '4b. semantic to acoustic '
|
183 |
+
'token '
|
184 |
+
'modeling.html#multiheadattention.forward',
|
185 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
186 |
+
'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention.qkv_attention_pth20': ( '4b. semantic '
|
187 |
+
'to acoustic '
|
188 |
+
'token '
|
189 |
+
'modeling.html#multiheadattention.qkv_attention_pth20',
|
190 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
191 |
+
'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention.qkv_attention_xformers': ( '4b. '
|
192 |
+
'semantic '
|
193 |
+
'to '
|
194 |
+
'acoustic '
|
195 |
+
'token '
|
196 |
+
'modeling.html#multiheadattention.qkv_attention_xformers',
|
197 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
198 |
+
'whisperspeech.s2a_delar_mup_wds.ResidualAttentionBlock': ( '4b. semantic to acoustic '
|
199 |
+
'token '
|
200 |
+
'modeling.html#residualattentionblock',
|
201 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
202 |
+
'whisperspeech.s2a_delar_mup_wds.ResidualAttentionBlock.__init__': ( '4b. semantic to '
|
203 |
+
'acoustic token '
|
204 |
+
'modeling.html#residualattentionblock.__init__',
|
205 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
206 |
+
'whisperspeech.s2a_delar_mup_wds.ResidualAttentionBlock.forward': ( '4b. semantic to '
|
207 |
+
'acoustic token '
|
208 |
+
'modeling.html#residualattentionblock.forward',
|
209 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
210 |
+
'whisperspeech.s2a_delar_mup_wds.Rotary': ( '4b. semantic to acoustic token '
|
211 |
+
'modeling.html#rotary',
|
212 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
213 |
+
'whisperspeech.s2a_delar_mup_wds.Rotary.__init__': ( '4b. semantic to acoustic token '
|
214 |
+
'modeling.html#rotary.__init__',
|
215 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
216 |
+
'whisperspeech.s2a_delar_mup_wds.Rotary.forward': ( '4b. semantic to acoustic token '
|
217 |
+
'modeling.html#rotary.forward',
|
218 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
219 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer': ( '4b. semantic to acoustic token '
|
220 |
+
'modeling.html#sadelartransformer',
|
221 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
222 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.__init__': ( '4b. semantic to '
|
223 |
+
'acoustic token '
|
224 |
+
'modeling.html#sadelartransformer.__init__',
|
225 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
226 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.device': ( '4b. semantic to acoustic '
|
227 |
+
'token '
|
228 |
+
'modeling.html#sadelartransformer.device',
|
229 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
230 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.embed_stoks': ( '4b. semantic to '
|
231 |
+
'acoustic token '
|
232 |
+
'modeling.html#sadelartransformer.embed_stoks',
|
233 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
234 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.forward': ( '4b. semantic to acoustic '
|
235 |
+
'token '
|
236 |
+
'modeling.html#sadelartransformer.forward',
|
237 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
238 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.generate': ( '4b. semantic to '
|
239 |
+
'acoustic token '
|
240 |
+
'modeling.html#sadelartransformer.generate',
|
241 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
242 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.get_extra_state': ( '4b. semantic to '
|
243 |
+
'acoustic token '
|
244 |
+
'modeling.html#sadelartransformer.get_extra_state',
|
245 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
246 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.get_metrics': ( '4b. semantic to '
|
247 |
+
'acoustic token '
|
248 |
+
'modeling.html#sadelartransformer.get_metrics',
|
249 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
250 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.init_transformer': ( '4b. semantic to '
|
251 |
+
'acoustic token '
|
252 |
+
'modeling.html#sadelartransformer.init_transformer',
|
253 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
254 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.load_checkpoint': ( '4b. semantic to '
|
255 |
+
'acoustic token '
|
256 |
+
'modeling.html#sadelartransformer.load_checkpoint',
|
257 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
258 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.load_frozen_semantic_embeddings': ( '4b. '
|
259 |
+
'semantic '
|
260 |
+
'to '
|
261 |
+
'acoustic '
|
262 |
+
'token '
|
263 |
+
'modeling.html#sadelartransformer.load_frozen_semantic_embeddings',
|
264 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
265 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.load_model': ( '4b. semantic to '
|
266 |
+
'acoustic token '
|
267 |
+
'modeling.html#sadelartransformer.load_model',
|
268 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
269 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.save_model': ( '4b. semantic to '
|
270 |
+
'acoustic token '
|
271 |
+
'modeling.html#sadelartransformer.save_model',
|
272 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
273 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.set_extra_state': ( '4b. semantic to '
|
274 |
+
'acoustic token '
|
275 |
+
'modeling.html#sadelartransformer.set_extra_state',
|
276 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
277 |
+
'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.setup': ( '4b. semantic to acoustic '
|
278 |
+
'token '
|
279 |
+
'modeling.html#sadelartransformer.setup',
|
280 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
281 |
+
'whisperspeech.s2a_delar_mup_wds.Tunables': ( '4b. semantic to acoustic token '
|
282 |
+
'modeling.html#tunables',
|
283 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
284 |
+
'whisperspeech.s2a_delar_mup_wds.Tunables.__post_init__': ( '4b. semantic to acoustic '
|
285 |
+
'token '
|
286 |
+
'modeling.html#tunables.__post_init__',
|
287 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
288 |
+
'whisperspeech.s2a_delar_mup_wds.Tunables.upgrade': ( '4b. semantic to acoustic token '
|
289 |
+
'modeling.html#tunables.upgrade',
|
290 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
291 |
+
'whisperspeech.s2a_delar_mup_wds._make_model': ( '4b. semantic to acoustic token '
|
292 |
+
'modeling.html#_make_model',
|
293 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
294 |
+
'whisperspeech.s2a_delar_mup_wds.apply_rotary_pos_emb': ( '4b. semantic to acoustic token '
|
295 |
+
'modeling.html#apply_rotary_pos_emb',
|
296 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
297 |
+
'whisperspeech.s2a_delar_mup_wds.load_datasets': ( '4b. semantic to acoustic token '
|
298 |
+
'modeling.html#load_datasets',
|
299 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
300 |
+
'whisperspeech.s2a_delar_mup_wds.make_model': ( '4b. semantic to acoustic token '
|
301 |
+
'modeling.html#make_model',
|
302 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
303 |
+
'whisperspeech.s2a_delar_mup_wds.pad_samples': ( '4b. semantic to acoustic token '
|
304 |
+
'modeling.html#pad_samples',
|
305 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
306 |
+
'whisperspeech.s2a_delar_mup_wds.rand': ( '4b. semantic to acoustic token '
|
307 |
+
'modeling.html#rand',
|
308 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
309 |
+
'whisperspeech.s2a_delar_mup_wds.random_trunc': ( '4b. semantic to acoustic token '
|
310 |
+
'modeling.html#random_trunc',
|
311 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
312 |
+
'whisperspeech.s2a_delar_mup_wds.rotate_half': ( '4b. semantic to acoustic token '
|
313 |
+
'modeling.html#rotate_half',
|
314 |
+
'whisperspeech/s2a_delar_mup_wds.py'),
|
315 |
+
'whisperspeech.s2a_delar_mup_wds.speaker_id_extractor': ( '4b. semantic to acoustic token '
|
316 |
+
'modeling.html#speaker_id_extractor',
|
317 |
+
'whisperspeech/s2a_delar_mup_wds.py')},
|
318 |
+
'whisperspeech.t2s_up_wds': { 'whisperspeech.t2s_up_wds.CharTokenizer': ( '5b. text to semantic token '
|
319 |
+
'modeling.html#chartokenizer',
|
320 |
+
'whisperspeech/t2s_up_wds.py'),
|
321 |
+
'whisperspeech.t2s_up_wds.CharTokenizer.decode': ( '5b. text to semantic token '
|
322 |
+
'modeling.html#chartokenizer.decode',
|
323 |
+
'whisperspeech/t2s_up_wds.py'),
|
324 |
+
'whisperspeech.t2s_up_wds.CharTokenizer.encode': ( '5b. text to semantic token '
|
325 |
+
'modeling.html#chartokenizer.encode',
|
326 |
+
'whisperspeech/t2s_up_wds.py'),
|
327 |
+
'whisperspeech.t2s_up_wds.Decoder': ( '5b. text to semantic token modeling.html#decoder',
|
328 |
+
'whisperspeech/t2s_up_wds.py'),
|
329 |
+
'whisperspeech.t2s_up_wds.Decoder.__init__': ( '5b. text to semantic token '
|
330 |
+
'modeling.html#decoder.__init__',
|
331 |
+
'whisperspeech/t2s_up_wds.py'),
|
332 |
+
'whisperspeech.t2s_up_wds.Decoder.forward': ( '5b. text to semantic token '
|
333 |
+
'modeling.html#decoder.forward',
|
334 |
+
'whisperspeech/t2s_up_wds.py'),
|
335 |
+
'whisperspeech.t2s_up_wds.EmbeddingProjector': ( '5b. text to semantic token '
|
336 |
+
'modeling.html#embeddingprojector',
|
337 |
+
'whisperspeech/t2s_up_wds.py'),
|
338 |
+
'whisperspeech.t2s_up_wds.Encoder': ( '5b. text to semantic token modeling.html#encoder',
|
339 |
+
'whisperspeech/t2s_up_wds.py'),
|
340 |
+
'whisperspeech.t2s_up_wds.Encoder.__init__': ( '5b. text to semantic token '
|
341 |
+
'modeling.html#encoder.__init__',
|
342 |
+
'whisperspeech/t2s_up_wds.py'),
|
343 |
+
'whisperspeech.t2s_up_wds.Encoder.forward': ( '5b. text to semantic token '
|
344 |
+
'modeling.html#encoder.forward',
|
345 |
+
'whisperspeech/t2s_up_wds.py'),
|
346 |
+
'whisperspeech.t2s_up_wds.TSARTransformer': ( '5b. text to semantic token '
|
347 |
+
'modeling.html#tsartransformer',
|
348 |
+
'whisperspeech/t2s_up_wds.py'),
|
349 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.__init__': ( '5b. text to semantic token '
|
350 |
+
'modeling.html#tsartransformer.__init__',
|
351 |
+
'whisperspeech/t2s_up_wds.py'),
|
352 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.device': ( '5b. text to semantic token '
|
353 |
+
'modeling.html#tsartransformer.device',
|
354 |
+
'whisperspeech/t2s_up_wds.py'),
|
355 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.ensure_tokenizer': ( '5b. text to semantic token '
|
356 |
+
'modeling.html#tsartransformer.ensure_tokenizer',
|
357 |
+
'whisperspeech/t2s_up_wds.py'),
|
358 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.forward': ( '5b. text to semantic token '
|
359 |
+
'modeling.html#tsartransformer.forward',
|
360 |
+
'whisperspeech/t2s_up_wds.py'),
|
361 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.generate': ( '5b. text to semantic token '
|
362 |
+
'modeling.html#tsartransformer.generate',
|
363 |
+
'whisperspeech/t2s_up_wds.py'),
|
364 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.generate_batch': ( '5b. text to semantic token '
|
365 |
+
'modeling.html#tsartransformer.generate_batch',
|
366 |
+
'whisperspeech/t2s_up_wds.py'),
|
367 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.init_transformer': ( '5b. text to semantic token '
|
368 |
+
'modeling.html#tsartransformer.init_transformer',
|
369 |
+
'whisperspeech/t2s_up_wds.py'),
|
370 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.load_checkpoint': ( '5b. text to semantic token '
|
371 |
+
'modeling.html#tsartransformer.load_checkpoint',
|
372 |
+
'whisperspeech/t2s_up_wds.py'),
|
373 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.load_frozen_semantic_embeddings': ( '5b. text to '
|
374 |
+
'semantic token '
|
375 |
+
'modeling.html#tsartransformer.load_frozen_semantic_embeddings',
|
376 |
+
'whisperspeech/t2s_up_wds.py'),
|
377 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.load_model': ( '5b. text to semantic token '
|
378 |
+
'modeling.html#tsartransformer.load_model',
|
379 |
+
'whisperspeech/t2s_up_wds.py'),
|
380 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.save_model': ( '5b. text to semantic token '
|
381 |
+
'modeling.html#tsartransformer.save_model',
|
382 |
+
'whisperspeech/t2s_up_wds.py'),
|
383 |
+
'whisperspeech.t2s_up_wds.TSARTransformer.setup': ( '5b. text to semantic token '
|
384 |
+
'modeling.html#tsartransformer.setup',
|
385 |
+
'whisperspeech/t2s_up_wds.py'),
|
386 |
+
'whisperspeech.t2s_up_wds.Tunables': ( '5b. text to semantic token modeling.html#tunables',
|
387 |
+
'whisperspeech/t2s_up_wds.py'),
|
388 |
+
'whisperspeech.t2s_up_wds.Tunables.__post_init__': ( '5b. text to semantic token '
|
389 |
+
'modeling.html#tunables.__post_init__',
|
390 |
+
'whisperspeech/t2s_up_wds.py'),
|
391 |
+
'whisperspeech.t2s_up_wds._make_model': ( '5b. text to semantic token modeling.html#_make_model',
|
392 |
+
'whisperspeech/t2s_up_wds.py'),
|
393 |
+
'whisperspeech.t2s_up_wds.ar_padder': ( '5b. text to semantic token modeling.html#ar_padder',
|
394 |
+
'whisperspeech/t2s_up_wds.py'),
|
395 |
+
'whisperspeech.t2s_up_wds.build_speaker_map': ( '5b. text to semantic token '
|
396 |
+
'modeling.html#build_speaker_map',
|
397 |
+
'whisperspeech/t2s_up_wds.py'),
|
398 |
+
'whisperspeech.t2s_up_wds.char_per_seconder': ( '5b. text to semantic token '
|
399 |
+
'modeling.html#char_per_seconder',
|
400 |
+
'whisperspeech/t2s_up_wds.py'),
|
401 |
+
'whisperspeech.t2s_up_wds.load_datasets': ( '5b. text to semantic token '
|
402 |
+
'modeling.html#load_datasets',
|
403 |
+
'whisperspeech/t2s_up_wds.py'),
|
404 |
+
'whisperspeech.t2s_up_wds.make_model': ( '5b. text to semantic token modeling.html#make_model',
|
405 |
+
'whisperspeech/t2s_up_wds.py'),
|
406 |
+
'whisperspeech.t2s_up_wds.rand': ( '5b. text to semantic token modeling.html#rand',
|
407 |
+
'whisperspeech/t2s_up_wds.py'),
|
408 |
+
'whisperspeech.t2s_up_wds.speaker_id_extractor': ( '5b. text to semantic token '
|
409 |
+
'modeling.html#speaker_id_extractor',
|
410 |
+
'whisperspeech/t2s_up_wds.py'),
|
411 |
+
'whisperspeech.t2s_up_wds.tokenizer': ( '5b. text to semantic token modeling.html#tokenizer',
|
412 |
+
'whisperspeech/t2s_up_wds.py')},
|
413 |
+
'whisperspeech.train': { 'whisperspeech.train.SimpleVisual': ('b1. training.html#simplevisual', 'whisperspeech/train.py'),
|
414 |
+
'whisperspeech.train.SimpleVisual.__init__': ( 'b1. training.html#simplevisual.__init__',
|
415 |
+
'whisperspeech/train.py'),
|
416 |
+
'whisperspeech.train.SimpleVisual.add_data': ( 'b1. training.html#simplevisual.add_data',
|
417 |
+
'whisperspeech/train.py'),
|
418 |
+
'whisperspeech.train.SimpleVisual.add_table_row': ( 'b1. training.html#simplevisual.add_table_row',
|
419 |
+
'whisperspeech/train.py'),
|
420 |
+
'whisperspeech.train.SimpleVisual.hide': ( 'b1. training.html#simplevisual.hide',
|
421 |
+
'whisperspeech/train.py'),
|
422 |
+
'whisperspeech.train.SimpleVisual.on_iter': ( 'b1. training.html#simplevisual.on_iter',
|
423 |
+
'whisperspeech/train.py'),
|
424 |
+
'whisperspeech.train.SimpleVisual.plot': ( 'b1. training.html#simplevisual.plot',
|
425 |
+
'whisperspeech/train.py'),
|
426 |
+
'whisperspeech.train.SimpleVisual.show': ( 'b1. training.html#simplevisual.show',
|
427 |
+
'whisperspeech/train.py'),
|
428 |
+
'whisperspeech.train.train': ('b1. training.html#train', 'whisperspeech/train.py'),
|
429 |
+
'whisperspeech.train.validate': ('b1. training.html#validate', 'whisperspeech/train.py')},
|
430 |
+
'whisperspeech.train_multi': { 'whisperspeech.train_multi.TrainingTask': ( 'b2. training (lightning).html#trainingtask',
|
431 |
+
'whisperspeech/train_multi.py'),
|
432 |
+
'whisperspeech.train_multi.TrainingTask.__init__': ( 'b2. training '
|
433 |
+
'(lightning).html#trainingtask.__init__',
|
434 |
+
'whisperspeech/train_multi.py'),
|
435 |
+
'whisperspeech.train_multi.TrainingTask.configure_optimizers': ( 'b2. training '
|
436 |
+
'(lightning).html#trainingtask.configure_optimizers',
|
437 |
+
'whisperspeech/train_multi.py'),
|
438 |
+
'whisperspeech.train_multi.TrainingTask.on_fit_start': ( 'b2. training '
|
439 |
+
'(lightning).html#trainingtask.on_fit_start',
|
440 |
+
'whisperspeech/train_multi.py'),
|
441 |
+
'whisperspeech.train_multi.TrainingTask.on_validation_epoch_end': ( 'b2. training '
|
442 |
+
'(lightning).html#trainingtask.on_validation_epoch_end',
|
443 |
+
'whisperspeech/train_multi.py'),
|
444 |
+
'whisperspeech.train_multi.TrainingTask.test_step': ( 'b2. training '
|
445 |
+
'(lightning).html#trainingtask.test_step',
|
446 |
+
'whisperspeech/train_multi.py'),
|
447 |
+
'whisperspeech.train_multi.TrainingTask.training_step': ( 'b2. training '
|
448 |
+
'(lightning).html#trainingtask.training_step',
|
449 |
+
'whisperspeech/train_multi.py'),
|
450 |
+
'whisperspeech.train_multi.TrainingTask.validation_step': ( 'b2. training '
|
451 |
+
'(lightning).html#trainingtask.validation_step',
|
452 |
+
'whisperspeech/train_multi.py'),
|
453 |
+
'whisperspeech.train_multi.parse_and_call': ( 'b2. training (lightning).html#parse_and_call',
|
454 |
+
'whisperspeech/train_multi.py')},
|
455 |
+
'whisperspeech.vad': { 'whisperspeech.vad.extract_segments': ( '1b. voice activity detection.html#extract_segments',
|
456 |
+
'whisperspeech/vad.py'),
|
457 |
+
'whisperspeech.vad.fix_dots_in_names': ( '1b. voice activity detection.html#fix_dots_in_names',
|
458 |
+
'whisperspeech/vad.py'),
|
459 |
+
'whisperspeech.vad.flac_to_vad_name': ( '1b. voice activity detection.html#flac_to_vad_name',
|
460 |
+
'whisperspeech/vad.py'),
|
461 |
+
'whisperspeech.vad.load_dataset': ( '1b. voice activity detection.html#load_dataset',
|
462 |
+
'whisperspeech/vad.py'),
|
463 |
+
'whisperspeech.vad.process_shard': ( '1b. voice activity detection.html#process_shard',
|
464 |
+
'whisperspeech/vad.py'),
|
465 |
+
'whisperspeech.vad.segment_audio': ( '1b. voice activity detection.html#segment_audio',
|
466 |
+
'whisperspeech/vad.py')},
|
467 |
+
'whisperspeech.verify_wds': { 'whisperspeech.verify_wds.process_shard': ( '0. verify webdataset archives.html#process_shard',
|
468 |
+
'whisperspeech/verify_wds.py')},
|
469 |
+
'whisperspeech.vq_stoks': { 'whisperspeech.vq_stoks.RQBottleneckTransformer': ( '2b. whisper quantization (semantic token) '
|
470 |
+
'model.html#rqbottlenecktransformer',
|
471 |
+
'whisperspeech/vq_stoks.py'),
|
472 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.__init__': ( '2b. whisper quantization (semantic '
|
473 |
+
'token) '
|
474 |
+
'model.html#rqbottlenecktransformer.__init__',
|
475 |
+
'whisperspeech/vq_stoks.py'),
|
476 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.decode_text': ( '2b. whisper quantization '
|
477 |
+
'(semantic token) '
|
478 |
+
'model.html#rqbottlenecktransformer.decode_text',
|
479 |
+
'whisperspeech/vq_stoks.py'),
|
480 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.dequantize': ( '2b. whisper quantization (semantic '
|
481 |
+
'token) '
|
482 |
+
'model.html#rqbottlenecktransformer.dequantize',
|
483 |
+
'whisperspeech/vq_stoks.py'),
|
484 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.device': ( '2b. whisper quantization (semantic '
|
485 |
+
'token) '
|
486 |
+
'model.html#rqbottlenecktransformer.device',
|
487 |
+
'whisperspeech/vq_stoks.py'),
|
488 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.downsample_embeddings': ( '2b. whisper '
|
489 |
+
'quantization (semantic '
|
490 |
+
'token) '
|
491 |
+
'model.html#rqbottlenecktransformer.downsample_embeddings',
|
492 |
+
'whisperspeech/vq_stoks.py'),
|
493 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.encode_audio': ( '2b. whisper quantization '
|
494 |
+
'(semantic token) '
|
495 |
+
'model.html#rqbottlenecktransformer.encode_audio',
|
496 |
+
'whisperspeech/vq_stoks.py'),
|
497 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.encode_mel': ( '2b. whisper quantization (semantic '
|
498 |
+
'token) '
|
499 |
+
'model.html#rqbottlenecktransformer.encode_mel',
|
500 |
+
'whisperspeech/vq_stoks.py'),
|
501 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.ensure_whisper': ( '2b. whisper quantization '
|
502 |
+
'(semantic token) '
|
503 |
+
'model.html#rqbottlenecktransformer.ensure_whisper',
|
504 |
+
'whisperspeech/vq_stoks.py'),
|
505 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.extract_teacher': ( '2b. whisper quantization '
|
506 |
+
'(semantic token) '
|
507 |
+
'model.html#rqbottlenecktransformer.extract_teacher',
|
508 |
+
'whisperspeech/vq_stoks.py'),
|
509 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.forward': ( '2b. whisper quantization (semantic '
|
510 |
+
'token) '
|
511 |
+
'model.html#rqbottlenecktransformer.forward',
|
512 |
+
'whisperspeech/vq_stoks.py'),
|
513 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.get_metrics': ( '2b. whisper quantization '
|
514 |
+
'(semantic token) '
|
515 |
+
'model.html#rqbottlenecktransformer.get_metrics',
|
516 |
+
'whisperspeech/vq_stoks.py'),
|
517 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.init_transformer': ( '2b. whisper quantization '
|
518 |
+
'(semantic token) '
|
519 |
+
'model.html#rqbottlenecktransformer.init_transformer',
|
520 |
+
'whisperspeech/vq_stoks.py'),
|
521 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.load_checkpoint': ( '2b. whisper quantization '
|
522 |
+
'(semantic token) '
|
523 |
+
'model.html#rqbottlenecktransformer.load_checkpoint',
|
524 |
+
'whisperspeech/vq_stoks.py'),
|
525 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.load_model': ( '2b. whisper quantization (semantic '
|
526 |
+
'token) '
|
527 |
+
'model.html#rqbottlenecktransformer.load_model',
|
528 |
+
'whisperspeech/vq_stoks.py'),
|
529 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.quantize': ( '2b. whisper quantization (semantic '
|
530 |
+
'token) '
|
531 |
+
'model.html#rqbottlenecktransformer.quantize',
|
532 |
+
'whisperspeech/vq_stoks.py'),
|
533 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.save_model': ( '2b. whisper quantization (semantic '
|
534 |
+
'token) '
|
535 |
+
'model.html#rqbottlenecktransformer.save_model',
|
536 |
+
'whisperspeech/vq_stoks.py'),
|
537 |
+
'whisperspeech.vq_stoks.RQBottleneckTransformer.setup': ( '2b. whisper quantization (semantic '
|
538 |
+
'token) '
|
539 |
+
'model.html#rqbottlenecktransformer.setup',
|
540 |
+
'whisperspeech/vq_stoks.py'),
|
541 |
+
'whisperspeech.vq_stoks.Tunables': ( '2b. whisper quantization (semantic token) '
|
542 |
+
'model.html#tunables',
|
543 |
+
'whisperspeech/vq_stoks.py'),
|
544 |
+
'whisperspeech.vq_stoks.Tunables.__post_init__': ( '2b. whisper quantization (semantic token) '
|
545 |
+
'model.html#tunables.__post_init__',
|
546 |
+
'whisperspeech/vq_stoks.py'),
|
547 |
+
'whisperspeech.vq_stoks.Tunables.upgrade': ( '2b. whisper quantization (semantic token) '
|
548 |
+
'model.html#tunables.upgrade',
|
549 |
+
'whisperspeech/vq_stoks.py'),
|
550 |
+
'whisperspeech.vq_stoks.add_masks': ( '2b. whisper quantization (semantic token) '
|
551 |
+
'model.html#add_masks',
|
552 |
+
'whisperspeech/vq_stoks.py'),
|
553 |
+
'whisperspeech.vq_stoks.derived_dataset': ( '2b. whisper quantization (semantic token) '
|
554 |
+
'model.html#derived_dataset',
|
555 |
+
'whisperspeech/vq_stoks.py'),
|
556 |
+
'whisperspeech.vq_stoks.load_datasets': ( '2b. whisper quantization (semantic token) '
|
557 |
+
'model.html#load_datasets',
|
558 |
+
'whisperspeech/vq_stoks.py'),
|
559 |
+
'whisperspeech.vq_stoks.logrand': ( '2b. whisper quantization (semantic token) model.html#logrand',
|
560 |
+
'whisperspeech/vq_stoks.py'),
|
561 |
+
'whisperspeech.vq_stoks.make_model': ( '2b. whisper quantization (semantic token) '
|
562 |
+
'model.html#make_model',
|
563 |
+
'whisperspeech/vq_stoks.py'),
|
564 |
+
'whisperspeech.vq_stoks.merge_in': ( '2b. whisper quantization (semantic token) '
|
565 |
+
'model.html#merge_in',
|
566 |
+
'whisperspeech/vq_stoks.py'),
|
567 |
+
'whisperspeech.vq_stoks.rand': ( '2b. whisper quantization (semantic token) model.html#rand',
|
568 |
+
'whisperspeech/vq_stoks.py'),
|
569 |
+
'whisperspeech.vq_stoks.tokenize_text': ( '2b. whisper quantization (semantic token) '
|
570 |
+
'model.html#tokenize_text',
|
571 |
+
'whisperspeech/vq_stoks.py')},
|
572 |
+
'whisperspeech.wer_metrics': { 'whisperspeech.wer_metrics.DfBuilder': ( 'c. word error rate metrics.html#dfbuilder',
|
573 |
+
'whisperspeech/wer_metrics.py'),
|
574 |
+
'whisperspeech.wer_metrics.DfBuilder.__init__': ( 'c. word error rate '
|
575 |
+
'metrics.html#dfbuilder.__init__',
|
576 |
+
'whisperspeech/wer_metrics.py'),
|
577 |
+
'whisperspeech.wer_metrics.DfBuilder.df': ( 'c. word error rate metrics.html#dfbuilder.df',
|
578 |
+
'whisperspeech/wer_metrics.py'),
|
579 |
+
'whisperspeech.wer_metrics.DfBuilder.push': ( 'c. word error rate metrics.html#dfbuilder.push',
|
580 |
+
'whisperspeech/wer_metrics.py'),
|
581 |
+
'whisperspeech.wer_metrics.WERStats': ( 'c. word error rate metrics.html#werstats',
|
582 |
+
'whisperspeech/wer_metrics.py'),
|
583 |
+
'whisperspeech.wer_metrics.WERStats.__init__': ( 'c. word error rate '
|
584 |
+
'metrics.html#werstats.__init__',
|
585 |
+
'whisperspeech/wer_metrics.py'),
|
586 |
+
'whisperspeech.wer_metrics.WERStats.push_sample': ( 'c. word error rate '
|
587 |
+
'metrics.html#werstats.push_sample',
|
588 |
+
'whisperspeech/wer_metrics.py'),
|
589 |
+
'whisperspeech.wer_metrics.librispeech_data': ( 'c. word error rate '
|
590 |
+
'metrics.html#librispeech_data',
|
591 |
+
'whisperspeech/wer_metrics.py'),
|
592 |
+
'whisperspeech.wer_metrics.whisper_normalize': ( 'c. word error rate '
|
593 |
+
'metrics.html#whisper_normalize',
|
594 |
+
'whisperspeech/wer_metrics.py')},
|
595 |
+
'whisperspeech.wh_transcribe': { 'whisperspeech.wh_transcribe.chunk_merger': ( '2a. whisper quantization dataset '
|
596 |
+
'preparation.html#chunk_merger',
|
597 |
+
'whisperspeech/wh_transcribe.py'),
|
598 |
+
'whisperspeech.wh_transcribe.flac_to_txt_name': ( '2a. whisper quantization dataset '
|
599 |
+
'preparation.html#flac_to_txt_name',
|
600 |
+
'whisperspeech/wh_transcribe.py'),
|
601 |
+
'whisperspeech.wh_transcribe.merge_in': ( '2a. whisper quantization dataset '
|
602 |
+
'preparation.html#merge_in',
|
603 |
+
'whisperspeech/wh_transcribe.py'),
|
604 |
+
'whisperspeech.wh_transcribe.process_shard': ( '2a. whisper quantization dataset '
|
605 |
+
'preparation.html#process_shard',
|
606 |
+
'whisperspeech/wh_transcribe.py'),
|
607 |
+
'whisperspeech.wh_transcribe.random_cutter': ( '2a. whisper quantization dataset '
|
608 |
+
'preparation.html#random_cutter',
|
609 |
+
'whisperspeech/wh_transcribe.py'),
|
610 |
+
'whisperspeech.wh_transcribe.split_to_chunks': ( '2a. whisper quantization dataset '
|
611 |
+
'preparation.html#split_to_chunks',
|
612 |
+
'whisperspeech/wh_transcribe.py'),
|
613 |
+
'whisperspeech.wh_transcribe.wds_compose': ( '2a. whisper quantization dataset '
|
614 |
+
'preparation.html#wds_compose',
|
615 |
+
'whisperspeech/wh_transcribe.py')}}}
|
whisperspeech/a2wav.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/6. Quality-boosting vocoder.ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = ['Vocoder']
|
5 |
+
|
6 |
+
# %% ../nbs/6. Quality-boosting vocoder.ipynb 1
|
7 |
+
from vocos import Vocos
|
8 |
+
import torch
|
9 |
+
import torchaudio
|
10 |
+
|
11 |
+
# %% ../nbs/6. Quality-boosting vocoder.ipynb 2
|
12 |
+
class Vocoder:
|
13 |
+
def __init__(self, repo_id="charactr/vocos-encodec-24khz"):
|
14 |
+
self.vocos = Vocos.from_pretrained(repo_id).cuda()
|
15 |
+
|
16 |
+
def is_notebook(self):
|
17 |
+
try:
|
18 |
+
return get_ipython().__class__.__name__ == "ZMQInteractiveShell"
|
19 |
+
except:
|
20 |
+
return False
|
21 |
+
|
22 |
+
@torch.no_grad()
|
23 |
+
def decode(self, atoks):
|
24 |
+
if len(atoks.shape) == 3:
|
25 |
+
b,q,t = atoks.shape
|
26 |
+
atoks = atoks.permute(1,0,2)
|
27 |
+
else:
|
28 |
+
q,t = atoks.shape
|
29 |
+
|
30 |
+
features = self.vocos.codes_to_features(atoks)
|
31 |
+
bandwidth_id = torch.tensor({2:0,4:1,8:2}[q]).cuda()
|
32 |
+
return self.vocos.decode(features, bandwidth_id=bandwidth_id)
|
33 |
+
|
34 |
+
def decode_to_file(self, fname, atoks):
|
35 |
+
audio = self.decode(atoks)
|
36 |
+
torchaudio.save(fname, audio.cpu(), 24000)
|
37 |
+
if self.is_notebook():
|
38 |
+
from IPython.display import display, HTML, Audio
|
39 |
+
display(HTML(f'<a href="{fname}" target="_blank">Listen to {fname}</a>'))
|
40 |
+
|
41 |
+
def decode_to_notebook(self, atoks):
|
42 |
+
from IPython.display import display, HTML, Audio
|
43 |
+
|
44 |
+
audio = self.decode(atoks)
|
45 |
+
display(Audio(audio.cpu().numpy(), rate=24000))
|
whisperspeech/extract_acoustic.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/1. Acoustic token extraction.ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = ['load', 'load_model', 'extract_Atoks', 'extract_acoustic']
|
5 |
+
|
6 |
+
# %% ../nbs/1. Acoustic token extraction.ipynb 2
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
import gc
|
10 |
+
|
11 |
+
from pathlib import Path
|
12 |
+
from fastcore.script import *
|
13 |
+
from fastprogress import progress_bar, master_bar
|
14 |
+
|
15 |
+
# %% ../nbs/1. Acoustic token extraction.ipynb 5
|
16 |
+
def load(fname, newsr=24000):
|
17 |
+
"""Load an audio file to the GPU and resample to `newsr`."""
|
18 |
+
x, sr = torchaudio.load(fname)
|
19 |
+
_tform = torchaudio.transforms.Resample(sr, newsr)
|
20 |
+
return _tform(x).cuda().unsqueeze(0)
|
21 |
+
|
22 |
+
# %% ../nbs/1. Acoustic token extraction.ipynb 6
|
23 |
+
def load_model():
|
24 |
+
"Load the pretrained EnCodec model"
|
25 |
+
from encodec.model import EncodecModel
|
26 |
+
model = EncodecModel.encodec_model_24khz()
|
27 |
+
model.set_target_bandwidth(1.5)
|
28 |
+
model.cuda().eval();
|
29 |
+
return model
|
30 |
+
|
31 |
+
# %% ../nbs/1. Acoustic token extraction.ipynb 7
|
32 |
+
def extract_Atoks(model, audio):
|
33 |
+
"""Extract EnCodec tokens for the given `audio` tensor (or file path)
|
34 |
+
using the given `model` (see `load_model`)."""
|
35 |
+
if isinstance(audio, (Path, str)):
|
36 |
+
audio = load(audio)
|
37 |
+
with torch.no_grad():
|
38 |
+
frames = torch.cat([model.encode(segment)[0][0]
|
39 |
+
for segment in torch.split(audio, 320*20000, dim=-1)], dim=-1)
|
40 |
+
return frames
|
41 |
+
|
42 |
+
# %% ../nbs/1. Acoustic token extraction.ipynb 8
|
43 |
+
@call_parse
|
44 |
+
def extract_acoustic(
|
45 |
+
srcdir:Path, # source dir, should contain *.flac files
|
46 |
+
outdir:Path, # output dir, will get the *.encodec files
|
47 |
+
):
|
48 |
+
"Convert audio files to .encodec files with tensors of tokens"
|
49 |
+
model = load_model()
|
50 |
+
outdir.mkdir(exist_ok=True, parents=True)
|
51 |
+
for name in progress_bar(list(srcdir.rglob('*.flac'))):
|
52 |
+
outname = outdir/name.with_suffix('.encodec').name
|
53 |
+
tokens = extract_Atoks(model, name)
|
54 |
+
torch.save(tokens, outname)
|
55 |
+
del tokens
|
56 |
+
gc.collect()
|
whisperspeech/fetch_models.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/0. Download models.ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = []
|
5 |
+
|
6 |
+
# %% ../nbs/0. Download models.ipynb 1
|
7 |
+
from fastcore.script import call_parse
|
8 |
+
import whisperx
|
9 |
+
import whisper
|
10 |
+
|
11 |
+
# %% ../nbs/0. Download models.ipynb 3
|
12 |
+
@call_parse
|
13 |
+
def main():
|
14 |
+
whisper.load_model('base.en')
|
15 |
+
whisper.load_model('small.en')
|
16 |
+
whisperx.vad.load_vad_model('cpu')
|
17 |
+
whisperx.asr.load_model('medium.en', "cpu", compute_type="float16", language='en')
|
whisperspeech/languages.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B. Languages.ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = ['to_id']
|
5 |
+
|
6 |
+
# %% ../nbs/B. Languages.ipynb 3
|
7 |
+
LANGUAGES = {
|
8 |
+
"en": "english",
|
9 |
+
"zh": "chinese",
|
10 |
+
"de": "german",
|
11 |
+
"es": "spanish",
|
12 |
+
"ru": "russian",
|
13 |
+
"ko": "korean",
|
14 |
+
"fr": "french",
|
15 |
+
"ja": "japanese",
|
16 |
+
"pt": "portuguese",
|
17 |
+
"tr": "turkish",
|
18 |
+
"pl": "polish",
|
19 |
+
"ca": "catalan",
|
20 |
+
"nl": "dutch",
|
21 |
+
"ar": "arabic",
|
22 |
+
"sv": "swedish",
|
23 |
+
"it": "italian",
|
24 |
+
"id": "indonesian",
|
25 |
+
"hi": "hindi",
|
26 |
+
"fi": "finnish",
|
27 |
+
"vi": "vietnamese",
|
28 |
+
"he": "hebrew",
|
29 |
+
"uk": "ukrainian",
|
30 |
+
"el": "greek",
|
31 |
+
"ms": "malay",
|
32 |
+
"cs": "czech",
|
33 |
+
"ro": "romanian",
|
34 |
+
"da": "danish",
|
35 |
+
"hu": "hungarian",
|
36 |
+
"ta": "tamil",
|
37 |
+
"no": "norwegian",
|
38 |
+
"th": "thai",
|
39 |
+
"ur": "urdu",
|
40 |
+
"hr": "croatian",
|
41 |
+
"bg": "bulgarian",
|
42 |
+
"lt": "lithuanian",
|
43 |
+
"la": "latin",
|
44 |
+
"mi": "maori",
|
45 |
+
"ml": "malayalam",
|
46 |
+
"cy": "welsh",
|
47 |
+
"sk": "slovak",
|
48 |
+
"te": "telugu",
|
49 |
+
"fa": "persian",
|
50 |
+
"lv": "latvian",
|
51 |
+
"bn": "bengali",
|
52 |
+
"sr": "serbian",
|
53 |
+
"az": "azerbaijani",
|
54 |
+
"sl": "slovenian",
|
55 |
+
"kn": "kannada",
|
56 |
+
"et": "estonian",
|
57 |
+
"mk": "macedonian",
|
58 |
+
"br": "breton",
|
59 |
+
"eu": "basque",
|
60 |
+
"is": "icelandic",
|
61 |
+
"hy": "armenian",
|
62 |
+
"ne": "nepali",
|
63 |
+
"mn": "mongolian",
|
64 |
+
"bs": "bosnian",
|
65 |
+
"kk": "kazakh",
|
66 |
+
"sq": "albanian",
|
67 |
+
"sw": "swahili",
|
68 |
+
"gl": "galician",
|
69 |
+
"mr": "marathi",
|
70 |
+
"pa": "punjabi",
|
71 |
+
"si": "sinhala",
|
72 |
+
"km": "khmer",
|
73 |
+
"sn": "shona",
|
74 |
+
"yo": "yoruba",
|
75 |
+
"so": "somali",
|
76 |
+
"af": "afrikaans",
|
77 |
+
"oc": "occitan",
|
78 |
+
"ka": "georgian",
|
79 |
+
"be": "belarusian",
|
80 |
+
"tg": "tajik",
|
81 |
+
"sd": "sindhi",
|
82 |
+
"gu": "gujarati",
|
83 |
+
"am": "amharic",
|
84 |
+
"yi": "yiddish",
|
85 |
+
"lo": "lao",
|
86 |
+
"uz": "uzbek",
|
87 |
+
"fo": "faroese",
|
88 |
+
"ht": "haitian creole",
|
89 |
+
"ps": "pashto",
|
90 |
+
"tk": "turkmen",
|
91 |
+
"nn": "nynorsk",
|
92 |
+
"mt": "maltese",
|
93 |
+
"sa": "sanskrit",
|
94 |
+
"lb": "luxembourgish",
|
95 |
+
"my": "myanmar",
|
96 |
+
"bo": "tibetan",
|
97 |
+
"tl": "tagalog",
|
98 |
+
"mg": "malagasy",
|
99 |
+
"as": "assamese",
|
100 |
+
"tt": "tatar",
|
101 |
+
"haw": "hawaiian",
|
102 |
+
"ln": "lingala",
|
103 |
+
"ha": "hausa",
|
104 |
+
"ba": "bashkir",
|
105 |
+
"jw": "javanese",
|
106 |
+
"su": "sundanese",
|
107 |
+
}
|
108 |
+
|
109 |
+
# %% ../nbs/B. Languages.ipynb 4
|
110 |
+
# language code lookup by name, with a few language aliases
|
111 |
+
TO_LANGUAGE_CODE = {
|
112 |
+
**{language: code for code, language in LANGUAGES.items()},
|
113 |
+
"burmese": "my",
|
114 |
+
"valencian": "ca",
|
115 |
+
"flemish": "nl",
|
116 |
+
"haitian": "ht",
|
117 |
+
"letzeburgesch": "lb",
|
118 |
+
"pushto": "ps",
|
119 |
+
"panjabi": "pa",
|
120 |
+
"moldavian": "ro",
|
121 |
+
"moldovan": "ro",
|
122 |
+
"sinhalese": "si",
|
123 |
+
"castilian": "es",
|
124 |
+
}
|
125 |
+
|
126 |
+
# %% ../nbs/B. Languages.ipynb 5
|
127 |
+
languages = tuple(LANGUAGES.keys())
|
128 |
+
|
129 |
+
# %% ../nbs/B. Languages.ipynb 6
|
130 |
+
def to_id(lang):
|
131 |
+
return languages.index(TO_LANGUAGE_CODE.get(lang, lang))
|
whisperspeech/modules.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/A. Neural modules.ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = ['LayerNorm', 'LinearHead', 'QueryHead', 'init_transformer', 'sinusoids', 'MultiHeadAttention',
|
5 |
+
'ResidualAttentionBlock', 'BaseDecoder', 'EmbeddingProjector', 'FlexEmbeddings']
|
6 |
+
|
7 |
+
# %% ../nbs/A. Neural modules.ipynb 2
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
import math
|
11 |
+
|
12 |
+
from torch import Tensor, nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from typing import Dict, Iterable, Optional
|
15 |
+
|
16 |
+
# import xformers.ops as xops
|
17 |
+
|
18 |
+
# %% ../nbs/A. Neural modules.ipynb 3
|
19 |
+
# Code in this file is mostly borrowed from
|
20 |
+
# https://github.com/openai/whisper/blob/main/whisper/model.py
|
21 |
+
# and is under the MIT License
|
22 |
+
|
23 |
+
class LayerNorm(nn.LayerNorm):
|
24 |
+
def forward(self, x):
|
25 |
+
return super().forward(x.float()).type(x.dtype)
|
26 |
+
|
27 |
+
# Used in ΞΌP to initialize the weights and configure the optimizer
|
28 |
+
# These two layers map the transformer width into a fixed dimension
|
29 |
+
class LinearHead(nn.Linear):
|
30 |
+
pass
|
31 |
+
|
32 |
+
class QueryHead(nn.Linear):
|
33 |
+
pass
|
34 |
+
|
35 |
+
# based on https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L163
|
36 |
+
def init_transformer(m):
|
37 |
+
if isinstance(m, (nn.Linear, nn.Embedding)):
|
38 |
+
torch.nn.init.trunc_normal_(m.weight, std=.02)
|
39 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
40 |
+
torch.nn.init.constant_(m.bias, 0)
|
41 |
+
elif isinstance(m, nn.LayerNorm):
|
42 |
+
torch.nn.init.constant_(m.bias, 0)
|
43 |
+
torch.nn.init.constant_(m.weight, 1.0)
|
44 |
+
|
45 |
+
# %% ../nbs/A. Neural modules.ipynb 4
|
46 |
+
def sinusoids(length, channels, max_timescale=10000):
|
47 |
+
"""Returns sinusoids for positional embedding"""
|
48 |
+
assert channels % 2 == 0
|
49 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
50 |
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
51 |
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
52 |
+
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
53 |
+
|
54 |
+
# %% ../nbs/A. Neural modules.ipynb 5
|
55 |
+
class MultiHeadAttention(nn.Module):
|
56 |
+
def __init__(self, n_state: int, n_head: int, qk_scale: float = 1, rope: bool = False, cross=False):
|
57 |
+
super().__init__()
|
58 |
+
self.n_state = n_state
|
59 |
+
self.n_head = n_head
|
60 |
+
self.sqrt_qk_scale = math.sqrt(qk_scale)
|
61 |
+
self.query = QueryHead(n_state, n_state)
|
62 |
+
self.key = nn.Linear(n_state, n_state, bias=False)
|
63 |
+
self.value = nn.Linear(n_state, n_state)
|
64 |
+
self.out = nn.Linear(n_state, n_state)
|
65 |
+
self.cross = cross
|
66 |
+
self.query_subsampling = 1
|
67 |
+
self.key_subsampling = 1
|
68 |
+
|
69 |
+
self.cached_kvx = None
|
70 |
+
self.register_buffer('k_cache', None)
|
71 |
+
self.register_buffer('v_cache', None)
|
72 |
+
|
73 |
+
self.rotary = None
|
74 |
+
if rope:
|
75 |
+
self.rotary = Rotary(n_state // n_head)
|
76 |
+
self.qkv = None
|
77 |
+
self.kv = None
|
78 |
+
|
79 |
+
def setup_kv_cache(self, max_batch_size, max_seq_len, dtype=torch.float32):
|
80 |
+
cache_shape = (max_batch_size, self.n_head, max_seq_len, self.n_state//self.n_head)
|
81 |
+
self.k_cache = torch.zeros(cache_shape, dtype=dtype, device=self.key.weight.device)
|
82 |
+
self.v_cache = torch.zeros(cache_shape, dtype=dtype, device=self.value.weight.device)
|
83 |
+
|
84 |
+
def merge_linears(self, layers, mults):
|
85 |
+
bias = [x.bias for x in layers if x.bias is not None][0]
|
86 |
+
din, dout = layers[0].weight.shape
|
87 |
+
new = nn.Linear(din, len(layers) * dout).to(layers[0].weight.device)
|
88 |
+
with torch.no_grad():
|
89 |
+
new.weight[:] = torch.cat([x.weight * m for x,m in zip(layers, mults)])
|
90 |
+
new.bias[:] = torch.cat([torch.zeros_like(bias) if x.bias is None else x.bias * m for x, m in zip(layers, mults)])
|
91 |
+
return new
|
92 |
+
|
93 |
+
def convert_for_eval(self):
|
94 |
+
if self.qkv or self.kv: raise AttributeError("already converted")
|
95 |
+
|
96 |
+
self.odim = self.key.weight.shape[1]
|
97 |
+
if self.cross:
|
98 |
+
self.q = self.merge_linears([self.query], [self.sqrt_qk_scale])
|
99 |
+
self.kv = self.merge_linears([self.key, self.value],
|
100 |
+
[self.sqrt_qk_scale, 1])
|
101 |
+
else:
|
102 |
+
self.qkv = self.merge_linears([self.query, self.key, self.value],
|
103 |
+
[self.sqrt_qk_scale, self.sqrt_qk_scale, 1])
|
104 |
+
|
105 |
+
def split_heads(self, x, x_positions, rope=False, subsampling=1):
|
106 |
+
x = x.view(*x.shape[:2], self.n_head, -1)
|
107 |
+
if rope:
|
108 |
+
x = rope_rotate(x, x_positions * subsampling, *self.rotary(x))
|
109 |
+
return x.permute(0, 2, 1, 3)
|
110 |
+
|
111 |
+
def forward(
|
112 |
+
self,
|
113 |
+
qx,
|
114 |
+
q_positions,
|
115 |
+
kvx,
|
116 |
+
kv_positions,
|
117 |
+
causal = False,
|
118 |
+
mask=None,
|
119 |
+
):
|
120 |
+
if self.qkv:
|
121 |
+
q,k,v = self.qkv(qx).split(self.odim, dim=-1)
|
122 |
+
elif self.kv:
|
123 |
+
q = self.q(qx)
|
124 |
+
k,v = self.kv(kvx).split(self.odim, dim=-1)
|
125 |
+
else:
|
126 |
+
q,k,v = None,None,None
|
127 |
+
|
128 |
+
if q is None: q = self.query(qx) * self.sqrt_qk_scale
|
129 |
+
q = self.split_heads(q, q_positions, rope = self.rotary, subsampling = self.query_subsampling)
|
130 |
+
|
131 |
+
if kvx is not self.cached_kvx:
|
132 |
+
if k is None: k = self.key(kvx) * self.sqrt_qk_scale
|
133 |
+
k = self.split_heads(k, kv_positions, rope = self.rotary, subsampling = self.key_subsampling)
|
134 |
+
if v is None: v = self.value(kvx)
|
135 |
+
v = self.split_heads(v, kv_positions)
|
136 |
+
if self.k_cache is not None:
|
137 |
+
self.k_cache[:,:,kv_positions] = k
|
138 |
+
self.v_cache[:,:,kv_positions] = v
|
139 |
+
|
140 |
+
if self.k_cache is not None:
|
141 |
+
k, v = self.k_cache, self.v_cache
|
142 |
+
|
143 |
+
if mask is not None:
|
144 |
+
mask = mask[q_positions]
|
145 |
+
|
146 |
+
wv = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0, is_causal=causal)
|
147 |
+
|
148 |
+
return self.out(wv.permute(0, 2, 1, 3).flatten(start_dim=2))
|
149 |
+
|
150 |
+
# %% ../nbs/A. Neural modules.ipynb 6
|
151 |
+
# modified from https://blog.eleuther.ai/rotary-embeddings/
|
152 |
+
|
153 |
+
import torch
|
154 |
+
|
155 |
+
class Rotary(torch.nn.Module):
|
156 |
+
def __init__(self, dim, base=10000):
|
157 |
+
super().__init__()
|
158 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
159 |
+
self.register_buffer("inv_freq", inv_freq)
|
160 |
+
self.seq_len_cached = None
|
161 |
+
self.cos_cached = None
|
162 |
+
self.sin_cached = None
|
163 |
+
|
164 |
+
def forward(self, x, seq_dim=1):
|
165 |
+
seq_len = x.shape[seq_dim]
|
166 |
+
if not self.seq_len_cached or seq_len > self.seq_len_cached:
|
167 |
+
self.seq_len_cached = 2500
|
168 |
+
# self.seq_len_cached = seq_len
|
169 |
+
|
170 |
+
t = torch.arange(self.seq_len_cached, device=x.device).type_as(self.inv_freq)
|
171 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
172 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
173 |
+
self.cos_cached = emb.cos()[None, :, None, :]
|
174 |
+
self.sin_cached = emb.sin()[None, :, None, :]
|
175 |
+
return self.cos_cached, self.sin_cached
|
176 |
+
|
177 |
+
|
178 |
+
# rotary pos emb helpers:
|
179 |
+
def rotate_half(x):
|
180 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
181 |
+
return torch.cat(
|
182 |
+
(-x2, x1), dim=len(x.shape)-1
|
183 |
+
)
|
184 |
+
|
185 |
+
def rope_rotate(x, positions, cos, sin):
|
186 |
+
return x * cos[:,positions] + rotate_half(x) * sin[:,positions]
|
187 |
+
|
188 |
+
# %% ../nbs/A. Neural modules.ipynb 7
|
189 |
+
class ResidualAttentionBlock(nn.Module):
|
190 |
+
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, rope: bool = False,
|
191 |
+
qk_scale: float = 1, ffn_mult: int = 4):
|
192 |
+
super().__init__()
|
193 |
+
self.attn = MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope)
|
194 |
+
self.attn_ln = LayerNorm(n_state)
|
195 |
+
|
196 |
+
self.cross_attn = (
|
197 |
+
MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope, cross=True) if cross_attention else None
|
198 |
+
)
|
199 |
+
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
200 |
+
|
201 |
+
n_mlp = n_state * ffn_mult
|
202 |
+
self.mlp = nn.Sequential(
|
203 |
+
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
|
204 |
+
)
|
205 |
+
self.mlp_ln = LayerNorm(n_state)
|
206 |
+
|
207 |
+
def setup_kv_cache(self, max_batch_size, max_seq_len, max_cross_seq_len=None):
|
208 |
+
self.attn.setup_kv_cache(max_batch_size, max_seq_len)
|
209 |
+
if self.cross_attn:
|
210 |
+
self.cross_attn.setup_kv_cache(max_batch_size, max_cross_seq_len)
|
211 |
+
|
212 |
+
def forward(
|
213 |
+
self,
|
214 |
+
x: Tensor,
|
215 |
+
x_positions: Tensor = None,
|
216 |
+
xa: Optional[Tensor] = None,
|
217 |
+
xa_positions: Optional[Tensor] = None,
|
218 |
+
causal = False,
|
219 |
+
mask=None,
|
220 |
+
):
|
221 |
+
lnx = self.attn_ln(x)
|
222 |
+
x = x + self.attn(lnx, x_positions, lnx, x_positions, causal=causal, mask=mask)
|
223 |
+
if self.cross_attn:
|
224 |
+
lnx = self.cross_attn_ln(x)
|
225 |
+
x = x + self.cross_attn(lnx, x_positions, xa, xa_positions)
|
226 |
+
x = x + self.mlp(self.mlp_ln(x))
|
227 |
+
return x
|
228 |
+
|
229 |
+
# %% ../nbs/A. Neural modules.ipynb 8
|
230 |
+
class BaseDecoder(nn.Module):
|
231 |
+
def __init__(self, depth=6, n_head=6, width=384, qk_scale=1, ffn_mult=4, length=2250, rope=False):
|
232 |
+
super().__init__()
|
233 |
+
self.length = length
|
234 |
+
self.width = width
|
235 |
+
self.layers = nn.ModuleList([
|
236 |
+
ResidualAttentionBlock(
|
237 |
+
self.width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, cross_attention=True, rope=rope
|
238 |
+
) for _ in range(math.floor(depth))
|
239 |
+
])
|
240 |
+
|
241 |
+
self.ln_post = LayerNorm(width)
|
242 |
+
|
243 |
+
mask = torch.empty(length, length).fill_(-torch.inf).triu_(1)
|
244 |
+
self.register_buffer("mask", mask, persistent=False)
|
245 |
+
|
246 |
+
def forward(self, x, x_positions, xenc, xenc_positions):
|
247 |
+
for i,l in enumerate(self.layers):
|
248 |
+
x = l(x, x_positions, xenc, xenc_positions, causal=False, mask=self.mask)
|
249 |
+
|
250 |
+
x = self.ln_post(x)
|
251 |
+
|
252 |
+
return x
|
253 |
+
|
254 |
+
# %% ../nbs/A. Neural modules.ipynb 9
|
255 |
+
class EmbeddingProjector(nn.Linear):
|
256 |
+
pass
|
257 |
+
|
258 |
+
class FlexEmbeddings(nn.Module):
|
259 |
+
def __init__(self, codes, width, special_codes=None, frozen_width=None, special_embedding=None, unembed=True):
|
260 |
+
super().__init__()
|
261 |
+
self.codes = codes
|
262 |
+
self.special_codes = special_codes
|
263 |
+
if frozen_width is None: frozen_width = width
|
264 |
+
|
265 |
+
self.main = nn.Embedding(codes, frozen_width or width)
|
266 |
+
self.emb_to_hidden = EmbeddingProjector(frozen_width, width) if frozen_width != width else None
|
267 |
+
self.hidden_to_emb = EmbeddingProjector(width, frozen_width) if unembed and frozen_width != width else None
|
268 |
+
if special_codes:
|
269 |
+
self.special = special_embedding or nn.Embedding(special_codes, width)
|
270 |
+
|
271 |
+
self.register_buffer('merged_in', None)
|
272 |
+
self.register_buffer('merged_out', None)
|
273 |
+
self.register_buffer('bias_out', None)
|
274 |
+
|
275 |
+
def set_frozen_embeddings(self, values):
|
276 |
+
with torch.no_grad():
|
277 |
+
self.main.weight[:] = values
|
278 |
+
self.main.lr_scale = 0
|
279 |
+
|
280 |
+
@torch.no_grad()
|
281 |
+
def convert_for_eval(self):
|
282 |
+
if not self.special_codes: return
|
283 |
+
# in
|
284 |
+
main_w = self.main.weight
|
285 |
+
if self.emb_to_hidden is not None: main_w = self.emb_to_hidden(main_w)
|
286 |
+
weight = torch.cat([main_w, self.special.weight], dim=0)
|
287 |
+
self.merged_in = nn.Embedding(*weight.shape, _weight=weight)
|
288 |
+
|
289 |
+
# out
|
290 |
+
weight = self.main.weight
|
291 |
+
if self.hidden_to_emb: weight = weight @ self.hidden_to_emb.weight
|
292 |
+
self.merged_out = torch.cat([weight.T, self.special.weight.T], dim=1).T.contiguous() # T is for F.linear
|
293 |
+
if self.hidden_to_emb:
|
294 |
+
self.bias_out = torch.cat([
|
295 |
+
self.hidden_to_emb.bias @ self.main.weight.T,
|
296 |
+
torch.zeros(self.special.weight.shape[0], device=weight.device, dtype=weight.dtype)
|
297 |
+
], dim=0)
|
298 |
+
else:
|
299 |
+
self.bias_out = None
|
300 |
+
|
301 |
+
def forward(self, toks):
|
302 |
+
if not self.training and self.merged_in is not None:
|
303 |
+
return self.merged_in(toks)
|
304 |
+
|
305 |
+
if self.special_codes:
|
306 |
+
special_mask = toks >= self.codes
|
307 |
+
embs = self.main(torch.where(special_mask, 0, toks))
|
308 |
+
else:
|
309 |
+
embs = self.main(toks)
|
310 |
+
|
311 |
+
if self.emb_to_hidden: embs = self.emb_to_hidden(embs)
|
312 |
+
|
313 |
+
if self.special_codes:
|
314 |
+
embs[special_mask] = self.special(toks[special_mask] - self.codes).to(embs.dtype)
|
315 |
+
|
316 |
+
return embs
|
317 |
+
|
318 |
+
def unembed(self, embs):
|
319 |
+
if not self.training and self.merged_out is not None:
|
320 |
+
return F.linear(embs, self.merged_out, self.bias_out) # embs @ self.merged_out + self.bias_out
|
321 |
+
|
322 |
+
orig_embs = embs
|
323 |
+
if self.hidden_to_emb: embs = self.hidden_to_emb(embs)
|
324 |
+
|
325 |
+
main_logits = (embs @ self.main.weight.to(embs.dtype).T).float()
|
326 |
+
|
327 |
+
if not self.special_codes:
|
328 |
+
return main_logits
|
329 |
+
|
330 |
+
special_logits = (orig_embs @ self.special.weight.to(orig_embs.dtype).T).float()
|
331 |
+
return torch.cat([main_logits, special_logits], dim=-1)
|
whisperspeech/pipeline.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/7. Pipeline.ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = ['Pipeline']
|
5 |
+
|
6 |
+
# %% ../nbs/7. Pipeline.ipynb 1
|
7 |
+
import torch
|
8 |
+
from whisperspeech.t2s_up_wds_mlang_enclm import TSARTransformer
|
9 |
+
from whisperspeech.s2a_delar_mup_wds_mlang import SADelARTransformer
|
10 |
+
from whisperspeech.a2wav import Vocoder
|
11 |
+
import traceback
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
# %% ../nbs/7. Pipeline.ipynb 2
|
15 |
+
class Pipeline:
|
16 |
+
default_speaker = torch.tensor(
|
17 |
+
[-0.2929, -0.4503, 0.4155, -0.1417, 0.0473, -0.1624, -0.2322, 0.7071,
|
18 |
+
0.4800, 0.5496, 0.0410, 0.6236, 0.4729, 0.0587, 0.2194, -0.0466,
|
19 |
+
-0.3036, 0.0497, 0.5028, -0.1703, 0.5039, -0.6464, 0.3857, -0.7350,
|
20 |
+
-0.1605, 0.4808, 0.5397, -0.4851, 0.1774, -0.8712, 0.5789, 0.1785,
|
21 |
+
-0.1417, 0.3039, 0.4232, -0.0186, 0.2685, 0.6153, -0.3103, -0.5706,
|
22 |
+
-0.4494, 0.3394, -0.6184, -0.3617, 1.1041, -0.1178, -0.1885, 0.1997,
|
23 |
+
0.5571, -0.2906, -0.0477, -0.4048, -0.1062, 1.4779, 0.1639, -0.3712,
|
24 |
+
-0.1776, -0.0568, -0.6162, 0.0110, -0.0207, -0.1319, -0.3854, 0.7248,
|
25 |
+
0.0343, 0.5724, 0.0670, 0.0486, -0.3813, 0.1738, 0.3017, 1.0502,
|
26 |
+
0.1550, 0.5708, 0.0366, 0.5093, 0.0294, -0.7091, -0.8220, -0.1583,
|
27 |
+
-0.2343, 0.1366, 0.7372, -0.0631, 0.1505, 0.4600, -0.1252, -0.5245,
|
28 |
+
0.7523, -0.0386, -0.2587, 1.0066, -0.2037, 0.1617, -0.3800, 0.2790,
|
29 |
+
0.0184, -0.5111, -0.7291, 0.1627, 0.2367, -0.0192, 0.4822, -0.4458,
|
30 |
+
0.1457, -0.5884, 0.1909, 0.2563, -0.2035, -0.0377, 0.7771, 0.2139,
|
31 |
+
0.3801, 0.6047, -0.6043, -0.2563, -0.0726, 0.3856, 0.3217, 0.0823,
|
32 |
+
-0.1302, 0.3287, 0.5693, 0.2453, 0.8231, 0.0072, 1.0327, 0.6065,
|
33 |
+
-0.0620, -0.5572, 0.5220, 0.2485, 0.1520, 0.0222, -0.2179, -0.7392,
|
34 |
+
-0.3855, 0.1822, 0.1042, 0.7133, 0.3583, 0.0606, -0.0424, -0.9189,
|
35 |
+
-0.4882, -0.5480, -0.5719, -0.1660, -0.3439, -0.5814, -0.2542, 0.0197,
|
36 |
+
0.4942, 0.0915, -0.0420, -0.0035, 0.5578, 0.1051, -0.0891, 0.2348,
|
37 |
+
0.6876, -0.6685, 0.8215, -0.3692, -0.3150, -0.0462, -0.6806, -0.2661,
|
38 |
+
-0.0308, -0.0050, 0.6756, -0.1647, 1.0734, 0.0049, 0.4969, 0.0259,
|
39 |
+
-0.8949, 0.0731, 0.0886, 0.3442, -0.1433, -0.6804, 0.2204, 0.1859,
|
40 |
+
0.2702, 0.1699, -0.1443, -0.9614, 0.3261, 0.1718, 0.3545, -0.0686]
|
41 |
+
)
|
42 |
+
|
43 |
+
def __init__(self, t2s_ref=None, s2a_ref=None, optimize=True, torch_compile=False):
|
44 |
+
args = dict()
|
45 |
+
try:
|
46 |
+
if t2s_ref:
|
47 |
+
args["ref"] = t2s_ref
|
48 |
+
self.t2s = TSARTransformer.load_model(**args).cuda()
|
49 |
+
if optimize: self.t2s.optimize(torch_compile=torch_compile)
|
50 |
+
except:
|
51 |
+
print("Failed to load the T2S model:")
|
52 |
+
print(traceback.format_exc())
|
53 |
+
try:
|
54 |
+
if s2a_ref:
|
55 |
+
args["ref"] = s2a_ref
|
56 |
+
self.s2a = SADelARTransformer.load_model(**args).cuda()
|
57 |
+
if optimize: self.s2a.optimize(torch_compile=torch_compile)
|
58 |
+
except:
|
59 |
+
print("Failed to load the S2A model:")
|
60 |
+
print(traceback.format_exc())
|
61 |
+
self.vocoder = Vocoder()
|
62 |
+
self.encoder = None
|
63 |
+
|
64 |
+
def extract_spk_emb(self, fname):
|
65 |
+
"""Extracts a speaker embedding from the first 30 seconds of the give audio file.
|
66 |
+
"""
|
67 |
+
import torchaudio
|
68 |
+
if self.encoder is None:
|
69 |
+
from speechbrain.pretrained import EncoderClassifier
|
70 |
+
self.encoder = EncoderClassifier.from_hparams("speechbrain/spkrec-ecapa-voxceleb",
|
71 |
+
savedir="~/.cache/speechbrain/",
|
72 |
+
run_opts={"device": "cuda"})
|
73 |
+
samples, sr = torchaudio.load(fname)
|
74 |
+
samples = self.encoder.audio_normalizer(samples[0,:30*sr], sr)
|
75 |
+
spk_emb = self.encoder.encode_batch(samples)
|
76 |
+
return spk_emb[0,0]
|
77 |
+
|
78 |
+
def generate_atoks(self, text, speaker=None, lang='en', cps=15, step_callback=None):
|
79 |
+
if speaker is None: speaker = self.default_speaker
|
80 |
+
elif isinstance(speaker, (str, Path)): speaker = self.extract_spk_emb(speaker)
|
81 |
+
text = text.replace("\n", " ")
|
82 |
+
stoks = self.t2s.generate(text, cps=cps, lang=lang, step=step_callback)
|
83 |
+
atoks = self.s2a.generate(stoks, speaker.unsqueeze(0), step=step_callback)
|
84 |
+
return atoks
|
85 |
+
|
86 |
+
def generate(self, text, speaker=None, lang='en', cps=15, step_callback=None):
|
87 |
+
return self.vocoder.decode(self.generate_atoks(text, speaker, lang=lang, cps=cps, step_callback=step_callback))
|
88 |
+
|
89 |
+
def generate_to_file(self, fname, text, speaker=None, lang='en', cps=15, step_callback=None):
|
90 |
+
self.vocoder.decode_to_file(fname, self.generate_atoks(text, speaker, lang=lang, cps=cps, step_callback=None))
|
91 |
+
|
92 |
+
def generate_to_notebook(self, text, speaker=None, lang='en', cps=15, step_callback=None):
|
93 |
+
self.vocoder.decode_to_notebook(self.generate_atoks(text, speaker, lang=lang, cps=cps, step_callback=None))
|
whisperspeech/prepare_s2a_dataset.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/4A. S2A dataset preparation.ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = ['flac_to_s2a_name']
|
5 |
+
|
6 |
+
# %% ../nbs/4A. S2A dataset preparation.ipynb 2
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
import itertools
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torchaudio
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch.profiler import profile, record_function, ProfilerActivity
|
17 |
+
|
18 |
+
from fastprogress import progress_bar
|
19 |
+
from fastcore.script import *
|
20 |
+
|
21 |
+
import whisper
|
22 |
+
from . import vad, wh_transcribe, vq_stoks, extract_acoustic
|
23 |
+
import webdataset as wds
|
24 |
+
|
25 |
+
# %% ../nbs/4A. S2A dataset preparation.ipynb 4
|
26 |
+
def flac_to_s2a_name(input):
|
27 |
+
if '-flac-' in input:
|
28 |
+
return input.rsplit("/", 1)[1].replace('flac', 's2a') + ".gz"
|
29 |
+
else:
|
30 |
+
return input.rsplit("/", 1)[1].replace('raw', 's2a') + ".gz"
|
31 |
+
|
32 |
+
# %% ../nbs/4A. S2A dataset preparation.ipynb 6
|
33 |
+
def resampler(newsr = 24000, key = 'samples_24k'):
|
34 |
+
_last_sr = None
|
35 |
+
tform = None
|
36 |
+
|
37 |
+
def _resample(samples):
|
38 |
+
for s in samples:
|
39 |
+
sr = s['sample_rate']
|
40 |
+
if sr != newsr:
|
41 |
+
if sr != _last_sr: tform = torchaudio.transforms.Resample(sr, newsr)
|
42 |
+
s[key] = tform(s['samples'])
|
43 |
+
else:
|
44 |
+
s[key] = s['samples']
|
45 |
+
yield s
|
46 |
+
|
47 |
+
return _resample
|
48 |
+
|
49 |
+
# %% ../nbs/4A. S2A dataset preparation.ipynb 9
|
50 |
+
@call_parse
|
51 |
+
def prepare_s2a(
|
52 |
+
input:str, # FLAC webdataset file path (or - to read the names from stdin)
|
53 |
+
proc_dataset_path:Path, # processed VAD files path
|
54 |
+
output:str=None, # output file name
|
55 |
+
vq_model:str="collabora/spear-tts-pytorch:whisper-vq-stoks.model", # the model path (use repo_id:filename to download it from hugginface)
|
56 |
+
n_samples:int=None, # process a limited amount of samples
|
57 |
+
batch_size:int=1, # process several segments at once
|
58 |
+
fix_dots:bool=False, # fix dots in file names
|
59 |
+
):
|
60 |
+
if ":" in vq_model:
|
61 |
+
repo, fname = vq_model.split(":", 1)
|
62 |
+
vq_model = vq_stoks.RQBottleneckTransformer.load_model(repo, fname).cuda()
|
63 |
+
else:
|
64 |
+
vq_model = vq_stoks.RQBottleneckTransformer.load_model(local_filename=vq_model).cuda()
|
65 |
+
amodel = extract_acoustic.load_model()
|
66 |
+
amodel.set_target_bandwidth(3)
|
67 |
+
|
68 |
+
if input == "-":
|
69 |
+
input = [f.strip() for f in sys.stdin.readlines()]
|
70 |
+
assert output, "please provide the output shard name"
|
71 |
+
else:
|
72 |
+
if output is None: output = flac_to_s2a_name(input)
|
73 |
+
input = [input]
|
74 |
+
|
75 |
+
total = n_samples//batch_size if n_samples else 'noinfer'
|
76 |
+
|
77 |
+
ds = wds.WebDataset(input, shardshuffle=True, rename_files=vad.fix_dots_in_names if fix_dots else None).compose(
|
78 |
+
wds.decode(wds.torch_audio),
|
79 |
+
wds.select(lambda x: 'wav' in x or 'flac' in x),
|
80 |
+
vq_stoks.merge_in(vq_stoks.derived_dataset(proc_dataset_path, 'vad')),
|
81 |
+
wds.map_dict(**{"vad.npy":wh_transcribe.chunk_merger}),
|
82 |
+
lambda x: wh_transcribe.split_to_chunks(x),
|
83 |
+
resampler(),
|
84 |
+
resampler(16000, 'samples_16k'),
|
85 |
+
wds.to_tuple('__key__', 'rpad_s', 'samples_16k', 'samples_24k'),
|
86 |
+
wds.batched(64),
|
87 |
+
)
|
88 |
+
|
89 |
+
dl = wds.WebLoader(ds, num_workers=4, batch_size=None).unbatched().shuffle(2000).batched(batch_size)
|
90 |
+
|
91 |
+
speakers = set()
|
92 |
+
tmp = output+".tmp"
|
93 |
+
with wds.TarWriter(tmp) as sink:
|
94 |
+
for keys, rpad_ss, samples, samples24k in progress_bar(dl, total=total):
|
95 |
+
with record_function('to_cuda'):
|
96 |
+
samples, samples24k = samples.cuda(), samples24k.unsqueeze(1).cuda()
|
97 |
+
with record_function('encodec'):
|
98 |
+
atoks = amodel.encode(samples24k)[0][0]
|
99 |
+
with record_function('vq_stoks'):
|
100 |
+
stoks = vq_model.encode_audio(samples)
|
101 |
+
with record_function('from_cuda'):
|
102 |
+
atoks, stoks = atoks.cpu().numpy().astype(np.int16), stoks.cpu().numpy().astype(np.int16)
|
103 |
+
for key, rpad_s, _atoks, _stoks in zip(keys, rpad_ss, atoks, stoks):
|
104 |
+
speakers.add(key.split('/')[1])
|
105 |
+
sink.write({
|
106 |
+
"__key__": key,
|
107 |
+
"atoks.npy": _atoks[:,:int(-rpad_s * 75)],
|
108 |
+
"stoks.npy": _stoks[:int(-rpad_s * 25)],
|
109 |
+
})
|
110 |
+
with open(output+".speakers.txt", "w") as f: f.write("\n".join(speakers))
|
111 |
+
if not n_samples:
|
112 |
+
os.rename(tmp, output)
|
whisperspeech/prepare_t2s_dataset.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/5A. T2S dataset preparation.ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = []
|
5 |
+
|
6 |
+
# %% ../nbs/5A. T2S dataset preparation.ipynb 2
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
import itertools
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torchaudio
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch.profiler import profile, record_function, ProfilerActivity
|
17 |
+
|
18 |
+
from fastprogress import progress_bar
|
19 |
+
from fastcore.script import *
|
20 |
+
|
21 |
+
import whisper, whisperx
|
22 |
+
from . import vad, wh_transcribe, vq_stoks, extract_acoustic
|
23 |
+
import webdataset as wds
|
24 |
+
|
25 |
+
# %% ../nbs/5A. T2S dataset preparation.ipynb 4
|
26 |
+
def flac_to_t2s_name(input):
|
27 |
+
return input.rsplit("/", 1)[1].replace('flac', 't2s') + ".gz"
|
28 |
+
|
29 |
+
# %% ../nbs/5A. T2S dataset preparation.ipynb 6
|
30 |
+
class Transcriber:
|
31 |
+
"""
|
32 |
+
A helper class to transcribe a batch of 30 second audio chunks.
|
33 |
+
"""
|
34 |
+
def __init__(self, model_size, lang=False):
|
35 |
+
self.model = whisperx.asr.load_model(model_size, "cuda", compute_type="float16", language=lang)
|
36 |
+
# without calling vad_model at least once the rest segfaults for some reason...
|
37 |
+
self.model.vad_model({"waveform": torch.zeros(1, 16000), "sample_rate": 16000})
|
38 |
+
|
39 |
+
def transcribe(self, batch):
|
40 |
+
batch = whisper.log_mel_spectrogram(batch)
|
41 |
+
embs = self.model.model.encode(batch.cpu().numpy())
|
42 |
+
return self.model.tokenizer.tokenizer.decode_batch([x.sequences_ids[0] for x in
|
43 |
+
self.model.model.model.generate(
|
44 |
+
embs,
|
45 |
+
[self.model.model.get_prompt(self.model.tokenizer, [], without_timestamps=True)]*len(batch),
|
46 |
+
)])
|
47 |
+
|
48 |
+
# %% ../nbs/5A. T2S dataset preparation.ipynb 7
|
49 |
+
@call_parse
|
50 |
+
def prepare_t2s(
|
51 |
+
input:str, # FLAC webdataset file path (or - to read the names from stdin)
|
52 |
+
proc_dataset_path:Path, # processed VAD files path
|
53 |
+
output:str=None, # output file name
|
54 |
+
vq_model:str="collabora/spear-tts-pytorch:whisper-vq-stoks.model", # the model path (use repo_id:filename to download it from hugginface)
|
55 |
+
n_samples:int=None, # process a limited amount of samples
|
56 |
+
batch_size:int=1, # process several segments at once
|
57 |
+
transcription_model:str="small.en",
|
58 |
+
):
|
59 |
+
if ":" in vq_model:
|
60 |
+
repo, fname = vq_model.split(":", 1)
|
61 |
+
vq_model = vq_stoks.RQBottleneckTransformer.load_model(repo, fname).cuda()
|
62 |
+
else:
|
63 |
+
vq_model = vq_stoks.RQBottleneckTransformer.load_model(local_filename=vq_model).cuda()
|
64 |
+
transcriber = Transcriber(transcription_model)
|
65 |
+
|
66 |
+
if input == "-":
|
67 |
+
input = [f.strip() for f in sys.stdin.readlines()]
|
68 |
+
assert output, "please provide the output shard name"
|
69 |
+
else:
|
70 |
+
if output is None: output = flac_to_t2s_name(input)
|
71 |
+
input = [input]
|
72 |
+
|
73 |
+
total = n_samples//batch_size if n_samples else 'noinfer'
|
74 |
+
if n_samples: print(f"Benchmarking run of {n_samples} samples ({total} batches)")
|
75 |
+
|
76 |
+
ds = wds.WebDataset(input, shardshuffle=True, rename_files=vad.fix_dots_in_names).compose(
|
77 |
+
wds.decode(wds.torch_audio),
|
78 |
+
vq_stoks.merge_in(vq_stoks.derived_dataset(proc_dataset_path, 'vad')),
|
79 |
+
wds.map_dict(**{"vad.npy": lambda s: wh_transcribe.chunk_merger(s, wh_transcribe.random_cutter)}),
|
80 |
+
lambda x: wh_transcribe.split_to_chunks(x),
|
81 |
+
# drop the first and last segment because they tend to be inaccurate
|
82 |
+
# (the transcriptions don't have the "LibriVox" header and "end of chapter" suffix)
|
83 |
+
wds.select(lambda x: x['i'] != 0 and x['i'] != x['imax']),
|
84 |
+
wds.to_tuple('__key__', 'rpad', 'samples'),
|
85 |
+
wds.batched(64),
|
86 |
+
)
|
87 |
+
|
88 |
+
dl = wds.WebLoader(ds, num_workers=4, batch_size=None).unbatched().shuffle(2000).batched(batch_size)
|
89 |
+
|
90 |
+
speakers = set()
|
91 |
+
tmp = output+".tmp"
|
92 |
+
with wds.TarWriter(tmp) as sink:
|
93 |
+
for keys, rpads, samples in progress_bar(dl, total=total):
|
94 |
+
with record_function('to_cuda'):
|
95 |
+
csamples = samples.cuda()
|
96 |
+
with record_function('transcribe'):
|
97 |
+
txts = transcriber.transcribe(csamples)
|
98 |
+
with record_function('vq_stoks'):
|
99 |
+
stoks = vq_model.encode_audio(csamples)
|
100 |
+
with record_function('from_cuda'):
|
101 |
+
stoks = stoks.cpu().numpy().astype(np.int16)
|
102 |
+
for key, rpad, txt, _stoks in zip(keys, rpads, txts, stoks):
|
103 |
+
speakers.add(key.split('/')[1])
|
104 |
+
sink.write({
|
105 |
+
"__key__": key,
|
106 |
+
"txt": txt,
|
107 |
+
"stoks.npy": _stoks[:int(-rpad/16000 * 25)],
|
108 |
+
})
|
109 |
+
with open(output+".speakers.txt", "w") as f: f.write("\n".join(speakers))
|
110 |
+
if not n_samples:
|
111 |
+
os.rename(tmp, output)
|
whisperspeech/s2a_delar_mup_wds.py
ADDED
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/4B. Semantic to acoustic token modeling.ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = ['load_datasets', 'CMLMVisual', 'Rotary', 'rotate_half', 'apply_rotary_pos_emb', 'ResidualAttentionBlock',
|
5 |
+
'MultiHeadAttention', 'DelSumDecoder', 'EmbeddingProjector', 'rand', 'Tunables', 'SADelARTransformer']
|
6 |
+
|
7 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 1
|
8 |
+
import io
|
9 |
+
import time
|
10 |
+
import math
|
11 |
+
import random
|
12 |
+
import dataclasses
|
13 |
+
|
14 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 2
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch.profiler import profile, record_function, ProfilerActivity, schedule
|
19 |
+
from fastcore.basics import store_attr
|
20 |
+
from huggingface_hub import hf_hub_download
|
21 |
+
|
22 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 3
|
23 |
+
from pathlib import Path
|
24 |
+
import json
|
25 |
+
from fastprogress import progress_bar, master_bar
|
26 |
+
import webdataset as wds
|
27 |
+
|
28 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 4
|
29 |
+
from .train import *
|
30 |
+
from .modules import *
|
31 |
+
from . import vq_stoks
|
32 |
+
|
33 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 8
|
34 |
+
def rand(start, end):
|
35 |
+
return random.random() * (end - start) + start
|
36 |
+
|
37 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 9
|
38 |
+
def random_trunc(random_trunc_p, atoks_len = 2250, stoks_len = 750):
|
39 |
+
atoks_per_second = atoks_len / 30
|
40 |
+
def _trunc(samples):
|
41 |
+
for s in samples:
|
42 |
+
if random.random() < random_trunc_p:
|
43 |
+
seconds = rand(0.3, 30)
|
44 |
+
s['atoks.npy'] = s['atoks.npy'][:,:math.ceil(seconds * atoks_per_second)]
|
45 |
+
s['stoks.npy'] = s['stoks.npy'][:math.ceil(s['atoks.npy'].shape[-1]/atoks_len*stoks_len)]
|
46 |
+
yield s
|
47 |
+
return _trunc
|
48 |
+
|
49 |
+
def pad_samples(atoks_len = 2250, stoks_len = 750, stoks_pad_token = 4096):
|
50 |
+
def _pad(samples):
|
51 |
+
for s in samples:
|
52 |
+
s['stoks.npy'] = F.pad(torch.tensor(s['stoks.npy']), (0, stoks_len - s['stoks.npy'].shape[-1]), value=stoks_pad_token)
|
53 |
+
s['atoks.npy'] = F.pad(torch.tensor(s['atoks.npy']), (0, atoks_len - s['atoks.npy'].shape[-1]), value=-100)
|
54 |
+
yield s
|
55 |
+
return _pad
|
56 |
+
|
57 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 10
|
58 |
+
def speaker_id_extractor(speaker_map):
|
59 |
+
def _extractor(samples):
|
60 |
+
for s in samples:
|
61 |
+
s['speaker'] = torch.tensor(speaker_map[s['__key__'].split("/")[1]])
|
62 |
+
yield s
|
63 |
+
return _extractor
|
64 |
+
|
65 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 14
|
66 |
+
def load_datasets(
|
67 |
+
input:str, # webdataset folder
|
68 |
+
samples:int, # samples per epoch
|
69 |
+
subsample:float=1, # use a fraction of the files
|
70 |
+
val_samples:int=512,
|
71 |
+
random_trunc_p:float=0,# probability of truncating the input to less than 30 seconds
|
72 |
+
stoks_pad_token=4096,
|
73 |
+
):
|
74 |
+
|
75 |
+
if isinstance(input, (Path, str)):
|
76 |
+
path = Path(input)
|
77 |
+
if path.is_dir():
|
78 |
+
glob = '*-s2a-*.tar.gz'
|
79 |
+
else:
|
80 |
+
glob = path.name
|
81 |
+
path = path.parent
|
82 |
+
input = Path(path).glob(glob)
|
83 |
+
elif isinstance(input, list):
|
84 |
+
pass
|
85 |
+
else:
|
86 |
+
raise ArgumentError("input should be either a list or a path with an optional glob specifier")
|
87 |
+
shards = [str(x) for x in input]
|
88 |
+
|
89 |
+
speakers = set()
|
90 |
+
for shard in shards:
|
91 |
+
with open(shard+'.speakers.txt') as f: speakers = speakers.union(set(x.strip() for x in f.readlines()))
|
92 |
+
speakers = {id:i for i,id in enumerate(sorted(speakers))}
|
93 |
+
|
94 |
+
def ds(shards, length):
|
95 |
+
ds = wds.WebDataset(wds.ResampledShards(shards)).compose(
|
96 |
+
wds.decode(),
|
97 |
+
speaker_id_extractor(speakers),
|
98 |
+
random_trunc(random_trunc_p) if random_trunc_p > 0 else lambda x: x,
|
99 |
+
pad_samples(stoks_pad_token=stoks_pad_token),
|
100 |
+
wds.to_tuple('stoks.npy', 'atoks.npy', 'speaker'),
|
101 |
+
wds.batched(64),
|
102 |
+
)
|
103 |
+
ds.speakers = speakers
|
104 |
+
ds.total_samples = length
|
105 |
+
return ds.compose(wds.slice(length // 64)).with_epoch(length // 64).with_length(length // 64)
|
106 |
+
|
107 |
+
return (
|
108 |
+
ds(shards[1:], samples),
|
109 |
+
ds(shards[:1], val_samples),
|
110 |
+
)
|
111 |
+
|
112 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 33
|
113 |
+
import pylab as plt
|
114 |
+
import fastprogress
|
115 |
+
import IPython
|
116 |
+
import numpy as np
|
117 |
+
|
118 |
+
class CMLMVisual:
|
119 |
+
"""Visualize training progress"""
|
120 |
+
def __init__ (self, model, masterbar, total_steps):
|
121 |
+
self.model = model
|
122 |
+
self.masterbar = masterbar
|
123 |
+
self.total_steps = total_steps
|
124 |
+
self.epochs = total_steps // masterbar.main_bar.total
|
125 |
+
|
126 |
+
gs = plt.GridSpec(3, 1, height_ratios=[2,2,1])
|
127 |
+
graph_fig = plt.figure(figsize=(10,6))
|
128 |
+
self.graph_fig = graph_fig
|
129 |
+
self.loss_p = graph_fig.add_subplot(gs[0])
|
130 |
+
self.acc_p = graph_fig.add_subplot(gs[1], sharex=self.loss_p)
|
131 |
+
self.acc_p.tick_params('x', labelbottom=False)
|
132 |
+
self.lr_p = graph_fig.add_subplot(gs[2], sharex=self.loss_p)
|
133 |
+
self.lr_p.tick_params('x', labelbottom=False)
|
134 |
+
self.graph_out = None
|
135 |
+
|
136 |
+
self.its = []
|
137 |
+
self.train_losses = []
|
138 |
+
self.val_losses = []
|
139 |
+
self.lr_history = []
|
140 |
+
self.acc = np.nan
|
141 |
+
self.acc_history = []
|
142 |
+
self.pacc_history = []
|
143 |
+
|
144 |
+
def show(self):
|
145 |
+
self.start_t = time.time()
|
146 |
+
self.masterbar.write(["samples", "train", "val", "time"], table=True)
|
147 |
+
self.graph_out = display(self.graph_fig, display_id=True)
|
148 |
+
self.acc_out = display(IPython.display.HTML(''), display_id=True)
|
149 |
+
|
150 |
+
def hide(self):
|
151 |
+
if self.graph_out is not None:
|
152 |
+
self.graph_out.update(IPython.display.HTML(''))
|
153 |
+
|
154 |
+
def plot(self):
|
155 |
+
loss_p, acc_p, lr_p = self.loss_p, self.acc_p, self.lr_p
|
156 |
+
loss_p.clear()
|
157 |
+
loss_p.plot(self.its, self.train_losses)
|
158 |
+
loss_p.plot(self.its, self.val_losses)
|
159 |
+
loss_p.set_xlim(0, self.total_steps)
|
160 |
+
loss_p.set_yscale('log')
|
161 |
+
acc_p.clear()
|
162 |
+
for k in self.acc_history[-1].keys():
|
163 |
+
acc_p.plot(self.its, [x[k] for x in self.acc_history], ':')
|
164 |
+
# acc_p.plot(self.its, np.stack(self.pacc_history), label=range(len(self.pacc_history[0])))
|
165 |
+
lr_p.clear()
|
166 |
+
lrs = np.array(self.lr_history)
|
167 |
+
lr_p.plot(self.its, lrs)
|
168 |
+
self.graph_out.update(self.graph_fig)
|
169 |
+
|
170 |
+
def add_data(self, it, lr, train_loss, val_los):
|
171 |
+
self.its.append(it)
|
172 |
+
self.train_losses.append(train_loss)
|
173 |
+
self.val_losses.append(val_los)
|
174 |
+
self.lr_history.append(lr)
|
175 |
+
metrics = self.model.get_metrics()
|
176 |
+
self.acc_history.append(metrics)
|
177 |
+
# self.acc_out.update(f"Accuracy: {self.entropy_history[-1]:.2f}")
|
178 |
+
# self.pacc_history.append((self.model.pval_true / self.model.pval_total).cpu().numpy())
|
179 |
+
# if self.acc_history:
|
180 |
+
html = "<h5>Accuracies:</h5><table>"
|
181 |
+
html += "<thead>"+(''.join([f"<td>{k}<td>" for k,x in metrics.items()]))+"</thead>"
|
182 |
+
html += "<tr>"+(''.join([f"<td>{x*100:.1f}%<td>" for k,x in metrics.items()]))+"</tr>"
|
183 |
+
html += "</table>"
|
184 |
+
self.acc_out.update(IPython.display.HTML(html))
|
185 |
+
self.plot()
|
186 |
+
|
187 |
+
def add_table_row(self, it, avg_train_loss, val_loss):
|
188 |
+
elapsed_t = time.time() - self.start_t
|
189 |
+
self.masterbar.write([it, f"{avg_train_loss:.5f}", f"{val_loss:.5f}", fastprogress.core.format_time(elapsed_t)], table=True)
|
190 |
+
|
191 |
+
def on_iter(self, bar, it, avg_train_loss, val_loss):
|
192 |
+
epoch = math.ceil(it / self.total_steps * self.epochs)
|
193 |
+
bar.comment = f"#{epoch}/{self.epochs} loss: {avg_train_loss:.3f} / {val_loss:.3f}"
|
194 |
+
|
195 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 34
|
196 |
+
# modified from https://blog.eleuther.ai/rotary-embeddings/
|
197 |
+
import torch
|
198 |
+
|
199 |
+
class Rotary(torch.nn.Module):
|
200 |
+
def __init__(self, dim, base=10000):
|
201 |
+
super().__init__()
|
202 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
203 |
+
self.register_buffer("inv_freq", inv_freq)
|
204 |
+
self.seq_len_cached = None
|
205 |
+
self.cos_cached = None
|
206 |
+
self.sin_cached = None
|
207 |
+
|
208 |
+
def forward(self, x, seq_dim=1):
|
209 |
+
seq_len = x.shape[seq_dim]
|
210 |
+
if seq_len != self.seq_len_cached:
|
211 |
+
self.seq_len_cached = seq_len
|
212 |
+
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
|
213 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
214 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
215 |
+
self.cos_cached = emb.cos()[None, :, None, :]
|
216 |
+
self.sin_cached = emb.sin()[None, :, None, :]
|
217 |
+
return self.cos_cached, self.sin_cached
|
218 |
+
|
219 |
+
|
220 |
+
# rotary pos emb helpers:
|
221 |
+
def rotate_half(x):
|
222 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
223 |
+
return torch.cat(
|
224 |
+
(-x2, x1), dim=-1
|
225 |
+
)
|
226 |
+
|
227 |
+
#@torch.jit.script
|
228 |
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
229 |
+
return (q * cos[:,:q.shape[1]]) + (rotate_half(q) * sin[:,:q.shape[1]]), (k * cos) + (rotate_half(k) * sin)
|
230 |
+
|
231 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 35
|
232 |
+
from torch import Tensor, nn
|
233 |
+
import torch.nn.functional as F
|
234 |
+
from typing import Dict, Iterable, Optional
|
235 |
+
|
236 |
+
class ResidualAttentionBlock(nn.Module):
|
237 |
+
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, rope: bool = False,
|
238 |
+
qk_scale: float = 1, ffn_mult: int = 4):
|
239 |
+
super().__init__()
|
240 |
+
|
241 |
+
self.attn = MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope)
|
242 |
+
self.attn_ln = LayerNorm(n_state)
|
243 |
+
|
244 |
+
self.cross_attn = (
|
245 |
+
MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope) if cross_attention else None
|
246 |
+
)
|
247 |
+
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
248 |
+
|
249 |
+
n_mlp = n_state * ffn_mult
|
250 |
+
self.mlp = nn.Sequential(
|
251 |
+
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
|
252 |
+
)
|
253 |
+
self.mlp_ln = LayerNorm(n_state)
|
254 |
+
|
255 |
+
def forward(
|
256 |
+
self,
|
257 |
+
x: Tensor,
|
258 |
+
xa: Optional[Tensor] = None,
|
259 |
+
causal = False,
|
260 |
+
kv_cache: Optional[dict] = None,
|
261 |
+
):
|
262 |
+
x = x + self.attn(self.attn_ln(x), causal=causal, kv_cache=kv_cache)[0]
|
263 |
+
if self.cross_attn:
|
264 |
+
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
265 |
+
x = x + self.mlp(self.mlp_ln(x))
|
266 |
+
return x
|
267 |
+
|
268 |
+
class MultiHeadAttention(nn.Module):
|
269 |
+
def __init__(self, n_state: int, n_head: int, qk_scale: float = 1, rope: bool = False):
|
270 |
+
super().__init__()
|
271 |
+
self.n_head = n_head
|
272 |
+
self.sqrt_qk_scale = math.sqrt(qk_scale)
|
273 |
+
self.query = QueryHead(n_state, n_state)
|
274 |
+
self.key = nn.Linear(n_state, n_state, bias=False)
|
275 |
+
self.value = nn.Linear(n_state, n_state)
|
276 |
+
self.out = nn.Linear(n_state, n_state)
|
277 |
+
|
278 |
+
self.rotary = None
|
279 |
+
if rope:
|
280 |
+
self.rotary = Rotary(n_state // n_head)
|
281 |
+
|
282 |
+
def forward(
|
283 |
+
self,
|
284 |
+
x: Tensor,
|
285 |
+
xa: Optional[Tensor] = None,
|
286 |
+
causal = False,
|
287 |
+
kv_cache: Optional[dict] = None,
|
288 |
+
):
|
289 |
+
q = self.query(x)
|
290 |
+
|
291 |
+
if kv_cache is None or xa is None or self.key not in kv_cache:
|
292 |
+
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
293 |
+
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
294 |
+
k = self.key(x if xa is None else xa)
|
295 |
+
v = self.value(x if xa is None else xa)
|
296 |
+
else:
|
297 |
+
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
298 |
+
k = kv_cache[self.key]
|
299 |
+
v = kv_cache[self.value]
|
300 |
+
|
301 |
+
if self.sqrt_qk_scale != 1:
|
302 |
+
q *= self.sqrt_qk_scale
|
303 |
+
k *= self.sqrt_qk_scale
|
304 |
+
|
305 |
+
wv, qk = self.qkv_attention_pth20(q, k, v, causal)
|
306 |
+
# wv, qk = self.qkv_attention_xformers(q, k, v, causal)
|
307 |
+
|
308 |
+
return self.out(wv), qk
|
309 |
+
|
310 |
+
def qkv_attention_pth20(
|
311 |
+
self, q: Tensor, k: Tensor, v: Tensor, causal = False
|
312 |
+
):
|
313 |
+
n_batch, n_ctx, n_state = q.shape
|
314 |
+
q = q.view(*q.shape[:2], self.n_head, -1)
|
315 |
+
k = k.view(*k.shape[:2], self.n_head, -1)
|
316 |
+
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
317 |
+
|
318 |
+
#print('before rot:', q.shape, k.shape)
|
319 |
+
if self.rotary:
|
320 |
+
q, k = apply_rotary_pos_emb(q, k, *self.rotary(k))
|
321 |
+
#print(' after rot:', q.shape, k.shape)
|
322 |
+
|
323 |
+
k = k.permute(0, 2, 1, 3)
|
324 |
+
q = q.permute(0, 2, 1, 3)
|
325 |
+
# modified for better performance under PyTorch 2.0
|
326 |
+
wv = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=causal)
|
327 |
+
|
328 |
+
# previously we've returned q@k which we don't have now
|
329 |
+
# since it's not actually used anywhere else, let's just keep two return values for compatibility
|
330 |
+
return wv.permute(0, 2, 1, 3).flatten(start_dim=2), None
|
331 |
+
|
332 |
+
def qkv_attention_xformers(
|
333 |
+
self, q: Tensor, k: Tensor, v: Tensor, causal = False
|
334 |
+
):
|
335 |
+
n_batch, n_ctx, n_state = q.shape
|
336 |
+
q = q.view(*q.shape[:2], self.n_head, -1)
|
337 |
+
k = k.view(*k.shape[:2], self.n_head, -1)
|
338 |
+
v = v.view(*v.shape[:2], self.n_head, -1)
|
339 |
+
|
340 |
+
if self.rotary:
|
341 |
+
q, k = apply_rotary_pos_emb(q, k, *self.rotary(k))
|
342 |
+
|
343 |
+
bias = xops.LowerTriangularMask() if causal else None
|
344 |
+
wv = xops.memory_efficient_attention(q,k,v, attn_bias=bias)
|
345 |
+
|
346 |
+
# previously we've returned q@k which we don't have now
|
347 |
+
# since it's not actually used anywhere else, let's just keep two return values for compatibility
|
348 |
+
return wv.flatten(start_dim=2), None
|
349 |
+
|
350 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 36
|
351 |
+
class DelSumDecoder(nn.Module):
|
352 |
+
def __init__(self, depth=6, n_head=6, head_width=64, qk_scale=1, ffn_mult=4, length=2250, codes=1024, quantizers=8, linear_heads=True, rope=False, pos_embs=None):
|
353 |
+
super().__init__()
|
354 |
+
self.length = length
|
355 |
+
width = n_head * head_width
|
356 |
+
self.width = width
|
357 |
+
self.codes = codes
|
358 |
+
self.quantizers = quantizers
|
359 |
+
self.linear_heads = linear_heads
|
360 |
+
|
361 |
+
self.embeddings = nn.ModuleList([nn.Embedding(codes+1, width) for _ in range(quantizers)])
|
362 |
+
if pos_embs is not None:
|
363 |
+
self.register_buffer("positional_embedding", pos_embs)
|
364 |
+
|
365 |
+
self.layers = nn.ModuleList([
|
366 |
+
ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, cross_attention=True, rope=rope) for _ in range(math.floor(depth))
|
367 |
+
])
|
368 |
+
|
369 |
+
self.ln_post = LayerNorm(width)
|
370 |
+
|
371 |
+
if self.linear_heads:
|
372 |
+
self.heads = LinearHead(width, (codes+1) * quantizers, bias=False)
|
373 |
+
else:
|
374 |
+
self.splitter = nn.Sequential(
|
375 |
+
nn.Linear(width, width * quantizers),
|
376 |
+
nn.GELU(),
|
377 |
+
)
|
378 |
+
self.heads = nn.ModuleList([
|
379 |
+
LinearHead(width, codes+1, bias=True) for _ in range(quantizers)
|
380 |
+
])
|
381 |
+
|
382 |
+
def forward(self, toks, xenc):
|
383 |
+
b,_,n = toks.shape
|
384 |
+
newn = min(n+1, self.length)
|
385 |
+
embs = torch.zeros((b,newn,self.width), dtype=xenc.dtype, device=xenc.device)
|
386 |
+
for i in range(self.quantizers):
|
387 |
+
embs[:,:i+1] += self.embeddings[i](torch.tensor([self.codes], device=xenc.device))
|
388 |
+
if i < n:
|
389 |
+
embs[:,i+1:] += self.embeddings[i](toks[:,i,:newn-i-1])
|
390 |
+
|
391 |
+
x = embs.to(xenc.dtype)
|
392 |
+
|
393 |
+
for l in self.layers:
|
394 |
+
x = l(x, xenc, causal=True)
|
395 |
+
x = self.ln_post(x)
|
396 |
+
|
397 |
+
if self.linear_heads:
|
398 |
+
logits = self.heads(x).view(b,newn,self.quantizers,self.codes+1).permute(0,2,1,3)
|
399 |
+
else:
|
400 |
+
split = self.splitter(x).view(b,newn,self.quantizers,self.width)
|
401 |
+
logits = torch.stack([self.heads[q](split[:,:,q]) for q in range(self.quantizers)], dim=1)
|
402 |
+
|
403 |
+
return logits
|
404 |
+
|
405 |
+
class EmbeddingProjector(nn.Linear):
|
406 |
+
pass
|
407 |
+
|
408 |
+
def rand(start, end):
|
409 |
+
return random.random() * (end - start) + start
|
410 |
+
|
411 |
+
@dataclasses.dataclass
|
412 |
+
class Tunables:
|
413 |
+
init_std :float = 9
|
414 |
+
embeddings_std :float = 0.2
|
415 |
+
embeddings_lr_scale: float = 10
|
416 |
+
output_mult :float = 5.6
|
417 |
+
# FIXME: try separate mults for self and cross attention
|
418 |
+
query_mult :float = .3
|
419 |
+
encoder_depth_ratio :float = 0.25
|
420 |
+
linear_heads :bool = False
|
421 |
+
rope :bool = True
|
422 |
+
|
423 |
+
lr0 :float = 3e-3
|
424 |
+
clip_gradient_norm :float = 2
|
425 |
+
weight_decay :float = 1e-3
|
426 |
+
warmup_steps :float = 2000
|
427 |
+
|
428 |
+
random :bool = False
|
429 |
+
|
430 |
+
def __post_init__(self):
|
431 |
+
# randomize the hyperparams if requested
|
432 |
+
if self.random:
|
433 |
+
self.init_std = 2*10**rand(0,1)
|
434 |
+
self.embeddings_std = 10**rand(-1.7,-0.22)
|
435 |
+
self.embeddings_lr_scale = 2**rand(2,4)
|
436 |
+
self.output_mult = 2**rand(1.5,3)
|
437 |
+
self.query_mult = 2**rand(-3,-1.3)
|
438 |
+
self.encoder_depth_ratio = random.choice([0.25,0.5])
|
439 |
+
self.linear_heads = False
|
440 |
+
self.rope = True
|
441 |
+
|
442 |
+
self.lr0 = 3e-3
|
443 |
+
self.clip_gradient_norm = 10**rand(-1,1)
|
444 |
+
self.warmup_steps = 100*(10**rand(1.18,1.3))
|
445 |
+
|
446 |
+
@staticmethod
|
447 |
+
def upgrade(args):
|
448 |
+
args = {k:v for k,v in args.items()}
|
449 |
+
def old_default(name, value):
|
450 |
+
if name not in args: args[name] = value
|
451 |
+
old_default('rope', False)
|
452 |
+
old_default('linear_heads', True)
|
453 |
+
return args
|
454 |
+
|
455 |
+
class SADelARTransformer(nn.Module):
|
456 |
+
def __init__(self, depth=3, ctx_n=2250, stoks_len=750, stoks_codes=4097, stoks_width=None, spk_width=None, n_head=3, head_width=64, ffn_mult=4,
|
457 |
+
quantizers=8, speaker_map={"1":0}, tunables=Tunables()):
|
458 |
+
super().__init__()
|
459 |
+
self.quantizers = quantizers
|
460 |
+
width = n_head * head_width
|
461 |
+
store_attr("depth,ctx_n,stoks_len,stoks_codes,stoks_width,spk_width,n_head,head_width,ffn_mult,quantizers,speaker_map")
|
462 |
+
self.width = width
|
463 |
+
self.base_width = 3 * head_width
|
464 |
+
self.tunables = tunables
|
465 |
+
|
466 |
+
if stoks_width is None: stoks_width = width
|
467 |
+
if spk_width is None: spk_width = width
|
468 |
+
self.emb_factor = width != stoks_width
|
469 |
+
self.spk_factor = width != spk_width
|
470 |
+
|
471 |
+
if tunables.rope:
|
472 |
+
self.positional_embeddings = None
|
473 |
+
else:
|
474 |
+
self.register_buffer('positional_embeddings', sinusoids(ctx_n, width))
|
475 |
+
|
476 |
+
self.speaker_embedding = nn.Embedding(len(speaker_map), width)
|
477 |
+
self.semantic_embedding = nn.Embedding(stoks_codes, stoks_width)
|
478 |
+
if self.emb_factor:
|
479 |
+
self.emb_to_hidden = nn.Linear(stoks_width, width)
|
480 |
+
|
481 |
+
if self.spk_factor:
|
482 |
+
self.spk_to_hidden = EmbeddingProjector(spk_width, width)
|
483 |
+
|
484 |
+
qk_scale = self.tunables.query_mult * 8 / math.sqrt(head_width)
|
485 |
+
|
486 |
+
encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio)
|
487 |
+
decoder_depth = depth * 2 - encoder_depth
|
488 |
+
self.encoder = nn.Sequential(*[
|
489 |
+
ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, rope=tunables.rope) for _ in range(encoder_depth)
|
490 |
+
])
|
491 |
+
self.ln_post = LayerNorm(width)
|
492 |
+
|
493 |
+
self.decoder = DelSumDecoder(pos_embs=self.positional_embeddings, qk_scale=qk_scale,
|
494 |
+
length=ctx_n, n_head=n_head, head_width=head_width, ffn_mult=ffn_mult,
|
495 |
+
depth=decoder_depth, quantizers=quantizers,
|
496 |
+
linear_heads=tunables.linear_heads, rope=tunables.rope)
|
497 |
+
|
498 |
+
self.register_buffer('val_true', torch.zeros(self.quantizers).cuda())
|
499 |
+
self.register_buffer('val_total', torch.zeros(self.quantizers).cuda())
|
500 |
+
self.apply(self.init_transformer)
|
501 |
+
|
502 |
+
def setup(self, device):
|
503 |
+
pass
|
504 |
+
|
505 |
+
def load_frozen_semantic_embeddings(self, vqmodel):
|
506 |
+
with torch.no_grad():
|
507 |
+
self.semantic_embedding.weight[:] = vqmodel.rq.layers[0]._codebook.embed[0]
|
508 |
+
self.semantic_embedding.lr_scale = 0
|
509 |
+
|
510 |
+
def init_transformer(self, m):
|
511 |
+
if isinstance(m, LinearHead):
|
512 |
+
m.no_weight_decay = True
|
513 |
+
torch.nn.init.constant_(m.weight, 0)
|
514 |
+
elif isinstance(m, QueryHead):
|
515 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
516 |
+
torch.nn.init.constant_(m.weight, 0)
|
517 |
+
elif isinstance(m, nn.Embedding):
|
518 |
+
m.no_weight_decay = True
|
519 |
+
m.lr_scale = self.tunables.embeddings_lr_scale
|
520 |
+
std = self.tunables.embeddings_std
|
521 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
522 |
+
elif isinstance(m, EmbeddingProjector):
|
523 |
+
m.lr_scale = self.tunables.embeddings_lr_scale/2
|
524 |
+
std = self.tunables.init_std
|
525 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
526 |
+
elif isinstance(m, nn.Linear):
|
527 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
528 |
+
std = self.tunables.init_std / m.weight.shape[1]
|
529 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
530 |
+
if m.bias is not None:
|
531 |
+
torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
|
532 |
+
elif isinstance(m, nn.LayerNorm):
|
533 |
+
m.no_weight_decay = True
|
534 |
+
torch.nn.init.constant_(m.bias, 0)
|
535 |
+
torch.nn.init.constant_(m.weight, 1)
|
536 |
+
|
537 |
+
def embed_stoks(self, Stoks):
|
538 |
+
b,n = Stoks.shape
|
539 |
+
if self.stoks_len == 1500:
|
540 |
+
# converts 50 toks/s to 75 toks/s by adding padding between every two tokens
|
541 |
+
x = Stoks.reshape(b,n//2,2)
|
542 |
+
x = x.repeat_interleave(2, -1)[:,:,:3]
|
543 |
+
x[:,:,1] = 1024
|
544 |
+
x = x.reshape(b,n//2*3)
|
545 |
+
else:
|
546 |
+
# it's a lot easier with 25 toks/s
|
547 |
+
x = Stoks.repeat_interleave(3, -1)
|
548 |
+
# embed semantic tokens
|
549 |
+
Sembs = self.semantic_embedding(x.to(torch.long))
|
550 |
+
if self.emb_factor:
|
551 |
+
Sembs = self.emb_to_hidden(Sembs)
|
552 |
+
return Sembs
|
553 |
+
|
554 |
+
def forward(self, Stoks, Atoks, speakers, noloss=False):
|
555 |
+
Atoks = Atoks.to(torch.long)
|
556 |
+
semb = self.embed_stoks(Stoks)
|
557 |
+
with record_function("encoder"):
|
558 |
+
if self.positional_embeddings is not None: semb = semb + self.positional_embeddings
|
559 |
+
xenc = self.ln_post(self.encoder(semb))
|
560 |
+
# xenc = torch.zeros_like(xenc)
|
561 |
+
with record_function("decoder"):
|
562 |
+
Atoks_gt = Atoks.clone()
|
563 |
+
Atoks_gt[Atoks == -100] = 1024
|
564 |
+
# we can randomize speaker ids during validation to measure
|
565 |
+
# the importance of the speaker embedding vs. just the acoustic prompt/prefix
|
566 |
+
# if not self.training: speakers = speakers[torch.randperm(speakers.nelement())]
|
567 |
+
spk_embs = self.speaker_embedding(speakers)
|
568 |
+
if self.spk_factor: spk_embs = self.spk_to_hidden(spk_embs)
|
569 |
+
logits = self.decoder(Atoks_gt, xenc + spk_embs.unsqueeze(1))
|
570 |
+
logits *= self.tunables.output_mult / (self.width / self.base_width)
|
571 |
+
|
572 |
+
if noloss:
|
573 |
+
return logits
|
574 |
+
|
575 |
+
with record_function("loss"):
|
576 |
+
N = Atoks.shape[-1]
|
577 |
+
loss = 0
|
578 |
+
for i in range(self.quantizers):
|
579 |
+
loss += F.cross_entropy(logits[:,i,i:].reshape(-1,logits.shape[-1]), Atoks[:,i,:N-i].reshape(-1))
|
580 |
+
loss /= self.quantizers
|
581 |
+
|
582 |
+
if not self.training:
|
583 |
+
for i in range(self.quantizers):
|
584 |
+
Atoks_i = Atoks[:,i,:N-i]
|
585 |
+
valid_Atoks = Atoks_i != -100
|
586 |
+
self.val_true[i] += (logits[:,i,i:].argmax(-1)[valid_Atoks] == Atoks_i[valid_Atoks]).float().sum()
|
587 |
+
self.val_total[i] += valid_Atoks.float().sum()
|
588 |
+
|
589 |
+
return logits, loss
|
590 |
+
|
591 |
+
def get_metrics(self):
|
592 |
+
metrics = {
|
593 |
+
f'acc_{i}':x.item() for i,x in enumerate(self.val_true / self.val_total)
|
594 |
+
}
|
595 |
+
self.val_true[:] = 0
|
596 |
+
self.val_total[:] = 0
|
597 |
+
return metrics
|
598 |
+
|
599 |
+
#
|
600 |
+
# inference
|
601 |
+
#
|
602 |
+
@classmethod
|
603 |
+
def load_model(cls, repo_id="collabora/whisperspeech", filename="s2a_up_wds.model", local_filename=None):
|
604 |
+
if not local_filename:
|
605 |
+
local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
|
606 |
+
spec = torch.load(local_filename)
|
607 |
+
if '_extra_state' not in spec['state_dict']: spec['state_dict']['_extra_state'] = { 'speaker_map': spec['config']['speaker_map'] }
|
608 |
+
model = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec['tunables'])))
|
609 |
+
model.load_state_dict(spec['state_dict'])
|
610 |
+
model.eval()
|
611 |
+
return model
|
612 |
+
|
613 |
+
def get_extra_state(self):
|
614 |
+
return { 'speaker_map': self.speaker_map }
|
615 |
+
|
616 |
+
def set_extra_state(self, st):
|
617 |
+
self.speaker_map = st['speaker_map']
|
618 |
+
|
619 |
+
def load_checkpoint(self, local_filename):
|
620 |
+
spec = torch.load(local_filename, map_location='cpu')
|
621 |
+
assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
|
622 |
+
state_dict = {k.replace('model.', ''):v
|
623 |
+
for k,v in spec['state_dict'].items()}
|
624 |
+
self.load_state_dict(state_dict)
|
625 |
+
return self
|
626 |
+
|
627 |
+
def save_model(self, fname):
|
628 |
+
torch.save(dict(config = self.__stored_args__,
|
629 |
+
tunables = dataclasses.asdict(self.tunables),
|
630 |
+
state_dict = self.state_dict()), fname)
|
631 |
+
|
632 |
+
@property
|
633 |
+
def device(self):
|
634 |
+
return next(self.parameters()).device
|
635 |
+
|
636 |
+
@torch.no_grad()
|
637 |
+
def generate(self, stoks, speakers, N=None, T=0.7, top_k=None, show_progress_bar=True):
|
638 |
+
dev = self.device
|
639 |
+
if self.stoks_len == 1500:
|
640 |
+
N = N or len(stoks) * 3 // 2
|
641 |
+
else:
|
642 |
+
N = N or len(stoks) * 3
|
643 |
+
stoks = F.pad(stoks.to(dev), (0, self.stoks_len - len(stoks)), value=self.stoks_codes-1).unsqueeze(0)
|
644 |
+
speakers = torch.tensor([self.speaker_map[spk] for spk in speakers], device=dev)
|
645 |
+
toks = torch.zeros((1,self.quantizers,N), dtype=torch.long, device=dev)
|
646 |
+
it = range(0,N)
|
647 |
+
if show_progress_bar: it = progress_bar(it)
|
648 |
+
for i in it:
|
649 |
+
p = self(stoks, toks[:,:,:i], speakers, noloss=True)
|
650 |
+
last_p = p[0,:,-1]
|
651 |
+
if top_k:
|
652 |
+
last_p[last_p < torch.topk(last_p, top_k).values[:,-1,None]] = -torch.inf
|
653 |
+
for j,tok in enumerate(torch.multinomial((last_p / float(T)).softmax(-1), 1)):
|
654 |
+
toks[0,j,max(0,i-j)] = tok
|
655 |
+
if toks[0,0,i] == 1024: return toks[0,:,:i]
|
656 |
+
return toks[0]
|
657 |
+
|
658 |
+
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 37
|
659 |
+
def _make_model(size:str, quantizers:int=4, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None, **kwargs):
|
660 |
+
assert(dataset is not None)
|
661 |
+
kwargs = dict(speaker_map=dataset.speakers, quantizers=quantizers, tunables=tunables, **kwargs)
|
662 |
+
if size == 'micro':
|
663 |
+
return SADelARTransformer(depth=4, n_head=3, ffn_mult=2, **kwargs)
|
664 |
+
if size == 'tiny-narrow':
|
665 |
+
return SADelARTransformer(depth=4, n_head=6, ffn_mult=1, **kwargs)
|
666 |
+
if size == 'tiny':
|
667 |
+
return SADelARTransformer(depth=4, n_head=6, **kwargs)
|
668 |
+
if size == 'base':
|
669 |
+
return SADelARTransformer(depth=6, n_head=8, **kwargs)
|
670 |
+
if size == 'base-deep':
|
671 |
+
return SADelARTransformer(depth=9, n_head=8, **kwargs)
|
672 |
+
if size == 'base-wide':
|
673 |
+
return SADelARTransformer(depth=6, n_head=12, **kwargs)
|
674 |
+
if size == 'small/2':
|
675 |
+
return SADelARTransformer(depth=9, n_head=12, **kwargs)
|
676 |
+
if size == 'small':
|
677 |
+
return SADelARTransformer(depth=12, n_head=12, **kwargs)
|
678 |
+
if size == 'medium':
|
679 |
+
return SADelARTransformer(depth=24, n_head=16, **kwargs)
|
680 |
+
|
681 |
+
def make_model(size:str, quantizers:int=4, frozen_embeddings_model:str=None, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None):
|
682 |
+
if frozen_embeddings_model:
|
683 |
+
vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model)
|
684 |
+
model = _make_model(size, quantizers, tunables, dataset, stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1])
|
685 |
+
model.load_frozen_semantic_embeddings(vqmodel)
|
686 |
+
else:
|
687 |
+
model = _make_model(size, quantizers, tunables, dataset)
|
688 |
+
return model
|
whisperspeech/s2a_delar_mup_wds_mlang.py
ADDED
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = ['load_dataset', 'DelSumEmbedding', 'DelSumHead', 'rand', 'Tunables', 'SADelARTransformer']
|
5 |
+
|
6 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 1
|
7 |
+
import io
|
8 |
+
import time
|
9 |
+
import math
|
10 |
+
import random
|
11 |
+
import dataclasses
|
12 |
+
|
13 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 2
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import numpy as np
|
18 |
+
from torch.profiler import profile, record_function, ProfilerActivity, schedule
|
19 |
+
from fastcore.basics import store_attr
|
20 |
+
from huggingface_hub import hf_hub_download
|
21 |
+
|
22 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 3
|
23 |
+
from pathlib import Path
|
24 |
+
import json
|
25 |
+
from fastprogress import progress_bar, master_bar
|
26 |
+
|
27 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 4
|
28 |
+
from .modules import *
|
29 |
+
|
30 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 8
|
31 |
+
def rand(start, end):
|
32 |
+
return random.random() * (end - start) + start
|
33 |
+
|
34 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 9
|
35 |
+
def random_trunc(random_trunc_p, atoks_len = 2250, stoks_len = 750):
|
36 |
+
atoks_per_second = atoks_len / 30
|
37 |
+
def _trunc(samples):
|
38 |
+
for s in samples:
|
39 |
+
if random.random() < random_trunc_p:
|
40 |
+
seconds = rand(0.3, 30)
|
41 |
+
s['atoks.npy'] = s['atoks.npy'][:,:math.ceil(seconds * atoks_per_second)]
|
42 |
+
s['stoks.npy'] = s['stoks.npy'][:math.ceil(s['atoks.npy'].shape[-1]/atoks_len*stoks_len)]
|
43 |
+
yield s
|
44 |
+
return _trunc
|
45 |
+
|
46 |
+
def pad_samples(atoks_len = 2250, stoks_len = 750, stoks_pad_token = 4096):
|
47 |
+
def _pad(samples):
|
48 |
+
for s in samples:
|
49 |
+
s['stoks.npy'] = F.pad(torch.tensor(s['stoks.npy']), (1, stoks_len - s['stoks.npy'].shape[-1]-1), value=stoks_pad_token)
|
50 |
+
s['out_stoks'] = F.pad(torch.tensor(s['stoks.npy']), (0, stoks_len - s['stoks.npy'].shape[-1]), value=stoks_pad_token)
|
51 |
+
s['atoks.npy'] = F.pad(torch.tensor(s['atoks.npy']), (0, atoks_len - s['atoks.npy'].shape[-1]), value=-100)
|
52 |
+
yield s
|
53 |
+
return _pad
|
54 |
+
|
55 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 10
|
56 |
+
def make_speaker_map(shards):
|
57 |
+
speakers = set()
|
58 |
+
for shard in shards:
|
59 |
+
with open(shard+'.speakers.txt') as f: speakers = speakers.union(set(x.strip() for x in f.readlines()))
|
60 |
+
return {id:i for i,id in enumerate(sorted(speakers))}
|
61 |
+
|
62 |
+
def speaker_id_extractor(speaker_map):
|
63 |
+
def _extractor(samples):
|
64 |
+
for s in samples:
|
65 |
+
s['speaker'] = torch.tensor(speaker_map[s['__key__'].split("/")[1]])
|
66 |
+
yield s
|
67 |
+
return _extractor
|
68 |
+
|
69 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 27
|
70 |
+
def load_dataset(
|
71 |
+
atoks_shard_spec:str, # webdataset folder
|
72 |
+
stoks_shard_dir:str, # stoks webdataset base dir
|
73 |
+
samples:int, # samples per epoch
|
74 |
+
random_trunc_p:float=0,# probability of truncating the input to less than 30 seconds
|
75 |
+
vq_codes:int=4096,
|
76 |
+
language:str='en',
|
77 |
+
weight:float=1,
|
78 |
+
validation:bool=False,
|
79 |
+
exclude_files:str=None,
|
80 |
+
randomize_speakers:bool=False,
|
81 |
+
):
|
82 |
+
import webdataset as wds
|
83 |
+
from whisperspeech import utils
|
84 |
+
|
85 |
+
shards = utils.shard_glob(atoks_shard_spec)
|
86 |
+
excludes = {x for file in exclude_files.split() for x in utils.readlines(file)} if exclude_files else set()
|
87 |
+
|
88 |
+
def check_for_nan(s):
|
89 |
+
if torch.tensor(s['spk_emb.npy']).isnan().any(): print("found NaN:", s['__key__'])
|
90 |
+
return s
|
91 |
+
|
92 |
+
def set_language(x):
|
93 |
+
x['language'] = language
|
94 |
+
return x
|
95 |
+
|
96 |
+
same_on_all_nodes = lambda urls: urls # will only be used for validation
|
97 |
+
ds = wds.WebDataset(shards, resampled=not validation, nodesplitter=same_on_all_nodes).compose(
|
98 |
+
wds.decode(),
|
99 |
+
utils.merge_in(utils.derived_dataset('maxvad-stoks', base='atoks-3kbps', suffix='', dir=stoks_shard_dir)),
|
100 |
+
wds.map(check_for_nan),
|
101 |
+
wds.select(lambda s: s['__key__'] not in excludes),
|
102 |
+
wds.map_dict(**{'spk_emb.npy':np.nan_to_num}), # remove nans from the speaker embedding model
|
103 |
+
random_trunc(random_trunc_p) if random_trunc_p > 0 else lambda x: x,
|
104 |
+
pad_samples(stoks_pad_token=vq_codes-1),
|
105 |
+
wds.map(set_language),
|
106 |
+
wds.to_tuple('stoks.npy', 'atoks.npy', 'spk_emb.npy', 'language', 'out_stoks'),
|
107 |
+
wds.shuffle(20000, initial=20000),
|
108 |
+
wds.batched(64),
|
109 |
+
)
|
110 |
+
if randomize_speakers:
|
111 |
+
rng = np.random.default_rng()
|
112 |
+
ds = ds.compose(
|
113 |
+
wds.map_tuple(None, None, lambda x: rng.permutation(x), None),
|
114 |
+
)
|
115 |
+
if validation:
|
116 |
+
ds = ds.slice(samples // 64)
|
117 |
+
ds.total_samples = samples
|
118 |
+
ds.weight = weight
|
119 |
+
|
120 |
+
return ds
|
121 |
+
|
122 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 37
|
123 |
+
class DelSumEmbedding(nn.Module):
|
124 |
+
def __init__(self, n_head=6, head_width=64, atoks_width=None, length=2250, codes=1024, quantizers=8, pos_embs=None):
|
125 |
+
super().__init__()
|
126 |
+
self.length = length
|
127 |
+
width = n_head * head_width
|
128 |
+
if atoks_width is None: atoks_width = width
|
129 |
+
self.width = width
|
130 |
+
self.quantizers = quantizers
|
131 |
+
|
132 |
+
emb = None
|
133 |
+
embs = []
|
134 |
+
for _ in range(quantizers):
|
135 |
+
emb = FlexEmbeddings(codes, width, special_codes=2, frozen_width=atoks_width,
|
136 |
+
special_embedding=emb and emb.special)
|
137 |
+
embs.append(emb)
|
138 |
+
self.embeddings = nn.ModuleList(embs)
|
139 |
+
if pos_embs is not None:
|
140 |
+
self.register_buffer("positional_embedding", pos_embs)
|
141 |
+
|
142 |
+
def forward(self, toks, xenc):
|
143 |
+
with record_function("embeddings"):
|
144 |
+
b,_,n = toks.shape
|
145 |
+
newn = min(n, self.length)
|
146 |
+
|
147 |
+
embs = torch.zeros((b,newn,self.width), dtype=xenc.dtype, device=xenc.device)
|
148 |
+
for i in range(self.quantizers):
|
149 |
+
embs[:, :] += self.embeddings[i](toks[:,i,:])
|
150 |
+
|
151 |
+
x = embs.to(xenc.dtype)
|
152 |
+
return x
|
153 |
+
|
154 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 38
|
155 |
+
class DelSumHead(nn.Module):
|
156 |
+
def __init__(self, quantizers=8, n_head=6, head_width=64):
|
157 |
+
super().__init__()
|
158 |
+
self.width = n_head * head_width
|
159 |
+
self.quantizers = quantizers
|
160 |
+
self.splitter = nn.Sequential(
|
161 |
+
nn.Linear(self.width, self.width * quantizers),
|
162 |
+
nn.GELU(),
|
163 |
+
)
|
164 |
+
|
165 |
+
def forward(self, x, embeddings=None):
|
166 |
+
b, newn, _ = x.shape
|
167 |
+
with record_function("splitter"):
|
168 |
+
split = self.splitter(x).view(b,newn,self.quantizers,self.width)
|
169 |
+
with record_function("unembed"):
|
170 |
+
logits = torch.stack([embeddings[q].unembed(split[:,:,q]) for q in range(self.quantizers)], dim=1)
|
171 |
+
return logits
|
172 |
+
|
173 |
+
def rand(start, end):
|
174 |
+
return random.random() * (end - start) + start
|
175 |
+
|
176 |
+
@dataclasses.dataclass
|
177 |
+
class Tunables:
|
178 |
+
init_std :float = 9
|
179 |
+
embeddings_std :float = 0.2
|
180 |
+
embeddings_lr_scale: float = 10
|
181 |
+
output_mult :float = 5.6
|
182 |
+
# FIXME: try separate mults for self and cross attention
|
183 |
+
query_mult :float = .3
|
184 |
+
encoder_depth_ratio :float = 0.25
|
185 |
+
linear_heads :bool = False
|
186 |
+
rope :bool = True
|
187 |
+
|
188 |
+
lr0 :float = 3e-3
|
189 |
+
clip_gradient_norm :float = 2
|
190 |
+
weight_decay :float = 1e-3
|
191 |
+
warmup_steps :float = 2000
|
192 |
+
|
193 |
+
random :bool = False
|
194 |
+
|
195 |
+
def __post_init__(self):
|
196 |
+
# randomize the hyperparams if requested
|
197 |
+
if self.random:
|
198 |
+
self.init_std = 2*10**rand(0,1)
|
199 |
+
self.embeddings_std = 10**rand(-1.7,-0.22)
|
200 |
+
self.embeddings_lr_scale = 2**rand(2,4)
|
201 |
+
self.output_mult = 2**rand(1.5,3)
|
202 |
+
self.query_mult = 2**rand(-3,-1.3)
|
203 |
+
self.encoder_depth_ratio = random.choice([0.25,0.5])
|
204 |
+
self.linear_heads = False
|
205 |
+
self.rope = True
|
206 |
+
|
207 |
+
self.lr0 = 3e-3
|
208 |
+
self.clip_gradient_norm = 10**rand(-1,1)
|
209 |
+
self.warmup_steps = 100*(10**rand(1.18,1.3))
|
210 |
+
|
211 |
+
@staticmethod
|
212 |
+
def upgrade(args):
|
213 |
+
args = {k:v for k,v in args.items()}
|
214 |
+
def old_default(name, value):
|
215 |
+
if name not in args: args[name] = value
|
216 |
+
old_default('rope', False)
|
217 |
+
old_default('linear_heads', True)
|
218 |
+
return args
|
219 |
+
|
220 |
+
class SADelARTransformer(nn.Module):
|
221 |
+
def __init__(self, depth=3, ctx_n=2250,
|
222 |
+
stoks_len=750, stoks_codes=4097, stoks_width=None,
|
223 |
+
spk_width=None,
|
224 |
+
atoks_width=None,
|
225 |
+
n_head=3, head_width=64, ffn_mult=4,
|
226 |
+
quantizers=8, speaker_map={"1":0}, tunables=Tunables()):
|
227 |
+
super().__init__()
|
228 |
+
self.quantizers = quantizers
|
229 |
+
self.codes = 1024
|
230 |
+
width = n_head * head_width
|
231 |
+
store_attr("depth,ctx_n,stoks_len,stoks_codes,stoks_width,spk_width,atoks_width,n_head,head_width,ffn_mult,quantizers,speaker_map")
|
232 |
+
self.width = width
|
233 |
+
self.base_width = 3 * head_width
|
234 |
+
self.tunables = tunables
|
235 |
+
|
236 |
+
if stoks_width is None: stoks_width = width
|
237 |
+
if spk_width is None: spk_width = width
|
238 |
+
self.emb_factor = width != stoks_width
|
239 |
+
self.spk_factor = width != spk_width
|
240 |
+
|
241 |
+
if tunables.rope:
|
242 |
+
self.positional_embeddings = None
|
243 |
+
else:
|
244 |
+
self.register_buffer('positional_embeddings', sinusoids(ctx_n, width))
|
245 |
+
|
246 |
+
# self.speaker_embedding = nn.Embedding(len(speaker_map), spk_width)
|
247 |
+
self.semantic_embedding = nn.Embedding(stoks_codes, stoks_width)
|
248 |
+
if self.emb_factor:
|
249 |
+
self.emb_to_hidden = nn.Linear(stoks_width, width)
|
250 |
+
self.hidden_to_emb = nn.Linear(width, stoks_width)
|
251 |
+
|
252 |
+
if self.spk_factor:
|
253 |
+
self.spk_to_hidden = nn.Linear(spk_width, width)
|
254 |
+
|
255 |
+
qk_scale = self.tunables.query_mult * 8 / math.sqrt(head_width)
|
256 |
+
|
257 |
+
encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio)
|
258 |
+
decoder_depth = depth * 2 - encoder_depth
|
259 |
+
self.encoder = nn.Sequential(*[
|
260 |
+
ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, rope=tunables.rope) for _ in range(encoder_depth)
|
261 |
+
]) # FIXME: enclm requires causal attention here
|
262 |
+
self.ln_post = LayerNorm(width)
|
263 |
+
|
264 |
+
self.embds = DelSumEmbedding(
|
265 |
+
pos_embs=self.positional_embeddings, length=ctx_n,
|
266 |
+
n_head=n_head, head_width=head_width, atoks_width=atoks_width,
|
267 |
+
quantizers=quantizers,
|
268 |
+
)
|
269 |
+
self.decoder = BaseDecoder(qk_scale=qk_scale, length=ctx_n,
|
270 |
+
n_head=n_head, width=n_head * head_width,
|
271 |
+
ffn_mult=ffn_mult, depth=decoder_depth,
|
272 |
+
rope=tunables.rope)
|
273 |
+
self.head = DelSumHead(n_head=n_head, head_width=head_width, quantizers=quantizers)
|
274 |
+
for l in self.decoder.layers:
|
275 |
+
l.cross_attn.key_subsampling = 3
|
276 |
+
# for l in self.encoder:
|
277 |
+
# l.attn.key_subsampling = 3
|
278 |
+
# l.attn.query_subsampling = 3
|
279 |
+
|
280 |
+
self.register_buffer('val_true', torch.zeros(self.quantizers).cuda())
|
281 |
+
self.register_buffer('val_total', torch.zeros(self.quantizers).cuda())
|
282 |
+
self.apply(self.init_transformer)
|
283 |
+
|
284 |
+
def setup(self, device):
|
285 |
+
pass
|
286 |
+
|
287 |
+
def load_frozen_semantic_embeddings(self, vqmodel):
|
288 |
+
with torch.no_grad():
|
289 |
+
self.semantic_embedding.weight[:] = vqmodel.rq.layers[0]._codebook.embed[0]
|
290 |
+
self.semantic_embedding.lr_scale = 0
|
291 |
+
|
292 |
+
def load_frozen_acoustic_embeddings(self, amodel):
|
293 |
+
for i in range(self.quantizers):
|
294 |
+
self.decoder.embeddings[i].set_frozen_embeddings(amodel.quantizer.vq.layers[i].codebook)
|
295 |
+
|
296 |
+
def init_transformer(self, m):
|
297 |
+
if isinstance(m, LinearHead):
|
298 |
+
m.no_weight_decay = True
|
299 |
+
torch.nn.init.constant_(m.weight, 0)
|
300 |
+
elif isinstance(m, QueryHead):
|
301 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
302 |
+
torch.nn.init.constant_(m.weight, 0)
|
303 |
+
elif isinstance(m, nn.Embedding):
|
304 |
+
m.no_weight_decay = True
|
305 |
+
m.lr_scale = self.tunables.embeddings_lr_scale
|
306 |
+
std = self.tunables.embeddings_std
|
307 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
308 |
+
# elif isinstance(m, EmbeddingProjector):
|
309 |
+
# m.lr_scale = self.tunables.embeddings_lr_scale #1/(m.weight.shape[1] / self.base_width)
|
310 |
+
# m.lr_scale = 2/(m.weight.shape[1] / self.base_width)
|
311 |
+
# std = self.tunables.init_std / m.weight.shape[1]
|
312 |
+
# torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
313 |
+
elif isinstance(m, nn.Linear):
|
314 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
315 |
+
std = self.tunables.init_std / m.weight.shape[1]
|
316 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
317 |
+
if m.bias is not None:
|
318 |
+
torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
|
319 |
+
elif isinstance(m, nn.LayerNorm):
|
320 |
+
m.no_weight_decay = True
|
321 |
+
torch.nn.init.constant_(m.bias, 0)
|
322 |
+
torch.nn.init.constant_(m.weight, 1)
|
323 |
+
|
324 |
+
def embed_stoks(self, Stoks):
|
325 |
+
b,n = Stoks.shape
|
326 |
+
if self.stoks_len == 1500:
|
327 |
+
# converts 50 toks/s to 75 toks/s by adding padding between every two tokens
|
328 |
+
x = Stoks.reshape(b,n//2,2)
|
329 |
+
x = x.repeat_interleave(2, -1)[:,:,:3]
|
330 |
+
x[:,:,1] = 1024
|
331 |
+
x = x.reshape(b,n//2*3)
|
332 |
+
else:
|
333 |
+
# it's a lot easier with 25 toks/s
|
334 |
+
# x = Stoks.repeat_interleave(3, -1)
|
335 |
+
x = Stoks
|
336 |
+
# embed semantic tokens
|
337 |
+
Sembs = self.semantic_embedding(x.to(torch.long))
|
338 |
+
if self.emb_factor:
|
339 |
+
Sembs = self.emb_to_hidden(Sembs)
|
340 |
+
return Sembs
|
341 |
+
|
342 |
+
def _encoder(self, semb, positions):
|
343 |
+
x = semb
|
344 |
+
for l in self.encoder: x = l(x, positions)
|
345 |
+
return self.ln_post(x)
|
346 |
+
|
347 |
+
def run_encoder(self, Stoks, speakers):
|
348 |
+
semb = self.embed_stoks(Stoks)
|
349 |
+
with record_function("encoder"):
|
350 |
+
if self.positional_embeddings is not None: semb = semb + self.positional_embeddings
|
351 |
+
positions = torch.arange(0, semb.shape[1], device=semb.device)
|
352 |
+
xenc = self._encoder(semb, positions)
|
353 |
+
if self.training:
|
354 |
+
enc_logits = (self.hidden_to_emb(xenc) @ self.semantic_embedding.weight.to(xenc.dtype).T).float()
|
355 |
+
enc_logits = enc_logits * self.tunables.output_mult / (self.width / self.base_width)
|
356 |
+
else:
|
357 |
+
enc_logits = None
|
358 |
+
# print(xenc.shape, speakers.shape)
|
359 |
+
spk_embs = F.normalize(speakers, dim=-1) # use extracted embeddings
|
360 |
+
if self.spk_factor: spk_embs = self.spk_to_hidden(spk_embs)
|
361 |
+
return xenc + spk_embs.unsqueeze(1), positions, enc_logits
|
362 |
+
|
363 |
+
def forward(self, Stoks, Atoks, speakers, langs=None, out_stoks=None, noloss=False, xenc=None, xenc_positions=None, atoks_positions=None):
|
364 |
+
if xenc is None:
|
365 |
+
Atoks = Atoks.to(torch.long)
|
366 |
+
out_stoks = out_stoks.to(torch.long)
|
367 |
+
Atoks_gt = Atoks.clone()
|
368 |
+
Atoks_gt[Atoks == -100] = 1024
|
369 |
+
xenc, enc_logits = self.run_encoder(Stoks, speakers)
|
370 |
+
else:
|
371 |
+
Atoks_gt = Atoks
|
372 |
+
with record_function("decoder"):
|
373 |
+
embs = self.embds(Atoks, xenc)
|
374 |
+
if atoks_positions is None: atoks_positions = torch.arange(0, embs.shape[1], device=embs.device)
|
375 |
+
x = self.decoder(embs, atoks_positions, xenc, xenc_positions)
|
376 |
+
logits = self.head(x, embeddings=self.embds.embeddings)
|
377 |
+
logits *= self.tunables.output_mult / (self.width / self.base_width)
|
378 |
+
|
379 |
+
if noloss:
|
380 |
+
return logits
|
381 |
+
|
382 |
+
with record_function("loss"):
|
383 |
+
N = Atoks.shape[-1]
|
384 |
+
loss = 0
|
385 |
+
for i in range(self.quantizers):
|
386 |
+
loss += F.cross_entropy(logits[:,i,i:].reshape(-1,logits.shape[-1]), Atoks[:,i,:N-i].reshape(-1))
|
387 |
+
if self.training and i == 0:
|
388 |
+
loss *= 5
|
389 |
+
loss /= self.quantizers
|
390 |
+
if self.training:
|
391 |
+
loss += 0.1 * F.cross_entropy(enc_logits.transpose(-1,-2), out_stoks)
|
392 |
+
|
393 |
+
if not self.training:
|
394 |
+
for i in range(self.quantizers):
|
395 |
+
Atoks_i = Atoks[:,i,:N-i]
|
396 |
+
valid_Atoks = Atoks_i != -100
|
397 |
+
self.val_true[i] += (logits[:,i,i:].argmax(-1)[valid_Atoks] == Atoks_i[valid_Atoks]).float().sum()
|
398 |
+
self.val_total[i] += valid_Atoks.float().sum()
|
399 |
+
|
400 |
+
return logits, loss
|
401 |
+
|
402 |
+
def get_metrics(self):
|
403 |
+
metrics = {
|
404 |
+
f'acc_{i}':x.item() for i,x in enumerate(self.val_true / self.val_total)
|
405 |
+
}
|
406 |
+
self.val_true[:] = 0
|
407 |
+
self.val_total[:] = 0
|
408 |
+
return metrics
|
409 |
+
|
410 |
+
#
|
411 |
+
# inference
|
412 |
+
#
|
413 |
+
@classmethod
|
414 |
+
def load_model(cls, ref="collabora/whisperspeech:s2a-q4-small-en+pl.model",
|
415 |
+
repo_id=None, filename=None, local_filename=None):
|
416 |
+
if repo_id is None and filename is None and local_filename is None:
|
417 |
+
if ":" in ref:
|
418 |
+
repo_id, filename = ref.split(":", 1)
|
419 |
+
else:
|
420 |
+
local_filename = ref
|
421 |
+
if not local_filename:
|
422 |
+
local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
|
423 |
+
spec = torch.load(local_filename)
|
424 |
+
if '_extra_state' not in spec['state_dict']: spec['state_dict']['_extra_state'] = { 'speaker_map': spec['config']['speaker_map'] }
|
425 |
+
model = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec['tunables'])))
|
426 |
+
model.load_state_dict(spec['state_dict'])
|
427 |
+
model.eval()
|
428 |
+
return model
|
429 |
+
|
430 |
+
def get_extra_state(self):
|
431 |
+
return { 'speaker_map': self.speaker_map }
|
432 |
+
|
433 |
+
def set_extra_state(self, st):
|
434 |
+
self.speaker_map = st['speaker_map']
|
435 |
+
|
436 |
+
def load_checkpoint(self, local_filename):
|
437 |
+
spec = torch.load(local_filename, map_location='cpu')
|
438 |
+
assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
|
439 |
+
state_dict = {k.replace('model.', ''):v
|
440 |
+
for k,v in spec['state_dict'].items()}
|
441 |
+
self.load_state_dict(state_dict)
|
442 |
+
return self
|
443 |
+
|
444 |
+
def save_model(self, fname):
|
445 |
+
torch.save(dict(config = self.__stored_args__,
|
446 |
+
tunables = dataclasses.asdict(self.tunables),
|
447 |
+
state_dict = self.state_dict()), fname)
|
448 |
+
|
449 |
+
def switch_dtypes(self, dtype=torch.float16):
|
450 |
+
self.dtype = dtype
|
451 |
+
for n,m in self.named_modules():
|
452 |
+
# convert every leaf layer apart from the LayerNorms
|
453 |
+
if isinstance(m, (nn.Linear, nn.Embedding)):
|
454 |
+
m.to(dtype)
|
455 |
+
# take care of buffers ([kv]_cache, masks) that are not in the leaf layers
|
456 |
+
for bn,b in m.named_buffers(recurse=False):
|
457 |
+
setattr(m,bn,b.to(dtype))
|
458 |
+
|
459 |
+
def optimize(self, max_batch_size=1, dtype=torch.float16, torch_compile=True):
|
460 |
+
for emb in self.embds.embeddings:
|
461 |
+
emb.convert_for_eval()
|
462 |
+
for l in self.encoder:
|
463 |
+
l.attn.convert_for_eval()
|
464 |
+
for l in self.decoder.layers:
|
465 |
+
l.attn.convert_for_eval()
|
466 |
+
l.cross_attn.convert_for_eval()
|
467 |
+
l.setup_kv_cache(max_batch_size, self.ctx_n, self.stoks_len)
|
468 |
+
self.switch_dtypes(dtype)
|
469 |
+
if torch_compile:
|
470 |
+
self.generate_next = torch.compile(self.generate_next, mode="reduce-overhead", fullgraph=True)
|
471 |
+
|
472 |
+
@property
|
473 |
+
def device(self):
|
474 |
+
return next(self.parameters()).device
|
475 |
+
|
476 |
+
# from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
|
477 |
+
def multinomial_sample_one_no_sync(self, probs_sort): # Does multinomial sampling without a cuda synchronization
|
478 |
+
q = torch.empty_like(probs_sort).exponential_(1)
|
479 |
+
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
480 |
+
|
481 |
+
def logits_to_probs(self, logits, T=1.0, top_k=None):
|
482 |
+
logits = logits / max(T, 1e-5)
|
483 |
+
|
484 |
+
if top_k is not None:
|
485 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
486 |
+
pivot = v.select(-1, -1).unsqueeze(-1)
|
487 |
+
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
488 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
489 |
+
return probs
|
490 |
+
|
491 |
+
def sample(self, logits, T=1.0, top_k=None):
|
492 |
+
probs = self.logits_to_probs(logits[0,:,-1], T, top_k)
|
493 |
+
idx_next = self.multinomial_sample_one_no_sync(probs)
|
494 |
+
return idx_next
|
495 |
+
|
496 |
+
def generate_one(self, toks, positions, langs, xenc, xenc_positions, T, top_k):
|
497 |
+
probs = self(None, toks, None, langs, noloss=True, xenc=xenc, xenc_positions=xenc_positions, atoks_positions=positions)
|
498 |
+
return self.sample(probs, T, top_k)
|
499 |
+
|
500 |
+
def generate_next(self, *args, **kwargs):
|
501 |
+
return self.generate_one(*args, **kwargs)
|
502 |
+
|
503 |
+
@torch.no_grad()
|
504 |
+
def generate(self, stoks, speakers, langs=None, N=None, T=0.7, top_k=None, show_progress_bar=True, step=None, subsample_enc=False):
|
505 |
+
dev = self.device
|
506 |
+
N = N or len(stoks) * 3
|
507 |
+
stoks = F.pad(stoks.to(dev), (1, self.stoks_len - len(stoks)-1), value=self.stoks_codes-1).unsqueeze(0)
|
508 |
+
speakers = speakers.to(device=dev, dtype=self.dtype)
|
509 |
+
toks = torch.full((1,self.quantizers,2250), self.codes+1, dtype=torch.long, device=dev)
|
510 |
+
it = range(1,min(N,2250-1))
|
511 |
+
if show_progress_bar: it = progress_bar(it)
|
512 |
+
with record_function("encode"):
|
513 |
+
xenc, xenc_positions, _ = self.run_encoder(stoks, speakers)
|
514 |
+
toks_positions = torch.arange(N, device=dev)
|
515 |
+
with record_function("prefill"):
|
516 |
+
toks[0,0,1] = self.generate_one(toks[:,:,:1], toks_positions[:1], langs, xenc, xenc_positions, T, top_k)[0,0]
|
517 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
518 |
+
for i in it:
|
519 |
+
with record_function("generate_one"):
|
520 |
+
toks[0,:i+1,i+1] = self.generate_next(toks[:,:,i:i+1], toks_positions[i:i+1], langs, xenc, xenc_positions, T, top_k)[:i+1,0]
|
521 |
+
|
522 |
+
# for profiling, debugging or early exit
|
523 |
+
if step is not None: step()
|
524 |
+
# shift tokens
|
525 |
+
toks = toks[:,:,1:N]
|
526 |
+
for j in range(self.quantizers):
|
527 |
+
toks[0, j] = torch.roll(toks[0, j], -j)
|
528 |
+
return toks[0]
|
529 |
+
|
530 |
+
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 39
|
531 |
+
def _make_model(size:str, quantizers:int=4, tunables:Tunables=Tunables(), **kwargs):
|
532 |
+
kwargs = dict(quantizers=quantizers, tunables=tunables, **kwargs)
|
533 |
+
if size == 'micro':
|
534 |
+
return SADelARTransformer(depth=4, n_head=3, ffn_mult=2, **kwargs)
|
535 |
+
if size == 'tiny-narrow':
|
536 |
+
return SADelARTransformer(depth=4, n_head=6, ffn_mult=1, **kwargs)
|
537 |
+
if size == 'tiny':
|
538 |
+
return SADelARTransformer(depth=4, n_head=6, **kwargs)
|
539 |
+
if size == 'base':
|
540 |
+
return SADelARTransformer(depth=6, n_head=8, **kwargs)
|
541 |
+
if size == 'base-deep':
|
542 |
+
return SADelARTransformer(depth=9, n_head=8, **kwargs)
|
543 |
+
if size == 'base-wide':
|
544 |
+
return SADelARTransformer(depth=6, n_head=12, **kwargs)
|
545 |
+
if size == 'small/2':
|
546 |
+
return SADelARTransformer(depth=9, n_head=12, **kwargs)
|
547 |
+
if size == 'small':
|
548 |
+
return SADelARTransformer(depth=12, n_head=12, **kwargs)
|
549 |
+
if size == 'medium':
|
550 |
+
return SADelARTransformer(depth=24, n_head=16, **kwargs)
|
551 |
+
|
552 |
+
def make_model(size:str, quantizers:int=4, frozen_embeddings_model:str=None, frozen_acoustic_embeddings:bool=False, spk_width:int=None, tunables:Tunables=Tunables(), dataset=None):
|
553 |
+
from encodec.model import EncodecModel
|
554 |
+
from whisperspeech import vq_stoks
|
555 |
+
|
556 |
+
amodel = EncodecModel.encodec_model_24khz() if frozen_acoustic_embeddings else None
|
557 |
+
vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model) if frozen_embeddings_model else None
|
558 |
+
model = _make_model(size, quantizers, tunables,
|
559 |
+
spk_width=spk_width,
|
560 |
+
atoks_width=amodel and amodel.quantizer.vq.layers[0]._codebook.embed.shape[-1],
|
561 |
+
stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1])
|
562 |
+
if vqmodel: model.load_frozen_semantic_embeddings(vqmodel)
|
563 |
+
if amodel: model.load_frozen_acoustic_embeddings(amodel)
|
564 |
+
return model
|
whisperspeech/t2s_up_wds.py
ADDED
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/5B. Text to semantic token modeling.ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = ['load_datasets', 'rand', 'Tunables', 'Encoder', 'Decoder', 'TSARTransformer', 'make_model']
|
5 |
+
|
6 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 1
|
7 |
+
import dataclasses
|
8 |
+
import random
|
9 |
+
import math
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch.profiler import record_function
|
14 |
+
|
15 |
+
from huggingface_hub import hf_hub_download
|
16 |
+
from fastcore.basics import store_attr
|
17 |
+
from fastprogress import progress_bar
|
18 |
+
|
19 |
+
import webdataset as wds
|
20 |
+
|
21 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 2
|
22 |
+
from pathlib import Path
|
23 |
+
import pylab as plt
|
24 |
+
import pandas as pd
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 3
|
28 |
+
import whisper
|
29 |
+
from whisperspeech.train import *
|
30 |
+
from whisperspeech.modules import *
|
31 |
+
from whisperspeech import vq_stoks
|
32 |
+
|
33 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 8
|
34 |
+
import re
|
35 |
+
|
36 |
+
class CharTokenizer:
|
37 |
+
"""Trivial tokenizer β just use UTF-8 bytes"""
|
38 |
+
eot = 0
|
39 |
+
|
40 |
+
def encode(self, txt):
|
41 |
+
return list(bytes(txt.strip(), 'utf-8'))
|
42 |
+
|
43 |
+
def decode(self, tokens):
|
44 |
+
return bytes(tokens).decode('utf-8')
|
45 |
+
|
46 |
+
def tokenizer(ikey, okey, length):
|
47 |
+
"""Tokenizes a transcript"""
|
48 |
+
tok = CharTokenizer()
|
49 |
+
def _tokenizer(samples):
|
50 |
+
for s in samples:
|
51 |
+
toks = torch.tensor(tok.encode(s[ikey]))
|
52 |
+
s[okey] = F.pad(toks, (0, length - toks.shape[-1]), value=tok.eot)
|
53 |
+
yield s
|
54 |
+
return _tokenizer
|
55 |
+
|
56 |
+
def ar_padder(ikey, okey, length, pad_token):
|
57 |
+
"""Pads the tokens for autoregresive training"""
|
58 |
+
def _ar_padder(samples):
|
59 |
+
for s in samples:
|
60 |
+
toks = s[ikey]
|
61 |
+
if isinstance(toks, (list, np.ndarray)): toks = torch.tensor(toks)
|
62 |
+
toks = toks.to(torch.long)
|
63 |
+
s['in_' +okey] = F.pad(toks, (1, length - toks.shape[-1] - 1), value=pad_token)
|
64 |
+
s['out_'+okey] = F.pad(toks, (0, length - toks.shape[-1]), value=pad_token)
|
65 |
+
yield s
|
66 |
+
return _ar_padder
|
67 |
+
|
68 |
+
def char_per_seconder(txt_key, stoks_key, cps_key, stoks_per_second=25):
|
69 |
+
"""Adds the characters per second metric to the input data"""
|
70 |
+
def _char_per_seconder(samples):
|
71 |
+
for s in samples:
|
72 |
+
secs = s[stoks_key].shape[-1] / stoks_per_second
|
73 |
+
s[cps_key] = len(s[txt_key]) / secs
|
74 |
+
yield s
|
75 |
+
return _char_per_seconder
|
76 |
+
|
77 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 9
|
78 |
+
def build_speaker_map(shards):
|
79 |
+
speakers = set()
|
80 |
+
for shard in shards:
|
81 |
+
with open(shard+'.speakers.txt') as f: speakers = speakers.union(set(x.strip() for x in f.readlines()))
|
82 |
+
return {id:i for i,id in enumerate(speakers)}
|
83 |
+
|
84 |
+
def speaker_id_extractor(speaker_map):
|
85 |
+
def _extractor(samples):
|
86 |
+
for s in samples:
|
87 |
+
s['speaker'] = torch.tensor(speaker_map[s['__key__'].split("/")[1]])
|
88 |
+
yield s
|
89 |
+
return _extractor
|
90 |
+
|
91 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 10
|
92 |
+
def load_datasets(
|
93 |
+
input:str, # webdataset folder or shard list
|
94 |
+
samples:int, # samples per epoch
|
95 |
+
subsample:float=1, # use a fraction of the files
|
96 |
+
val_samples:int=512,
|
97 |
+
vq_codes:int=4096,
|
98 |
+
):
|
99 |
+
if isinstance(input, (Path, str)):
|
100 |
+
path = Path(input)
|
101 |
+
if path.is_dir():
|
102 |
+
glob = '*-t2s-*.tar.gz'
|
103 |
+
else:
|
104 |
+
glob = path.name
|
105 |
+
path = path.parent
|
106 |
+
input = Path(path).glob(glob)
|
107 |
+
elif isinstance(input, list):
|
108 |
+
pass
|
109 |
+
else:
|
110 |
+
raise ArgumentError("input should be either a list of a path with an optional glob specifier")
|
111 |
+
shards = [str(x) for x in input]
|
112 |
+
|
113 |
+
speaker_map = build_speaker_map(shards)
|
114 |
+
|
115 |
+
def ds(shards, length):
|
116 |
+
ds = wds.WebDataset(wds.ResampledShards(shards)).compose(
|
117 |
+
wds.decode(),
|
118 |
+
speaker_id_extractor(speaker_map),
|
119 |
+
wds.select(lambda s: s['stoks.npy'].shape[-1] > 12), # select samples > .5s
|
120 |
+
tokenizer('txt', 'ttoks', length=550),
|
121 |
+
ar_padder('stoks.npy', 'stoks', length=750, pad_token=vq_codes-1),
|
122 |
+
char_per_seconder('txt', 'stoks.npy', 'cps', stoks_per_second=25),
|
123 |
+
wds.to_tuple('ttoks', 'speaker', 'cps', 'in_stoks', 'out_stoks'),
|
124 |
+
wds.batched(64)
|
125 |
+
)
|
126 |
+
ds.speakers = speaker_map
|
127 |
+
ds.total_samples = length
|
128 |
+
ds.stoks_len = 750
|
129 |
+
ds.stoks_codes = vq_codes
|
130 |
+
ds.ttoks_len = 550
|
131 |
+
return ds.compose(wds.slice(length // 64)).with_epoch(length // 64).with_length(length // 64)
|
132 |
+
|
133 |
+
return (
|
134 |
+
ds(shards[1:], samples),
|
135 |
+
ds(shards[:1], val_samples),
|
136 |
+
)
|
137 |
+
|
138 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 14
|
139 |
+
def rand(start, end):
|
140 |
+
return random.random() * (end - start) + start
|
141 |
+
|
142 |
+
@dataclasses.dataclass
|
143 |
+
class Tunables:
|
144 |
+
init_std :float = 1
|
145 |
+
embeddings_std :float = .01
|
146 |
+
embeddings_lr_scale: float = 5
|
147 |
+
embedding_projector_lr_scale: float = 2.5
|
148 |
+
output_mult :float = .35
|
149 |
+
query_mult :float = 1
|
150 |
+
encoder_depth_ratio :float = 0.25
|
151 |
+
eot_dropout_p :float = .5
|
152 |
+
cps_input: bool = True
|
153 |
+
cps_bins: int = 32
|
154 |
+
|
155 |
+
lr0 :float = 1.5e-3
|
156 |
+
clip_gradient_norm :float = .2
|
157 |
+
weight_decay :float = 1e-1
|
158 |
+
warmup_steps :float = 4000
|
159 |
+
|
160 |
+
random :bool = False
|
161 |
+
|
162 |
+
def __post_init__(self):
|
163 |
+
# randomize the hyperparams if requested
|
164 |
+
if self.random:
|
165 |
+
self.init_std = 10**rand(-1,1)
|
166 |
+
self.embeddings_std = 10**rand(-3,-.7)
|
167 |
+
self.embeddings_lr_scale = rand(2,6)
|
168 |
+
self.output_mult = rand(0.25,0.65)
|
169 |
+
self.query_mult = 2**rand(-2,3)
|
170 |
+
self.encoder_depth_ratio = 0.25
|
171 |
+
|
172 |
+
self.lr0 = rand(1,5)*1e-3
|
173 |
+
self.clip_gradient_norm = 10**rand(-3,0)
|
174 |
+
self.warmup_steps = 100*(10**rand(1,1.85))
|
175 |
+
|
176 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 15
|
177 |
+
class EmbeddingProjector(nn.Linear):
|
178 |
+
pass
|
179 |
+
|
180 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 16
|
181 |
+
class Encoder(nn.Module):
|
182 |
+
def __init__(self, depth=6, width=384, n_head=6, length=1500, codes=1024, emb_width=384, ffn_mult=4, pos_embs=None, tunables=Tunables()):
|
183 |
+
super().__init__()
|
184 |
+
self.emb_width = emb_width
|
185 |
+
|
186 |
+
self.emb_factor = width != emb_width
|
187 |
+
|
188 |
+
self.embedding = nn.Embedding(codes, emb_width)
|
189 |
+
if self.emb_factor:
|
190 |
+
self.emb_to_hidden = EmbeddingProjector(emb_width, width)
|
191 |
+
|
192 |
+
if pos_embs is None: pos_embs = sinusoids(length, width)
|
193 |
+
self.register_buffer("positional_embedding", pos_embs)
|
194 |
+
|
195 |
+
self.layers = nn.Sequential(*[
|
196 |
+
ResidualAttentionBlock(width, n_head,
|
197 |
+
qk_scale=tunables.query_mult*8/math.sqrt(width/n_head), ffn_mult=ffn_mult) for _ in range(depth)
|
198 |
+
])
|
199 |
+
|
200 |
+
self.ln_post = LayerNorm(width)
|
201 |
+
|
202 |
+
def forward(self, Stoks):
|
203 |
+
xin = self.embedding(Stoks)
|
204 |
+
if self.emb_factor:
|
205 |
+
xin = self.emb_to_hidden(xin)
|
206 |
+
|
207 |
+
assert xin.shape[1:] == self.positional_embedding.shape, "incorrect semantic token shape"
|
208 |
+
xin = (xin + self.positional_embedding).to(xin.dtype)
|
209 |
+
|
210 |
+
return self.ln_post(self.layers(xin))
|
211 |
+
|
212 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 17
|
213 |
+
class Decoder(nn.Module):
|
214 |
+
def __init__(self, depth=6, stoks_width=384, width=384, n_head=6, length=1500, codes=1024, ffn_mult=4, pos_embs=None, tunables=Tunables()):
|
215 |
+
super().__init__()
|
216 |
+
self.length = length
|
217 |
+
self.codes = codes
|
218 |
+
self.width = width
|
219 |
+
self.stoks_width = stoks_width
|
220 |
+
|
221 |
+
self.emb_factor = width != stoks_width
|
222 |
+
|
223 |
+
# embed semantic tokens
|
224 |
+
self.embedding = nn.Embedding(codes, stoks_width)
|
225 |
+
if self.emb_factor:
|
226 |
+
self.emb_to_hidden = EmbeddingProjector(stoks_width, width)
|
227 |
+
self.hidden_to_emb = EmbeddingProjector(width, stoks_width)
|
228 |
+
|
229 |
+
if pos_embs is None: pos_embs = sinusoids(length, width)
|
230 |
+
self.register_buffer("positional_embedding", pos_embs)
|
231 |
+
|
232 |
+
self.layers = nn.ModuleList([
|
233 |
+
ResidualAttentionBlock(width, n_head, cross_attention=True,
|
234 |
+
qk_scale=tunables.query_mult*8/math.sqrt(width/n_head), ffn_mult=ffn_mult) for _ in range(depth)
|
235 |
+
])
|
236 |
+
self.ln_post = LayerNorm(width)
|
237 |
+
|
238 |
+
def forward(self, Stoks, xenc, cps=None):
|
239 |
+
Sembs = self.embedding(Stoks)
|
240 |
+
|
241 |
+
if self.emb_factor:
|
242 |
+
Sembs = self.emb_to_hidden(Sembs)
|
243 |
+
|
244 |
+
xin = (Sembs + self.positional_embedding[:Sembs.shape[1]]).to(xenc.dtype)
|
245 |
+
if cps is not None: xin = xin + cps
|
246 |
+
|
247 |
+
x = xin
|
248 |
+
for l in self.layers: x = l(x, xenc, causal=True)
|
249 |
+
|
250 |
+
x = self.ln_post(x)
|
251 |
+
|
252 |
+
if self.emb_factor:
|
253 |
+
x = self.hidden_to_emb(x)
|
254 |
+
|
255 |
+
logits = (x @ self.embedding.weight.to(x.dtype).T).float()
|
256 |
+
return logits
|
257 |
+
|
258 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 18
|
259 |
+
class TSARTransformer(nn.Module):
|
260 |
+
def __init__(self, depth=6, n_head=6, head_width=64, ffn_mult=4, language='en',
|
261 |
+
ttoks_len=200, ttoks_codes=50364, ttoks_width=None,
|
262 |
+
stoks_len=1500, stoks_codes=1024, stoks_width=None,
|
263 |
+
tunables=Tunables()):
|
264 |
+
assert language == 'en', "only english is supported right now"
|
265 |
+
super().__init__()
|
266 |
+
store_attr("depth,n_head,head_width,ffn_mult,stoks_width,ttoks_width,ttoks_len,stoks_len,ttoks_codes,stoks_codes,language")
|
267 |
+
|
268 |
+
width = n_head * head_width
|
269 |
+
self.width = width
|
270 |
+
self.base_width = 3 * head_width
|
271 |
+
self.tunables = tunables
|
272 |
+
if self.stoks_width is None: self.stoks_width = self.width
|
273 |
+
if self.ttoks_width is None: self.ttoks_width = self.width
|
274 |
+
|
275 |
+
if tunables.cps_input:
|
276 |
+
self.cps_embeddings = nn.Embedding(tunables.cps_bins, self.width)
|
277 |
+
else:
|
278 |
+
self.cps_embeddings = None
|
279 |
+
|
280 |
+
encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio)
|
281 |
+
decoder_depth = depth * 2 - encoder_depth
|
282 |
+
tformer_args = dict(width=width, n_head=n_head, ffn_mult=ffn_mult, tunables=tunables)
|
283 |
+
self.encoder = Encoder(length=ttoks_len, codes=ttoks_codes, emb_width=self.ttoks_width, depth=encoder_depth, **tformer_args)
|
284 |
+
self.decoder = Decoder(length=stoks_len, codes=stoks_codes, stoks_width=self.stoks_width, depth=decoder_depth, **tformer_args)
|
285 |
+
|
286 |
+
self.tokenizer = None
|
287 |
+
|
288 |
+
self.apply(self.init_transformer)
|
289 |
+
|
290 |
+
def load_frozen_semantic_embeddings(self, vqmodel):
|
291 |
+
with torch.no_grad():
|
292 |
+
self.decoder.embedding.weight[:] = vqmodel.rq.layers[0]._codebook.embed[0]
|
293 |
+
self.decoder.embedding.lr_scale = 0
|
294 |
+
|
295 |
+
def setup(self, device):
|
296 |
+
pass
|
297 |
+
|
298 |
+
def init_transformer(self, m):
|
299 |
+
if isinstance(m, LinearHead):
|
300 |
+
m.no_weight_decay = True
|
301 |
+
torch.nn.init.constant_(m.weight, 0)
|
302 |
+
elif isinstance(m, QueryHead):
|
303 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
304 |
+
torch.nn.init.constant_(m.weight, 0)
|
305 |
+
elif isinstance(m, nn.Embedding):
|
306 |
+
m.no_weight_decay = True
|
307 |
+
m.lr_scale = self.tunables.embeddings_lr_scale
|
308 |
+
std = self.tunables.embeddings_std
|
309 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
310 |
+
elif isinstance(m, EmbeddingProjector):
|
311 |
+
m.lr_scale = self.tunables.embedding_projector_lr_scale
|
312 |
+
std = self.tunables.init_std
|
313 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
314 |
+
elif isinstance(m, nn.Linear):
|
315 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
316 |
+
std = self.tunables.init_std / m.weight.shape[1]
|
317 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
318 |
+
if m.bias is not None:
|
319 |
+
torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
|
320 |
+
elif isinstance(m, nn.LayerNorm):
|
321 |
+
m.no_weight_decay = True
|
322 |
+
torch.nn.init.constant_(m.bias, 0)
|
323 |
+
torch.nn.init.constant_(m.weight, 1)
|
324 |
+
|
325 |
+
def forward(self, Ttoks, speakers, cpss, in_stoks, out_stoks=None, loss=True):
|
326 |
+
with record_function("encoder"):
|
327 |
+
xenc = self.encoder(Ttoks.to(torch.long))
|
328 |
+
with record_function("decoder"):
|
329 |
+
if self.cps_embeddings:
|
330 |
+
cps_bin = (cpss / 20 * self.tunables.cps_bins).to(torch.long)
|
331 |
+
cps_bin[cps_bin >= self.tunables.cps_bins] = self.tunables.cps_bins-1
|
332 |
+
cps_embs = self.cps_embeddings(cps_bin).unsqueeze(1)
|
333 |
+
else:
|
334 |
+
cps_embs = None
|
335 |
+
logits = self.decoder(in_stoks, xenc, cps=cps_embs) * self.tunables.output_mult / (self.width / self.base_width)
|
336 |
+
if loss is not None:
|
337 |
+
with record_function("loss"):
|
338 |
+
loss = F.cross_entropy(logits.transpose(-1,-2), out_stoks)#, reduction='none')
|
339 |
+
return logits, loss
|
340 |
+
|
341 |
+
#
|
342 |
+
# inference
|
343 |
+
#
|
344 |
+
@classmethod
|
345 |
+
def load_model(cls, repo_id="collabora/whisperspeech", filename="t2s_up_wds.model", local_filename=None):
|
346 |
+
if not local_filename:
|
347 |
+
local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
|
348 |
+
spec = torch.load(local_filename)
|
349 |
+
model = cls(**spec['config'], tunables=Tunables(**spec['tunables']))
|
350 |
+
model.load_state_dict(spec['state_dict'])
|
351 |
+
model.eval()
|
352 |
+
return model
|
353 |
+
|
354 |
+
def load_checkpoint(self, local_filename):
|
355 |
+
spec = torch.load(local_filename, map_location='cpu')
|
356 |
+
assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
|
357 |
+
state_dict = {k.replace('model.', ''):v
|
358 |
+
for k,v in spec['state_dict'].items()}
|
359 |
+
self.load_state_dict(state_dict)
|
360 |
+
return self
|
361 |
+
|
362 |
+
def save_model(self, fname):
|
363 |
+
torch.save(dict(config = self.__stored_args__,
|
364 |
+
tunables = dataclasses.asdict(self.tunables),
|
365 |
+
state_dict = self.state_dict()), fname)
|
366 |
+
|
367 |
+
def ensure_tokenizer(self):
|
368 |
+
assert not self.training
|
369 |
+
if self.tokenizer is None: self.tokenizer = CharTokenizer()
|
370 |
+
#whisper.tokenizer.get_tokenizer(multilingual=True)
|
371 |
+
|
372 |
+
@property
|
373 |
+
def device(self):
|
374 |
+
return next(self.parameters()).device
|
375 |
+
|
376 |
+
@torch.no_grad()
|
377 |
+
def generate(self, txt, cps=15, N=None, T=0.7, top_k=None, show_progress_bar=True):
|
378 |
+
self.ensure_tokenizer()
|
379 |
+
N = N or self.stoks_len
|
380 |
+
dev = self.device
|
381 |
+
ttoks = torch.tensor(self.tokenizer.encode(txt), device=dev)
|
382 |
+
ttoks = F.pad(ttoks, (0, self.ttoks_len - len(ttoks)), value=self.tokenizer.eot).unsqueeze(0)
|
383 |
+
cpss = torch.tensor([cps], device=dev)
|
384 |
+
toks = torch.zeros((1,N), dtype=torch.long, device=dev)
|
385 |
+
toks[0,0] = self.stoks_codes-1
|
386 |
+
it = range(1,N)
|
387 |
+
if show_progress_bar: it = progress_bar(it)
|
388 |
+
for i in it:
|
389 |
+
p, _ = self(ttoks, None, cpss, toks[:,:i], loss=None)
|
390 |
+
last_p = p[0,-1]
|
391 |
+
if top_k:
|
392 |
+
last_p[last_p < torch.topk(last_p, top_k).values[-1,None]] = -torch.inf
|
393 |
+
tok = torch.multinomial((last_p / float(T)).softmax(-1), 1)
|
394 |
+
toks[0,i] = tok
|
395 |
+
if toks[0,i] == self.stoks_codes-1: return toks[0,1:i]
|
396 |
+
return toks[0,1:]
|
397 |
+
|
398 |
+
@torch.no_grad()
|
399 |
+
def generate_batch(self, txts, N=None, T=1.1, top_k=7, show_progress_bar=True):
|
400 |
+
self.ensure_tokenizer()
|
401 |
+
N = self.stoks_len
|
402 |
+
dev = self.device
|
403 |
+
ttoks = []
|
404 |
+
for txt in txts:
|
405 |
+
ttoks_ = torch.tensor(self.tokenizer.encode(txt), device=dev)
|
406 |
+
ttoks_ = F.pad(ttoks_, (0, self.ttoks_len - len(ttoks_)), value=self.tokenizer.eot).unsqueeze(0)
|
407 |
+
ttoks.append(ttoks_)
|
408 |
+
ttoks = torch.cat(ttoks, dim=0)
|
409 |
+
toks = torch.zeros((len(ttoks),N), dtype=torch.long, device=dev)
|
410 |
+
it = range(N)
|
411 |
+
if show_progress_bar: it = progress_bar(it)
|
412 |
+
for i in it:
|
413 |
+
p, _ = self(ttoks, toks[:,:i], loss=None)
|
414 |
+
last_p = p[:,-1]
|
415 |
+
if top_k:
|
416 |
+
last_p[last_p < torch.topk(last_p, top_k).values[:,-1,None]] = -torch.inf
|
417 |
+
tok = torch.multinomial((last_p / float(T)).softmax(-1), 1)
|
418 |
+
toks[:,i] = tok[:,0]
|
419 |
+
if (toks[:,i] == self.stoks_codes-1).all(): return toks[:,:i]
|
420 |
+
return toks
|
421 |
+
|
422 |
+
# %% ../nbs/5B. Text to semantic token modeling.ipynb 19
|
423 |
+
def _make_model(size:str, tunables:Tunables=Tunables(), dataset=None, **kwargs):
|
424 |
+
kwargs = dict(stoks_len = dataset.stoks_len, ttoks_len = dataset.ttoks_len, tunables=tunables, **kwargs)
|
425 |
+
if 'stoks_codes' not in kwargs: kwargs['stoks_codes'] = dataset.stoks_codes
|
426 |
+
if size == 'micro':
|
427 |
+
return TSARTransformer(depth=2, n_head=3, ffn_mult=1, **kwargs)
|
428 |
+
if size == 'tiny':
|
429 |
+
return TSARTransformer(depth=4, n_head=6, **kwargs)
|
430 |
+
if size == 'base':
|
431 |
+
return TSARTransformer(depth=6, n_head=8, **kwargs)
|
432 |
+
if size == 'small':
|
433 |
+
return TSARTransformer(depth=12, n_head=16, **kwargs)
|
434 |
+
|
435 |
+
def make_model(size:str, frozen_embeddings_model:str=None, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None):
|
436 |
+
if frozen_embeddings_model:
|
437 |
+
vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model)
|
438 |
+
model = _make_model(size, tunables, dataset, stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1])
|
439 |
+
model.load_frozen_semantic_embeddings(vqmodel)
|
440 |
+
else:
|
441 |
+
model = _make_model(size, quantizers, tunables, dataset)
|
442 |
+
return model
|
whisperspeech/t2s_up_wds_mlang_enclm.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/5B. Multi-lang text to semantic token modeling.ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = ['load_dataset', 'rand', 'Tunables', 'T2SEmbedding', 'Encoder', 'TSARTransformer', 'make_model']
|
5 |
+
|
6 |
+
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 1
|
7 |
+
import dataclasses
|
8 |
+
import random
|
9 |
+
import math
|
10 |
+
import itertools
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch.profiler import record_function
|
15 |
+
|
16 |
+
from huggingface_hub import hf_hub_download
|
17 |
+
from fastcore.basics import store_attr
|
18 |
+
from fastprogress import progress_bar
|
19 |
+
|
20 |
+
from pathlib import Path
|
21 |
+
|
22 |
+
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 2
|
23 |
+
from whisperspeech.modules import *
|
24 |
+
from whisperspeech import languages
|
25 |
+
|
26 |
+
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 6
|
27 |
+
import re
|
28 |
+
|
29 |
+
class CharTokenizer:
|
30 |
+
"""Trivial tokenizer β just use UTF-8 bytes"""
|
31 |
+
eot = 0
|
32 |
+
|
33 |
+
def encode(self, txt):
|
34 |
+
return list(bytes(txt.strip(), 'utf-8'))
|
35 |
+
|
36 |
+
def decode(self, tokens):
|
37 |
+
return bytes(tokens).decode('utf-8')
|
38 |
+
|
39 |
+
def tokenizer(ikey, okey, length):
|
40 |
+
"""Tokenizes a transcript"""
|
41 |
+
tok = CharTokenizer()
|
42 |
+
def _tokenizer(samples):
|
43 |
+
for s in samples:
|
44 |
+
toks = torch.tensor(tok.encode(s[ikey]))
|
45 |
+
s[okey] = F.pad(toks, (0, length - toks.shape[-1]), value=tok.eot)
|
46 |
+
yield s
|
47 |
+
return _tokenizer
|
48 |
+
|
49 |
+
def ar_padder(ikey, okey, length, pad_token):
|
50 |
+
"""Pads the tokens for autoregresive training"""
|
51 |
+
import numpy as np
|
52 |
+
|
53 |
+
def _ar_padder(samples):
|
54 |
+
for s in samples:
|
55 |
+
toks = s[ikey]
|
56 |
+
if isinstance(toks, (list, np.ndarray)): toks = torch.tensor(toks)
|
57 |
+
toks = toks.to(torch.long)
|
58 |
+
s['in_' +okey] = F.pad(toks, (1, length - toks.shape[-1] - 1), value=pad_token)
|
59 |
+
s['out_'+okey] = F.pad(toks, (0, length - toks.shape[-1]), value=pad_token)
|
60 |
+
yield s
|
61 |
+
return _ar_padder
|
62 |
+
|
63 |
+
def char_per_seconder(txt_key, stoks_key, cps_key, stoks_per_second=25):
|
64 |
+
"""Adds the characters per second metric to the input data"""
|
65 |
+
def _char_per_seconder(samples):
|
66 |
+
for s in samples:
|
67 |
+
secs = s[stoks_key].shape[-1] / stoks_per_second
|
68 |
+
s[cps_key] = len(s[txt_key]) / secs
|
69 |
+
yield s
|
70 |
+
return _char_per_seconder
|
71 |
+
|
72 |
+
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 7
|
73 |
+
def load_dataset(
|
74 |
+
txt_shard_spec:str, # transcription webdataset shards
|
75 |
+
stoks_shard_dir:str, # stoks webdataset base dir
|
76 |
+
samples:int, # samples per epoch
|
77 |
+
txt_kind:str='small.en-txt',
|
78 |
+
vq_codes:int=4096,
|
79 |
+
language:str='en',
|
80 |
+
weight:float=1,
|
81 |
+
validation:bool=False,
|
82 |
+
exclude_files:str=None,
|
83 |
+
):
|
84 |
+
import webdataset as wds
|
85 |
+
from whisperspeech import utils
|
86 |
+
|
87 |
+
shards = utils.shard_glob(txt_shard_spec)
|
88 |
+
excludes = {x for file in exclude_files.split() for x in utils.readlines(file)} if exclude_files else set()
|
89 |
+
|
90 |
+
language = languages.to_id(language)
|
91 |
+
|
92 |
+
def set_language(x):
|
93 |
+
x['language'] = language
|
94 |
+
return x
|
95 |
+
|
96 |
+
same_on_all_nodes = lambda urls: urls # will only be used for validation
|
97 |
+
ds = wds.WebDataset(shards, resampled=not validation, nodesplitter=same_on_all_nodes).compose(
|
98 |
+
wds.decode(),
|
99 |
+
utils.merge_in(utils.derived_dataset('eqvad-stoks', base=txt_kind, suffix='', dir=stoks_shard_dir)),
|
100 |
+
# discard validation samples, select samples > .5s
|
101 |
+
wds.select(lambda s: s['__key__'] not in excludes and s['stoks.npy'].shape[-1] > 12),
|
102 |
+
tokenizer('txt', 'ttoks', length=550),
|
103 |
+
ar_padder('stoks.npy', 'stoks', length=750, pad_token=vq_codes-1),
|
104 |
+
ar_padder('ttoks', 'ttoks', length=550, pad_token=CharTokenizer.eot),
|
105 |
+
char_per_seconder('txt', 'stoks.npy', 'cps', stoks_per_second=25),
|
106 |
+
wds.map(set_language),
|
107 |
+
wds.to_tuple('in_ttoks', 'out_ttoks', 'language', 'cps', 'in_stoks', 'out_stoks'),
|
108 |
+
wds.shuffle(20000, initial=20000),
|
109 |
+
wds.batched(64)
|
110 |
+
)
|
111 |
+
if validation:
|
112 |
+
ds = ds.slice(samples // 64)
|
113 |
+
ds.total_samples = samples
|
114 |
+
ds.stoks_len = 750
|
115 |
+
ds.stoks_codes = vq_codes
|
116 |
+
ds.ttoks_len = 550
|
117 |
+
ds.weight = weight
|
118 |
+
|
119 |
+
return ds
|
120 |
+
|
121 |
+
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 14
|
122 |
+
def rand(start, end):
|
123 |
+
return random.random() * (end - start) + start
|
124 |
+
|
125 |
+
@dataclasses.dataclass
|
126 |
+
class Tunables:
|
127 |
+
init_std :float = 1
|
128 |
+
embeddings_std :float = .01
|
129 |
+
embeddings_lr_scale: float = 5
|
130 |
+
embedding_projector_lr_scale: float = 2.5
|
131 |
+
output_mult :float = .35
|
132 |
+
query_mult :float = 1
|
133 |
+
encoder_depth_ratio :float = 0.25
|
134 |
+
eot_dropout_p :float = .5
|
135 |
+
cps_input: bool = True
|
136 |
+
cps_bins: int = 32
|
137 |
+
|
138 |
+
lr0 :float = 1.5e-3
|
139 |
+
clip_gradient_norm :float = .2
|
140 |
+
weight_decay :float = 1e-1
|
141 |
+
warmup_steps :float = 4000
|
142 |
+
|
143 |
+
random :bool = False
|
144 |
+
|
145 |
+
def __post_init__(self):
|
146 |
+
# randomize the hyperparams if requested
|
147 |
+
if self.random:
|
148 |
+
self.init_std = 10**rand(-1,1)
|
149 |
+
self.embeddings_std = 10**rand(-3,-.7)
|
150 |
+
self.embeddings_lr_scale = rand(2,6)
|
151 |
+
self.output_mult = rand(0.25,0.65)
|
152 |
+
self.query_mult = 2**rand(-2,3)
|
153 |
+
self.encoder_depth_ratio = 0.25
|
154 |
+
|
155 |
+
self.lr0 = rand(1,5)*1e-3
|
156 |
+
self.clip_gradient_norm = 10**rand(-3,0)
|
157 |
+
self.warmup_steps = 100*(10**rand(1,1.85))
|
158 |
+
|
159 |
+
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 15
|
160 |
+
class T2SEmbedding(nn.Module):
|
161 |
+
def __init__(self, length=1500, codes=1024, width=384, pos_embs=None, stoks_width=384):
|
162 |
+
super().__init__()
|
163 |
+
self.embedding = FlexEmbeddings(codes, width, special_codes=1, frozen_width=stoks_width)
|
164 |
+
if pos_embs is None: pos_embs = sinusoids(length, width)
|
165 |
+
self.register_buffer("positional_embedding", pos_embs)
|
166 |
+
|
167 |
+
def forward(self, Stoks, xenc, cps=None, offset=0):
|
168 |
+
Sembs = self.embedding(Stoks)
|
169 |
+
xin = (Sembs + self.positional_embedding[offset : offset + Sembs.shape[1]]).to(xenc.dtype)
|
170 |
+
if cps is not None: xin = xin + cps
|
171 |
+
return xin, offset
|
172 |
+
|
173 |
+
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 16
|
174 |
+
class Encoder(nn.Module):
|
175 |
+
def __init__(self, depth=6, width=384, n_head=6, length=1500, codes=1024, emb_width=384, ffn_mult=4, pos_embs=None, tunables=Tunables()):
|
176 |
+
super().__init__()
|
177 |
+
self.emb_width = emb_width
|
178 |
+
|
179 |
+
self.embedding = FlexEmbeddings(codes, width, frozen_width=emb_width)
|
180 |
+
|
181 |
+
if pos_embs is None: pos_embs = sinusoids(length, width)
|
182 |
+
self.register_buffer("positional_embedding", pos_embs)
|
183 |
+
|
184 |
+
self.layers = nn.ModuleList([
|
185 |
+
ResidualAttentionBlock(width, n_head,
|
186 |
+
qk_scale=tunables.query_mult*8/math.sqrt(width/n_head), ffn_mult=ffn_mult) for _ in range(depth)
|
187 |
+
])
|
188 |
+
|
189 |
+
self.ln_post = LayerNorm(width)
|
190 |
+
|
191 |
+
mask = torch.empty(length, length).fill_(-torch.inf).triu_(1)
|
192 |
+
self.register_buffer("mask", mask, persistent=False)
|
193 |
+
|
194 |
+
def forward(self, Stoks, positions, lang_emb=None):
|
195 |
+
xin = self.embedding(Stoks)
|
196 |
+
|
197 |
+
if lang_emb is not None: xin += lang_emb
|
198 |
+
|
199 |
+
# assert xin.shape[1:] == self.positional_embedding.shape, "incorrect semantic token shape"
|
200 |
+
x = (xin +
|
201 |
+
self.positional_embedding[positions]).to(xin.dtype)
|
202 |
+
|
203 |
+
for l in self.layers: x = l(x, positions, causal=False, mask=self.mask)
|
204 |
+
|
205 |
+
return self.ln_post(x)
|
206 |
+
|
207 |
+
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 17
|
208 |
+
class TSARTransformer(nn.Module):
|
209 |
+
def __init__(self, depth=6, n_head=6, head_width=64, ffn_mult=4,
|
210 |
+
ttoks_len=200, ttoks_codes=256, ttoks_width=None,
|
211 |
+
stoks_len=1500, stoks_codes=1024, stoks_width=None,
|
212 |
+
tunables=Tunables()):
|
213 |
+
super().__init__()
|
214 |
+
store_attr("depth,n_head,head_width,ffn_mult,stoks_width,ttoks_width,ttoks_len,stoks_len,ttoks_codes,stoks_codes")
|
215 |
+
|
216 |
+
width = n_head * head_width
|
217 |
+
self.width = width
|
218 |
+
self.base_width = 3 * head_width
|
219 |
+
self.tunables = tunables
|
220 |
+
if self.stoks_width is None: self.stoks_width = self.width
|
221 |
+
if self.ttoks_width is None: self.ttoks_width = self.width
|
222 |
+
|
223 |
+
self.lang_embeddings = nn.Embedding(len(languages.languages), width)
|
224 |
+
if tunables.cps_input:
|
225 |
+
self.cps_embeddings = nn.Embedding(tunables.cps_bins, self.width)
|
226 |
+
else:
|
227 |
+
self.cps_embeddings = None
|
228 |
+
|
229 |
+
encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio)
|
230 |
+
decoder_depth = depth * 2 - encoder_depth
|
231 |
+
tformer_args = dict(width=width, n_head=n_head, ffn_mult=ffn_mult, tunables=tunables)
|
232 |
+
self.encoder = Encoder(length=ttoks_len, codes=ttoks_codes, emb_width=self.ttoks_width, depth=encoder_depth, **tformer_args)
|
233 |
+
self.embeddings = T2SEmbedding(length=stoks_len, codes=stoks_codes, width=width, stoks_width=self.stoks_width)
|
234 |
+
|
235 |
+
self.decoder = BaseDecoder(
|
236 |
+
length=stoks_len,
|
237 |
+
depth=decoder_depth,
|
238 |
+
qk_scale=tunables.query_mult*8/math.sqrt(width/n_head),
|
239 |
+
width=width, n_head=n_head, ffn_mult=ffn_mult,
|
240 |
+
)
|
241 |
+
self.tokenizer = None
|
242 |
+
|
243 |
+
self.apply(self.init_transformer)
|
244 |
+
|
245 |
+
def load_frozen_semantic_embeddings(self, vqmodel):
|
246 |
+
self.embeddings.embedding.set_frozen_embeddings(vqmodel.rq.layers[0]._codebook.embed[0])
|
247 |
+
|
248 |
+
def setup(self, device):
|
249 |
+
pass
|
250 |
+
|
251 |
+
def init_transformer(self, m):
|
252 |
+
if isinstance(m, LinearHead):
|
253 |
+
m.no_weight_decay = True
|
254 |
+
torch.nn.init.constant_(m.weight, 0)
|
255 |
+
elif isinstance(m, QueryHead):
|
256 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
257 |
+
torch.nn.init.constant_(m.weight, 0)
|
258 |
+
elif isinstance(m, nn.Embedding):
|
259 |
+
m.no_weight_decay = True
|
260 |
+
m.lr_scale = self.tunables.embeddings_lr_scale
|
261 |
+
std = self.tunables.embeddings_std
|
262 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
263 |
+
elif isinstance(m, EmbeddingProjector):
|
264 |
+
m.lr_scale = self.tunables.embedding_projector_lr_scale
|
265 |
+
std = self.tunables.init_std
|
266 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
267 |
+
elif isinstance(m, nn.Linear):
|
268 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
269 |
+
std = self.tunables.init_std / m.weight.shape[1]
|
270 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
271 |
+
if m.bias is not None:
|
272 |
+
torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
|
273 |
+
elif isinstance(m, nn.LayerNorm):
|
274 |
+
m.no_weight_decay = True
|
275 |
+
torch.nn.init.constant_(m.bias, 0)
|
276 |
+
torch.nn.init.constant_(m.weight, 1)
|
277 |
+
|
278 |
+
def _embed_cps(self, cpss):
|
279 |
+
if self.cps_embeddings is None: return None
|
280 |
+
|
281 |
+
cps_bin = (cpss / 20 * self.tunables.cps_bins).to(torch.long)
|
282 |
+
cps_bin[cps_bin >= self.tunables.cps_bins] = self.tunables.cps_bins-1
|
283 |
+
return self.cps_embeddings(cps_bin).unsqueeze(1)
|
284 |
+
|
285 |
+
def run_encoder(self, in_ttoks, languages, cpss):
|
286 |
+
if len(languages.shape) != 3: lang_embs = self.lang_embeddings(languages)
|
287 |
+
else: lang_embs = languages
|
288 |
+
if len(lang_embs.shape) == 2: lang_embs = lang_embs.unsqueeze(1)
|
289 |
+
|
290 |
+
cps_emb = self._embed_cps(cpss)
|
291 |
+
|
292 |
+
with record_function("encoder"):
|
293 |
+
positions = torch.arange(0, in_ttoks.shape[1], device=in_ttoks.device)
|
294 |
+
xenc = self.encoder(in_ttoks.to(torch.long), positions, lang_emb=lang_embs)
|
295 |
+
|
296 |
+
return xenc, positions, cps_emb
|
297 |
+
|
298 |
+
def forward(self, in_ttoks, out_ttoks, languages, cpss, in_stoks, in_stoks_positions, out_stoks=None, loss=True, offset=None, xenc=None, xenc_positions=None, cps_emb=None):
|
299 |
+
if xenc is None:
|
300 |
+
xenc, cps_emb = self.run_encoder(in_ttoks, languages, cpss)
|
301 |
+
|
302 |
+
with record_function("decoder"):
|
303 |
+
x = (self.embeddings.embedding(in_stoks) +
|
304 |
+
self.embeddings.positional_embedding[in_stoks_positions] +
|
305 |
+
cps_emb).to(xenc[0].dtype)
|
306 |
+
x = self.decoder(x, in_stoks_positions, xenc, xenc_positions)
|
307 |
+
logits = self.embeddings.embedding.unembed(x)
|
308 |
+
logits = logits * self.tunables.output_mult / (self.width / self.base_width)
|
309 |
+
|
310 |
+
if loss is not None:
|
311 |
+
enc_logits = self.encoder.embedding.unembed(xenc[0])
|
312 |
+
enc_logits = enc_logits * self.tunables.output_mult / (self.width / self.base_width)
|
313 |
+
with record_function("loss"):
|
314 |
+
loss = F.cross_entropy(logits.transpose(-1,-2), out_stoks)
|
315 |
+
if self.training:
|
316 |
+
loss += 0.1 * F.cross_entropy(enc_logits.transpose(-1,-2), out_ttoks)
|
317 |
+
|
318 |
+
return logits, loss
|
319 |
+
|
320 |
+
#
|
321 |
+
# inference
|
322 |
+
#
|
323 |
+
@classmethod
|
324 |
+
def load_model(cls, ref="collabora/whisperspeech:t2s-small-en+pl.model",
|
325 |
+
repo_id=None, filename=None, local_filename=None):
|
326 |
+
if repo_id is None and filename is None and local_filename is None:
|
327 |
+
if ":" in ref:
|
328 |
+
repo_id, filename = ref.split(":", 1)
|
329 |
+
else:
|
330 |
+
local_filename = ref
|
331 |
+
if not local_filename:
|
332 |
+
local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
|
333 |
+
spec = torch.load(local_filename)
|
334 |
+
model = cls(**spec['config'], tunables=Tunables(**spec['tunables']))
|
335 |
+
model.load_state_dict(spec['state_dict'])
|
336 |
+
model.eval()
|
337 |
+
return model
|
338 |
+
|
339 |
+
def load_checkpoint(self, local_filename):
|
340 |
+
spec = torch.load(local_filename, map_location='cpu')
|
341 |
+
assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
|
342 |
+
state_dict = {k.replace('model.', ''):v
|
343 |
+
for k,v in spec['state_dict'].items()}
|
344 |
+
self.load_state_dict(state_dict)
|
345 |
+
return self
|
346 |
+
|
347 |
+
def save_model(self, fname):
|
348 |
+
torch.save(dict(config = self.__stored_args__,
|
349 |
+
tunables = dataclasses.asdict(self.tunables),
|
350 |
+
state_dict = self.state_dict()), fname)
|
351 |
+
|
352 |
+
def ensure_tokenizer(self):
|
353 |
+
assert not self.training
|
354 |
+
if self.tokenizer is None: self.tokenizer = CharTokenizer()
|
355 |
+
|
356 |
+
def switch_dtypes(self, dtype=torch.float16):
|
357 |
+
self.dtype = dtype
|
358 |
+
for n,m in self.named_modules():
|
359 |
+
# convert every leaf layer apart from the LayerNorms
|
360 |
+
if isinstance(m, (nn.Linear, nn.Embedding)):
|
361 |
+
m.to(dtype)
|
362 |
+
# take care of buffers ([kv]_cache, masks) that are not in the leaf layers
|
363 |
+
for bn,b in m.named_buffers(recurse=False):
|
364 |
+
setattr(m,bn,b.to(dtype))
|
365 |
+
|
366 |
+
def optimize(self, max_batch_size=1, dtype=torch.float16, torch_compile=True):
|
367 |
+
for emb in [self.embeddings.embedding, self.embeddings.embedding]:
|
368 |
+
emb.convert_for_eval()
|
369 |
+
for l in self.encoder.layers:
|
370 |
+
l.attn.convert_for_eval()
|
371 |
+
for l in self.decoder.layers:
|
372 |
+
l.attn.convert_for_eval()
|
373 |
+
l.cross_attn.convert_for_eval()
|
374 |
+
l.setup_kv_cache(max_batch_size, self.stoks_len, self.ttoks_len)
|
375 |
+
self.switch_dtypes(dtype)
|
376 |
+
if torch_compile:
|
377 |
+
self.generate_next = torch.compile(self.generate_next, mode="reduce-overhead", fullgraph=True)
|
378 |
+
|
379 |
+
@property
|
380 |
+
def device(self):
|
381 |
+
return next(self.parameters()).device
|
382 |
+
|
383 |
+
# from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
|
384 |
+
def multinomial_sample_one_no_sync(self, probs_sort): # Does multinomial sampling without a cuda synchronization
|
385 |
+
q = torch.empty_like(probs_sort).exponential_(1)
|
386 |
+
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
387 |
+
|
388 |
+
def logits_to_probs(self, logits, T=1.0, top_k=None):
|
389 |
+
logits = logits / max(T, 1e-5)
|
390 |
+
|
391 |
+
logits[self.embeddings.embedding.codes:] = -torch.inf
|
392 |
+
if top_k is not None:
|
393 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
394 |
+
pivot = v.select(-1, -1).unsqueeze(-1)
|
395 |
+
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
396 |
+
|
397 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
398 |
+
return probs
|
399 |
+
|
400 |
+
def sample(self, logits, T=1.0, top_k=None):
|
401 |
+
probs = self.logits_to_probs(logits[0,-1], T, top_k)
|
402 |
+
idx_next = self.multinomial_sample_one_no_sync(probs)
|
403 |
+
return idx_next
|
404 |
+
|
405 |
+
def generate_one(self, toks, toks_positions, cps_emb, xenc, xenc_positions, T, top_k):
|
406 |
+
probs, _ = self(None, None, None, None, toks, toks_positions, loss=None, xenc=xenc, xenc_positions=xenc_positions, cps_emb=cps_emb)
|
407 |
+
return self.sample(probs, T, top_k)
|
408 |
+
|
409 |
+
def generate_next(self, *args, **kwargs):
|
410 |
+
return self.generate_one(*args, **kwargs)
|
411 |
+
|
412 |
+
@torch.no_grad()
|
413 |
+
def prep(self, txt, cps=15, lang="en"):
|
414 |
+
dev = self.device
|
415 |
+
ttoks = torch.tensor(self.tokenizer.encode(txt), device=dev)
|
416 |
+
ttoks = F.pad(ttoks, (0, self.ttoks_len - len(ttoks)), value=self.tokenizer.eot).unsqueeze(0)
|
417 |
+
cpss = torch.tensor([cps], device=dev)
|
418 |
+
langs = torch.tensor([languages.to_id(lang)], device=dev)
|
419 |
+
return ttoks, cpss, langs
|
420 |
+
|
421 |
+
@torch.no_grad()
|
422 |
+
def generate(self, txt, cps=15, lang="en", N=None, T=0.7, top_k=None, step=None, show_progress_bar=True):
|
423 |
+
self.ensure_tokenizer()
|
424 |
+
N = N or self.stoks_len
|
425 |
+
dev = self.device
|
426 |
+
ttoks = []
|
427 |
+
langs = []
|
428 |
+
if isinstance(lang, list):
|
429 |
+
lang0 = lang[0]
|
430 |
+
assert isinstance(txt, list), "lang and txt have to be both lists or strings"
|
431 |
+
for txt, lang in zip(txt, lang):
|
432 |
+
tt = self.tokenizer.encode(txt)
|
433 |
+
ttoks += tt
|
434 |
+
langs += [languages.to_id(lang)] * len(tt)
|
435 |
+
elif isinstance(lang, torch.Tensor):
|
436 |
+
langs = lang
|
437 |
+
ttoks = self.tokenizer.encode(txt)
|
438 |
+
else:
|
439 |
+
lang0 = lang
|
440 |
+
ttoks = self.tokenizer.encode(txt)
|
441 |
+
langs = torch.tensor([languages.to_id(lang)], device=dev).unsqueeze(0)
|
442 |
+
ttoks = torch.tensor(ttoks, device=dev)
|
443 |
+
ttoks = F.pad(ttoks, (1, self.ttoks_len - len(ttoks) - 1), value=self.tokenizer.eot).unsqueeze(0)
|
444 |
+
cpss = torch.tensor([cps], device=dev)
|
445 |
+
if not isinstance(langs, torch.Tensor):
|
446 |
+
langs = torch.tensor(langs, device=dev)
|
447 |
+
langs = F.pad(langs, (1, self.ttoks_len - len(langs) - 1), value=languages.to_id(lang0)).unsqueeze(0)
|
448 |
+
it = range(0,N-1)
|
449 |
+
if show_progress_bar: it = progress_bar(it)
|
450 |
+
|
451 |
+
toks = torch.zeros((1,N), dtype=torch.long, device=dev)
|
452 |
+
toks[:,0] = self.stoks_codes-1
|
453 |
+
toks_positions = torch.arange(N, device=dev)
|
454 |
+
with record_function("encode"):
|
455 |
+
xenc, xenc_positions, cps_emb = self.run_encoder(ttoks, langs, cpss)
|
456 |
+
toks_positions = torch.arange(N+1, device=dev)
|
457 |
+
# contrary to S2A this model works without prefill and is actually a tiny bit faster
|
458 |
+
# with record_function("prefill"):
|
459 |
+
# toks[0,1] = self.generate_one(toks[:,:1], toks_positions[:1], cps_emb, xenc, xenc_positions, T, top_k)
|
460 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
461 |
+
for i in it:
|
462 |
+
toks[0,i+1] = self.generate_next(toks[:,i:i+1], toks_positions[i:i+1], cps_emb, xenc, xenc_positions, T, top_k)
|
463 |
+
if i % 25 == 0 and toks[0,i+1] == self.stoks_codes-1: return toks[0,:i+1]
|
464 |
+
|
465 |
+
# for profiling, debugging or early exit
|
466 |
+
if step is not None: step()
|
467 |
+
return toks[0,:]
|
468 |
+
|
469 |
+
@torch.no_grad()
|
470 |
+
def generate_batch(self, txts, N=None, T=1.1, top_k=7, show_progress_bar=True):
|
471 |
+
self.ensure_tokenizer()
|
472 |
+
N = self.stoks_len
|
473 |
+
dev = self.device
|
474 |
+
ttoks = []
|
475 |
+
for txt in txts:
|
476 |
+
ttoks_ = torch.tensor(self.tokenizer.encode(txt), device=dev)
|
477 |
+
ttoks_ = F.pad(ttoks_, (0, self.ttoks_len - len(ttoks_)), value=self.tokenizer.eot).unsqueeze(0)
|
478 |
+
ttoks.append(ttoks_)
|
479 |
+
ttoks = torch.cat(ttoks, dim=0)
|
480 |
+
toks = torch.zeros((len(ttoks),N), dtype=torch.long, device=dev)
|
481 |
+
it = range(N)
|
482 |
+
if show_progress_bar: it = progress_bar(it)
|
483 |
+
for i in it:
|
484 |
+
p, _ = self(ttoks, toks[:,:i], loss=None)
|
485 |
+
last_p = p[:,-1]
|
486 |
+
if top_k:
|
487 |
+
last_p[last_p < torch.topk(last_p, top_k).values[:,-1,None]] = -torch.inf
|
488 |
+
tok = torch.multinomial((last_p / float(T)).softmax(-1), 1)
|
489 |
+
toks[:,i] = tok[:,0]
|
490 |
+
if (toks[:,i] == self.stoks_codes-1).all(): return toks[:,:i]
|
491 |
+
return toks
|
492 |
+
|
493 |
+
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 18
|
494 |
+
def _make_model(size:str, tunables:Tunables=Tunables(), dataset=None, **kwargs):
|
495 |
+
kwargs = dict(stoks_len = dataset.stoks_len, ttoks_len = dataset.ttoks_len, tunables=tunables, **kwargs)
|
496 |
+
if 'stoks_codes' not in kwargs: kwargs['stoks_codes'] = dataset.stoks_codes
|
497 |
+
if size == 'micro':
|
498 |
+
return TSARTransformer(depth=2, n_head=3, ffn_mult=1, **kwargs)
|
499 |
+
if size == 'tiny':
|
500 |
+
return TSARTransformer(depth=4, n_head=6, **kwargs)
|
501 |
+
if size == 'base':
|
502 |
+
return TSARTransformer(depth=6, n_head=8, **kwargs)
|
503 |
+
if size == 'small':
|
504 |
+
return TSARTransformer(depth=12, n_head=12, **kwargs)
|
505 |
+
if size == 'small+':
|
506 |
+
return TSARTransformer(depth=12, n_head=16, **kwargs)
|
507 |
+
if size == 'medium':
|
508 |
+
return TSARTransformer(depth=24, n_head=16, **kwargs)
|
509 |
+
|
510 |
+
def make_model(size:str, frozen_embeddings_model:str=None, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None):
|
511 |
+
from whisperspeech import vq_stoks
|
512 |
+
|
513 |
+
if frozen_embeddings_model:
|
514 |
+
vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model)
|
515 |
+
model = _make_model(size, tunables, dataset, stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1])
|
516 |
+
model.load_frozen_semantic_embeddings(vqmodel)
|
517 |
+
else:
|
518 |
+
model = _make_model(size, tunables, dataset, mode=mode)
|
519 |
+
return model
|
whisperspeech/train.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B1. Training.ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = ['SimpleVisual', 'validate', 'train']
|
5 |
+
|
6 |
+
# %% ../nbs/B1. Training.ipynb 2
|
7 |
+
import io
|
8 |
+
import time
|
9 |
+
import random
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
from fastprogress import progress_bar, master_bar
|
13 |
+
import fastprogress
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import pylab as plt
|
17 |
+
import math
|
18 |
+
|
19 |
+
import IPython
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
from torch.utils.data.dataloader import DataLoader
|
24 |
+
from torch.profiler import record_function
|
25 |
+
|
26 |
+
import webdataset as wds
|
27 |
+
|
28 |
+
torch.backends.cudnn.benchmark = True
|
29 |
+
torch.backends.cudnn.enabled = True
|
30 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
31 |
+
torch.set_float32_matmul_precision('medium')
|
32 |
+
|
33 |
+
# %% ../nbs/B1. Training.ipynb 3
|
34 |
+
class SimpleVisual:
|
35 |
+
def __init__ (self, model, masterbar, total_steps):
|
36 |
+
self.model = model
|
37 |
+
self.masterbar = masterbar
|
38 |
+
self.total_steps = total_steps
|
39 |
+
self.epochs = total_steps // masterbar.main_bar.total
|
40 |
+
|
41 |
+
gs = plt.GridSpec(2, 1, height_ratios=[3,1])
|
42 |
+
graph_fig = plt.figure(figsize=(10,6))
|
43 |
+
self.graph_fig = graph_fig
|
44 |
+
self.loss_p = graph_fig.add_subplot(gs[0])
|
45 |
+
self.lr_p = graph_fig.add_subplot(gs[1], sharex=self.loss_p)
|
46 |
+
self.lr_p.tick_params('x', labelbottom=False)
|
47 |
+
self.graph_out = None
|
48 |
+
|
49 |
+
self.its = []
|
50 |
+
self.train_losses = []
|
51 |
+
self.val_losses = []
|
52 |
+
self.lr_history = []
|
53 |
+
|
54 |
+
def show(self):
|
55 |
+
self.start_t = time.time()
|
56 |
+
self.masterbar.write(["samples", "train", "val", "time"], table=True)
|
57 |
+
self.graph_out = display(self.graph_fig, display_id=True, clear=True)
|
58 |
+
|
59 |
+
def hide(self):
|
60 |
+
if self.graph_out is not None:
|
61 |
+
self.graph_out.update(IPython.display.HTML(''))
|
62 |
+
|
63 |
+
def plot(self):
|
64 |
+
loss_p, lr_p = self.loss_p, self.lr_p
|
65 |
+
loss_p.clear()
|
66 |
+
loss_p.plot(self.its, self.train_losses)
|
67 |
+
loss_p.plot(self.its, self.val_losses)
|
68 |
+
loss_p.set_xlim(0, self.total_steps)
|
69 |
+
loss_p.set_yscale('log')
|
70 |
+
lr_p.clear()
|
71 |
+
lrs = np.array(self.lr_history)
|
72 |
+
lr_p.plot(self.its, lrs)
|
73 |
+
self.graph_out.update(self.graph_fig)
|
74 |
+
|
75 |
+
def add_data(self, it, lr, train_loss, val_los):
|
76 |
+
self.its.append(it)
|
77 |
+
self.train_losses.append(train_loss)
|
78 |
+
self.val_losses.append(val_los)
|
79 |
+
self.lr_history.append(lr)
|
80 |
+
self.plot()
|
81 |
+
|
82 |
+
def add_table_row(self, it, avg_train_loss, val_loss):
|
83 |
+
elapsed_t = time.time() - self.start_t
|
84 |
+
self.masterbar.write([it, f"{avg_train_loss:.5f}", f"{val_loss:.5f}", fastprogress.core.format_time(elapsed_t)], table=True)
|
85 |
+
|
86 |
+
def on_iter(self, bar, it, avg_train_loss, val_loss):
|
87 |
+
epoch = math.ceil(it / self.total_steps * self.epochs)
|
88 |
+
bar.comment = f"#{epoch}/{self.epochs} loss: {avg_train_loss:.3f} / {val_loss:.3f}"
|
89 |
+
|
90 |
+
# %% ../nbs/B1. Training.ipynb 4
|
91 |
+
# FIXME: we need to keep this synchronised with the validation code below...
|
92 |
+
def validate(model, val, half=True, bs=16, drop_last=False, dl_workers=8, device="cuda"):
|
93 |
+
if isinstance(val, torch.utils.data.IterableDataset):
|
94 |
+
val_loader = wds.WebLoader(val, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \
|
95 |
+
.unbatched().shuffle(1024).batched(bs)
|
96 |
+
else:
|
97 |
+
val_loader = DataLoader(val, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last)
|
98 |
+
|
99 |
+
with torch.no_grad():
|
100 |
+
val_loss = 0
|
101 |
+
val_samples = 0
|
102 |
+
for args in val_loader:
|
103 |
+
args = [x.to(device, non_blocking=True) for x in args]
|
104 |
+
with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'):
|
105 |
+
ps, loss = model(*args)
|
106 |
+
N = args[0].shape[0]
|
107 |
+
val_loss += loss.mean().item() * N
|
108 |
+
val_samples += N
|
109 |
+
val_loss = val_loss / val_samples
|
110 |
+
|
111 |
+
return val_loss
|
112 |
+
|
113 |
+
# %% ../nbs/B1. Training.ipynb 5
|
114 |
+
def train(checkpoint_path, model, train, val, half=True, bs=16, lr=1e-4, drop_last=False,
|
115 |
+
weight_decay=0.1, warmup_steps=10000, epochs=10, clip_gradient_norm=None,
|
116 |
+
dl_workers=8, visual_class = SimpleVisual, profiler=None,
|
117 |
+
run_valid_every_iters=8000, table_row_every_iters=80000, chkpt_every_iters=None,
|
118 |
+
device="cuda", trainable_params=None):
|
119 |
+
if chkpt_every_iters is None:
|
120 |
+
chkpt_every_iters = table_row_every_iters
|
121 |
+
|
122 |
+
mb = master_bar(range(epochs))
|
123 |
+
if isinstance(train, torch.utils.data.IterableDataset):
|
124 |
+
pct_start = min(0.3, warmup_steps / (epochs * (train.total_samples//bs)))
|
125 |
+
visual = visual_class(model, mb, epochs * train.total_samples)
|
126 |
+
# pct_start = min(0.3, warmup_steps / (epochs * len(train)))
|
127 |
+
# visual = visual_class(model, mb, epochs*len(train)*bs)
|
128 |
+
else:
|
129 |
+
pct_start = min(0.3, warmup_steps / (epochs * len(train) / bs))
|
130 |
+
visual = visual_class(model, mb, epochs*len(train))
|
131 |
+
model.visual = visual
|
132 |
+
|
133 |
+
Path(checkpoint_path).mkdir(exist_ok=True)
|
134 |
+
|
135 |
+
if isinstance(train, torch.utils.data.IterableDataset):
|
136 |
+
# train_loader = DataLoader(train, batch_size=None, num_workers=dl_workers, pin_memory=True, drop_last=False, shuffle=False)
|
137 |
+
# val_loader = DataLoader(val, batch_size=None, num_workers=dl_workers, pin_memory=True, drop_last=False)
|
138 |
+
train_loader = wds.WebLoader(train, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \
|
139 |
+
.unbatched().shuffle(1024).batched(bs, partial=False)
|
140 |
+
val_loader = wds.WebLoader(val, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \
|
141 |
+
.unbatched().shuffle(1024).batched(bs)
|
142 |
+
else:
|
143 |
+
train_loader = DataLoader(train, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last, shuffle=True)
|
144 |
+
val_loader = DataLoader(val, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last)
|
145 |
+
|
146 |
+
val_loss = torch.nan
|
147 |
+
avg_train_loss = torch.nan
|
148 |
+
|
149 |
+
if hasattr(model, 'setup'):
|
150 |
+
model.setup(device)
|
151 |
+
|
152 |
+
try:
|
153 |
+
scheduler = None
|
154 |
+
|
155 |
+
if trainable_params is None: trainable_params = model.parameters()
|
156 |
+
all_params = set(trainable_params)
|
157 |
+
customized_params = set()
|
158 |
+
groups = []
|
159 |
+
group_map = {}
|
160 |
+
for name,m in model.named_modules():
|
161 |
+
if hasattr(m, 'no_weight_decay') or hasattr(m, 'lr_scale'):
|
162 |
+
m_trainable = [x for x in m.parameters() if x in all_params]
|
163 |
+
if not m_trainable: continue
|
164 |
+
customized_params |= set(m_trainable)
|
165 |
+
m_wd = 0 if hasattr(m, 'no_weight_decay') else weight_decay
|
166 |
+
m_lr = lr * getattr(m, 'lr_scale', 1)
|
167 |
+
group = group_map.get((m_wd, m_lr), None)
|
168 |
+
if not group:
|
169 |
+
group = {"params": [], "names": [], "weight_decay": m_wd, "lr": m_lr}
|
170 |
+
groups.append(group)
|
171 |
+
group_map[(m_wd, m_lr)] = group
|
172 |
+
group['params'] += m_trainable
|
173 |
+
group['names'].append(name)
|
174 |
+
|
175 |
+
other_params = all_params - customized_params
|
176 |
+
|
177 |
+
if other_params:
|
178 |
+
groups = groups + [
|
179 |
+
{"names": ["other"], "params": list(other_params), "weight_decay": weight_decay },
|
180 |
+
]
|
181 |
+
|
182 |
+
optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), fused=device!='cpu', params=groups)
|
183 |
+
model._optimizer = optimizer
|
184 |
+
scaler = torch.cuda.amp.GradScaler(enabled=half)
|
185 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
186 |
+
optimizer, pct_start=pct_start, steps_per_epoch=math.ceil(train.total_samples/bs), epochs=epochs,
|
187 |
+
max_lr=[pg.get('lr', lr) for pg in groups],
|
188 |
+
final_div_factor=25)
|
189 |
+
|
190 |
+
it = 0
|
191 |
+
next_val_it = it + 50
|
192 |
+
next_chkpt_it = chkpt_every_iters
|
193 |
+
next_table_it = table_row_every_iters
|
194 |
+
|
195 |
+
visual.show()
|
196 |
+
|
197 |
+
running_loss = [0]
|
198 |
+
|
199 |
+
for epoch in mb:
|
200 |
+
bar = progress_bar(train_loader, total=train.total_samples//bs, parent=mb)
|
201 |
+
for args in bar:
|
202 |
+
with record_function("forward"):
|
203 |
+
args = [x.to(device, non_blocking=True) for x in args]
|
204 |
+
|
205 |
+
# zero the parameter gradients
|
206 |
+
optimizer.zero_grad(set_to_none=True)
|
207 |
+
|
208 |
+
with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'):
|
209 |
+
ps, loss = model(*args)
|
210 |
+
loss = loss.mean()
|
211 |
+
|
212 |
+
with record_function("backward"):
|
213 |
+
scaler.scale(loss).backward()
|
214 |
+
|
215 |
+
if clip_gradient_norm:
|
216 |
+
scaler.unscale_(optimizer)
|
217 |
+
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
|
218 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_gradient_norm)
|
219 |
+
|
220 |
+
scaler.step(optimizer)
|
221 |
+
scaler.update()
|
222 |
+
|
223 |
+
scheduler.step()
|
224 |
+
|
225 |
+
if profiler is not None: profiler.step()
|
226 |
+
|
227 |
+
with record_function("running_loss"):
|
228 |
+
running_loss.append(loss.item())
|
229 |
+
running_loss = running_loss[-5:]
|
230 |
+
avg_train_loss = sum(running_loss)/len(running_loss)
|
231 |
+
|
232 |
+
if it >= next_chkpt_it:
|
233 |
+
with record_function("checkpoint"):
|
234 |
+
next_chkpt_it += chkpt_every_iters
|
235 |
+
torch.save(model.state_dict(), f'{checkpoint_path}/{it:08d}.pt')
|
236 |
+
|
237 |
+
if it >= next_val_it:
|
238 |
+
next_val_it += run_valid_every_iters
|
239 |
+
with record_function("validation"):
|
240 |
+
with record_function("model.eval"):
|
241 |
+
model.eval()
|
242 |
+
with torch.no_grad():
|
243 |
+
val_loss = 0
|
244 |
+
val_samples = 0
|
245 |
+
for args in val_loader:
|
246 |
+
args = [x.to(device, non_blocking=True) for x in args]
|
247 |
+
with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'):
|
248 |
+
ps, loss = model(*args)
|
249 |
+
N = args[0].shape[0]
|
250 |
+
val_loss += loss.mean().item() * N
|
251 |
+
val_samples += N
|
252 |
+
val_loss = val_loss / val_samples
|
253 |
+
with record_function("model.train"):
|
254 |
+
model.train()
|
255 |
+
with record_function("plotting"):
|
256 |
+
visual.add_data(it, scheduler.get_last_lr(), avg_train_loss, val_loss)
|
257 |
+
|
258 |
+
if it >= next_table_it:
|
259 |
+
visual.add_table_row(it, avg_train_loss, val_loss)
|
260 |
+
next_table_it += table_row_every_iters
|
261 |
+
|
262 |
+
it += bs
|
263 |
+
visual.on_iter(bar, it, avg_train_loss, val_loss)
|
264 |
+
except KeyboardInterrupt:
|
265 |
+
mb.write(f"interrupted")
|
266 |
+
mb.show()
|
267 |
+
pass
|
268 |
+
finally:
|
269 |
+
visual.add_table_row(it, avg_train_loss, val_loss)
|
270 |
+
mb.show()
|
271 |
+
visual.hide()
|
whisperspeech/train_multi.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B2. Training (Lightning).ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = []
|
5 |
+
|
6 |
+
# %% ../nbs/B2. Training (Lightning).ipynb 2
|
7 |
+
import io
|
8 |
+
import time
|
9 |
+
import random
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
from fastprogress import progress_bar, master_bar
|
13 |
+
import fastprogress
|
14 |
+
import wandb
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import pylab as plt
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
from torch.utils.data.dataloader import DataLoader
|
22 |
+
from torch.profiler import record_function
|
23 |
+
|
24 |
+
# %% ../nbs/B2. Training (Lightning).ipynb 3
|
25 |
+
import lightning.pytorch as pl
|
26 |
+
import math
|
27 |
+
|
28 |
+
class TrainingTask(pl.LightningModule):
|
29 |
+
def __init__(self, model, model_hparams=None):
|
30 |
+
super().__init__()
|
31 |
+
self.model = model
|
32 |
+
self.model_hparams = model_hparams
|
33 |
+
|
34 |
+
def on_fit_start(self):
|
35 |
+
if getattr(self.model, 'setup'):
|
36 |
+
self.model.setup(self.device)
|
37 |
+
|
38 |
+
def configure_optimizers(self):
|
39 |
+
""" Initialize AdamW optimizer"""
|
40 |
+
lr = self.model_hparams['lr0']
|
41 |
+
weight_decay = self.model_hparams['weight_decay']
|
42 |
+
|
43 |
+
all_params = set(model.parameters())
|
44 |
+
customized_params = set()
|
45 |
+
groups = []
|
46 |
+
group_map = {}
|
47 |
+
for name,m in model.named_modules():
|
48 |
+
if hasattr(m, 'no_weight_decay') or hasattr(m, 'lr_scale'):
|
49 |
+
customized_params |= set(m.parameters())
|
50 |
+
m_wd = 0 if hasattr(m, 'no_weight_decay') else weight_decay
|
51 |
+
m_lr = lr * getattr(m, 'lr_scale', 1)
|
52 |
+
group = group_map.get((m_wd, m_lr), None)
|
53 |
+
if not group:
|
54 |
+
group = {"params": [], "names": [], "weight_decay": m_wd, "lr": m_lr}
|
55 |
+
groups.append(group)
|
56 |
+
group_map[(m_wd, m_lr)] = group
|
57 |
+
group['params'] += m.parameters()
|
58 |
+
group['names'].append(name)
|
59 |
+
|
60 |
+
other_params = all_params - customized_params
|
61 |
+
|
62 |
+
param_groups = groups + [
|
63 |
+
{"names": ["other"], "params": list(other_params), "weight_decay": weight_decay },
|
64 |
+
]
|
65 |
+
|
66 |
+
optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), params=param_groups)
|
67 |
+
|
68 |
+
# modified from https://github.com/Lightning-AI/lightning/issues/5449#issuecomment-1501597319
|
69 |
+
def num_steps_per_epoch() -> int:
|
70 |
+
"""Get number of steps"""
|
71 |
+
# Accessing _data_source is flaky and might break
|
72 |
+
dataset = self.trainer.fit_loop._data_source.dataloader()
|
73 |
+
dataset_size = len(dataset)
|
74 |
+
# math.ceil so always overestimate (underestimating throws exceptions)
|
75 |
+
num_steps = math.ceil(dataset_size / self.trainer.accumulate_grad_batches)
|
76 |
+
return num_steps
|
77 |
+
|
78 |
+
total_steps = self.model_hparams['epochs'] * num_steps_per_epoch()
|
79 |
+
self.model_hparams['pct_start'] = min(0.3, self.model_hparams['warmup_steps'] / total_steps)
|
80 |
+
|
81 |
+
print(f"{self.model_hparams['epochs']=} epochs x {num_steps_per_epoch()=} steps")
|
82 |
+
|
83 |
+
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
84 |
+
optimizer,
|
85 |
+
pct_start=self.model_hparams['pct_start'],
|
86 |
+
max_lr=[pg.get('lr', lr) for pg in param_groups],
|
87 |
+
steps_per_epoch=num_steps_per_epoch(),
|
88 |
+
epochs=int(self.model_hparams['epochs']),
|
89 |
+
final_div_factor=25
|
90 |
+
)
|
91 |
+
|
92 |
+
return [optimizer], [{'scheduler': lr_scheduler, 'interval': 'step'}]
|
93 |
+
|
94 |
+
def training_step(self, train_batch, batch_idx):
|
95 |
+
train_logits, train_loss = self.model.forward(*train_batch)
|
96 |
+
|
97 |
+
self.log("train_loss", train_loss, sync_dist=True)
|
98 |
+
return train_loss
|
99 |
+
|
100 |
+
def validation_step(self, val_batch, batch_idx):
|
101 |
+
val_logits, val_loss = self.model.forward(*val_batch)
|
102 |
+
|
103 |
+
self.log("val_loss", val_loss, sync_dist=True)
|
104 |
+
return val_loss
|
105 |
+
|
106 |
+
def on_validation_epoch_end(self):
|
107 |
+
if hasattr(self.model, 'get_metrics'):
|
108 |
+
self.log_dict({'metrics/'+k:v for k,v in self.model.get_metrics().items()}, sync_dist=True)
|
109 |
+
|
110 |
+
def test_step(self, val_batch, batch_idx):
|
111 |
+
test_logits, test_loss = self.model.forward(*val_batch)
|
112 |
+
|
113 |
+
self.log("test_loss", test_loss, sync_dist=True)
|
114 |
+
return test_loss
|
115 |
+
|
116 |
+
# %% ../nbs/B2. Training (Lightning).ipynb 4
|
117 |
+
from fastcore.script import anno_parser
|
118 |
+
import shlex
|
119 |
+
|
120 |
+
# watch out: we can only pass Python values as keyword arguments (not positional)
|
121 |
+
# everything else has to be a string
|
122 |
+
def parse_and_call(name, fun, args, kwargs={}, log_to_wandb=True):
|
123 |
+
p = anno_parser(fun)
|
124 |
+
args = p.parse_args(args).__dict__
|
125 |
+
args.pop('xtra'); args.pop('pdb')
|
126 |
+
args.update({k:v for k, v in kwargs.items()})
|
127 |
+
if log_to_wandb and type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config:
|
128 |
+
wandb_logger.experiment.config[name] = {k:v for k,v in args.items() if k not in ['dataset', 'tunables']}
|
129 |
+
return fun(**args)
|
130 |
+
|
131 |
+
# %% ../nbs/B2. Training (Lightning).ipynb 8
|
132 |
+
import argparse
|
133 |
+
|
134 |
+
parser = argparse.ArgumentParser()
|
135 |
+
parser.add_argument('--task', type=str, help='Task to train')
|
136 |
+
parser.add_argument('--seed', type=int, default=0, help='Global training seed')
|
137 |
+
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
|
138 |
+
parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
|
139 |
+
parser.add_argument('--input-dir', type=str, default='', help='input data path') # fixed in the model for now
|
140 |
+
parser.add_argument("--checkpoint-dir", type=str, default="./checkpoints/", help="directory to save the checkpoints")
|
141 |
+
parser.add_argument('--epochs', type=int, default=10, help='total training epochs')
|
142 |
+
parser.add_argument('--validate-every-n-steps', type=int, default=500, help='how training steps to run between validations')
|
143 |
+
parser.add_argument('--weight-decay', type=float, default=1e-2, help='optimizer weight decay')
|
144 |
+
parser.add_argument('--lr0', type=float, default=1e-4, help='optimizer initial learning rate')
|
145 |
+
parser.add_argument('--clip-gradient-norm', type=float, default=None, help='enable gradient norm clipping')
|
146 |
+
parser.add_argument('--accumulate-grad-batches', type=int, default=1, help='perform the optimizer step only after going through several batches of samples')
|
147 |
+
parser.add_argument('--precision', type=str, default="16-mixed", help="floating point precision")
|
148 |
+
parser.add_argument('--warmup-steps', type=int, default=10000, help='total number steps during which the learning rate rises (defaults to 10k updates)')
|
149 |
+
parser.add_argument('--tunables', type=str, default="", help='tunable hyperparameters')
|
150 |
+
parser.add_argument('--resume-from', type=Path, default=None, help='resume training from the given checkpoint')
|
151 |
+
parser.add_argument('--strategy', type=str, default='ddp', help='distributed training strategy')
|
152 |
+
parser.add_argument('--wandb-suffix', type=str, default=None, help='W&B project name suffix')
|
153 |
+
parser.add_argument('--wandb-task-name', type=str, default=None, help='Task name for the W&B project name')
|
154 |
+
|
155 |
+
args = parser.parse_args().__dict__
|
156 |
+
|
157 |
+
task_args: list = shlex.split(args.pop("task"))
|
158 |
+
task_name, task_args = task_args[0], task_args[1:]
|
159 |
+
input_args: list = shlex.split(args.pop("input_dir"))
|
160 |
+
checkpoint_dir: str = args.pop("checkpoint_dir")
|
161 |
+
num_workers: int = args.pop("workers")
|
162 |
+
batch_size: int = args.pop("batch_size")
|
163 |
+
epochs: int = args.pop("epochs")
|
164 |
+
tunables_args: list = shlex.split(args.pop("tunables"))
|
165 |
+
|
166 |
+
hyp_params = {}
|
167 |
+
hyp_params['batch_size'] = batch_size
|
168 |
+
hyp_params['warmup_steps'] = args['warmup_steps']
|
169 |
+
hyp_params['weight_decay'] = args['weight_decay']
|
170 |
+
hyp_params['clip_gradient_norm'] = args['clip_gradient_norm']
|
171 |
+
hyp_params['accumulate_grad_batches'] = args['accumulate_grad_batches']
|
172 |
+
hyp_params['precision'] = args['precision']
|
173 |
+
hyp_params['lr0'] = args['lr0']
|
174 |
+
hyp_params['epochs'] = epochs
|
175 |
+
hyp_params['strategy'] = args['strategy']
|
176 |
+
|
177 |
+
# %% ../nbs/B2. Training (Lightning).ipynb 9
|
178 |
+
from lightning.pytorch.loggers import WandbLogger
|
179 |
+
from lightning.pytorch.callbacks import LearningRateMonitor
|
180 |
+
import datetime
|
181 |
+
import webdataset as wds
|
182 |
+
import importlib
|
183 |
+
|
184 |
+
torch.set_float32_matmul_precision('medium')
|
185 |
+
|
186 |
+
project = f"WhisperSpeech-{args['wandb_task_name'] or task_name}"
|
187 |
+
if args['wandb_suffix']:
|
188 |
+
project += "-"+args['wandb_suffix']
|
189 |
+
|
190 |
+
wandb_logger = WandbLogger(project=project)
|
191 |
+
|
192 |
+
ckpt_callback = pl.callbacks.ModelCheckpoint(
|
193 |
+
dirpath=f'{task_name}-{epochs}e',
|
194 |
+
filename=task_name+"-{epoch}-{step}-{val_loss:.2f}",
|
195 |
+
monitor="val_loss",
|
196 |
+
save_top_k=4,
|
197 |
+
train_time_interval=datetime.timedelta(minutes=5),
|
198 |
+
)
|
199 |
+
|
200 |
+
lr_monitor_callback = LearningRateMonitor(logging_interval='step')
|
201 |
+
|
202 |
+
from torch.utils.data import DataLoader
|
203 |
+
|
204 |
+
task = importlib.import_module("whisperspeech."+task_name)
|
205 |
+
|
206 |
+
train_ds, val_ds = parse_and_call('dataset', task.load_datasets, input_args)
|
207 |
+
|
208 |
+
tunables = None
|
209 |
+
if hasattr(task, "Tunables"):
|
210 |
+
import dataclasses
|
211 |
+
tunables = parse_and_call('tunables', task.Tunables, tunables_args, log_to_wandb=False)
|
212 |
+
if type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config:
|
213 |
+
wandb_logger.experiment.config['tunables'] = dataclasses.asdict(tunables)
|
214 |
+
|
215 |
+
for name in ["lr0", "clip_gradient_norm", "weight_decay", "warmup_steps"]:
|
216 |
+
val = getattr(tunables, name, None)
|
217 |
+
if val is not None: hyp_params[name] = val
|
218 |
+
|
219 |
+
if isinstance(train_ds, torch.utils.data.IterableDataset):
|
220 |
+
dl_batch_size, dl_shuffle = None, False
|
221 |
+
pin_memory = False
|
222 |
+
else:
|
223 |
+
dl_batch_size, dl_shuffle = batch_size, True
|
224 |
+
pin_memory = True
|
225 |
+
|
226 |
+
val_loader = wds.WebLoader(val_ds,
|
227 |
+
batch_size=dl_batch_size,
|
228 |
+
num_workers=num_workers,
|
229 |
+
drop_last=False,
|
230 |
+
pin_memory=pin_memory).unbatched().shuffle(1024).batched(batch_size).with_length(val_ds.total_samples // batch_size)
|
231 |
+
|
232 |
+
train_loader = wds.WebLoader(train_ds,
|
233 |
+
batch_size=dl_batch_size,
|
234 |
+
num_workers=num_workers,
|
235 |
+
drop_last=False,
|
236 |
+
shuffle=dl_shuffle,
|
237 |
+
pin_memory=pin_memory).unbatched().shuffle(1024).batched(batch_size).with_length(train_ds.total_samples // batch_size)
|
238 |
+
|
239 |
+
model_kwargs = dict(dataset=train_ds)
|
240 |
+
if tunables is not None: model_kwargs['tunables'] = tunables
|
241 |
+
model = parse_and_call('model', task.make_model, task_args, model_kwargs)
|
242 |
+
|
243 |
+
task = TrainingTask(model, model_hparams=hyp_params)
|
244 |
+
|
245 |
+
trainer = pl.Trainer(strategy=hyp_params['strategy'],
|
246 |
+
max_epochs=hyp_params['epochs'],
|
247 |
+
accelerator="gpu",
|
248 |
+
profiler="simple",
|
249 |
+
precision=hyp_params['precision'],
|
250 |
+
gradient_clip_val=hyp_params['clip_gradient_norm'],
|
251 |
+
accumulate_grad_batches=hyp_params['accumulate_grad_batches'],
|
252 |
+
val_check_interval=args.pop("validate_every_n_steps"),
|
253 |
+
enable_checkpointing=True,
|
254 |
+
logger=wandb_logger,
|
255 |
+
callbacks=[ckpt_callback, lr_monitor_callback])
|
256 |
+
|
257 |
+
if type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config:
|
258 |
+
wandb_logger.experiment.config.update(hyp_params)
|
259 |
+
|
260 |
+
kwargs = {}
|
261 |
+
if 'resume_from' in args:
|
262 |
+
kwargs['ckpt_path'] = args['resume_from']
|
263 |
+
trainer.fit(model=task, train_dataloaders=train_loader, val_dataloaders=val_loader, **kwargs)
|
whisperspeech/utils.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/D. Common dataset utilities.ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = ['shard_glob', 'join_datasets', 'resampler', 'derived_name', 'derived_dataset', 'merge_in', 'AtomicTarWriter',
|
5 |
+
'readlines']
|
6 |
+
|
7 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 1
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
import torchaudio
|
11 |
+
from pathlib import Path
|
12 |
+
import webdataset as wds
|
13 |
+
from contextlib import contextmanager
|
14 |
+
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 2
|
18 |
+
def shard_glob(input):
|
19 |
+
if '{' in input:
|
20 |
+
return wds.shardlists.expand_urls(input)
|
21 |
+
if isinstance(input, (Path, str)):
|
22 |
+
path = Path(input)
|
23 |
+
if path.is_dir():
|
24 |
+
glob = '*.tar.gz'
|
25 |
+
else:
|
26 |
+
glob = path.name
|
27 |
+
path = path.parent
|
28 |
+
input = Path(path).glob(glob)
|
29 |
+
else:
|
30 |
+
raise ArgumentError("input should be either a list or a path with an optional glob specifier")
|
31 |
+
return [str(x) for x in input]
|
32 |
+
|
33 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 3
|
34 |
+
class join_datasets(torch.utils.data.IterableDataset):
|
35 |
+
def __init__(self, datasets):
|
36 |
+
self.datasets = datasets
|
37 |
+
|
38 |
+
def __iter__(self):
|
39 |
+
probs = torch.tensor([getattr(ds, 'weight', 1) for ds in self.datasets], dtype=torch.float)
|
40 |
+
its = [iter(ds) for ds in self.datasets]
|
41 |
+
while True:
|
42 |
+
try:
|
43 |
+
yield next(its[torch.multinomial(probs, 1)])
|
44 |
+
except StopIteration:
|
45 |
+
return
|
46 |
+
|
47 |
+
def __len__(self):
|
48 |
+
return sum([ds.total_samples for ds in self.datasets])
|
49 |
+
|
50 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 5
|
51 |
+
def resampler(newsr = 24000, key = 'samples_24k'):
|
52 |
+
_last_sr = None
|
53 |
+
tform = None
|
54 |
+
|
55 |
+
def _resample(samples):
|
56 |
+
for s in samples:
|
57 |
+
sr = s['sample_rate']
|
58 |
+
if sr != newsr:
|
59 |
+
if sr != _last_sr: tform = torchaudio.transforms.Resample(sr, newsr)
|
60 |
+
s[key] = tform(s['samples'])
|
61 |
+
else:
|
62 |
+
s[key] = s['samples']
|
63 |
+
yield s
|
64 |
+
|
65 |
+
return _resample
|
66 |
+
|
67 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 6
|
68 |
+
def derived_name(input, kind, base="audio", suffix=".gz", dir=None):
|
69 |
+
dir = Path(dir) if dir else Path(input).parent
|
70 |
+
return str(dir/(Path(input).name.replace(f"-{base}-", f"-{kind}-") + suffix))
|
71 |
+
|
72 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 7
|
73 |
+
def derived_dataset(kind, base='audio', suffix=".gz", decoders=[], dir=None):
|
74 |
+
def deriver(url):
|
75 |
+
url = str(derived_name(url, kind, base=base, suffix=suffix, dir=dir))
|
76 |
+
return wds.WebDataset(
|
77 |
+
wds.SimpleShardList([url])
|
78 |
+
).decode(*decoders)
|
79 |
+
return deriver
|
80 |
+
|
81 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 8
|
82 |
+
def merge_in(dataset_fun):
|
83 |
+
"""Merge a dataset into the current one returning samples with the union of keys. Pass in a function
|
84 |
+
that takes a URL of a sample and returns a dataset for it (called everytime the URL changes).
|
85 |
+
|
86 |
+
It requires (and validates) that both datasets have the same ordering of keys so you have
|
87 |
+
to use it before any sample shuffling. Shard shuffling is ok.
|
88 |
+
"""
|
89 |
+
def merge_loop(main_samples):
|
90 |
+
#print("new merge loop:", dataset_fun)
|
91 |
+
merged_samples = None
|
92 |
+
cur_url = None
|
93 |
+
i = None
|
94 |
+
for s in main_samples:
|
95 |
+
url = s['__url__']
|
96 |
+
if url != cur_url:
|
97 |
+
# this will open a new file when we get the first sample with a new __url__
|
98 |
+
merged_samples = iter(dataset_fun(url))
|
99 |
+
cur_url = url
|
100 |
+
try:
|
101 |
+
merge_s = next(merged_samples)
|
102 |
+
except StopIteration:
|
103 |
+
# if the original shard got repeated we won't observe a __url__ change
|
104 |
+
# in this case restart the dataset from the beginning
|
105 |
+
merged_samples = iter(dataset_fun(url))
|
106 |
+
merge_s = next(merged_samples)
|
107 |
+
assert merge_s['__key__'] == s['__key__'], f"sample keys don't match: {merge_s['__key__']}, {s['__key__']} in file {s['__url__']}"
|
108 |
+
news = {}
|
109 |
+
news.update(merge_s)
|
110 |
+
news.update(s)
|
111 |
+
yield news
|
112 |
+
return merge_loop
|
113 |
+
|
114 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 9
|
115 |
+
def split_to_chunks(stream, ikey='vad.npy', metakeys=[], pad_to_seconds=30, random_shift=False):
|
116 |
+
for s in stream:
|
117 |
+
audio, sr = s['audio']
|
118 |
+
imax = len(s[ikey]) - 1
|
119 |
+
for i,(ts,te) in enumerate(s[ikey]):
|
120 |
+
samples = audio[0,int(ts*sr):int(te*sr)]
|
121 |
+
if pad_to_seconds is not None:
|
122 |
+
padding = pad_to_seconds*sr-samples.shape[-1]
|
123 |
+
lpad = random.randint(0, padding) if random_shift else 0
|
124 |
+
samples = F.pad(samples, (lpad, padding-lpad))
|
125 |
+
subs = {"__key__": s['__key__'] + f"_{i:03d}",
|
126 |
+
"src_key": s['__key__'],
|
127 |
+
"__url__": s['__url__'],
|
128 |
+
"i": i, "imax": imax,
|
129 |
+
"tstart": ts, "tend": te, "total_seconds": audio.shape[-1]/sr,
|
130 |
+
"lpad": lpad, "rpad": padding-lpad,
|
131 |
+
"lpad_s": lpad/sr, "rpad_s": (padding-lpad)/sr,
|
132 |
+
"samples": samples, "sample_rate": sr}
|
133 |
+
for k in metakeys:
|
134 |
+
subs[k] = s[k][i]
|
135 |
+
yield subs
|
136 |
+
|
137 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 10
|
138 |
+
def vad_dataset(shards, ikey='vad.npy', kind='vad'):
|
139 |
+
return wds.WebDataset(shards).compose(
|
140 |
+
wds.decode(wds.torch_audio),
|
141 |
+
merge_in(derived_dataset(kind)),
|
142 |
+
wds.select(lambda x: 'wav' in x or 'flac' in x or 'mp3' in x or 'ogg' in x), # skip samples without audio
|
143 |
+
wds.rename(audio="flac;mp3;wav;ogg"),
|
144 |
+
lambda x: split_to_chunks(x, ikey=ikey),
|
145 |
+
)
|
146 |
+
|
147 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 11
|
148 |
+
@contextmanager
|
149 |
+
def AtomicTarWriter(name, throwaway=False):
|
150 |
+
tmp = name+".tmp"
|
151 |
+
with wds.TarWriter(tmp, compress=name.endswith('gz')) as sink:
|
152 |
+
yield sink
|
153 |
+
if not throwaway:
|
154 |
+
os.rename(tmp, name)
|
155 |
+
|
156 |
+
# %% ../nbs/D. Common dataset utilities.ipynb 12
|
157 |
+
def readlines(fname):
|
158 |
+
with open(fname) as file:
|
159 |
+
return [line.rstrip() for line in file]
|
whisperspeech/vad.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/1B. Voice activity detection.ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = []
|
5 |
+
|
6 |
+
# %% ../nbs/1B. Voice activity detection.ipynb 3
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
import torchaudio
|
10 |
+
|
11 |
+
from pathlib import Path
|
12 |
+
from fastprogress import progress_bar
|
13 |
+
from fastcore.script import call_parse
|
14 |
+
|
15 |
+
import whisperx
|
16 |
+
import random
|
17 |
+
import numpy as np
|
18 |
+
import webdataset as wds
|
19 |
+
|
20 |
+
# %% ../nbs/1B. Voice activity detection.ipynb 5
|
21 |
+
# some of the original file names have a dot in their name
|
22 |
+
# webdataset does not like it so let's patch it
|
23 |
+
def fix_dots_in_names(name):
|
24 |
+
name, ext = name.rsplit('.', 1)
|
25 |
+
return ".".join((name.replace('.', '_'), ext))
|
26 |
+
|
27 |
+
def load_dataset(url, decode=True, rename_files=None):
|
28 |
+
ds = wds.WebDataset(url, rename_files=rename_files)
|
29 |
+
if not decode: return ds
|
30 |
+
return ds.decode(wds.torch_audio)
|
31 |
+
|
32 |
+
# %% ../nbs/1B. Voice activity detection.ipynb 7
|
33 |
+
def extract_segments(vad_result, max_duration):
|
34 |
+
binarize = whisperx.vad.Binarize(max_duration=max_duration)
|
35 |
+
segments = binarize(vad_result)
|
36 |
+
return [(x.start, x.end) for x in segments.get_timeline()]
|
37 |
+
|
38 |
+
def segment_audio(vad_model, audio, sr=16000):
|
39 |
+
vad_result = vad_model({"waveform": audio, "sample_rate": sr})
|
40 |
+
return extract_segments(vad_result, 30)
|
41 |
+
|
42 |
+
# %% ../nbs/1B. Voice activity detection.ipynb 13
|
43 |
+
def flac_to_vad_name(input):
|
44 |
+
if '-flac-' in input:
|
45 |
+
return input.rsplit("/", 1)[1].replace('flac', 'vad') + ".gz"
|
46 |
+
else:
|
47 |
+
return input.rsplit("/", 1)[1].replace('raw', 'vad') + ".gz"
|
48 |
+
|
49 |
+
@call_parse
|
50 |
+
def process_shard(
|
51 |
+
input:str, # input shard URL/path
|
52 |
+
output:str=None, # output shard URL/path
|
53 |
+
fix_dots:bool=False, # fix dots in LibriLight filenames
|
54 |
+
):
|
55 |
+
if output is None: output = flac_to_vad_name(input)
|
56 |
+
|
57 |
+
ds = torch.utils.data.DataLoader(load_dataset(input, rename_files=fix_dots_in_names if fix_dots else None), num_workers=2, batch_size=None)
|
58 |
+
vad_model = whisperx.vad.load_vad_model('cuda')
|
59 |
+
|
60 |
+
tmp = output+".tmp"
|
61 |
+
with wds.TarWriter(tmp) as sink:
|
62 |
+
for s in progress_bar(ds, total='noinfer'):
|
63 |
+
audio, sr = s.get('flac', s.get('wav', (None, None)))
|
64 |
+
if audio is None:
|
65 |
+
print(f"warning: '{s['__key__']}' does not contain an audio file")
|
66 |
+
continue
|
67 |
+
sink.write({
|
68 |
+
"__key__": s['__key__'],
|
69 |
+
"vad.npy": np.array(segment_audio(vad_model, audio, sr=sr), dtype=np.float16)
|
70 |
+
})
|
71 |
+
os.rename(tmp, output)
|
whisperspeech/vq_stoks.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/2B. Whisper quantization (semantic token) model.ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = ['RQBottleneckTransformer', 'make_model']
|
5 |
+
|
6 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 2
|
7 |
+
import io
|
8 |
+
import sys
|
9 |
+
import time
|
10 |
+
import torch
|
11 |
+
import torchaudio
|
12 |
+
|
13 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 3
|
14 |
+
from pathlib import Path
|
15 |
+
import json
|
16 |
+
from fastprogress import progress_bar, master_bar
|
17 |
+
import fastprogress
|
18 |
+
import numpy as np
|
19 |
+
import pylab as plt
|
20 |
+
import pandas as pd
|
21 |
+
import random
|
22 |
+
|
23 |
+
import whisper
|
24 |
+
from huggingface_hub import hf_hub_download
|
25 |
+
from fastcore.basics import store_attr
|
26 |
+
|
27 |
+
from torch import nn
|
28 |
+
import torch.optim as optim
|
29 |
+
import torch.nn.functional as F
|
30 |
+
from torch.utils.data.dataloader import DataLoader
|
31 |
+
import webdataset as wds
|
32 |
+
from . import utils
|
33 |
+
|
34 |
+
from vector_quantize_pytorch import ResidualVQ
|
35 |
+
|
36 |
+
from fastcore.script import *
|
37 |
+
|
38 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 9
|
39 |
+
def merge_in(dataset_fun):
|
40 |
+
"""Merge a dataset into the current one returning samples with the union of keys. Pass in a function
|
41 |
+
that takes a URL of a sample and returns a dataset for it (called everytime the URL changes).
|
42 |
+
|
43 |
+
It requires (and validates) that both datasets have the same ordering of keys so you have
|
44 |
+
to use it before any sample shuffling. Shard shuffling is ok.
|
45 |
+
"""
|
46 |
+
def merge_loop(main_samples):
|
47 |
+
#print("new merge loop:", dataset_fun)
|
48 |
+
merged_samples = None
|
49 |
+
cur_url = None
|
50 |
+
i = None
|
51 |
+
for s in main_samples:
|
52 |
+
url = s['__url__']
|
53 |
+
if url != cur_url:
|
54 |
+
# this will open a new file when we get the first sample with a new __url__
|
55 |
+
merged_samples = iter(dataset_fun(url))
|
56 |
+
cur_url = url
|
57 |
+
try:
|
58 |
+
merge_s = next(merged_samples)
|
59 |
+
except StopIteration:
|
60 |
+
# if the original shard got repeated we won't observe a __url__ change
|
61 |
+
# in this case restart the dataset from the beginning
|
62 |
+
merged_samples = iter(dataset_fun(url))
|
63 |
+
merge_s = next(merged_samples)
|
64 |
+
assert merge_s['__key__'] == s['__key__'], f"sample keys don't match: {merge_s['__key__']}, {s['__key__']} in file {s['__url__']}"
|
65 |
+
news = {}
|
66 |
+
news.update(merge_s)
|
67 |
+
news.update(s)
|
68 |
+
yield news
|
69 |
+
return merge_loop
|
70 |
+
|
71 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 10
|
72 |
+
def derived_dataset(kind, key='audio'):
|
73 |
+
def deriver(url):
|
74 |
+
url = str(Path(url).parent/(Path(url).name.replace(key, kind) + ".gz"))
|
75 |
+
return wds.WebDataset(
|
76 |
+
wds.SimpleShardList([url])
|
77 |
+
).decode()
|
78 |
+
return deriver
|
79 |
+
|
80 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 17
|
81 |
+
def add_masks(samples):
|
82 |
+
for s in samples:
|
83 |
+
seconds = s['tend'] - s['tstart']
|
84 |
+
# a mask (downsampled to the Whisper encoder token rate of 50/s) is used
|
85 |
+
# to teach the model the concept of padding
|
86 |
+
# this let's us decode shorter sequences later
|
87 |
+
mask = torch.zeros(30*16000//320, dtype=torch.bool)
|
88 |
+
mask[:int(seconds * 16000) // 320] = 1
|
89 |
+
s['mask'] = mask
|
90 |
+
yield s
|
91 |
+
|
92 |
+
def tokenize_text(samples, ttoks_size=200, model="base.en", language="en"):
|
93 |
+
multilingual = not model.endswith(".en")
|
94 |
+
tokenizer = whisper.tokenizer.get_tokenizer(multilingual, language=language, task="transcribe")
|
95 |
+
for s in samples:
|
96 |
+
ttoks = tokenizer.encode(s['txt'])
|
97 |
+
tokens = list(tokenizer.sot_sequence) + ttoks
|
98 |
+
rpad = ttoks_size - len(tokens)
|
99 |
+
s['in_ttoks'] = F.pad(torch.tensor(tokens), (0, rpad), value=tokenizer.eot)
|
100 |
+
s['out_ttoks'] = F.pad(torch.tensor(tokens[1:] + [tokenizer.eot]), (0, rpad), value=-100)
|
101 |
+
yield s
|
102 |
+
|
103 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 22
|
104 |
+
def load_dataset(
|
105 |
+
shard_spec:str,
|
106 |
+
proc_dataset_path:Path, # processed VAD and txt files
|
107 |
+
samples:int, # set the per-GPU sample count
|
108 |
+
txt_label:str="base.en-txt", # the label of the files containing transcriptions
|
109 |
+
model:str="base.en",
|
110 |
+
key:str="flac",
|
111 |
+
language:str=None,
|
112 |
+
validation:bool=False,
|
113 |
+
):
|
114 |
+
from . import wh_transcribe
|
115 |
+
shards = utils.shard_glob(shard_spec)
|
116 |
+
|
117 |
+
if not language and model.endswith('en'): language = 'en'
|
118 |
+
assert language, "please provide the dataset language for multilang models"
|
119 |
+
|
120 |
+
same_on_all_nodes = lambda urls: urls # will only be used for validation
|
121 |
+
ds = wds.WebDataset(shards, resampled=not validation, nodesplitter=same_on_all_nodes).compose(
|
122 |
+
wds.decode(wds.torch_audio),
|
123 |
+
wds.select(lambda x: 'wav' in x or 'flac' in x or 'mp3' in x or 'ogg' in x), # skip samples without audio
|
124 |
+
wds.rename(audio="flac;mp3;wav;ogg"),
|
125 |
+
merge_in(derived_dataset(proc_dataset_path, 'vad', key=key)),
|
126 |
+
wds.map_dict(**{"vad.npy":wh_transcribe.chunk_merger}),
|
127 |
+
wh_transcribe.split_to_chunks,
|
128 |
+
utils.resampler(16000, 'samples_16k'),
|
129 |
+
merge_in(derived_dataset(proc_dataset_path, txt_label, key=key)),
|
130 |
+
)
|
131 |
+
if 'librilight' in shards[0]:
|
132 |
+
ds = ds.compose(
|
133 |
+
# drop the first and last segment because they tend to be inaccurate
|
134 |
+
# (the transcriptions don't have the "LibriVox" headers and "end of chapter" suffixes)
|
135 |
+
wds.select(lambda x: x['i'] != 0 and x['i'] != x['imax']),
|
136 |
+
)
|
137 |
+
ds = ds.compose(
|
138 |
+
add_masks,
|
139 |
+
lambda x: tokenize_text(x, model=model, language=language),
|
140 |
+
wds.to_tuple('samples_16k', 'mask', 'in_ttoks', 'out_ttoks'),
|
141 |
+
wds.batched(32),
|
142 |
+
)
|
143 |
+
ds.total_samples = samples
|
144 |
+
|
145 |
+
return ds
|
146 |
+
|
147 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 28
|
148 |
+
from whisperspeech.train import *
|
149 |
+
from whisperspeech.modules import *
|
150 |
+
|
151 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 29
|
152 |
+
import dataclasses
|
153 |
+
|
154 |
+
def rand(start, end):
|
155 |
+
return random.random() * (end - start) + start
|
156 |
+
|
157 |
+
def logrand(start, end):
|
158 |
+
return 10**rand(math.log10(start), math.log10(end))
|
159 |
+
|
160 |
+
@dataclasses.dataclass
|
161 |
+
class Tunables:
|
162 |
+
init_std :float = 1.5
|
163 |
+
embeddings_std :float = 4.5e-2
|
164 |
+
embeddings_lr_scale: float = 1
|
165 |
+
output_mult :float = 1
|
166 |
+
query_mult :float = 2
|
167 |
+
rope :bool = True
|
168 |
+
mask_embs :bool = True # force embeddings corresponding to the input audio padding to a constant value
|
169 |
+
downsample_conv: bool = False
|
170 |
+
downsample_mean: bool = True
|
171 |
+
|
172 |
+
codebook_dim: int = 32
|
173 |
+
codebook_decay: float = 0.9
|
174 |
+
|
175 |
+
lr0 :float = .9e-3
|
176 |
+
clip_gradient_norm :float = 2
|
177 |
+
weight_decay :float = 1e-3
|
178 |
+
warmup_steps :float = 850
|
179 |
+
|
180 |
+
random :bool = False
|
181 |
+
|
182 |
+
def __post_init__(self):
|
183 |
+
# randomize the hyperparams if requested
|
184 |
+
if self.random:
|
185 |
+
self.init_std = logrand(1, 2)
|
186 |
+
self.embeddings_std = logrand(3e-2,6e-2)
|
187 |
+
self.embeddings_lr_scale = 2**rand(0,3)
|
188 |
+
self.output_mult = 2**rand(-3,3)
|
189 |
+
self.query_mult = logrand(1,8)
|
190 |
+
self.codebook_dim = int(logrand(30,50))
|
191 |
+
self.codebook_decay = logrand(0.86,0.95)
|
192 |
+
self.rope = True
|
193 |
+
self.mask_embs = True
|
194 |
+
self.downsample_mean = True
|
195 |
+
|
196 |
+
self.lr0 = logrand(.8e-3,1e-3)
|
197 |
+
self.clip_gradient_norm = 10**rand(-1,1)
|
198 |
+
self.warmup_steps = logrand(700,1000)
|
199 |
+
|
200 |
+
@staticmethod
|
201 |
+
def upgrade(args):
|
202 |
+
args = {k:v for k,v in args.items()}
|
203 |
+
def old_default(name, value):
|
204 |
+
if name not in args: args[name] = value
|
205 |
+
old_default('output_mult', 1)
|
206 |
+
old_default('query_mult', 1)
|
207 |
+
old_default('rope', False)
|
208 |
+
old_default('mask_embs', False)
|
209 |
+
old_default('downsample_conv', False)
|
210 |
+
old_default('downsample_mean', False)
|
211 |
+
if 'encoder_depth_ratio' in args: del args['encoder_depth_ratio']
|
212 |
+
if 'vq_codes' in args: del args['vq_codes']
|
213 |
+
return args
|
214 |
+
|
215 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 30
|
216 |
+
import math
|
217 |
+
|
218 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 31
|
219 |
+
class RQBottleneckTransformer(nn.Module):
|
220 |
+
def __init__(self, vq_codes=512, q_depth=12, depth=1, n_head=2, head_width=64, ffn_mult=4,
|
221 |
+
codebook_dim=2, threshold_ema_dead_code=2, use_cosine_sim = False, kl_loss_mul=1,
|
222 |
+
downsample=1,
|
223 |
+
whisper_model_name='tiny.en', tunables=Tunables()):
|
224 |
+
super().__init__()
|
225 |
+
width = n_head * head_width
|
226 |
+
store_attr("codebook_dim,vq_codes,q_depth,n_head,head_width,ffn_mult,depth,use_cosine_sim,downsample,whisper_model_name")
|
227 |
+
self.width = width
|
228 |
+
self.base_width = 3 * head_width
|
229 |
+
self.vq_codes = vq_codes
|
230 |
+
self.tunables = tunables
|
231 |
+
self.stoks_len = 1500//downsample
|
232 |
+
self.stoks_per_sec = self.stoks_len//30
|
233 |
+
|
234 |
+
qk_scale = self.tunables.query_mult * 8 / math.sqrt(head_width)
|
235 |
+
|
236 |
+
self.kl_loss_mul = kl_loss_mul
|
237 |
+
|
238 |
+
n_mlp = width * ffn_mult
|
239 |
+
self.mlp = nn.Sequential(
|
240 |
+
nn.Linear(width, n_mlp), nn.GELU(), nn.Linear(n_mlp, width)
|
241 |
+
)
|
242 |
+
self.mlp_ln = LayerNorm(width)
|
243 |
+
|
244 |
+
if tunables.downsample_conv:
|
245 |
+
self.downsample_conv = nn.Conv1d(width, width, kernel_size=3, stride=downsample, padding=1)
|
246 |
+
else:
|
247 |
+
self.downsample_conv = None
|
248 |
+
|
249 |
+
if tunables.mask_embs: vq_codes = vq_codes + 1
|
250 |
+
self.rq = ResidualVQ(
|
251 |
+
dim = width,
|
252 |
+
codebook_size = vq_codes, # codebook size
|
253 |
+
decay = tunables.codebook_decay, # the exponential moving average decay, lower means the dictionary will change faster
|
254 |
+
commitment_weight = 1., # the weight on the commitment loss
|
255 |
+
threshold_ema_dead_code = threshold_ema_dead_code,
|
256 |
+
use_cosine_sim = use_cosine_sim,
|
257 |
+
codebook_dim = codebook_dim,
|
258 |
+
num_quantizers= 1,
|
259 |
+
)
|
260 |
+
|
261 |
+
self.ce_lossf = nn.CrossEntropyLoss(ignore_index=-100)
|
262 |
+
self.kl_lossf = nn.KLDivLoss(reduction='batchmean')
|
263 |
+
|
264 |
+
self.positional_embedding = nn.Embedding(1500, width) # FIXME: should be self.stoks_len
|
265 |
+
|
266 |
+
self.out_blocks = nn.Sequential(*[
|
267 |
+
ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, rope=tunables.rope) for _ in range(depth)
|
268 |
+
])
|
269 |
+
self.ln_post = LayerNorm(width)
|
270 |
+
|
271 |
+
self.whmodel = None
|
272 |
+
|
273 |
+
self.apply(self.init_transformer)
|
274 |
+
self.register_buffer('val_true', torch.zeros(1).cuda())
|
275 |
+
self.register_buffer('val_total', torch.zeros(1).cuda())
|
276 |
+
|
277 |
+
def setup(self, device):
|
278 |
+
self.ensure_whisper(device)
|
279 |
+
|
280 |
+
def init_transformer(self, m):
|
281 |
+
if isinstance(m, LinearHead):
|
282 |
+
m.no_weight_decay = True
|
283 |
+
torch.nn.init.constant_(m.weight, 0)
|
284 |
+
elif isinstance(m, QueryHead):
|
285 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
286 |
+
torch.nn.init.constant_(m.weight, 0)
|
287 |
+
elif isinstance(m, nn.Embedding):
|
288 |
+
m.no_weight_decay = True
|
289 |
+
m.lr_scale = self.tunables.embeddings_lr_scale
|
290 |
+
std = self.tunables.embeddings_std
|
291 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
292 |
+
elif isinstance(m, nn.Linear):
|
293 |
+
m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
|
294 |
+
std = self.tunables.init_std / m.weight.shape[1]
|
295 |
+
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
|
296 |
+
if m.bias is not None:
|
297 |
+
torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
|
298 |
+
elif isinstance(m, nn.LayerNorm):
|
299 |
+
m.no_weight_decay = True
|
300 |
+
torch.nn.init.constant_(m.bias, 0)
|
301 |
+
torch.nn.init.constant_(m.weight, 1)
|
302 |
+
|
303 |
+
@property
|
304 |
+
def device(self):
|
305 |
+
return next(self.parameters()).device
|
306 |
+
|
307 |
+
#
|
308 |
+
# training
|
309 |
+
#
|
310 |
+
@torch.no_grad()
|
311 |
+
def extract_teacher(self, samples, input_toks, output_toks):
|
312 |
+
embs = self.whmodel[0].encoder(whisper.log_mel_spectrogram(samples))
|
313 |
+
teacher_logits = self.whmodel[0].decoder(input_toks, embs)
|
314 |
+
# set teacher logits to 0 for padding positions so KLDivLoss ignores them
|
315 |
+
teacher_logits[output_toks == -100] = 0
|
316 |
+
return embs, teacher_logits
|
317 |
+
|
318 |
+
def downsample_embeddings(self, x):
|
319 |
+
if self.downsample_conv is not None:
|
320 |
+
return x[:,::self.downsample] + self.downsample_conv(x.transpose(-1,-2)).transpose(-2,-1)
|
321 |
+
elif self.tunables.downsample_mean:
|
322 |
+
bs,slen,depth = x.shape
|
323 |
+
return x.reshape(bs,slen//self.downsample,self.downsample,depth).mean(-2)
|
324 |
+
else:
|
325 |
+
return x[:,::self.downsample]
|
326 |
+
|
327 |
+
def forward(self, samples, mask, input_toks, output_toks):
|
328 |
+
embs, teacher_logits = self.extract_teacher(samples, input_toks, output_toks)
|
329 |
+
|
330 |
+
x = self.downsample_embeddings(embs)
|
331 |
+
x = x + self.mlp(self.mlp_ln(x))
|
332 |
+
# VQ bottleneck
|
333 |
+
quantized, self.indices, self.commit_loss = self.rq(x)
|
334 |
+
self.commit_loss = self.commit_loss.mean()
|
335 |
+
|
336 |
+
x = quantized.repeat_interleave(self.downsample, -2)
|
337 |
+
project_out = getattr(self.rq, 'project_out', None) or self.rq.layers[0].project_out
|
338 |
+
if self.tunables.mask_embs: x[~mask] = project_out(self.rq.layers[0]._codebook.embed[0,self.vq_codes])
|
339 |
+
positions = torch.arange(0, x.shape[-2], dtype=torch.long, device=x.device)
|
340 |
+
x = x + self.positional_embedding(positions)
|
341 |
+
x = self.ln_post(self.out_blocks(x))
|
342 |
+
|
343 |
+
logits = self.whmodel[0].decoder(input_toks, x)
|
344 |
+
self.ce_loss = self.ce_lossf(logits.view(-1,logits.shape[-1]), output_toks.view(-1))
|
345 |
+
self.kl_loss = self.kl_lossf(F.log_softmax(logits, dim=-1), F.softmax(teacher_logits, dim=-1))
|
346 |
+
loss = self.ce_loss + self.kl_loss_mul * self.kl_loss + self.commit_loss
|
347 |
+
|
348 |
+
if not self.training:
|
349 |
+
valid_toks = output_toks != -100
|
350 |
+
self.val_true += (logits.argmax(-1)[valid_toks] == output_toks[valid_toks]).float().sum()
|
351 |
+
self.val_total += valid_toks.float().sum()
|
352 |
+
|
353 |
+
return x, loss
|
354 |
+
|
355 |
+
def get_metrics(self):
|
356 |
+
metrics = {
|
357 |
+
'acc_0': (self.val_true / self.val_total).item(),
|
358 |
+
}
|
359 |
+
self.val_true[:] = 0
|
360 |
+
self.val_total[:] = 0
|
361 |
+
return metrics
|
362 |
+
|
363 |
+
#
|
364 |
+
# inference
|
365 |
+
#
|
366 |
+
@classmethod
|
367 |
+
def load_model(cls, ref="collabora/spear-tts-pytorch:whisper-vq-stoks-medium-en+pl.model",
|
368 |
+
repo_id=None, filename=None, local_filename=None):
|
369 |
+
if repo_id is None and filename is None and local_filename is None:
|
370 |
+
if ":" in ref:
|
371 |
+
repo_id, filename = ref.split(":", 1)
|
372 |
+
else:
|
373 |
+
local_filename = ref
|
374 |
+
if not local_filename:
|
375 |
+
local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
|
376 |
+
spec = torch.load(local_filename)
|
377 |
+
vqmodel = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec.get('tunables', {}))))
|
378 |
+
vqmodel.load_state_dict(spec['state_dict'])
|
379 |
+
vqmodel.eval()
|
380 |
+
return vqmodel
|
381 |
+
|
382 |
+
def load_checkpoint(self, local_filename):
|
383 |
+
spec = torch.load(local_filename, map_location='cpu')
|
384 |
+
assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
|
385 |
+
state_dict = {k.replace('model.', ''):v
|
386 |
+
for k,v in spec['state_dict'].items()}
|
387 |
+
self.load_state_dict(state_dict)
|
388 |
+
return self
|
389 |
+
|
390 |
+
def save_model(self, fname, store_parameters=True):
|
391 |
+
torch.save(dict(config = self.__stored_args__,
|
392 |
+
tunables = dataclasses.asdict(self.tunables),
|
393 |
+
state_dict = self.state_dict() if store_parameters else None), fname)
|
394 |
+
|
395 |
+
def ensure_whisper(self, device):
|
396 |
+
# the list wrapper is a hack to make sure the whole of Whisper is not sucked into self.parameters()
|
397 |
+
if self.whmodel is None: self.whmodel = [whisper.load_model(self.whisper_model_name, device=device)]
|
398 |
+
self.decoding_options = whisper.DecodingOptions()
|
399 |
+
multilingual = not self.whisper_model_name.endswith('.en')
|
400 |
+
self.tokenizer = whisper.tokenizer.get_tokenizer(multilingual)
|
401 |
+
|
402 |
+
def quantize(self, embs):
|
403 |
+
x = self.downsample_embeddings(embs)
|
404 |
+
x = x + self.mlp(self.mlp_ln(x))
|
405 |
+
_, stoks, _ = self.rq(x)
|
406 |
+
if self.q_depth == 1:
|
407 |
+
stoks = stoks.squeeze(-1)
|
408 |
+
return stoks
|
409 |
+
|
410 |
+
def dequantize(self, stoks):
|
411 |
+
assert self.q_depth == 1
|
412 |
+
assert len(stoks.shape) == 1, "batch processing is not supported"
|
413 |
+
if isinstance(stoks, np.ndarray): stoks = torch.tensor(stoks)
|
414 |
+
# remove padding
|
415 |
+
padding = torch.nonzero(stoks == self.vq_codes)
|
416 |
+
if padding.any(): stoks = stoks[:padding[0,0]]
|
417 |
+
stoks = F.pad(stoks, (0,self.stoks_len - stoks.shape[-1]), value=self.vq_codes if self.tunables.mask_embs else 0)
|
418 |
+
x = self.rq.layers[0]._codebook.embed[0,stoks.to(torch.long).view(-1)]
|
419 |
+
x = x.repeat_interleave(self.downsample, -2)
|
420 |
+
project_out = getattr(self.rq, 'project_out', None) or self.rq.layers[0].project_out
|
421 |
+
x = project_out(x).unsqueeze(0)
|
422 |
+
positions = torch.arange(0, x.shape[-2], dtype=torch.long, device=x.device)
|
423 |
+
x = x + self.positional_embedding(positions)
|
424 |
+
return self.ln_post(self.out_blocks(x))
|
425 |
+
|
426 |
+
def encode_audio(self, audio):
|
427 |
+
if isinstance(audio, str):
|
428 |
+
x, sr = torchaudio.load(audio)
|
429 |
+
x = torchaudio.transforms.Resample(sr, 16000)(x)[0]
|
430 |
+
audio = x.unsqueeze(0)
|
431 |
+
return self.encode_mel(whisper.log_mel_spectrogram(audio).to(self.device))
|
432 |
+
|
433 |
+
def encode_mel(self, mel):
|
434 |
+
assert len(mel.shape) == 3, "invalid mel spectrogram shape, expect (batch,chn,time)"
|
435 |
+
self.ensure_whisper(self.device)
|
436 |
+
n = mel.shape[-1]
|
437 |
+
if n > whisper.audio.N_FRAMES:
|
438 |
+
padding = 0
|
439 |
+
padded = mel[:,:,:whisper.audio.N_FRAMES]
|
440 |
+
else:
|
441 |
+
padding = -n % whisper.audio.N_FRAMES
|
442 |
+
padded = F.pad(mel, (0, padding), value=-1.5)
|
443 |
+
embs = self.whmodel[0].encoder(padded)#.to(self.whmodel[0].device))#[:,:n//2]
|
444 |
+
stoks = self.quantize(embs)
|
445 |
+
if self.tunables.mask_embs:
|
446 |
+
return stoks[:,:n//2//self.downsample]
|
447 |
+
else:
|
448 |
+
return stoks
|
449 |
+
|
450 |
+
def decode_text(self, stoks, decoding_options=None):
|
451 |
+
self.ensure_whisper(self.device)
|
452 |
+
if decoding_options is None: decoding_options = self.decoding_options
|
453 |
+
embs = self.dequantize(stoks).to(self.whmodel[0].device)
|
454 |
+
return self.whmodel[0].decode(embs, decoding_options)
|
455 |
+
|
456 |
+
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 33
|
457 |
+
def make_model(size:str, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None):
|
458 |
+
if size == 'base.en-2d-4096c':
|
459 |
+
model = RQBottleneckTransformer(codebook_dim=32, vq_codes=4096, q_depth=1, n_head=8, depth=1,
|
460 |
+
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
|
461 |
+
whisper_model_name=size.split("-")[0], tunables=tunables)
|
462 |
+
return model
|
463 |
+
if size == 'base.en-2d-512c':
|
464 |
+
model = RQBottleneckTransformer(codebook_dim=32, vq_codes=512, q_depth=1, n_head=8, depth=1,
|
465 |
+
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
|
466 |
+
whisper_model_name=size.split("-")[0], tunables=tunables)
|
467 |
+
return model
|
468 |
+
if size == 'base.en-2d-512c-dim64':
|
469 |
+
model = RQBottleneckTransformer(codebook_dim=64, vq_codes=512, q_depth=1, n_head=8, depth=1,
|
470 |
+
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
|
471 |
+
whisper_model_name=size.split("-")[0], tunables=tunables)
|
472 |
+
return model
|
473 |
+
if size == 'base-2d-512c-dim64':
|
474 |
+
model = RQBottleneckTransformer(codebook_dim=64, vq_codes=512, q_depth=1, n_head=8, depth=1,
|
475 |
+
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
|
476 |
+
whisper_model_name=size.split("-")[0], tunables=tunables)
|
477 |
+
return model
|
478 |
+
if size == 'base-2d-1024c-dim64':
|
479 |
+
model = RQBottleneckTransformer(codebook_dim=64, vq_codes=1024, q_depth=1, n_head=8, depth=1,
|
480 |
+
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
|
481 |
+
whisper_model_name=size.split("-")[0], tunables=tunables)
|
482 |
+
return model
|
483 |
+
if size == 'medium-2d-512c-dim64':
|
484 |
+
model = RQBottleneckTransformer(codebook_dim=64, vq_codes=512, q_depth=1, n_head=16, depth=1,
|
485 |
+
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
|
486 |
+
whisper_model_name=size.split("-")[0], tunables=tunables)
|
487 |
+
return model
|
488 |
+
if size == 'medium-2d-1024c-dim64':
|
489 |
+
model = RQBottleneckTransformer(codebook_dim=64, vq_codes=1024, q_depth=1, n_head=16, depth=1,
|
490 |
+
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
|
491 |
+
whisper_model_name=size.split("-")[0], tunables=tunables)
|
492 |
+
return model
|
493 |
+
raise ArgumentError(f"invalid model size: {size}")
|
whisperspeech/wer_metrics.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/C. Word error rate metrics.ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = ['librispeech_data', 'DfBuilder', 'WERStats']
|
5 |
+
|
6 |
+
# %% ../nbs/C. Word error rate metrics.ipynb 2
|
7 |
+
import jiwer
|
8 |
+
from whisper_normalizer.english import EnglishTextNormalizer
|
9 |
+
|
10 |
+
import torchaudio
|
11 |
+
from pathlib import Path
|
12 |
+
import pandas as pd
|
13 |
+
|
14 |
+
# %% ../nbs/C. Word error rate metrics.ipynb 3
|
15 |
+
engnorm = EnglishTextNormalizer()
|
16 |
+
def whisper_normalize(x):
|
17 |
+
if type(x) == list:
|
18 |
+
return [engnorm(y) for y in x]
|
19 |
+
else:
|
20 |
+
return engnorm(x)
|
21 |
+
|
22 |
+
default_transform = jiwer.transforms.Compose([
|
23 |
+
jiwer.transforms.ToLowerCase(),
|
24 |
+
jiwer.transforms.ExpandCommonEnglishContractions(),
|
25 |
+
whisper_normalize,
|
26 |
+
jiwer.transforms.RemoveMultipleSpaces(),
|
27 |
+
jiwer.transforms.Strip(),
|
28 |
+
jiwer.transforms.RemovePunctuation(),
|
29 |
+
jiwer.transforms.ReduceToListOfListOfWords(),
|
30 |
+
])
|
31 |
+
|
32 |
+
# %% ../nbs/C. Word error rate metrics.ipynb 5
|
33 |
+
def librispeech_data(datadir, sample_rate=16000):
|
34 |
+
for file in Path(datadir).rglob('*.txt'):
|
35 |
+
for line in file.read_text().split('\n'):
|
36 |
+
if not line: continue
|
37 |
+
idx, text = line.split(" ", 1)
|
38 |
+
x, sr = torchaudio.load((file.parent/idx).with_suffix('.flac'))
|
39 |
+
if sr != sample_rate:
|
40 |
+
x = torchaudio.transforms.Resample(sr, self.sample_rate)(x)
|
41 |
+
yield x, text
|
42 |
+
|
43 |
+
# %% ../nbs/C. Word error rate metrics.ipynb 6
|
44 |
+
class DfBuilder:
|
45 |
+
def __init__(self):
|
46 |
+
self.data = {}
|
47 |
+
|
48 |
+
def push(self, **kwargs):
|
49 |
+
for k,v in kwargs.items():
|
50 |
+
if k not in self.data:
|
51 |
+
self.data[k] = [v]
|
52 |
+
else:
|
53 |
+
self.data[k].append(v)
|
54 |
+
|
55 |
+
def df(self):
|
56 |
+
return pd.DataFrame(self.data)
|
57 |
+
|
58 |
+
# %% ../nbs/C. Word error rate metrics.ipynb 7
|
59 |
+
class WERStats(DfBuilder):
|
60 |
+
def __init__(self, transform=default_transform):
|
61 |
+
super().__init__()
|
62 |
+
self.reference_transform = transform
|
63 |
+
self.hypothesis_transform = transform
|
64 |
+
|
65 |
+
def push_sample(self, snd, gt_text, text, idx=None):
|
66 |
+
if snd is not None: self.push(secs = snd.shape[-1]/16000)
|
67 |
+
diff = jiwer.process_words(gt_text, text, reference_transform=self.reference_transform, hypothesis_transform=self.hypothesis_transform)
|
68 |
+
self.push(
|
69 |
+
idx = idx,
|
70 |
+
gt_text = gt_text,
|
71 |
+
text = text,
|
72 |
+
wer = diff.wer,
|
73 |
+
mer = diff.mer,
|
74 |
+
wil = diff.wil,
|
75 |
+
wip = diff.wip,
|
76 |
+
)
|
77 |
+
return diff
|
whisperspeech/wh_transcribe.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/2A. Whisper quantization dataset preparation.ipynb.
|
2 |
+
|
3 |
+
# %% auto 0
|
4 |
+
__all__ = []
|
5 |
+
|
6 |
+
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 3
|
7 |
+
import os
|
8 |
+
import io
|
9 |
+
import time
|
10 |
+
import torch
|
11 |
+
import torchaudio
|
12 |
+
|
13 |
+
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 4
|
14 |
+
from pathlib import Path
|
15 |
+
import json
|
16 |
+
from fastprogress import progress_bar, master_bar
|
17 |
+
import numpy as np
|
18 |
+
import random
|
19 |
+
|
20 |
+
import whisper
|
21 |
+
|
22 |
+
from torch import nn
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from torch.utils.data.dataloader import DataLoader
|
25 |
+
|
26 |
+
from fastcore.script import *
|
27 |
+
|
28 |
+
from . import vad
|
29 |
+
import webdataset as wds
|
30 |
+
|
31 |
+
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 9
|
32 |
+
# let's make it a bit more conservative
|
33 |
+
# with full 30 second chunks it sometimes misses a small part of the transcript
|
34 |
+
def random_cutter(dur):
|
35 |
+
if random.random() < 0.5:
|
36 |
+
return dur > 28 * (random.random()*0.95+0.05)
|
37 |
+
else:
|
38 |
+
return dur > 28
|
39 |
+
|
40 |
+
def chunk_merger(segments, should_cut=lambda x: x > 28):
|
41 |
+
if len(segments) == 0: return segments
|
42 |
+
curr_start = segments[0][0]
|
43 |
+
curr_end = 0
|
44 |
+
merged = []
|
45 |
+
|
46 |
+
for ts,te in segments:
|
47 |
+
if should_cut(te - curr_start) and curr_end - curr_start > 0:
|
48 |
+
merged.append((curr_start, curr_end))
|
49 |
+
curr_start = ts
|
50 |
+
curr_end = te
|
51 |
+
merged.append((curr_start, curr_end))
|
52 |
+
return merged
|
53 |
+
|
54 |
+
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 18
|
55 |
+
def merge_in(*datasets):
|
56 |
+
"""Merge multiple datasets into the current one returning samples with the union of keys.
|
57 |
+
|
58 |
+
It requires (and validates) all datasets to have the same ordering of keys so you have
|
59 |
+
to use it before any sample shuffling. Shard shuffling is ok.
|
60 |
+
"""
|
61 |
+
def merge_loop(main_samples):
|
62 |
+
for samples in zip(*[main_samples]+[iter(x) for x in datasets]):
|
63 |
+
key = samples[0]['__key__']
|
64 |
+
news = {}
|
65 |
+
for s in samples:
|
66 |
+
assert s['__key__'] == key
|
67 |
+
news.update(s)
|
68 |
+
yield news
|
69 |
+
return merge_loop
|
70 |
+
|
71 |
+
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 19
|
72 |
+
import copy
|
73 |
+
|
74 |
+
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 20
|
75 |
+
# a workaround for https://github.com/webdataset/webdataset/issues/297
|
76 |
+
# should be possible to use ds.compose here
|
77 |
+
def wds_compose(ds, *args):
|
78 |
+
ds = copy.copy(ds)
|
79 |
+
ds.pipeline = copy.copy(ds.pipeline)
|
80 |
+
for f in args:
|
81 |
+
ds.append(f)
|
82 |
+
return ds
|
83 |
+
|
84 |
+
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 24
|
85 |
+
def split_to_chunks(stream, pad_to_seconds=30, random_shift=False):
|
86 |
+
for s in stream:
|
87 |
+
audio, sr = s.get('flac', s.get('wav', (None, None)))
|
88 |
+
if audio is None:
|
89 |
+
print(f"warning: '{s['__key__']}' does not contain an audio file")
|
90 |
+
continue
|
91 |
+
imax = len(s['vad.npy']) - 1
|
92 |
+
for i,(ts,te) in enumerate(s['vad.npy']):
|
93 |
+
samples = audio[0,int(ts*sr):int(te*sr)]
|
94 |
+
if pad_to_seconds is not None:
|
95 |
+
padding = pad_to_seconds*sr-samples.shape[-1]
|
96 |
+
lpad = random.randint(0, padding) if random_shift else 0
|
97 |
+
samples = F.pad(samples, (lpad, padding-lpad))
|
98 |
+
yield {"__key__": s['__key__'] + f"_{i:03d}",
|
99 |
+
"__url__": s['__url__'],
|
100 |
+
"i": i, "imax": imax,
|
101 |
+
"tstart": ts, "tend": te, "total_seconds": audio.shape[-1]/sr,
|
102 |
+
"lpad": lpad, "rpad": padding-lpad,
|
103 |
+
"lpad_s": lpad/sr, "rpad_s": (padding-lpad)/sr,
|
104 |
+
"samples": samples, "sample_rate": sr}
|
105 |
+
|
106 |
+
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 38
|
107 |
+
def flac_to_txt_name(input, model_size):
|
108 |
+
return input.rsplit("/", 1)[1].replace('flac', f'{model_size}-txt') + ".gz"
|
109 |
+
|
110 |
+
@call_parse
|
111 |
+
def process_shard(
|
112 |
+
input:str, # input shard URL/path
|
113 |
+
output:str=None, # output shard URL/path
|
114 |
+
bs:int=None, # batch size (16 uses around 11GB of VRAM)
|
115 |
+
n_samples:int=None, # limit the number of samples (useful for quick benchmarking)
|
116 |
+
whisper_model:str="base.en" # Whisper model size
|
117 |
+
):
|
118 |
+
if output is None: output = flac_to_txt_name(input, whisper_model)
|
119 |
+
if bs is None: bs = 16
|
120 |
+
if n_samples is None: n_samples = 'noinfer'
|
121 |
+
else: n_samples = n_samples // bs
|
122 |
+
|
123 |
+
ds = wds_compose(vad.load_dataset(input),
|
124 |
+
merge_in(wds.WebDataset(vad.flac_to_vad_name(input)).decode()),
|
125 |
+
wds.map_dict(**{"vad.npy":chunk_merger}),
|
126 |
+
split_to_chunks,
|
127 |
+
wds.to_tuple('__key__', 'samples'),
|
128 |
+
wds.batched(bs),
|
129 |
+
)
|
130 |
+
dl = DataLoader(ds, num_workers=2, batch_size=None)
|
131 |
+
|
132 |
+
whmodel = whisper.load_model(whisper_model)
|
133 |
+
decoding_options = whisper.DecodingOptions(language='en')
|
134 |
+
|
135 |
+
tmp = output+".tmp"
|
136 |
+
with wds.TarWriter(tmp) as sink:
|
137 |
+
for keys, samples in progress_bar(dl, total=n_samples):
|
138 |
+
with torch.no_grad():
|
139 |
+
embs = whmodel.encoder(whisper.log_mel_spectrogram(samples).cuda())
|
140 |
+
decs = whmodel.decode(embs, decoding_options)
|
141 |
+
for key, dec in zip(keys, decs):
|
142 |
+
sink.write({
|
143 |
+
"__key__": key,
|
144 |
+
"txt": dec.text,
|
145 |
+
})
|
146 |
+
os.rename(tmp, output)
|