Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,628 Bytes
bc752b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import argparse
import logging
import sys
import time
from typing import Dict, Optional, Tuple
import numpy as np
import six
import torch
from vita.model.multimodal_encoder.whale.module.component.mamba import MambaSSM
from vita.model.multimodal_encoder.whale.module.component.subsampling import Subsampling
from vita.model.multimodal_encoder.whale.module.component.transformer import Transformer
from vita.model.multimodal_encoder.whale.utils import make_pad_mask
def add_encoder_args(group):
"""Add Encoder common arguments."""
group.add_argument(
"--encoder-layer-config",
type=str,
default="tdnn-dtc",
help="Layer config of encoder. Format layername-layername-..., default(conv1d-fsmn-rnn)",
)
group.add_argument(
"--encoder-input-dim",
type=int,
default=256,
help="Input dim of encoder. Must equal to the input dim of the first Component (default=40)",
)
group.add_argument(
"--encoder-output-dim",
type=int,
default=256,
help="Output dim of encoder. Must enqual to the output dim of the last Component ! (default=256)",
)
# Add args of all kinds of components.
# If you add a new component, DO NOT forget to add args to add_component_args func.
group = Transformer.add_arguments(group)
group = Subsampling.add_arguments(group)
group = MambaSSM.add_arguments(group)
return group
def assign_args_from_dict(args, dict, prefix_key=None):
if prefix_key is not None:
dict = dict[prefix_key]
for k, v in dict.items():
k_args = k.replace("-", "_")
if hasattr(args, k_args):
setattr(args, k_args, dict[k])
return args
class whaleEncoder(torch.nn.Module):
def __init__(self, input_dim, overview_conf=None, para_conf=None, global_cmvn=None):
super(whaleEncoder, self).__init__()
parser = argparse.ArgumentParser()
add_encoder_args(parser)
args, _ = parser.parse_known_args()
assign_args_from_dict(args, overview_conf)
# assign_args_from_dict(args, para_conf)
self.config = args.encoder_layer_config.split("-")
encoder_input_dim = args.encoder_input_dim
encoder_output_dim = args.encoder_output_dim
prev_output_dim = encoder_input_dim
prev_component_name = "encoder"
self.enc = torch.nn.ModuleList([])
for name in self.config:
assign_args_from_dict(args, para_conf[name])
if len(name.split("_")) == 2:
name = name.split("_")[0]
elif len(name.split("_")) == 1:
name = name
else:
logging.error("WRONG CONFIG! {} is not valid".format("encoder", name))
sys.exit()
if name == "transformer":
self.enc.append(Transformer(args))
elif name == "subsampling":
self.enc.append(Subsampling(args))
elif name == "mamba":
self.enc.append(MambaSSM(args))
else:
print("{} is not supported now!".format(name))
return NotImplemented
component_input_dim = getattr(args, name + "_input_dim")
if component_input_dim != prev_output_dim:
# This is the first layer
logging.error(
"WRONG CONFIG! --{}-output-dim ({}) does not equal to --{}-input-dim ({})".format(
prev_component_name, prev_output_dim, name, component_input_dim
)
)
sys.exit()
prev_output_dim = getattr(args, name + "_output_dim")
prev_component_name = name
self.global_cmvn = global_cmvn
if prev_output_dim != encoder_output_dim:
logging.error(
"WRONG CONFIG! --{}-output-dim ({}) does not equal to --{}-output-dim ({}, the last component)".format(
"encoder", encoder_output_dim, name, prev_output_dim
)
)
sys.exit()
self._output_size = encoder_output_dim
num_params = sum(p.numel() for p in self.parameters())
print("the number of whale encoder params: {}M".format(num_params / 1024 / 1024))
def output_size(self) -> int:
return self._output_size
@torch.jit.unused
def forward(self, xs, ilens, decoding_chunk_size=None, num_decoding_left_chunks=None):
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Optional[List[int]], Optional[Tensor]]
"""Encoder forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:return: batch of hidden state sequences (B, Tmax, eprojs)
:rtype: torch.Tensor
"""
if decoding_chunk_size is not None and num_decoding_left_chunks is not None:
for layer in self.enc:
if hasattr(layer, "chunk_size"):
layer.chunk_size = decoding_chunk_size
if hasattr(layer, "left_chunks"):
layer.left_chunks = num_decoding_left_chunks
if hasattr(layer, "transformer_dynamic_chunks"):
layer.transformer_dynamic_chunks = False
assert (len(xs.shape)) == 3
T = xs.size(1)
masks = ~make_pad_mask(ilens, T).unsqueeze(1) # (B, 1, T)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
for module in self.enc:
xs, ilens, masks = module(xs, ilens, masks)
return xs, masks
@torch.jit.export
def infer(self, xs_pad, buffer, buffer_index, buffer_out):
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
for module in self.enc:
xs_pad, buffer, buffer_index, buffer_out = module.infer(
xs_pad, buffer, buffer_index, buffer_out
)
return xs_pad, buffer, buffer_index, buffer_out
@torch.jit.export
def infer_hidden(self, xs_pad, buffer, buffer_index, buffer_out, hidden_out):
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
for module in self.enc:
xs_pad, buffer, buffer_index, buffer_out, hidden_out = module.infer_hidden(
xs_pad, buffer, buffer_index, buffer_out, hidden_out
)
return xs_pad, buffer, buffer_index, buffer_out, hidden_out
@torch.jit.ignore(drop=True)
def get_extra_loss(self) -> Dict[str, torch.Tensor]:
return None
|