# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import librosa import logging import json import random import tarfile from subprocess import PIPE, Popen from urllib.parse import urlparse import torch import torchaudio import torchaudio.compliance.kaldi as kaldi import torch.nn.functional as F from gxl_ai_utils.utils import utils_file from torch.nn.utils.rnn import pad_sequence from wenet.text.base_tokenizer import BaseTokenizer # torchaudio.utils.sox_utils.set_buffer_size(16500) torchaudio.set_audio_backend("soundfile") AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) def url_opener(data): """ Give url or local file, return file descriptor Inplace operation. Args: data(Iterable[str]): url or local file list Returns: Iterable[{src, stream}] """ for sample in data: assert 'src' in sample # TODO(Binbin Zhang): support HTTP url = sample['src'] try: pr = urlparse(url) # local file if pr.scheme == '' or pr.scheme == 'file': stream = open(url, 'rb') # network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP else: cmd = f'wget -q -O - {url}' process = Popen(cmd, shell=True, stdout=PIPE) sample.update(process=process) stream = process.stdout sample.update(stream=stream) yield sample except Exception as ex: logging.warning('Failed to open {}'.format(url)) def tar_file_and_group(data): """ Expand a stream of open tar files into a stream of tar file contents. And groups the file with same prefix Args: data: Iterable[{src, stream}] Returns: Iterable[{key, wav, txt, sample_rate}] """ for sample in data: assert 'stream' in sample stream = None try: stream = tarfile.open(fileobj=sample['stream'], mode="r:*") prev_prefix = None example = {} valid = True for tarinfo in stream: name = tarinfo.name pos = name.rfind('.') assert pos > 0 prefix, postfix = name[:pos], name[pos + 1:] if prev_prefix is not None and prefix != prev_prefix: example['key'] = prev_prefix if valid: yield example example = {} valid = True with stream.extractfile(tarinfo) as file_obj: try: if postfix == 'txt': example['txt'] = file_obj.read().decode( 'utf8').strip() elif postfix in AUDIO_FORMAT_SETS: waveform, sample_rate = torchaudio.load(file_obj) example['wav'] = waveform example['sample_rate'] = sample_rate else: example[postfix] = file_obj.read() except Exception as ex: valid = False logging.warning('error to parse {}'.format(name)) prev_prefix = prefix if prev_prefix is not None: example['key'] = prev_prefix yield example except Exception as ex: logging.warning( 'In tar_file_and_group: {} when processing {}'.format( ex, sample['src'])) finally: if stream is not None: stream.close() if 'process' in sample: sample['process'].communicate() sample['stream'].close() def tar_file_and_group_full_data(data): """ Expand a stream of open tar files into a stream of tar file contents. And groups the file with same prefix Args: data: Iterable[{src, stream}] Returns: Iterable[{key, wav, txt, sample_rate}] """ for sample in data: assert 'stream' in sample stream = None try: stream = tarfile.open(fileobj=sample['stream'], mode="r:*") prev_prefix = None example = {} valid = True for tarinfo in stream: name = tarinfo.name pos = name.rfind('.') assert pos > 0 prefix, postfix = name[:pos], name[pos + 1:] if prev_prefix is not None and prefix != prev_prefix: example['key'] = prev_prefix if valid: # assert 'txt' in example if 'txt' not in example: example['txt'] = '' yield example example = {} valid = True with stream.extractfile(tarinfo) as file_obj: try: if postfix == 'txt': example['txt'] = file_obj.read().decode( 'utf8').strip() elif postfix == 'lang': example['lang'] = file_obj.read().decode( 'utf8').strip() elif postfix == 'speaker': try: example['speaker'] = file_obj.read().decode( 'utf8').strip() except Exception as ex: example['speaker'] = "none" elif postfix == 'emotion': example['emotion'] = file_obj.read().decode( 'utf8').strip() elif postfix == 'gender': example['gender'] = file_obj.read().decode( 'utf8').strip() elif postfix == 'task': example['task'] = file_obj.read().decode( 'utf8').strip() elif postfix == 'speech_token': example['speech_token'] = file_obj.read() elif postfix == 'duration': duration_str = file_obj.read().decode( 'utf8').strip() try: duration_float = float(duration_str) example['duration'] = duration_float except Exception as ex: logging.warning(f'error to parse duration {duration_str}') example['duration'] = 0 elif postfix in AUDIO_FORMAT_SETS: waveform, sample_rate = torchaudio.load(file_obj) # 检查音频的维度 num_channels = waveform.shape[0] # 如果音频是多通道的,则进行通道平均 if num_channels > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) example['wav'] = waveform example['sample_rate'] = sample_rate else: example[postfix] = file_obj.read() except Exception as ex: valid = False # logging.warning('error to parse {}'.format(name)) prev_prefix = prefix if prev_prefix is not None: example['key'] = prev_prefix if 'txt' in example: yield example except Exception as ex: logging.warning( 'In tar_file_and_group: {} when processing {}'.format( ex, sample['src'])) finally: if stream is not None: stream.close() if 'process' in sample: sample['process'].communicate() sample['stream'].close() def parse_raw(data): """ Parse key/wav/txt from json line Args: data: Iterable[str], str is a json line has key/wav/txt Returns: Iterable[{key, wav, txt, sample_rate}] """ for sample in data: assert 'src' in sample json_line = sample['src'] obj = json.loads(json_line) assert 'key' in obj assert 'wav' in obj assert 'txt' in obj key = obj['key'] wav_file = obj['wav'] txt = obj['txt'] try: if 'start' in obj: assert 'end' in obj sample_rate = torchaudio.info(wav_file).sample_rate start_frame = int(obj['start'] * sample_rate) end_frame = int(obj['end'] * sample_rate) waveform, _ = torchaudio.load(filepath=wav_file, num_frames=end_frame - start_frame, frame_offset=start_frame) else: waveform, sample_rate = torchaudio.load(wav_file) # 检查音频的维度 num_channels = waveform.shape[0] # 如果音频是多通道的,则进行通道平均 if num_channels > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) example = copy.deepcopy(obj) # copy and keep all the fields example['wav'] = waveform # overwrite wav example['sample_rate'] = sample_rate yield example except Exception as ex: logging.warning('Failed to read {}'.format(wav_file)) def parse_speaker(data, speaker_table_path): speaker_dict = {} with open(speaker_table_path, 'r', encoding='utf8') as fin: for line in fin: arr = line.strip().split() speaker_dict[arr[0]] = int(arr[1]) for sample in data: assert 'speaker' in sample speaker = sample['speaker'] sample['speaker'] = speaker_dict.get(speaker, 0) yield sample def filter(data, max_length=1200, min_length=10, token_max_length=250, token_min_length=1, min_output_input_ratio=0.00005, max_output_input_ratio=1, filter_no_extra_info: bool = False, max_seq_len=1000): """ Filter sample according to feature and label length Inplace operation. Args:: data: Iterable[{key, wav, label, sample_rate}] max_length: drop utterance which is greater than max_length(10ms) min_length: drop utterance which is less than min_length(10ms) token_max_length: drop utterance which is greater than token_max_length, especially when use char unit for english modeling token_min_length: drop utterance which is less than token_max_length min_output_input_ratio: minimal ration of token_length / feats_length(10ms) max_output_input_ratio: maximum ration of token_length / feats_length(10ms) Returns: Iterable[{key, wav, label, sample_rate}] """ for sample in data: try: assert 'sample_rate' in sample assert 'wav' in sample assert 'label' in sample except: continue # sample['wav'] is torch.Tensor, we have 100 frames every second num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100 # filter for shard_in_common if filter_no_extra_info: if 'lang' not in sample: continue if 'task' not in sample: continue if num_frames < min_length: continue # if "output_type" in sample and sample["output_type"] == "speech2text_token": # max_length = int(max_length / 2) # if "output_type" in sample and sample["output_type"] == "text2token": # max_length = int(max_length / 1.5) if num_frames > max_length: # continue if 'task' in sample and sample['task'] == '': # utils_file.logging_limit_print('进行了随机剪裁') # 随机选择一个起始点进行裁剪 start_frame = random.randint(0, int(num_frames - max_length)) end_frame = start_frame + max_length sample['wav'] = sample['wav'][:, int(start_frame / 100 * sample['sample_rate']): int( end_frame / 100 * sample['sample_rate'])] # print('sample[', sample['wav'].shape) else: continue if len(sample['label']) < token_min_length: continue if len(sample['label']) > token_max_length: continue # if num_frames != 0: # if len(sample['label']) / num_frames < min_output_input_ratio: # continue # if len(sample['label']) / num_frames > max_output_input_ratio: # continue if sample["output_type"] == "speech2text_token": seq_len = len(sample['prompt']) + num_frames / 8 + len(sample['label']) + len(sample['speech_token']) elif sample["output_type"] == "text2token": seq_len = len(sample['prompt']) + len(sample['label']) + len(sample['speech_token']) else: seq_len = len(sample['prompt']) + num_frames / 8 + len(sample['label']) utils_file.logging_limit_print(f'seqlen: {seq_len}, output_type:{sample["output_type"]},len(sample["prompt"]):{len(sample["prompt"])},num_frames / 8:{num_frames / 8},len(sample["label"]):{len(sample["label"])},len(sample["speech_token"]):{len(sample["speech_token"])} ') if max_seq_len > 0 and max_seq_len < seq_len: utils_file.logging_limit_print(f"seqlen: {seq_len} 超过了最大长度:{max_seq_len},contiune") continue yield sample def resample(data, resample_rate=16000): """ Resample data. Inplace operation. Args: data: Iterable[{key, wav, label, sample_rate}] resample_rate: target resample rate Returns: Iterable[{key, wav, label, sample_rate}] """ for sample in data: assert 'sample_rate' in sample assert 'wav' in sample sample_rate = sample['sample_rate'] waveform = sample['wav'] if sample_rate != resample_rate: sample['sample_rate'] = resample_rate sample['wav'] = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=resample_rate)(waveform) yield sample def speed_perturb(data, speeds=None): """ Apply speed perturb to the data. Inplace operation. Args: data: Iterable[{key, wav, label, sample_rate}] speeds(List[float]): optional speed Returns: Iterable[{key, wav, label, sample_rate}] """ if speeds is None: speeds = [0.9, 1.0, 1.1] for sample in data: assert 'sample_rate' in sample assert 'wav' in sample sample_rate = sample['sample_rate'] waveform = sample['wav'] speed = random.choice(speeds) if speed != 1.0: wav, _ = torchaudio.sox_effects.apply_effects_tensor( waveform, sample_rate, [['speed', str(speed)], ['rate', str(sample_rate)]]) sample['wav'] = wav yield sample def compute_fbank(data, num_mel_bins=23, frame_length=25, frame_shift=10, dither=0.0): """ Extract fbank Args: data: Iterable[{key, wav, label, sample_rate}] Returns: Iterable[{key, feat, label}] """ for sample in data: assert 'sample_rate' in sample assert 'wav' in sample assert 'key' in sample assert 'label' in sample sample_rate = sample['sample_rate'] waveform = sample['wav'] waveform = waveform * (1 << 15) # Only keep key, feat, label mat = kaldi.fbank(waveform, num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither, energy_floor=0.0, sample_frequency=sample_rate) sample['feat'] = mat yield sample def compute_mfcc(data, num_mel_bins=23, frame_length=25, frame_shift=10, dither=0.0, num_ceps=40, high_freq=0.0, low_freq=20.0): """ Extract mfcc Args: data: Iterable[{key, wav, label, sample_rate}] Returns: Iterable[{key, feat, label}] """ for sample in data: assert 'sample_rate' in sample assert 'wav' in sample assert 'key' in sample assert 'label' in sample sample_rate = sample['sample_rate'] waveform = sample['wav'] waveform = waveform * (1 << 15) # Only keep key, feat, label mat = kaldi.mfcc(waveform, num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither, num_ceps=num_ceps, high_freq=high_freq, low_freq=low_freq, sample_frequency=sample_rate) sample['feat'] = mat yield sample def compute_log_mel_spectrogram(data, n_fft=400, hop_length=160, num_mel_bins=80, padding=0): """ Extract log mel spectrogram, modified from openai-whisper, see: - https://github.com/openai/whisper/blob/main/whisper/audio.py - https://github.com/wenet-e2e/wenet/pull/2141#issuecomment-1811765040 Args: data: Iterable[{key, wav, label, sample_rate}] Returns: Iterable[{key, feat, label}] """ for sample in data: assert 'sample_rate' in sample assert 'wav' in sample assert 'key' in sample assert 'label' in sample sample_rate = sample['sample_rate'] waveform = sample['wav'].squeeze(0) # (channel=1, sample) -> (sample,) # print(f'wavform shape: {waveform.shape}') if padding > 0: waveform = F.pad(waveform, (0, padding)) window = torch.hann_window(n_fft) stft = torch.stft(waveform, n_fft, hop_length, window=window, return_complex=True) magnitudes = stft[..., :-1].abs() ** 2 filters = torch.from_numpy( librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mel_bins)) mel_spec = filters @ magnitudes # NOTE(xcsong): https://github.com/openai/whisper/discussions/269 log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 sample['feat'] = log_spec.transpose(0, 1) yield sample import re def process_text(text): # 1. 删除汉字左右两侧的空格 text = re.sub(r'\s*([\u4e00-\u9fff])\s*', r'\1', text) # 2. 将英文转成小写 text = text.lower() # 3. 删除 < 和 > 符号两侧的空格 text = re.sub(r'\s*<\s*', '<', text) text = re.sub(r'\s*>\s*', '>', text) return text global_style_dict = { "朗读": "新闻科普", "科普百科": "新闻科普", "悬疑恐怖": "恐怖故事", "童话故事": "童话故事", "客服": "客服", "诗歌": "诗歌散文", "散文": "诗歌散文", "武侠评书": "有声书", "小说": "有声书", "历史": "有声书", "科幻": "有声书", "对话": "日常口语", "口语": "日常口语", "幽默": "其他", "其他": "其他", } def replace_keys_in_brackets(input_str, key_value_dict): for key, value in key_value_dict.items(): # 构造匹配 形式的正则表达式模式 pattern = re.compile(r'<{}>'.format(key)) input_str = pattern.sub(f"<{value}>", input_str) return input_str def tokenize(data, tokenizer: BaseTokenizer, global_prompt_dict=None): """ Decode text to chars or BPE Inplace operation Args: data: Iterable[{key, wav, txt, sample_rate}] Returns: Iterable[{key, wav, txt, tokens, label, sample_rate}] """ for sample in data: try: assert 'txt' in sample except: print(f'tokenize: {sample}') exit() if 'task' in sample: task_name = sample['task'] # if "" in task_name: # txt = sample['txt'].replace("", "").replace("", "").replace("", "") if "