File size: 4,315 Bytes
33d9042
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/4A. S2A dataset preparation.ipynb.

# %% auto 0
__all__ = ['flac_to_s2a_name']

# %% ../nbs/4A. S2A dataset preparation.ipynb 2
import sys
import os
import itertools
from pathlib import Path

import numpy as np
import torch
import torchaudio
import torch.nn.functional as F
from torch.profiler import profile, record_function, ProfilerActivity

from fastprogress import progress_bar
from fastcore.script import *

import whisper
from . import vad, wh_transcribe, vq_stoks, extract_acoustic
import webdataset as wds

# %% ../nbs/4A. S2A dataset preparation.ipynb 4
def flac_to_s2a_name(input):
    if '-flac-' in input:
        return input.rsplit("/", 1)[1].replace('flac', 's2a') + ".gz"
    else:
        return input.rsplit("/", 1)[1].replace('raw', 's2a') + ".gz"

# %% ../nbs/4A. S2A dataset preparation.ipynb 6
def resampler(newsr = 24000, key = 'samples_24k'):
    _last_sr = None
    tform = None
    
    def _resample(samples):
        for s in samples:
            sr = s['sample_rate']
            if sr != newsr:
                if sr != _last_sr: tform = torchaudio.transforms.Resample(sr, newsr)
                s[key] = tform(s['samples'])
            else:
                s[key] = s['samples']
            yield s
    
    return _resample

# %% ../nbs/4A. S2A dataset preparation.ipynb 9
@call_parse
def prepare_s2a(
    input:str,  # FLAC webdataset file path (or - to read the names from stdin)
    proc_dataset_path:Path, # processed VAD files path
    output:str=None, # output file name
    vq_model:str="collabora/spear-tts-pytorch:whisper-vq-stoks.model", # the model path (use repo_id:filename to download it from hugginface)
    n_samples:int=None, # process a limited amount of samples
    batch_size:int=1, # process several segments at once
    fix_dots:bool=False, # fix dots in file names
):
    if ":" in vq_model:
        repo, fname = vq_model.split(":", 1)
        vq_model = vq_stoks.RQBottleneckTransformer.load_model(repo, fname).cuda()
    else:
        vq_model = vq_stoks.RQBottleneckTransformer.load_model(local_filename=vq_model).cuda()
    amodel = extract_acoustic.load_model()
    amodel.set_target_bandwidth(3)

    if input == "-":
        input = [f.strip() for f in sys.stdin.readlines()]
        assert output, "please provide the output shard name"
    else:
        if output is None: output = flac_to_s2a_name(input)
        input = [input]
        
    total = n_samples//batch_size if n_samples else 'noinfer'

    ds = wds.WebDataset(input, shardshuffle=True, rename_files=vad.fix_dots_in_names if fix_dots else None).compose(
        wds.decode(wds.torch_audio),
        wds.select(lambda x: 'wav' in x or 'flac' in x),
        vq_stoks.merge_in(vq_stoks.derived_dataset(proc_dataset_path, 'vad')),
        wds.map_dict(**{"vad.npy":wh_transcribe.chunk_merger}),
        lambda x: wh_transcribe.split_to_chunks(x),
        resampler(),
        resampler(16000, 'samples_16k'),
        wds.to_tuple('__key__', 'rpad_s', 'samples_16k', 'samples_24k'),
        wds.batched(64),
    )

    dl = wds.WebLoader(ds, num_workers=4, batch_size=None).unbatched().shuffle(2000).batched(batch_size)

    speakers = set()
    tmp = output+".tmp"
    with wds.TarWriter(tmp) as sink:
        for keys, rpad_ss, samples, samples24k in progress_bar(dl, total=total):
            with record_function('to_cuda'):
                samples, samples24k = samples.cuda(), samples24k.unsqueeze(1).cuda()
            with record_function('encodec'):
                atoks = amodel.encode(samples24k)[0][0]
            with record_function('vq_stoks'):
                stoks = vq_model.encode_audio(samples)
            with record_function('from_cuda'):
                atoks, stoks = atoks.cpu().numpy().astype(np.int16), stoks.cpu().numpy().astype(np.int16)
            for key, rpad_s, _atoks, _stoks in zip(keys, rpad_ss, atoks, stoks):
                speakers.add(key.split('/')[1])
                sink.write({
                    "__key__": key,
                    "atoks.npy": _atoks[:,:int(-rpad_s * 75)],
                    "stoks.npy": _stoks[:int(-rpad_s * 25)],
                })
    with open(output+".speakers.txt", "w") as f: f.write("\n".join(speakers))
    if not n_samples:
        os.rename(tmp, output)