# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) """Unility functions for Transformer.""" import math import time from typing import List, Tuple import torch from torch.nn.utils.rnn import pad_sequence from whisper.tokenizer import LANGUAGES as WhiserLanguages WHISPER_LANGS = tuple(WhiserLanguages.keys()) IGNORE_ID = -1 def pad_list(xs: List[torch.Tensor], pad_value: int): """Perform padding for the list of tensors. Args: xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. pad_value (float): Value for padding. Returns: Tensor: Padded tensor (B, Tmax, `*`). Examples: >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] >>> x [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] >>> pad_list(x, 0) tensor([[1., 1., 1., 1.], [1., 1., 0., 0.], [1., 0., 0., 0.]]) """ max_len = max([len(item) for item in xs]) batchs = len(xs) ndim = xs[0].ndim if ndim == 1: pad_res = torch.zeros(batchs, max_len, dtype=xs[0].dtype, device=xs[0].device) elif ndim == 2: pad_res = torch.zeros(batchs, max_len, xs[0].shape[1], dtype=xs[0].dtype, device=xs[0].device) elif ndim == 3: pad_res = torch.zeros(batchs, max_len, xs[0].shape[1], xs[0].shape[2], dtype=xs[0].dtype, device=xs[0].device) else: raise ValueError(f"Unsupported ndim: {ndim}") pad_res.fill_(pad_value) for i in range(batchs): pad_res[i, :len(xs[i])] = xs[i] return pad_res def add_blank(ys_pad: torch.Tensor, blank: int, ignore_id: int) -> torch.Tensor: """ Prepad blank for transducer predictor Args: ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) blank (int): index of Returns: ys_in (torch.Tensor) : (B, Lmax + 1) Examples: >>> blank = 0 >>> ignore_id = -1 >>> ys_pad tensor([[ 1, 2, 3, 4, 5], [ 4, 5, 6, -1, -1], [ 7, 8, 9, -1, -1]], dtype=torch.int32) >>> ys_in = add_blank(ys_pad, 0, -1) >>> ys_in tensor([[0, 1, 2, 3, 4, 5], [0, 4, 5, 6, 0, 0], [0, 7, 8, 9, 0, 0]]) """ bs = ys_pad.size(0) _blank = torch.tensor([blank], dtype=torch.long, requires_grad=False, device=ys_pad.device) _blank = _blank.repeat(bs).unsqueeze(1) # [bs,1] out = torch.cat([_blank, ys_pad], dim=1) # [bs, Lmax+1] return torch.where(out == ignore_id, blank, out) def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int, ignore_id: int) -> Tuple[torch.Tensor, torch.Tensor]: """Add and labels. Args: ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) sos (int): index of eos (int): index of ignore_id (int): index of padding Returns: ys_in (torch.Tensor) : (B, Lmax + 1) ys_out (torch.Tensor) : (B, Lmax + 1) Examples: >>> sos_id = 10 >>> eos_id = 11 >>> ignore_id = -1 >>> ys_pad tensor([[ 1, 2, 3, 4, 5], [ 4, 5, 6, -1, -1], [ 7, 8, 9, -1, -1]], dtype=torch.int32) >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id) >>> ys_in tensor([[10, 1, 2, 3, 4, 5], [10, 4, 5, 6, 11, 11], [10, 7, 8, 9, 11, 11]]) >>> ys_out tensor([[ 1, 2, 3, 4, 5, 11], [ 4, 5, 6, 11, -1, -1], [ 7, 8, 9, 11, -1, -1]]) """ _sos = torch.tensor([sos], dtype=torch.long, requires_grad=False, device=ys_pad.device) _eos = torch.tensor([eos], dtype=torch.long, requires_grad=False, device=ys_pad.device) ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys ys_in = [torch.cat([_sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, _eos], dim=0) for y in ys] return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) def add_whisper_tokens(special_tokens, ys_pad: torch.Tensor, ignore_id: int, tasks: List[str], no_timestamp: bool, langs: List[str], use_prev: bool) -> Tuple[torch.Tensor, torch.Tensor]: """Add whisper-style tokens. ([PREV] -> [previous text tokens or hotwords]).optional -- ┌------------------------------------------------------↲ ↓ [sot] -> [language id] -> [transcribe] -> [begin time] -> [text tokens] -> [end time] -> ... -> [eot] # noqa | | |-------> [no timestamps] -> [text tokens] ----------------------↑ # noqa | | | # noqa | |--------> [translate] -> [begin time] -> [text tokens] -> [end time] -> ... --->| # noqa | |-------> [no timestamps] -> [text tokens] --------------------->| # noqa | | # noqa |--> [no speech(VAD)] ---------------------------------------------------------------------->| # noqa Args: special_tokens: get IDs of special tokens ignore_id (int): index of padding no_timestamp (bool): whether to add timestamps tokens tasks (List[str]): list of task tags langs (List[str]): list of language tags Returns: ys_in (torch.Tensor) : (B, Lmax + ?) ys_out (torch.Tensor) : (B, Lmax + ?) """ assert len(langs) == ys_pad.size(0) assert len(tasks) == ys_pad.size(0) if use_prev: # i.e., hotword list _prev = [special_tokens["sot_prev"]] # append hotword list to _prev # ... raise NotImplementedError else: _prev = [] _sot = [] for task, lang in zip(tasks, langs): if task == "transcribe": task_id = special_tokens["transcribe"] elif task == "translate": task_id = special_tokens["translate"] elif task == "vad": task_id = special_tokens["no_speech"] else: if task in special_tokens: task_id = special_tokens[task] else: raise NotImplementedError("unsupported task {}".format(task)) language_id = special_tokens["sot"] + 1 + WHISPER_LANGS.index(lang) prefix = _prev + [special_tokens["sot"], language_id, task_id] if task != 'vad': if no_timestamp: prefix.append(special_tokens["no_timestamps"]) else: prefix.append(special_tokens["timestamp_begin"]) # add subsequent tokens # ... raise NotImplementedError elif task == "vad": prefix.append(special_tokens["no_speech"]) else: raise NotImplementedError prefix = torch.tensor(prefix, dtype=torch.long, requires_grad=False, device=ys_pad.device) _sot.append(prefix) _eot = torch.tensor([special_tokens["eot"]], dtype=torch.long, requires_grad=False, device=ys_pad.device) ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys ys_in = [torch.cat([prefix, y], dim=0) for prefix, y in zip(_sot, ys)] ys_out = [ torch.cat([prefix[1:], y, _eot], dim=0) for prefix, y in zip(_sot, ys) ] return pad_list(ys_in, special_tokens["eot"]), pad_list(ys_out, ignore_id) def reverse_pad_list(ys_pad: torch.Tensor, ys_lens: torch.Tensor, pad_value: float = -1.0) -> torch.Tensor: """Reverse padding for the list of tensors. Args: ys_pad (tensor): The padded tensor (B, Tokenmax). ys_lens (tensor): The lens of token seqs (B) pad_value (int): Value for padding. Returns: Tensor: Padded tensor (B, Tokenmax). Examples: >>> x tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]]) >>> pad_list(x, 0) tensor([[4, 3, 2, 1], [7, 6, 5, 0], [9, 8, 0, 0]]) """ r_ys_pad = pad_sequence([(torch.flip(y.int()[:i], [0])) for y, i in zip(ys_pad, ys_lens)], True, pad_value) return r_ys_pad def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor, ignore_label: int) -> torch.Tensor: """Calculate accuracy. Args: pad_outputs (Tensor): Prediction tensors (B * Lmax, D). pad_targets (LongTensor): Target label tensors (B, Lmax). ignore_label (int): Ignore label id. Returns: torch.Tensor: Accuracy value (0.0 - 1.0). """ pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2) mask = pad_targets != ignore_label numerator = torch.sum( pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) denominator = torch.sum(mask) return (numerator / denominator).detach() def get_subsample(config): input_layer = config["encoder_conf"]["input_layer"] assert input_layer in ["conv2d", "conv2d6", "conv2d8"] if input_layer == "conv2d": return 4 elif input_layer == "conv2d6": return 6 elif input_layer == "conv2d8": return 8 def log_add(*args) -> float: """ Stable log add """ if all(a == -float('inf') for a in args): return -float('inf') a_max = max(args) lsp = math.log(sum(math.exp(a - a_max) for a in args)) return a_max + lsp def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: assert mask.dtype == torch.bool assert dtype in [torch.float32, torch.bfloat16, torch.float16] mask = mask.to(dtype) # attention mask bias # NOTE(Mddct): torch.finfo jit issues # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min mask = (1.0 - mask) * -1.0e+10 return mask def get_nested_attribute(obj, attr_path): if isinstance(obj, torch.nn.parallel.DistributedDataParallel): obj = obj.module attributes = attr_path.split('.') for attr in attributes: obj = getattr(obj, attr) return obj def lrs_to_str(lrs: List): return " ".join(["{:.4e}".format(lr) for lr in lrs]) class StepTimer: """Utility class for measuring steps/second.""" def __init__(self, step=0.0): self.last_iteration = step self.start() def start(self): self.last_time = time.time() def steps_per_second(self, cur_step, restart=True): value = ((float(cur_step) - self.last_iteration) / (time.time() - self.last_time)) if restart: self.start() self.last_iteration = float(cur_step) return value def tensor_to_scalar(x): if torch.is_tensor(x): return x.item() return x def is_torch_npu_available() -> bool: ''' check if torch_npu is available. torch_npu is a npu adapter of PyTorch ''' try: import torch_npu # noqa return True except ImportError: if not torch.cuda.is_available(): print("Module \"torch_npu\" not found. \"pip install torch_npu\" \ if you are using Ascend NPU, otherwise, ignore it") return False TORCH_NPU_AVAILABLE = is_torch_npu_available()