File size: 5,488 Bytes
67c46fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import io
import torch
import numpy as np
import torchaudio
from torch.nn.utils.rnn import pad_sequence

try:
    from funasr_detach.download.file import download_from_url
except:
    print("urllib is not installed, if you infer from url, please install it first.")


def load_audio_text_image_video(
    data_or_path_or_list,
    fs: int = 16000,
    audio_fs: int = 16000,
    data_type="sound",
    tokenizer=None,
    **kwargs
):
    if isinstance(data_or_path_or_list, (list, tuple)):
        if data_type is not None and isinstance(data_type, (list, tuple)):

            data_types = [data_type] * len(data_or_path_or_list)
            data_or_path_or_list_ret = [[] for d in data_type]
            for i, (data_type_i, data_or_path_or_list_i) in enumerate(
                zip(data_types, data_or_path_or_list)
            ):

                for j, (data_type_j, data_or_path_or_list_j) in enumerate(
                    zip(data_type_i, data_or_path_or_list_i)
                ):

                    data_or_path_or_list_j = load_audio_text_image_video(
                        data_or_path_or_list_j,
                        fs=fs,
                        audio_fs=audio_fs,
                        data_type=data_type_j,
                        tokenizer=tokenizer,
                        **kwargs
                    )
                    data_or_path_or_list_ret[j].append(data_or_path_or_list_j)

            return data_or_path_or_list_ret
        else:
            return [
                load_audio_text_image_video(
                    audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs
                )
                for audio in data_or_path_or_list
            ]

    if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith(
        "http"
    ):  # download url to local file
        data_or_path_or_list = download_from_url(data_or_path_or_list)

    if isinstance(data_or_path_or_list, io.BytesIO):
        data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
        if kwargs.get("reduce_channels", True):
            data_or_path_or_list = data_or_path_or_list.mean(0)
    elif isinstance(data_or_path_or_list, str) and os.path.exists(
        data_or_path_or_list
    ):  # local file
        if data_type is None or data_type == "sound":
            data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
            if kwargs.get("reduce_channels", True):
                data_or_path_or_list = data_or_path_or_list.mean(0)
        elif data_type == "text" and tokenizer is not None:
            data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
        elif data_type == "image":  # undo
            pass
        elif data_type == "video":  # undo
            pass

        # if data_in is a file or url, set is_final=True
        if "cache" in kwargs:
            kwargs["cache"]["is_final"] = True
            kwargs["cache"]["is_streaming_input"] = False
    elif (
        isinstance(data_or_path_or_list, str)
        and data_type == "text"
        and tokenizer is not None
    ):
        data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
    elif isinstance(data_or_path_or_list, np.ndarray):  # audio sample point
        data_or_path_or_list = torch.from_numpy(
            data_or_path_or_list
        ).squeeze()  # [n_samples,]
    else:
        pass
        # print(f"unsupport data type: {data_or_path_or_list}, return raw data")

    if audio_fs != fs and data_type != "text":
        resampler = torchaudio.transforms.Resample(audio_fs, fs)
        data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
    return data_or_path_or_list


def load_bytes(input):
    middle_data = np.frombuffer(input, dtype=np.int16)
    middle_data = np.asarray(middle_data)
    if middle_data.dtype.kind not in "iu":
        raise TypeError("'middle_data' must be an array of integers")
    dtype = np.dtype("float32")
    if dtype.kind != "f":
        raise TypeError("'dtype' must be a floating point type")

    i = np.iinfo(middle_data.dtype)
    abs_max = 2 ** (i.bits - 1)
    offset = i.min + abs_max
    array = np.frombuffer(
        (middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32
    )
    return array


def extract_fbank(
    data, data_len=None, data_type: str = "sound", frontend=None, **kwargs
):
    # import pdb;
    # pdb.set_trace()
    if isinstance(data, np.ndarray):
        data = torch.from_numpy(data)
        if len(data.shape) < 2:
            data = data[None, :]  # data: [batch, N]
        data_len = [data.shape[1]] if data_len is None else data_len
    elif isinstance(data, torch.Tensor):
        if len(data.shape) < 2:
            data = data[None, :]  # data: [batch, N]
        data_len = [data.shape[1]] if data_len is None else data_len
    elif isinstance(data, (list, tuple)):
        data_list, data_len = [], []
        for data_i in data:
            if isinstance(data_i, np.ndarray):
                data_i = torch.from_numpy(data_i)
            data_list.append(data_i)
            data_len.append(data_i.shape[0])
        data = pad_sequence(data_list, batch_first=True)  # data: [batch, N]
    # import pdb;
    # pdb.set_trace()
    # if data_type == "sound":
    data, data_len = frontend(data, data_len, **kwargs)

    if isinstance(data_len, (list, tuple)):
        data_len = torch.tensor([data_len])
    return data.to(torch.float32), data_len.to(torch.int32)