File size: 5,254 Bytes
33d9042
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/2A. Whisper quantization dataset preparation.ipynb.

# %% auto 0
__all__ = []

# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 3
import os
import io
import time
import torch
import torchaudio

# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 4
from pathlib import Path
import json
from fastprogress import progress_bar, master_bar
import numpy as np
import random

import whisper

from torch import nn
import torch.nn.functional as F
from torch.utils.data.dataloader import DataLoader

from fastcore.script import *

from . import vad
import webdataset as wds

# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 9
# let's make it a bit more conservative
# with full 30 second chunks it sometimes misses a small part of the transcript
def random_cutter(dur):
    if random.random() < 0.5:
        return dur > 28 * (random.random()*0.95+0.05)
    else:
        return dur > 28

def chunk_merger(segments, should_cut=lambda x: x > 28):
    if len(segments) == 0: return segments
    curr_start = segments[0][0]
    curr_end = 0
    merged = []

    for ts,te in segments:
        if should_cut(te - curr_start) and curr_end - curr_start > 0:
            merged.append((curr_start, curr_end))
            curr_start = ts
        curr_end = te
    merged.append((curr_start, curr_end))
    return merged

# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 18
def merge_in(*datasets):
    """Merge multiple datasets into the current one returning samples with the union of keys.
    
    It requires (and validates) all datasets to have the same ordering of keys so you have
    to use it before any sample shuffling. Shard shuffling is ok.
    """
    def merge_loop(main_samples):
        for samples in zip(*[main_samples]+[iter(x) for x in datasets]):
            key = samples[0]['__key__']
            news = {}
            for s in samples:
                assert s['__key__'] == key
                news.update(s)
            yield news
    return merge_loop

# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 19
import copy

# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 20
# a workaround for https://github.com/webdataset/webdataset/issues/297
# should be possible to use ds.compose here
def wds_compose(ds, *args):
    ds = copy.copy(ds)
    ds.pipeline = copy.copy(ds.pipeline)
    for f in args:
        ds.append(f)
    return ds

# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 24
def split_to_chunks(stream, pad_to_seconds=30, random_shift=False):
    for s in stream:
        audio, sr = s.get('flac', s.get('wav', (None, None)))
        if audio is None:
            print(f"warning: '{s['__key__']}' does not contain an audio file")
            continue
        imax = len(s['vad.npy']) - 1
        for i,(ts,te) in enumerate(s['vad.npy']):
            samples = audio[0,int(ts*sr):int(te*sr)]
            if pad_to_seconds is not None:
                padding = pad_to_seconds*sr-samples.shape[-1]
                lpad = random.randint(0, padding) if random_shift else 0
                samples = F.pad(samples, (lpad, padding-lpad))
            yield {"__key__": s['__key__'] + f"_{i:03d}",
                   "__url__": s['__url__'],
                   "i": i, "imax": imax,
                   "tstart": ts, "tend": te, "total_seconds": audio.shape[-1]/sr,
                   "lpad": lpad, "rpad": padding-lpad,
                   "lpad_s": lpad/sr, "rpad_s": (padding-lpad)/sr,
                   "samples": samples, "sample_rate": sr}

# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 38
def flac_to_txt_name(input, model_size):
    return input.rsplit("/", 1)[1].replace('flac', f'{model_size}-txt') + ".gz"

@call_parse
def process_shard(
    input:str,          # input shard URL/path
    output:str=None,    # output shard URL/path
    bs:int=None,        # batch size (16 uses around 11GB of VRAM)
    n_samples:int=None, # limit the number of samples (useful for quick benchmarking)
    whisper_model:str="base.en" # Whisper model size
):
    if output is None: output = flac_to_txt_name(input, whisper_model)
    if bs is None: bs = 16
    if n_samples is None: n_samples = 'noinfer'
    else: n_samples = n_samples // bs

    ds = wds_compose(vad.load_dataset(input),
        merge_in(wds.WebDataset(vad.flac_to_vad_name(input)).decode()),
        wds.map_dict(**{"vad.npy":chunk_merger}),
        split_to_chunks,
        wds.to_tuple('__key__', 'samples'),
        wds.batched(bs),
    )
    dl = DataLoader(ds, num_workers=2, batch_size=None)
    
    whmodel = whisper.load_model(whisper_model)
    decoding_options = whisper.DecodingOptions(language='en')
    
    tmp = output+".tmp"
    with wds.TarWriter(tmp) as sink:
        for keys, samples in progress_bar(dl, total=n_samples):
            with torch.no_grad():
                embs = whmodel.encoder(whisper.log_mel_spectrogram(samples).cuda())
                decs = whmodel.decode(embs, decoding_options)
            for key, dec in zip(keys, decs):
                sink.write({
                    "__key__": key,
                    "txt": dec.text,
                })
    os.rename(tmp, output)