# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import librosa import logging import json import random import tarfile from subprocess import PIPE, Popen from urllib.parse import urlparse import torch import torchaudio import torchaudio.compliance.kaldi as kaldi import torch.nn.functional as F from gxl_ai_utils.utils import utils_file from torch.nn.utils.rnn import pad_sequence from wenet.text.base_tokenizer import BaseTokenizer # torchaudio.utils.sox_utils.set_buffer_size(16500) torchaudio.set_audio_backend("soundfile") AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) def url_opener(data): """ Give url or local file, return file descriptor Inplace operation. Args: data(Iterable[str]): url or local file list Returns: Iterable[{src, stream}] """ for sample in data: assert 'src' in sample # TODO(Binbin Zhang): support HTTP url = sample['src'] try: pr = urlparse(url) # local file if pr.scheme == '' or pr.scheme == 'file': stream = open(url, 'rb') # network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP else: cmd = f'wget -q -O - {url}' process = Popen(cmd, shell=True, stdout=PIPE) sample.update(process=process) stream = process.stdout sample.update(stream=stream) yield sample except Exception as ex: logging.warning('Failed to open {}'.format(url)) def tar_file_and_group(data): """ Expand a stream of open tar files into a stream of tar file contents. And groups the file with same prefix Args: data: Iterable[{src, stream}] Returns: Iterable[{key, wav, txt, sample_rate}] """ for sample in data: assert 'stream' in sample stream = None try: stream = tarfile.open(fileobj=sample['stream'], mode="r:*") prev_prefix = None example = {} valid = True for tarinfo in stream: name = tarinfo.name pos = name.rfind('.') assert pos > 0 prefix, postfix = name[:pos], name[pos + 1:] if prev_prefix is not None and prefix != prev_prefix: example['key'] = prev_prefix if valid: yield example example = {} valid = True with stream.extractfile(tarinfo) as file_obj: try: if postfix == 'txt': example['txt'] = file_obj.read().decode( 'utf8').strip() elif postfix in AUDIO_FORMAT_SETS: waveform, sample_rate = torchaudio.load(file_obj) example['wav'] = waveform example['sample_rate'] = sample_rate else: example[postfix] = file_obj.read() except Exception as ex: valid = False logging.warning('error to parse {}'.format(name)) prev_prefix = prefix if prev_prefix is not None: example['key'] = prev_prefix yield example except Exception as ex: logging.warning( 'In tar_file_and_group: {} when processing {}'.format( ex, sample['src'])) finally: if stream is not None: stream.close() if 'process' in sample: sample['process'].communicate() sample['stream'].close() def tar_file_and_group_full_data(data): """ Expand a stream of open tar files into a stream of tar file contents. And groups the file with same prefix Args: data: Iterable[{src, stream}] Returns: Iterable[{key, wav, txt, sample_rate}] """ for sample in data: assert 'stream' in sample stream = None try: stream = tarfile.open(fileobj=sample['stream'], mode="r:*") prev_prefix = None example = {} valid = True for tarinfo in stream: name = tarinfo.name pos = name.rfind('.') assert pos > 0 prefix, postfix = name[:pos], name[pos + 1:] if prev_prefix is not None and prefix != prev_prefix: example['key'] = prev_prefix if valid: # assert 'txt' in example if 'txt' not in example: example['txt'] = '' yield example example = {} valid = True with stream.extractfile(tarinfo) as file_obj: try: if postfix == 'txt': example['txt'] = file_obj.read().decode( 'utf8').strip() elif postfix == 'lang': example['lang'] = file_obj.read().decode( 'utf8').strip() elif postfix == 'speaker': try: example['speaker'] = file_obj.read().decode( 'utf8').strip() except Exception as ex: example['speaker'] = "none" elif postfix == 'emotion': example['emotion'] = file_obj.read().decode( 'utf8').strip() elif postfix == 'gender': example['gender'] = file_obj.read().decode( 'utf8').strip() elif postfix == 'task': example['task'] = file_obj.read().decode( 'utf8').strip() elif postfix == 'speech_token': example['speech_token'] = file_obj.read() elif postfix == 'duration': duration_str = file_obj.read().decode( 'utf8').strip() try: duration_float = float(duration_str) example['duration'] = duration_float except Exception as ex: logging.warning(f'error to parse duration {duration_str}') example['duration'] = 0 elif postfix in AUDIO_FORMAT_SETS: waveform, sample_rate = torchaudio.load(file_obj) # 检查音频的维度 num_channels = waveform.shape[0] # 如果音频是多通道的,则进行通道平均 if num_channels > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) example['wav'] = waveform example['sample_rate'] = sample_rate else: example[postfix] = file_obj.read() except Exception as ex: valid = False # logging.warning('error to parse {}'.format(name)) prev_prefix = prefix if prev_prefix is not None: example['key'] = prev_prefix if 'txt' in example: yield example except Exception as ex: logging.warning( 'In tar_file_and_group: {} when processing {}'.format( ex, sample['src'])) finally: if stream is not None: stream.close() if 'process' in sample: sample['process'].communicate() sample['stream'].close() def parse_raw(data): """ Parse key/wav/txt from json line Args: data: Iterable[str], str is a json line has key/wav/txt Returns: Iterable[{key, wav, txt, sample_rate}] """ for sample in data: assert 'src' in sample json_line = sample['src'] obj = json.loads(json_line) assert 'key' in obj assert 'wav' in obj assert 'txt' in obj key = obj['key'] wav_file = obj['wav'] txt = obj['txt'] try: if 'start' in obj: assert 'end' in obj sample_rate = torchaudio.info(wav_file).sample_rate start_frame = int(obj['start'] * sample_rate) end_frame = int(obj['end'] * sample_rate) waveform, _ = torchaudio.load(filepath=wav_file, num_frames=end_frame - start_frame, frame_offset=start_frame) else: waveform, sample_rate = torchaudio.load(wav_file) # 检查音频的维度 num_channels = waveform.shape[0] # 如果音频是多通道的,则进行通道平均 if num_channels > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) example = copy.deepcopy(obj) # copy and keep all the fields example['wav'] = waveform # overwrite wav example['sample_rate'] = sample_rate yield example except Exception as ex: logging.warning('Failed to read {}'.format(wav_file)) def parse_speaker(data, speaker_table_path): speaker_dict = {} with open(speaker_table_path, 'r', encoding='utf8') as fin: for line in fin: arr = line.strip().split() speaker_dict[arr[0]] = int(arr[1]) for sample in data: assert 'speaker' in sample speaker = sample['speaker'] sample['speaker'] = speaker_dict.get(speaker, 0) yield sample def filter(data, max_length=1200, min_length=10, token_max_length=250, token_min_length=1, min_output_input_ratio=0.00005, max_output_input_ratio=1, filter_no_extra_info: bool = False, max_seq_len=1000): """ Filter sample according to feature and label length Inplace operation. Args:: data: Iterable[{key, wav, label, sample_rate}] max_length: drop utterance which is greater than max_length(10ms) min_length: drop utterance which is less than min_length(10ms) token_max_length: drop utterance which is greater than token_max_length, especially when use char unit for english modeling token_min_length: drop utterance which is less than token_max_length min_output_input_ratio: minimal ration of token_length / feats_length(10ms) max_output_input_ratio: maximum ration of token_length / feats_length(10ms) Returns: Iterable[{key, wav, label, sample_rate}] """ for sample in data: try: assert 'sample_rate' in sample assert 'wav' in sample assert 'label' in sample except: continue # sample['wav'] is torch.Tensor, we have 100 frames every second num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100 # filter for shard_in_common if filter_no_extra_info: if 'lang' not in sample: continue if 'task' not in sample: continue if num_frames < min_length: continue # if "output_type" in sample and sample["output_type"] == "speech2text_token": # max_length = int(max_length / 2) # if "output_type" in sample and sample["output_type"] == "text2token": # max_length = int(max_length / 1.5) if num_frames > max_length: # continue if 'task' in sample and sample['task'] == '