Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# Modified from ESPnet(https://github.com/espnet/espnet) | |
"""Unility functions for Transformer.""" | |
import math | |
import time | |
from typing import List, Tuple | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
from whisper.tokenizer import LANGUAGES as WhiserLanguages | |
WHISPER_LANGS = tuple(WhiserLanguages.keys()) | |
IGNORE_ID = -1 | |
def pad_list(xs: List[torch.Tensor], pad_value: int): | |
"""Perform padding for the list of tensors. | |
Args: | |
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. | |
pad_value (float): Value for padding. | |
Returns: | |
Tensor: Padded tensor (B, Tmax, `*`). | |
Examples: | |
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] | |
>>> x | |
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] | |
>>> pad_list(x, 0) | |
tensor([[1., 1., 1., 1.], | |
[1., 1., 0., 0.], | |
[1., 0., 0., 0.]]) | |
""" | |
max_len = max([len(item) for item in xs]) | |
batchs = len(xs) | |
ndim = xs[0].ndim | |
if ndim == 1: | |
pad_res = torch.zeros(batchs, | |
max_len, | |
dtype=xs[0].dtype, | |
device=xs[0].device) | |
elif ndim == 2: | |
pad_res = torch.zeros(batchs, | |
max_len, | |
xs[0].shape[1], | |
dtype=xs[0].dtype, | |
device=xs[0].device) | |
elif ndim == 3: | |
pad_res = torch.zeros(batchs, | |
max_len, | |
xs[0].shape[1], | |
xs[0].shape[2], | |
dtype=xs[0].dtype, | |
device=xs[0].device) | |
else: | |
raise ValueError(f"Unsupported ndim: {ndim}") | |
pad_res.fill_(pad_value) | |
for i in range(batchs): | |
pad_res[i, :len(xs[i])] = xs[i] | |
return pad_res | |
def add_blank(ys_pad: torch.Tensor, blank: int, | |
ignore_id: int) -> torch.Tensor: | |
""" Prepad blank for transducer predictor | |
Args: | |
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) | |
blank (int): index of <blank> | |
Returns: | |
ys_in (torch.Tensor) : (B, Lmax + 1) | |
Examples: | |
>>> blank = 0 | |
>>> ignore_id = -1 | |
>>> ys_pad | |
tensor([[ 1, 2, 3, 4, 5], | |
[ 4, 5, 6, -1, -1], | |
[ 7, 8, 9, -1, -1]], dtype=torch.int32) | |
>>> ys_in = add_blank(ys_pad, 0, -1) | |
>>> ys_in | |
tensor([[0, 1, 2, 3, 4, 5], | |
[0, 4, 5, 6, 0, 0], | |
[0, 7, 8, 9, 0, 0]]) | |
""" | |
bs = ys_pad.size(0) | |
_blank = torch.tensor([blank], | |
dtype=torch.long, | |
requires_grad=False, | |
device=ys_pad.device) | |
_blank = _blank.repeat(bs).unsqueeze(1) # [bs,1] | |
out = torch.cat([_blank, ys_pad], dim=1) # [bs, Lmax+1] | |
return torch.where(out == ignore_id, blank, out) | |
def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int, | |
ignore_id: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Add <sos> and <eos> labels. | |
Args: | |
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) | |
sos (int): index of <sos> | |
eos (int): index of <eeos> | |
ignore_id (int): index of padding | |
Returns: | |
ys_in (torch.Tensor) : (B, Lmax + 1) | |
ys_out (torch.Tensor) : (B, Lmax + 1) | |
Examples: | |
>>> sos_id = 10 | |
>>> eos_id = 11 | |
>>> ignore_id = -1 | |
>>> ys_pad | |
tensor([[ 1, 2, 3, 4, 5], | |
[ 4, 5, 6, -1, -1], | |
[ 7, 8, 9, -1, -1]], dtype=torch.int32) | |
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id) | |
>>> ys_in | |
tensor([[10, 1, 2, 3, 4, 5], | |
[10, 4, 5, 6, 11, 11], | |
[10, 7, 8, 9, 11, 11]]) | |
>>> ys_out | |
tensor([[ 1, 2, 3, 4, 5, 11], | |
[ 4, 5, 6, 11, -1, -1], | |
[ 7, 8, 9, 11, -1, -1]]) | |
""" | |
_sos = torch.tensor([sos], | |
dtype=torch.long, | |
requires_grad=False, | |
device=ys_pad.device) | |
_eos = torch.tensor([eos], | |
dtype=torch.long, | |
requires_grad=False, | |
device=ys_pad.device) | |
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys | |
ys_in = [torch.cat([_sos, y], dim=0) for y in ys] | |
ys_out = [torch.cat([y, _eos], dim=0) for y in ys] | |
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) | |
def add_whisper_tokens(special_tokens, ys_pad: torch.Tensor, ignore_id: int, | |
tasks: List[str], no_timestamp: bool, langs: List[str], | |
use_prev: bool) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Add whisper-style tokens. | |
([PREV] -> [previous text tokens or hotwords]).optional -- | |
┌------------------------------------------------------↲ | |
↓ | |
[sot] -> [language id] -> [transcribe] -> [begin time] -> [text tokens] -> [end time] -> ... -> [eot] # noqa | |
| | |-------> [no timestamps] -> [text tokens] ----------------------↑ # noqa | |
| | | # noqa | |
| |--------> [translate] -> [begin time] -> [text tokens] -> [end time] -> ... --->| # noqa | |
| |-------> [no timestamps] -> [text tokens] --------------------->| # noqa | |
| | # noqa | |
|--> [no speech(VAD)] ---------------------------------------------------------------------->| # noqa | |
Args: | |
special_tokens: get IDs of special tokens | |
ignore_id (int): index of padding | |
no_timestamp (bool): whether to add timestamps tokens | |
tasks (List[str]): list of task tags | |
langs (List[str]): list of language tags | |
Returns: | |
ys_in (torch.Tensor) : (B, Lmax + ?) | |
ys_out (torch.Tensor) : (B, Lmax + ?) | |
""" | |
assert len(langs) == ys_pad.size(0) | |
assert len(tasks) == ys_pad.size(0) | |
if use_prev: | |
# i.e., hotword list | |
_prev = [special_tokens["sot_prev"]] | |
# append hotword list to _prev | |
# ... | |
raise NotImplementedError | |
else: | |
_prev = [] | |
_sot = [] | |
for task, lang in zip(tasks, langs): | |
if task == "transcribe": | |
task_id = special_tokens["transcribe"] | |
elif task == "translate": | |
task_id = special_tokens["translate"] | |
elif task == "vad": | |
task_id = special_tokens["no_speech"] | |
else: | |
if task in special_tokens: | |
task_id = special_tokens[task] | |
else: | |
raise NotImplementedError("unsupported task {}".format(task)) | |
language_id = special_tokens["sot"] + 1 + WHISPER_LANGS.index(lang) | |
prefix = _prev + [special_tokens["sot"], language_id, task_id] | |
if task != 'vad': | |
if no_timestamp: | |
prefix.append(special_tokens["no_timestamps"]) | |
else: | |
prefix.append(special_tokens["timestamp_begin"]) | |
# add subsequent tokens | |
# ... | |
raise NotImplementedError | |
elif task == "vad": | |
prefix.append(special_tokens["no_speech"]) | |
else: | |
raise NotImplementedError | |
prefix = torch.tensor(prefix, | |
dtype=torch.long, | |
requires_grad=False, | |
device=ys_pad.device) | |
_sot.append(prefix) | |
_eot = torch.tensor([special_tokens["eot"]], | |
dtype=torch.long, | |
requires_grad=False, | |
device=ys_pad.device) | |
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys | |
ys_in = [torch.cat([prefix, y], dim=0) for prefix, y in zip(_sot, ys)] | |
ys_out = [ | |
torch.cat([prefix[1:], y, _eot], dim=0) for prefix, y in zip(_sot, ys) | |
] | |
return pad_list(ys_in, special_tokens["eot"]), pad_list(ys_out, ignore_id) | |
def reverse_pad_list(ys_pad: torch.Tensor, | |
ys_lens: torch.Tensor, | |
pad_value: float = -1.0) -> torch.Tensor: | |
"""Reverse padding for the list of tensors. | |
Args: | |
ys_pad (tensor): The padded tensor (B, Tokenmax). | |
ys_lens (tensor): The lens of token seqs (B) | |
pad_value (int): Value for padding. | |
Returns: | |
Tensor: Padded tensor (B, Tokenmax). | |
Examples: | |
>>> x | |
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]]) | |
>>> pad_list(x, 0) | |
tensor([[4, 3, 2, 1], | |
[7, 6, 5, 0], | |
[9, 8, 0, 0]]) | |
""" | |
r_ys_pad = pad_sequence([(torch.flip(y.int()[:i], [0])) | |
for y, i in zip(ys_pad, ys_lens)], True, | |
pad_value) | |
return r_ys_pad | |
def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor, | |
ignore_label: int) -> torch.Tensor: | |
"""Calculate accuracy. | |
Args: | |
pad_outputs (Tensor): Prediction tensors (B * Lmax, D). | |
pad_targets (LongTensor): Target label tensors (B, Lmax). | |
ignore_label (int): Ignore label id. | |
Returns: | |
torch.Tensor: Accuracy value (0.0 - 1.0). | |
""" | |
pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1), | |
pad_outputs.size(1)).argmax(2) | |
mask = pad_targets != ignore_label | |
numerator = torch.sum( | |
pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) | |
denominator = torch.sum(mask) | |
return (numerator / denominator).detach() | |
def get_subsample(config): | |
input_layer = config["encoder_conf"]["input_layer"] | |
assert input_layer in ["conv2d", "conv2d6", "conv2d8"] | |
if input_layer == "conv2d": | |
return 4 | |
elif input_layer == "conv2d6": | |
return 6 | |
elif input_layer == "conv2d8": | |
return 8 | |
def log_add(*args) -> float: | |
""" | |
Stable log add | |
""" | |
if all(a == -float('inf') for a in args): | |
return -float('inf') | |
a_max = max(args) | |
lsp = math.log(sum(math.exp(a - a_max) for a in args)) | |
return a_max + lsp | |
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: | |
assert mask.dtype == torch.bool | |
assert dtype in [torch.float32, torch.bfloat16, torch.float16] | |
mask = mask.to(dtype) | |
# attention mask bias | |
# NOTE(Mddct): torch.finfo jit issues | |
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min | |
mask = (1.0 - mask) * -1.0e+10 | |
return mask | |
def get_nested_attribute(obj, attr_path): | |
if isinstance(obj, torch.nn.parallel.DistributedDataParallel): | |
obj = obj.module | |
attributes = attr_path.split('.') | |
for attr in attributes: | |
obj = getattr(obj, attr) | |
return obj | |
def lrs_to_str(lrs: List): | |
return " ".join(["{:.4e}".format(lr) for lr in lrs]) | |
class StepTimer: | |
"""Utility class for measuring steps/second.""" | |
def __init__(self, step=0.0): | |
self.last_iteration = step | |
self.start() | |
def start(self): | |
self.last_time = time.time() | |
def steps_per_second(self, cur_step, restart=True): | |
value = ((float(cur_step) - self.last_iteration) / | |
(time.time() - self.last_time)) | |
if restart: | |
self.start() | |
self.last_iteration = float(cur_step) | |
return value | |
def tensor_to_scalar(x): | |
if torch.is_tensor(x): | |
return x.item() | |
return x | |
def is_torch_npu_available() -> bool: | |
''' | |
check if torch_npu is available. | |
torch_npu is a npu adapter of PyTorch | |
''' | |
try: | |
import torch_npu # noqa | |
return True | |
except ImportError: | |
if not torch.cuda.is_available(): | |
print("Module \"torch_npu\" not found. \"pip install torch_npu\" \ | |
if you are using Ascend NPU, otherwise, ignore it") | |
return False | |
TORCH_NPU_AVAILABLE = is_torch_npu_available() | |