Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,552 Bytes
bc752b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import argparse
import importlib
import json
import os
from distutils.util import strtobool as dist_strtobool
import torch
import yaml
IGNORE_ID = -1
def assign_args_from_yaml(args, yaml_path, prefix_key=None):
with open(yaml_path) as f:
ydict = yaml.load(f, Loader=yaml.FullLoader)
if prefix_key is not None:
ydict = ydict[prefix_key]
for k, v in ydict.items():
k_args = k.replace("-", "_")
if hasattr(args, k_args):
setattr(args, k_args, ydict[k])
return args
def get_model_conf(model_path):
model_conf = os.path.dirname(model_path) + "/model.json"
with open(model_conf, "rb") as f:
print("reading a config file from " + model_conf)
confs = json.load(f)
# for asr, tts, mt
idim, odim, args = confs
return argparse.Namespace(**args)
def strtobool(x):
return bool(dist_strtobool(x))
def dynamic_import(import_path, alias=dict()):
"""dynamic import module and class
:param str import_path: syntax 'module_name:class_name'
e.g., 'espnet.transform.add_deltas:AddDeltas'
:param dict alias: shortcut for registered class
:return: imported class
"""
if import_path not in alias and ":" not in import_path:
raise ValueError(
"import_path should be one of {} or "
'include ":", e.g. "espnet.transform.add_deltas:AddDeltas" : '
"{}".format(set(alias), import_path)
)
if ":" not in import_path:
import_path = alias[import_path]
module_name, objname = import_path.split(":")
m = importlib.import_module(module_name)
return getattr(m, objname)
def set_deterministic_pytorch(args):
# seed setting
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
def pad_list(xs, pad_value):
n_batch = len(xs)
max_len = max(x.size(0) for x in xs)
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
for i in range(n_batch):
pad[i, : xs[i].size(0)] = xs[i]
return pad
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
def subsequent_chunk_mask(
size: int,
ck_size: int,
num_l_cks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if num_l_cks < 0:
start = 0
else:
start = max((i // ck_size - num_l_cks) * ck_size, 0)
ending = min((i // ck_size + 1) * ck_size, size)
ret[i, start:ending] = True
return ret
def add_optional_chunk_mask(
xs: torch.Tensor,
masks: torch.Tensor,
use_dynamic_chunk: bool,
use_dynamic_left_chunk: bool,
decoding_chunk_size: int,
static_chunk_size: int,
num_decoding_left_chunks: int,
):
if use_dynamic_chunk:
max_len = xs.size(1)
if decoding_chunk_size < 0:
chunk_size = max_len
num_l_cks = -1
elif decoding_chunk_size > 0:
chunk_size = decoding_chunk_size
num_l_cks = num_decoding_left_chunks
else:
chunk_size = torch.randint(1, max_len, (1,)).item()
num_l_cks = -1
if chunk_size > max_len // 2:
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_l_cks = torch.randint(0, max_left_chunks, (1,)).item()
ck_masks = subsequent_chunk_mask(
xs.size(1), chunk_size, num_l_cks, xs.device
) # (L, L)
ck_masks = ck_masks.unsqueeze(0) # (1, L, L)
ck_masks = masks & ck_masks # (B, L, L)
elif static_chunk_size > 0:
num_l_cks = num_decoding_left_chunks
ck_masks = subsequent_chunk_mask(
xs.size(1), static_chunk_size, num_l_cks, xs.device
) # (L, L)
ck_masks = ck_masks.unsqueeze(0) # (1, L, L)
ck_masks = masks & ck_masks # (B, L, L)
else:
ck_masks = masks
return ck_masks
|