|
import torch |
|
import torchaudio |
|
from typing import Callable, List |
|
import torch.nn.functional as F |
|
import warnings |
|
|
|
languages = ['ru', 'en', 'de', 'es'] |
|
|
|
|
|
class OnnxWrapper(): |
|
|
|
def __init__(self, path, force_onnx_cpu=False): |
|
import numpy as np |
|
global np |
|
import onnxruntime |
|
|
|
opts = onnxruntime.SessionOptions() |
|
opts.inter_op_num_threads = 1 |
|
opts.intra_op_num_threads = 1 |
|
|
|
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers(): |
|
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts) |
|
else: |
|
self.session = onnxruntime.InferenceSession(path, sess_options=opts) |
|
|
|
self.reset_states() |
|
self.sample_rates = [8000, 16000] |
|
|
|
def _validate_input(self, x, sr: int): |
|
if x.dim() == 1: |
|
x = x.unsqueeze(0) |
|
if x.dim() > 2: |
|
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}") |
|
|
|
if sr != 16000 and (sr % 16000 == 0): |
|
step = sr // 16000 |
|
x = x[:,::step] |
|
sr = 16000 |
|
|
|
if sr not in self.sample_rates: |
|
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)") |
|
|
|
if sr / x.shape[1] > 31.25: |
|
raise ValueError("Input audio chunk is too short") |
|
|
|
return x, sr |
|
|
|
def reset_states(self, batch_size=1): |
|
self._h = np.zeros((2, batch_size, 64)).astype('float32') |
|
self._c = np.zeros((2, batch_size, 64)).astype('float32') |
|
self._last_sr = 0 |
|
self._last_batch_size = 0 |
|
|
|
def __call__(self, x, sr: int): |
|
|
|
x, sr = self._validate_input(x, sr) |
|
batch_size = x.shape[0] |
|
|
|
if not self._last_batch_size: |
|
self.reset_states(batch_size) |
|
if (self._last_sr) and (self._last_sr != sr): |
|
self.reset_states(batch_size) |
|
if (self._last_batch_size) and (self._last_batch_size != batch_size): |
|
self.reset_states(batch_size) |
|
|
|
if sr in [8000, 16000]: |
|
ort_inputs = {'input': x.numpy(), 'h': self._h, 'c': self._c, 'sr': np.array(sr, dtype='int64')} |
|
ort_outs = self.session.run(None, ort_inputs) |
|
out, self._h, self._c = ort_outs |
|
else: |
|
raise ValueError() |
|
|
|
self._last_sr = sr |
|
self._last_batch_size = batch_size |
|
|
|
out = torch.tensor(out) |
|
return out |
|
|
|
def audio_forward(self, x, sr: int, num_samples: int = 512): |
|
outs = [] |
|
x, sr = self._validate_input(x, sr) |
|
|
|
if x.shape[1] % num_samples: |
|
pad_num = num_samples - (x.shape[1] % num_samples) |
|
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0) |
|
|
|
self.reset_states(x.shape[0]) |
|
for i in range(0, x.shape[1], num_samples): |
|
wavs_batch = x[:, i:i+num_samples] |
|
out_chunk = self.__call__(wavs_batch, sr) |
|
outs.append(out_chunk) |
|
|
|
stacked = torch.cat(outs, dim=1) |
|
return stacked.cpu() |
|
|
|
|
|
class Validator(): |
|
def __init__(self, url, force_onnx_cpu): |
|
self.onnx = True if url.endswith('.onnx') else False |
|
torch.hub.download_url_to_file(url, 'inf.model') |
|
if self.onnx: |
|
import onnxruntime |
|
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers(): |
|
self.model = onnxruntime.InferenceSession('inf.model', providers=['CPUExecutionProvider']) |
|
else: |
|
self.model = onnxruntime.InferenceSession('inf.model') |
|
else: |
|
self.model = init_jit_model(model_path='inf.model') |
|
|
|
def __call__(self, inputs: torch.Tensor): |
|
with torch.no_grad(): |
|
if self.onnx: |
|
ort_inputs = {'input': inputs.cpu().numpy()} |
|
outs = self.model.run(None, ort_inputs) |
|
outs = [torch.Tensor(x) for x in outs] |
|
else: |
|
outs = self.model(inputs) |
|
|
|
return outs |
|
|
|
|
|
def read_audio(path: str, |
|
sampling_rate: int = 16000): |
|
|
|
sox_backends = set(['sox', 'sox_io']) |
|
audio_backends = torchaudio.list_audio_backends() |
|
|
|
if len(sox_backends.intersection(audio_backends)) > 0: |
|
effects = [ |
|
['channels', '1'], |
|
['rate', str(sampling_rate)] |
|
] |
|
|
|
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects) |
|
else: |
|
wav, sr = torchaudio.load(path) |
|
|
|
if wav.size(0) > 1: |
|
wav = wav.mean(dim=0, keepdim=True) |
|
|
|
if sr != sampling_rate: |
|
transform = torchaudio.transforms.Resample(orig_freq=sr, |
|
new_freq=sampling_rate) |
|
wav = transform(wav) |
|
sr = sampling_rate |
|
|
|
assert sr == sampling_rate |
|
return wav.squeeze(0) |
|
|
|
|
|
def save_audio(path: str, |
|
tensor: torch.Tensor, |
|
sampling_rate: int = 16000): |
|
torchaudio.save(path, tensor.unsqueeze(0), sampling_rate, bits_per_sample=16) |
|
|
|
|
|
def init_jit_model(model_path: str, |
|
device=torch.device('cpu')): |
|
torch.set_grad_enabled(False) |
|
model = torch.jit.load(model_path, map_location=device) |
|
model.eval() |
|
return model |
|
|
|
|
|
def make_visualization(probs, step): |
|
import pandas as pd |
|
pd.DataFrame({'probs': probs}, |
|
index=[x * step for x in range(len(probs))]).plot(figsize=(16, 8), |
|
kind='area', ylim=[0, 1.05], xlim=[0, len(probs) * step], |
|
xlabel='seconds', |
|
ylabel='speech probability', |
|
colormap='tab20') |
|
|
|
|
|
def get_speech_timestamps(audio: torch.Tensor, |
|
model, |
|
threshold: float = 0.5, |
|
sampling_rate: int = 16000, |
|
min_speech_duration_ms: int = 250, |
|
max_speech_duration_s: float = float('inf'), |
|
min_silence_duration_ms: int = 100, |
|
window_size_samples: int = 512, |
|
speech_pad_ms: int = 30, |
|
return_seconds: bool = False, |
|
visualize_probs: bool = False, |
|
progress_tracking_callback: Callable[[float], None] = None): |
|
|
|
""" |
|
This method is used for splitting long audios into speech chunks using silero VAD |
|
|
|
Parameters |
|
---------- |
|
audio: torch.Tensor, one dimensional |
|
One dimensional float torch.Tensor, other types are casted to torch if possible |
|
|
|
model: preloaded .jit silero VAD model |
|
|
|
threshold: float (default - 0.5) |
|
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. |
|
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. |
|
|
|
sampling_rate: int (default - 16000) |
|
Currently silero VAD models support 8000 and 16000 sample rates |
|
|
|
min_speech_duration_ms: int (default - 250 milliseconds) |
|
Final speech chunks shorter min_speech_duration_ms are thrown out |
|
|
|
max_speech_duration_s: int (default - inf) |
|
Maximum duration of speech chunks in seconds |
|
Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that lasts more than 100ms (if any), to prevent agressive cutting. |
|
Otherwise, they will be split aggressively just before max_speech_duration_s. |
|
|
|
min_silence_duration_ms: int (default - 100 milliseconds) |
|
In the end of each speech chunk wait for min_silence_duration_ms before separating it |
|
|
|
window_size_samples: int (default - 1536 samples) |
|
Audio chunks of window_size_samples size are fed to the silero VAD model. |
|
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate and 256, 512, 768 samples for 8000 sample rate. |
|
Values other than these may affect model perfomance!! |
|
|
|
speech_pad_ms: int (default - 30 milliseconds) |
|
Final speech chunks are padded by speech_pad_ms each side |
|
|
|
return_seconds: bool (default - False) |
|
whether return timestamps in seconds (default - samples) |
|
|
|
visualize_probs: bool (default - False) |
|
whether draw prob hist or not |
|
|
|
progress_tracking_callback: Callable[[float], None] (default - None) |
|
callback function taking progress in percents as an argument |
|
|
|
Returns |
|
---------- |
|
speeches: list of dicts |
|
list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds) |
|
""" |
|
|
|
if not torch.is_tensor(audio): |
|
try: |
|
audio = torch.Tensor(audio) |
|
except: |
|
raise TypeError("Audio cannot be casted to tensor. Cast it manually") |
|
|
|
if len(audio.shape) > 1: |
|
for i in range(len(audio.shape)): |
|
audio = audio.squeeze(0) |
|
if len(audio.shape) > 1: |
|
raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?") |
|
|
|
if sampling_rate > 16000 and (sampling_rate % 16000 == 0): |
|
step = sampling_rate // 16000 |
|
sampling_rate = 16000 |
|
audio = audio[::step] |
|
warnings.warn('Sampling rate is a multiply of 16000, casting to 16000 manually!') |
|
else: |
|
step = 1 |
|
|
|
if sampling_rate == 8000 and window_size_samples > 768: |
|
warnings.warn('window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 768 for 8000 sample rate!') |
|
if window_size_samples not in [256, 512, 768, 1024, 1536]: |
|
warnings.warn('Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 16000 sampling_rate\n - [256, 512, 768] for 8000 sampling_rate') |
|
|
|
model.reset_states() |
|
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 |
|
speech_pad_samples = sampling_rate * speech_pad_ms / 1000 |
|
max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples |
|
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 |
|
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 |
|
|
|
audio_length_samples = len(audio) |
|
|
|
speech_probs = [] |
|
for current_start_sample in range(0, audio_length_samples, window_size_samples): |
|
chunk = audio[current_start_sample: current_start_sample + window_size_samples] |
|
if len(chunk) < window_size_samples: |
|
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk)))) |
|
speech_prob = model(chunk, sampling_rate).item() |
|
speech_probs.append(speech_prob) |
|
|
|
progress = current_start_sample + window_size_samples |
|
if progress > audio_length_samples: |
|
progress = audio_length_samples |
|
progress_percent = (progress / audio_length_samples) * 100 |
|
if progress_tracking_callback: |
|
progress_tracking_callback(progress_percent) |
|
|
|
triggered = False |
|
speeches = [] |
|
current_speech = {} |
|
neg_threshold = threshold - 0.15 |
|
temp_end = 0 |
|
prev_end = next_start = 0 |
|
|
|
for i, speech_prob in enumerate(speech_probs): |
|
if (speech_prob >= threshold) and temp_end: |
|
temp_end = 0 |
|
if next_start < prev_end: |
|
next_start = window_size_samples * i |
|
|
|
if (speech_prob >= threshold) and not triggered: |
|
triggered = True |
|
current_speech['start'] = window_size_samples * i |
|
continue |
|
|
|
if triggered and (window_size_samples * i) - current_speech['start'] > max_speech_samples: |
|
if prev_end: |
|
current_speech['end'] = prev_end |
|
speeches.append(current_speech) |
|
current_speech = {} |
|
if next_start < prev_end: |
|
triggered = False |
|
else: |
|
current_speech['start'] = next_start |
|
prev_end = next_start = temp_end = 0 |
|
else: |
|
current_speech['end'] = window_size_samples * i |
|
speeches.append(current_speech) |
|
current_speech = {} |
|
prev_end = next_start = temp_end = 0 |
|
triggered = False |
|
continue |
|
|
|
if (speech_prob < neg_threshold) and triggered: |
|
if not temp_end: |
|
temp_end = window_size_samples * i |
|
if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech : |
|
prev_end = temp_end |
|
if (window_size_samples * i) - temp_end < min_silence_samples: |
|
continue |
|
else: |
|
current_speech['end'] = temp_end |
|
if (current_speech['end'] - current_speech['start']) > min_speech_samples: |
|
speeches.append(current_speech) |
|
current_speech = {} |
|
prev_end = next_start = temp_end = 0 |
|
triggered = False |
|
continue |
|
|
|
if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples: |
|
current_speech['end'] = audio_length_samples |
|
speeches.append(current_speech) |
|
|
|
for i, speech in enumerate(speeches): |
|
if i == 0: |
|
speech['start'] = int(max(0, speech['start'] - speech_pad_samples)) |
|
if i != len(speeches) - 1: |
|
silence_duration = speeches[i+1]['start'] - speech['end'] |
|
if silence_duration < 2 * speech_pad_samples: |
|
speech['end'] += int(silence_duration // 2) |
|
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - silence_duration // 2)) |
|
else: |
|
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples)) |
|
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - speech_pad_samples)) |
|
else: |
|
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples)) |
|
|
|
if return_seconds: |
|
for speech_dict in speeches: |
|
speech_dict['start'] = round(speech_dict['start'] / sampling_rate, 1) |
|
speech_dict['end'] = round(speech_dict['end'] / sampling_rate, 1) |
|
elif step > 1: |
|
for speech_dict in speeches: |
|
speech_dict['start'] *= step |
|
speech_dict['end'] *= step |
|
|
|
if visualize_probs: |
|
make_visualization(speech_probs, window_size_samples / sampling_rate) |
|
|
|
return speeches |
|
|
|
|
|
def get_number_ts(wav: torch.Tensor, |
|
model, |
|
model_stride=8, |
|
hop_length=160, |
|
sample_rate=16000): |
|
wav = torch.unsqueeze(wav, dim=0) |
|
perframe_logits = model(wav)[0] |
|
perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() |
|
extended_preds = [] |
|
for i in perframe_preds: |
|
extended_preds.extend([i.item()] * model_stride) |
|
|
|
triggered = False |
|
timings = [] |
|
cur_timing = {} |
|
for i, pred in enumerate(extended_preds): |
|
if pred == 1: |
|
if not triggered: |
|
cur_timing['start'] = int((i * hop_length) / (sample_rate / 1000)) |
|
triggered = True |
|
elif pred == 0: |
|
if triggered: |
|
cur_timing['end'] = int((i * hop_length) / (sample_rate / 1000)) |
|
timings.append(cur_timing) |
|
cur_timing = {} |
|
triggered = False |
|
if cur_timing: |
|
cur_timing['end'] = int(len(wav) / (sample_rate / 1000)) |
|
timings.append(cur_timing) |
|
return timings |
|
|
|
|
|
def get_language(wav: torch.Tensor, |
|
model): |
|
wav = torch.unsqueeze(wav, dim=0) |
|
lang_logits = model(wav)[2] |
|
lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() |
|
assert lang_pred < len(languages) |
|
return languages[lang_pred] |
|
|
|
|
|
def get_language_and_group(wav: torch.Tensor, |
|
model, |
|
lang_dict: dict, |
|
lang_group_dict: dict, |
|
top_n=1): |
|
wav = torch.unsqueeze(wav, dim=0) |
|
lang_logits, lang_group_logits = model(wav) |
|
|
|
softm = torch.softmax(lang_logits, dim=1).squeeze() |
|
softm_group = torch.softmax(lang_group_logits, dim=1).squeeze() |
|
|
|
srtd = torch.argsort(softm, descending=True) |
|
srtd_group = torch.argsort(softm_group, descending=True) |
|
|
|
outs = [] |
|
outs_group = [] |
|
for i in range(top_n): |
|
prob = round(softm[srtd[i]].item(), 2) |
|
prob_group = round(softm_group[srtd_group[i]].item(), 2) |
|
outs.append((lang_dict[str(srtd[i].item())], prob)) |
|
outs_group.append((lang_group_dict[str(srtd_group[i].item())], prob_group)) |
|
|
|
return outs, outs_group |
|
|
|
|
|
class VADIterator: |
|
def __init__(self, |
|
model, |
|
threshold: float = 0.5, |
|
sampling_rate: int = 16000, |
|
min_silence_duration_ms: int = 100, |
|
speech_pad_ms: int = 30 |
|
): |
|
|
|
""" |
|
Class for stream imitation |
|
|
|
Parameters |
|
---------- |
|
model: preloaded .jit silero VAD model |
|
|
|
threshold: float (default - 0.5) |
|
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. |
|
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. |
|
|
|
sampling_rate: int (default - 16000) |
|
Currently silero VAD models support 8000 and 16000 sample rates |
|
|
|
min_silence_duration_ms: int (default - 100 milliseconds) |
|
In the end of each speech chunk wait for min_silence_duration_ms before separating it |
|
|
|
speech_pad_ms: int (default - 30 milliseconds) |
|
Final speech chunks are padded by speech_pad_ms each side |
|
""" |
|
|
|
self.model = model |
|
self.threshold = threshold |
|
self.sampling_rate = sampling_rate |
|
|
|
if sampling_rate not in [8000, 16000]: |
|
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]') |
|
|
|
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 |
|
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000 |
|
self.reset_states() |
|
|
|
def reset_states(self): |
|
|
|
self.model.reset_states() |
|
self.triggered = False |
|
self.temp_end = 0 |
|
self.current_sample = 0 |
|
|
|
def __call__(self, x, return_seconds=False): |
|
""" |
|
x: torch.Tensor |
|
audio chunk (see examples in repo) |
|
|
|
return_seconds: bool (default - False) |
|
whether return timestamps in seconds (default - samples) |
|
""" |
|
|
|
if not torch.is_tensor(x): |
|
try: |
|
x = torch.Tensor(x) |
|
except: |
|
raise TypeError("Audio cannot be casted to tensor. Cast it manually") |
|
|
|
window_size_samples = len(x[0]) if x.dim() == 2 else len(x) |
|
self.current_sample += window_size_samples |
|
|
|
speech_prob = self.model(x, self.sampling_rate).item() |
|
|
|
if (speech_prob >= self.threshold) and self.temp_end: |
|
self.temp_end = 0 |
|
|
|
if (speech_prob >= self.threshold) and not self.triggered: |
|
self.triggered = True |
|
speech_start = self.current_sample - self.speech_pad_samples - window_size_samples |
|
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)} |
|
|
|
if (speech_prob < self.threshold - 0.15) and self.triggered: |
|
if not self.temp_end: |
|
self.temp_end = self.current_sample |
|
if self.current_sample - self.temp_end < self.min_silence_samples: |
|
return None |
|
else: |
|
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples |
|
self.temp_end = 0 |
|
self.triggered = False |
|
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)} |
|
|
|
return None |
|
|
|
|
|
def collect_chunks(tss: List[dict], |
|
wav: torch.Tensor): |
|
chunks = [] |
|
for i in tss: |
|
chunks.append(wav[i['start']: i['end']]) |
|
return torch.cat(chunks) |
|
|
|
|
|
def drop_chunks(tss: List[dict], |
|
wav: torch.Tensor): |
|
chunks = [] |
|
cur_start = 0 |
|
for i in tss: |
|
chunks.append((wav[cur_start: i['start']])) |
|
cur_start = i['end'] |
|
return torch.cat(chunks) |
|
|