OSUM / wenet /utils /common.py
tomxxie
适配zeroGPU
568e264
raw
history blame
12.9 kB
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Unility functions for Transformer."""
import math
import time
from typing import List, Tuple
import torch
from torch.nn.utils.rnn import pad_sequence
from whisper.tokenizer import LANGUAGES as WhiserLanguages
WHISPER_LANGS = tuple(WhiserLanguages.keys())
IGNORE_ID = -1
def pad_list(xs: List[torch.Tensor], pad_value: int):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
max_len = max([len(item) for item in xs])
batchs = len(xs)
ndim = xs[0].ndim
if ndim == 1:
pad_res = torch.zeros(batchs,
max_len,
dtype=xs[0].dtype,
device=xs[0].device)
elif ndim == 2:
pad_res = torch.zeros(batchs,
max_len,
xs[0].shape[1],
dtype=xs[0].dtype,
device=xs[0].device)
elif ndim == 3:
pad_res = torch.zeros(batchs,
max_len,
xs[0].shape[1],
xs[0].shape[2],
dtype=xs[0].dtype,
device=xs[0].device)
else:
raise ValueError(f"Unsupported ndim: {ndim}")
pad_res.fill_(pad_value)
for i in range(batchs):
pad_res[i, :len(xs[i])] = xs[i]
return pad_res
def add_blank(ys_pad: torch.Tensor, blank: int,
ignore_id: int) -> torch.Tensor:
""" Prepad blank for transducer predictor
Args:
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)
blank (int): index of <blank>
Returns:
ys_in (torch.Tensor) : (B, Lmax + 1)
Examples:
>>> blank = 0
>>> ignore_id = -1
>>> ys_pad
tensor([[ 1, 2, 3, 4, 5],
[ 4, 5, 6, -1, -1],
[ 7, 8, 9, -1, -1]], dtype=torch.int32)
>>> ys_in = add_blank(ys_pad, 0, -1)
>>> ys_in
tensor([[0, 1, 2, 3, 4, 5],
[0, 4, 5, 6, 0, 0],
[0, 7, 8, 9, 0, 0]])
"""
bs = ys_pad.size(0)
_blank = torch.tensor([blank],
dtype=torch.long,
requires_grad=False,
device=ys_pad.device)
_blank = _blank.repeat(bs).unsqueeze(1) # [bs,1]
out = torch.cat([_blank, ys_pad], dim=1) # [bs, Lmax+1]
return torch.where(out == ignore_id, blank, out)
def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int,
ignore_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Add <sos> and <eos> labels.
Args:
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)
sos (int): index of <sos>
eos (int): index of <eeos>
ignore_id (int): index of padding
Returns:
ys_in (torch.Tensor) : (B, Lmax + 1)
ys_out (torch.Tensor) : (B, Lmax + 1)
Examples:
>>> sos_id = 10
>>> eos_id = 11
>>> ignore_id = -1
>>> ys_pad
tensor([[ 1, 2, 3, 4, 5],
[ 4, 5, 6, -1, -1],
[ 7, 8, 9, -1, -1]], dtype=torch.int32)
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
>>> ys_in
tensor([[10, 1, 2, 3, 4, 5],
[10, 4, 5, 6, 11, 11],
[10, 7, 8, 9, 11, 11]])
>>> ys_out
tensor([[ 1, 2, 3, 4, 5, 11],
[ 4, 5, 6, 11, -1, -1],
[ 7, 8, 9, 11, -1, -1]])
"""
_sos = torch.tensor([sos],
dtype=torch.long,
requires_grad=False,
device=ys_pad.device)
_eos = torch.tensor([eos],
dtype=torch.long,
requires_grad=False,
device=ys_pad.device)
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
def add_whisper_tokens(special_tokens, ys_pad: torch.Tensor, ignore_id: int,
tasks: List[str], no_timestamp: bool, langs: List[str],
use_prev: bool) -> Tuple[torch.Tensor, torch.Tensor]:
"""Add whisper-style tokens.
([PREV] -> [previous text tokens or hotwords]).optional --
┌------------------------------------------------------↲
[sot] -> [language id] -> [transcribe] -> [begin time] -> [text tokens] -> [end time] -> ... -> [eot] # noqa
| | |-------> [no timestamps] -> [text tokens] ----------------------↑ # noqa
| | | # noqa
| |--------> [translate] -> [begin time] -> [text tokens] -> [end time] -> ... --->| # noqa
| |-------> [no timestamps] -> [text tokens] --------------------->| # noqa
| | # noqa
|--> [no speech(VAD)] ---------------------------------------------------------------------->| # noqa
Args:
special_tokens: get IDs of special tokens
ignore_id (int): index of padding
no_timestamp (bool): whether to add timestamps tokens
tasks (List[str]): list of task tags
langs (List[str]): list of language tags
Returns:
ys_in (torch.Tensor) : (B, Lmax + ?)
ys_out (torch.Tensor) : (B, Lmax + ?)
"""
assert len(langs) == ys_pad.size(0)
assert len(tasks) == ys_pad.size(0)
if use_prev:
# i.e., hotword list
_prev = [special_tokens["sot_prev"]]
# append hotword list to _prev
# ...
raise NotImplementedError
else:
_prev = []
_sot = []
for task, lang in zip(tasks, langs):
if task == "transcribe":
task_id = special_tokens["transcribe"]
elif task == "translate":
task_id = special_tokens["translate"]
elif task == "vad":
task_id = special_tokens["no_speech"]
else:
if task in special_tokens:
task_id = special_tokens[task]
else:
raise NotImplementedError("unsupported task {}".format(task))
language_id = special_tokens["sot"] + 1 + WHISPER_LANGS.index(lang)
prefix = _prev + [special_tokens["sot"], language_id, task_id]
if task != 'vad':
if no_timestamp:
prefix.append(special_tokens["no_timestamps"])
else:
prefix.append(special_tokens["timestamp_begin"])
# add subsequent tokens
# ...
raise NotImplementedError
elif task == "vad":
prefix.append(special_tokens["no_speech"])
else:
raise NotImplementedError
prefix = torch.tensor(prefix,
dtype=torch.long,
requires_grad=False,
device=ys_pad.device)
_sot.append(prefix)
_eot = torch.tensor([special_tokens["eot"]],
dtype=torch.long,
requires_grad=False,
device=ys_pad.device)
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
ys_in = [torch.cat([prefix, y], dim=0) for prefix, y in zip(_sot, ys)]
ys_out = [
torch.cat([prefix[1:], y, _eot], dim=0) for prefix, y in zip(_sot, ys)
]
return pad_list(ys_in, special_tokens["eot"]), pad_list(ys_out, ignore_id)
def reverse_pad_list(ys_pad: torch.Tensor,
ys_lens: torch.Tensor,
pad_value: float = -1.0) -> torch.Tensor:
"""Reverse padding for the list of tensors.
Args:
ys_pad (tensor): The padded tensor (B, Tokenmax).
ys_lens (tensor): The lens of token seqs (B)
pad_value (int): Value for padding.
Returns:
Tensor: Padded tensor (B, Tokenmax).
Examples:
>>> x
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
>>> pad_list(x, 0)
tensor([[4, 3, 2, 1],
[7, 6, 5, 0],
[9, 8, 0, 0]])
"""
r_ys_pad = pad_sequence([(torch.flip(y.int()[:i], [0]))
for y, i in zip(ys_pad, ys_lens)], True,
pad_value)
return r_ys_pad
def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
ignore_label: int) -> torch.Tensor:
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax).
ignore_label (int): Ignore label id.
Returns:
torch.Tensor: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
pad_outputs.size(1)).argmax(2)
mask = pad_targets != ignore_label
numerator = torch.sum(
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
denominator = torch.sum(mask)
return (numerator / denominator).detach()
def get_subsample(config):
input_layer = config["encoder_conf"]["input_layer"]
assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
if input_layer == "conv2d":
return 4
elif input_layer == "conv2d6":
return 6
elif input_layer == "conv2d8":
return 8
def log_add(*args) -> float:
"""
Stable log add
"""
if all(a == -float('inf') for a in args):
return -float('inf')
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max) for a in args))
return a_max + lsp
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
assert mask.dtype == torch.bool
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
mask = mask.to(dtype)
# attention mask bias
# NOTE(Mddct): torch.finfo jit issues
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
mask = (1.0 - mask) * -1.0e+10
return mask
def get_nested_attribute(obj, attr_path):
if isinstance(obj, torch.nn.parallel.DistributedDataParallel):
obj = obj.module
attributes = attr_path.split('.')
for attr in attributes:
obj = getattr(obj, attr)
return obj
def lrs_to_str(lrs: List):
return " ".join(["{:.4e}".format(lr) for lr in lrs])
class StepTimer:
"""Utility class for measuring steps/second."""
def __init__(self, step=0.0):
self.last_iteration = step
self.start()
def start(self):
self.last_time = time.time()
def steps_per_second(self, cur_step, restart=True):
value = ((float(cur_step) - self.last_iteration) /
(time.time() - self.last_time))
if restart:
self.start()
self.last_iteration = float(cur_step)
return value
def tensor_to_scalar(x):
if torch.is_tensor(x):
return x.item()
return x
def is_torch_npu_available() -> bool:
'''
check if torch_npu is available.
torch_npu is a npu adapter of PyTorch
'''
try:
import torch_npu # noqa
return True
except ImportError:
if not torch.cuda.is_available():
print("Module \"torch_npu\" not found. \"pip install torch_npu\" \
if you are using Ascend NPU, otherwise, ignore it")
return False
TORCH_NPU_AVAILABLE = is_torch_npu_available()