|
import soundfile |
|
import io |
|
from typing import Any, Tuple, Union, Optional |
|
import numpy as np |
|
import torch |
|
|
|
def preprocess_wav(data: Any, incoming_sample_rate) -> Tuple[np.ndarray, int]: |
|
segment, sample_rate = soundfile.read( |
|
io.BytesIO(data), |
|
dtype="float32", |
|
always_2d=True, |
|
frames=-1, |
|
start=0, |
|
format="RAW", |
|
subtype="PCM_16", |
|
samplerate=incoming_sample_rate, |
|
channels=1, |
|
) |
|
return segment, sample_rate |
|
|
|
def convert_waveform( |
|
waveform: Union[np.ndarray, torch.Tensor], |
|
sample_rate: int, |
|
normalize_volume: bool = False, |
|
to_mono: bool = False, |
|
to_sample_rate: Optional[int] = None, |
|
) -> Tuple[Union[np.ndarray, torch.Tensor], int]: |
|
"""convert a waveform: |
|
- to a target sample rate |
|
- from multi-channel to mono channel |
|
- volume normalization |
|
|
|
Args: |
|
waveform (numpy.ndarray or torch.Tensor): 2D original waveform |
|
(channels x length) |
|
sample_rate (int): original sample rate |
|
normalize_volume (bool): perform volume normalization |
|
to_mono (bool): convert to mono channel if having multiple channels |
|
to_sample_rate (Optional[int]): target sample rate |
|
Returns: |
|
waveform (numpy.ndarray): converted 2D waveform (channels x length) |
|
sample_rate (float): target sample rate |
|
""" |
|
try: |
|
import torchaudio.sox_effects as ta_sox |
|
except ImportError: |
|
raise ImportError("Please install torchaudio: pip install torchaudio") |
|
|
|
effects = [] |
|
if normalize_volume: |
|
effects.append(["gain", "-n"]) |
|
if to_sample_rate is not None and to_sample_rate != sample_rate: |
|
effects.append(["rate", f"{to_sample_rate}"]) |
|
if to_mono and waveform.shape[0] > 1: |
|
effects.append(["channels", "1"]) |
|
if len(effects) > 0: |
|
is_np_input = isinstance(waveform, np.ndarray) |
|
_waveform = torch.from_numpy(waveform) if is_np_input else waveform |
|
converted, converted_sample_rate = ta_sox.apply_effects_tensor( |
|
_waveform, sample_rate, effects |
|
) |
|
if is_np_input: |
|
converted = converted.numpy() |
|
return converted, converted_sample_rate |
|
return waveform, sample_rate |