Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- encoding: utf-8 -*- | |
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. | |
# MIT License (https://opensource.org/licenses/MIT) | |
import torch | |
from typing import Dict, List, Optional, Tuple | |
from funasr_detach.register import tables | |
from funasr_detach.models.rwkv_bat.rwkv import RWKV | |
from funasr_detach.models.transformer.layer_norm import LayerNorm | |
from funasr_detach.models.transformer.utils.nets_utils import make_source_mask | |
from funasr_detach.models.rwkv_bat.rwkv_subsampling import RWKVConvInput | |
class RWKVEncoder(torch.nn.Module): | |
"""RWKV encoder module. | |
Based on https://arxiv.org/pdf/2305.13048.pdf. | |
Args: | |
vocab_size: Vocabulary size. | |
output_size: Input/Output size. | |
context_size: Context size for WKV computation. | |
linear_size: FeedForward hidden size. | |
attention_size: SelfAttention hidden size. | |
normalization_type: Normalization layer type. | |
normalization_args: Normalization layer arguments. | |
num_blocks: Number of RWKV blocks. | |
embed_dropout_rate: Dropout rate for embedding layer. | |
att_dropout_rate: Dropout rate for the attention module. | |
ffn_dropout_rate: Dropout rate for the feed-forward module. | |
""" | |
def __init__( | |
self, | |
input_size: int, | |
output_size: int = 512, | |
context_size: int = 1024, | |
linear_size: Optional[int] = None, | |
attention_size: Optional[int] = None, | |
num_blocks: int = 4, | |
att_dropout_rate: float = 0.0, | |
ffn_dropout_rate: float = 0.0, | |
dropout_rate: float = 0.0, | |
subsampling_factor: int = 4, | |
time_reduction_factor: int = 1, | |
kernel: int = 3, | |
**kwargs, | |
) -> None: | |
"""Construct a RWKVEncoder object.""" | |
super().__init__() | |
self.embed = RWKVConvInput( | |
input_size, | |
[output_size // 4, output_size // 2, output_size], | |
subsampling_factor, | |
conv_kernel_size=kernel, | |
output_size=output_size, | |
) | |
self.subsampling_factor = subsampling_factor | |
linear_size = output_size * 4 if linear_size is None else linear_size | |
attention_size = output_size if attention_size is None else attention_size | |
self.rwkv_blocks = torch.nn.ModuleList( | |
[ | |
RWKV( | |
output_size, | |
linear_size, | |
attention_size, | |
context_size, | |
block_id, | |
num_blocks, | |
att_dropout_rate=att_dropout_rate, | |
ffn_dropout_rate=ffn_dropout_rate, | |
dropout_rate=dropout_rate, | |
) | |
for block_id in range(num_blocks) | |
] | |
) | |
self.embed_norm = LayerNorm(output_size) | |
self.final_norm = LayerNorm(output_size) | |
self._output_size = output_size | |
self.context_size = context_size | |
self.num_blocks = num_blocks | |
self.time_reduction_factor = time_reduction_factor | |
def output_size(self) -> int: | |
return self._output_size | |
def forward(self, x: torch.Tensor, x_len) -> torch.Tensor: | |
"""Encode source label sequences. | |
Args: | |
x: Encoder input sequences. (B, L) | |
Returns: | |
out: Encoder output sequences. (B, U, D) | |
""" | |
_, length, _ = x.size() | |
assert ( | |
length <= self.context_size * self.subsampling_factor | |
), "Context size is too short for current length: %d versus %d" % ( | |
length, | |
self.context_size * self.subsampling_factor, | |
) | |
mask = make_source_mask(x_len).to(x.device) | |
x, mask = self.embed(x, mask, None) | |
x = self.embed_norm(x) | |
olens = mask.eq(0).sum(1) | |
if self.training: | |
for block in self.rwkv_blocks: | |
x, _ = block(x) | |
else: | |
x = self.rwkv_infer(x) | |
x = self.final_norm(x) | |
if self.time_reduction_factor > 1: | |
x = x[:, :: self.time_reduction_factor, :] | |
olens = torch.floor_divide(olens - 1, self.time_reduction_factor) + 1 | |
return x, olens, None | |
def rwkv_infer(self, xs_pad): | |
batch_size = xs_pad.shape[0] | |
hidden_sizes = [self._output_size for i in range(5)] | |
state = [ | |
torch.zeros( | |
(batch_size, 1, hidden_sizes[i], self.num_blocks), | |
dtype=torch.float32, | |
device=xs_pad.device, | |
) | |
for i in range(5) | |
] | |
state[4] -= 1e-30 | |
xs_out = [] | |
for t in range(xs_pad.shape[1]): | |
x_t = xs_pad[:, t, :] | |
for idx, block in enumerate(self.rwkv_blocks): | |
x_t, state = block(x_t, state=state) | |
xs_out.append(x_t) | |
xs_out = torch.cat(xs_out, dim=1) | |
return xs_out | |