Spaces:
Running
Running
File size: 5,488 Bytes
67c46fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
import io
import torch
import numpy as np
import torchaudio
from torch.nn.utils.rnn import pad_sequence
try:
from funasr_detach.download.file import download_from_url
except:
print("urllib is not installed, if you infer from url, please install it first.")
def load_audio_text_image_video(
data_or_path_or_list,
fs: int = 16000,
audio_fs: int = 16000,
data_type="sound",
tokenizer=None,
**kwargs
):
if isinstance(data_or_path_or_list, (list, tuple)):
if data_type is not None and isinstance(data_type, (list, tuple)):
data_types = [data_type] * len(data_or_path_or_list)
data_or_path_or_list_ret = [[] for d in data_type]
for i, (data_type_i, data_or_path_or_list_i) in enumerate(
zip(data_types, data_or_path_or_list)
):
for j, (data_type_j, data_or_path_or_list_j) in enumerate(
zip(data_type_i, data_or_path_or_list_i)
):
data_or_path_or_list_j = load_audio_text_image_video(
data_or_path_or_list_j,
fs=fs,
audio_fs=audio_fs,
data_type=data_type_j,
tokenizer=tokenizer,
**kwargs
)
data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
return data_or_path_or_list_ret
else:
return [
load_audio_text_image_video(
audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs
)
for audio in data_or_path_or_list
]
if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith(
"http"
): # download url to local file
data_or_path_or_list = download_from_url(data_or_path_or_list)
if isinstance(data_or_path_or_list, io.BytesIO):
data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
if kwargs.get("reduce_channels", True):
data_or_path_or_list = data_or_path_or_list.mean(0)
elif isinstance(data_or_path_or_list, str) and os.path.exists(
data_or_path_or_list
): # local file
if data_type is None or data_type == "sound":
data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
if kwargs.get("reduce_channels", True):
data_or_path_or_list = data_or_path_or_list.mean(0)
elif data_type == "text" and tokenizer is not None:
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
elif data_type == "image": # undo
pass
elif data_type == "video": # undo
pass
# if data_in is a file or url, set is_final=True
if "cache" in kwargs:
kwargs["cache"]["is_final"] = True
kwargs["cache"]["is_streaming_input"] = False
elif (
isinstance(data_or_path_or_list, str)
and data_type == "text"
and tokenizer is not None
):
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point
data_or_path_or_list = torch.from_numpy(
data_or_path_or_list
).squeeze() # [n_samples,]
else:
pass
# print(f"unsupport data type: {data_or_path_or_list}, return raw data")
if audio_fs != fs and data_type != "text":
resampler = torchaudio.transforms.Resample(audio_fs, fs)
data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
return data_or_path_or_list
def load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in "iu":
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype("float32")
if dtype.kind != "f":
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
array = np.frombuffer(
(middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32
)
return array
def extract_fbank(
data, data_len=None, data_type: str = "sound", frontend=None, **kwargs
):
# import pdb;
# pdb.set_trace()
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
if len(data.shape) < 2:
data = data[None, :] # data: [batch, N]
data_len = [data.shape[1]] if data_len is None else data_len
elif isinstance(data, torch.Tensor):
if len(data.shape) < 2:
data = data[None, :] # data: [batch, N]
data_len = [data.shape[1]] if data_len is None else data_len
elif isinstance(data, (list, tuple)):
data_list, data_len = [], []
for data_i in data:
if isinstance(data_i, np.ndarray):
data_i = torch.from_numpy(data_i)
data_list.append(data_i)
data_len.append(data_i.shape[0])
data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
# import pdb;
# pdb.set_trace()
# if data_type == "sound":
data, data_len = frontend(data, data_len, **kwargs)
if isinstance(data_len, (list, tuple)):
data_len = torch.tensor([data_len])
return data.to(torch.float32), data_len.to(torch.int32)
|