Spaces:
Runtime error
Runtime error
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import copy | |
from funasr_detach.models.base_model import FunASRModel | |
from funasr_detach.models.encoder.mossformer_encoder import ( | |
MossFormerEncoder, | |
MossFormer_MaskNet, | |
) | |
from funasr_detach.models.decoder.mossformer_decoder import MossFormerDecoder | |
class MossFormer(FunASRModel): | |
"""The MossFormer model for separating input mixed speech into different speaker's speech. | |
Arguments | |
--------- | |
in_channels : int | |
Number of channels at the output of the encoder. | |
out_channels : int | |
Number of channels that would be inputted to the intra and inter blocks. | |
num_blocks : int | |
Number of layers of Dual Computation Block. | |
norm : str | |
Normalization type. | |
num_spks : int | |
Number of sources (speakers). | |
skip_around_intra : bool | |
Skip connection around intra. | |
use_global_pos_enc : bool | |
Global positional encodings. | |
max_length : int | |
Maximum sequence length. | |
kernel_size: int | |
Encoder and decoder kernel size | |
""" | |
def __init__( | |
self, | |
in_channels=512, | |
out_channels=512, | |
num_blocks=24, | |
kernel_size=16, | |
norm="ln", | |
num_spks=2, | |
skip_around_intra=True, | |
use_global_pos_enc=True, | |
max_length=20000, | |
): | |
super(MossFormer, self).__init__() | |
self.num_spks = num_spks | |
# Encoding | |
self.enc = MossFormerEncoder( | |
kernel_size=kernel_size, out_channels=in_channels, in_channels=1 | |
) | |
##Compute Mask | |
self.mask_net = MossFormer_MaskNet( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
num_blocks=num_blocks, | |
norm=norm, | |
num_spks=num_spks, | |
skip_around_intra=skip_around_intra, | |
use_global_pos_enc=use_global_pos_enc, | |
max_length=max_length, | |
) | |
self.dec = MossFormerDecoder( | |
in_channels=out_channels, | |
out_channels=1, | |
kernel_size=kernel_size, | |
stride=kernel_size // 2, | |
bias=False, | |
) | |
def forward(self, input): | |
x = self.enc(input) | |
mask = self.mask_net(x) | |
x = torch.stack([x] * self.num_spks) | |
sep_x = x * mask | |
# Decoding | |
est_source = torch.cat( | |
[self.dec(sep_x[i]).unsqueeze(-1) for i in range(self.num_spks)], | |
dim=-1, | |
) | |
T_origin = input.size(1) | |
T_est = est_source.size(1) | |
if T_origin > T_est: | |
est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est)) | |
else: | |
est_source = est_source[:, :T_origin, :] | |
out = [] | |
for spk in range(self.num_spks): | |
out.append(est_source[:, :, spk]) | |
return out | |