Spaces:
Paused
Paused
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/5A. T2S dataset preparation.ipynb. | |
# %% auto 0 | |
__all__ = [] | |
# %% ../nbs/5A. T2S dataset preparation.ipynb 2 | |
import sys | |
import os | |
import itertools | |
from pathlib import Path | |
import numpy as np | |
import torch | |
import torchaudio | |
import torch.nn.functional as F | |
from torch.profiler import profile, record_function, ProfilerActivity | |
from fastprogress import progress_bar | |
from fastcore.script import * | |
import whisper, whisperx | |
from . import vad, wh_transcribe, vq_stoks, extract_acoustic | |
import webdataset as wds | |
# %% ../nbs/5A. T2S dataset preparation.ipynb 4 | |
def flac_to_t2s_name(input): | |
return input.rsplit("/", 1)[1].replace('flac', 't2s') + ".gz" | |
# %% ../nbs/5A. T2S dataset preparation.ipynb 6 | |
class Transcriber: | |
""" | |
A helper class to transcribe a batch of 30 second audio chunks. | |
""" | |
def __init__(self, model_size, lang=False): | |
self.model = whisperx.asr.load_model(model_size, "cuda", compute_type="float16", language=lang) | |
# without calling vad_model at least once the rest segfaults for some reason... | |
self.model.vad_model({"waveform": torch.zeros(1, 16000), "sample_rate": 16000}) | |
def transcribe(self, batch): | |
batch = whisper.log_mel_spectrogram(batch) | |
embs = self.model.model.encode(batch.cpu().numpy()) | |
return self.model.tokenizer.tokenizer.decode_batch([x.sequences_ids[0] for x in | |
self.model.model.model.generate( | |
embs, | |
[self.model.model.get_prompt(self.model.tokenizer, [], without_timestamps=True)]*len(batch), | |
)]) | |
# %% ../nbs/5A. T2S dataset preparation.ipynb 7 | |
def prepare_t2s( | |
input:str, # FLAC webdataset file path (or - to read the names from stdin) | |
proc_dataset_path:Path, # processed VAD files path | |
output:str=None, # output file name | |
vq_model:str="collabora/spear-tts-pytorch:whisper-vq-stoks.model", # the model path (use repo_id:filename to download it from hugginface) | |
n_samples:int=None, # process a limited amount of samples | |
batch_size:int=1, # process several segments at once | |
transcription_model:str="small.en", | |
): | |
if ":" in vq_model: | |
repo, fname = vq_model.split(":", 1) | |
vq_model = vq_stoks.RQBottleneckTransformer.load_model(repo, fname).cuda() | |
else: | |
vq_model = vq_stoks.RQBottleneckTransformer.load_model(local_filename=vq_model).cuda() | |
transcriber = Transcriber(transcription_model) | |
if input == "-": | |
input = [f.strip() for f in sys.stdin.readlines()] | |
assert output, "please provide the output shard name" | |
else: | |
if output is None: output = flac_to_t2s_name(input) | |
input = [input] | |
total = n_samples//batch_size if n_samples else 'noinfer' | |
if n_samples: print(f"Benchmarking run of {n_samples} samples ({total} batches)") | |
ds = wds.WebDataset(input, shardshuffle=True, rename_files=vad.fix_dots_in_names).compose( | |
wds.decode(wds.torch_audio), | |
vq_stoks.merge_in(vq_stoks.derived_dataset(proc_dataset_path, 'vad')), | |
wds.map_dict(**{"vad.npy": lambda s: wh_transcribe.chunk_merger(s, wh_transcribe.random_cutter)}), | |
lambda x: wh_transcribe.split_to_chunks(x), | |
# drop the first and last segment because they tend to be inaccurate | |
# (the transcriptions don't have the "LibriVox" header and "end of chapter" suffix) | |
wds.select(lambda x: x['i'] != 0 and x['i'] != x['imax']), | |
wds.to_tuple('__key__', 'rpad', 'samples'), | |
wds.batched(64), | |
) | |
dl = wds.WebLoader(ds, num_workers=4, batch_size=None).unbatched().shuffle(2000).batched(batch_size) | |
speakers = set() | |
tmp = output+".tmp" | |
with wds.TarWriter(tmp) as sink: | |
for keys, rpads, samples in progress_bar(dl, total=total): | |
with record_function('to_cuda'): | |
csamples = samples.cuda() | |
with record_function('transcribe'): | |
txts = transcriber.transcribe(csamples) | |
with record_function('vq_stoks'): | |
stoks = vq_model.encode_audio(csamples) | |
with record_function('from_cuda'): | |
stoks = stoks.cpu().numpy().astype(np.int16) | |
for key, rpad, txt, _stoks in zip(keys, rpads, txts, stoks): | |
speakers.add(key.split('/')[1]) | |
sink.write({ | |
"__key__": key, | |
"txt": txt, | |
"stoks.npy": _stoks[:int(-rpad/16000 * 25)], | |
}) | |
with open(output+".speakers.txt", "w") as f: f.write("\n".join(speakers)) | |
if not n_samples: | |
os.rename(tmp, output) | |