Spaces:
Runtime error
Runtime error
File size: 2,865 Bytes
0102e16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from funasr_detach.models.base_model import FunASRModel
from funasr_detach.models.encoder.mossformer_encoder import (
MossFormerEncoder,
MossFormer_MaskNet,
)
from funasr_detach.models.decoder.mossformer_decoder import MossFormerDecoder
class MossFormer(FunASRModel):
"""The MossFormer model for separating input mixed speech into different speaker's speech.
Arguments
---------
in_channels : int
Number of channels at the output of the encoder.
out_channels : int
Number of channels that would be inputted to the intra and inter blocks.
num_blocks : int
Number of layers of Dual Computation Block.
norm : str
Normalization type.
num_spks : int
Number of sources (speakers).
skip_around_intra : bool
Skip connection around intra.
use_global_pos_enc : bool
Global positional encodings.
max_length : int
Maximum sequence length.
kernel_size: int
Encoder and decoder kernel size
"""
def __init__(
self,
in_channels=512,
out_channels=512,
num_blocks=24,
kernel_size=16,
norm="ln",
num_spks=2,
skip_around_intra=True,
use_global_pos_enc=True,
max_length=20000,
):
super(MossFormer, self).__init__()
self.num_spks = num_spks
# Encoding
self.enc = MossFormerEncoder(
kernel_size=kernel_size, out_channels=in_channels, in_channels=1
)
##Compute Mask
self.mask_net = MossFormer_MaskNet(
in_channels=in_channels,
out_channels=out_channels,
num_blocks=num_blocks,
norm=norm,
num_spks=num_spks,
skip_around_intra=skip_around_intra,
use_global_pos_enc=use_global_pos_enc,
max_length=max_length,
)
self.dec = MossFormerDecoder(
in_channels=out_channels,
out_channels=1,
kernel_size=kernel_size,
stride=kernel_size // 2,
bias=False,
)
def forward(self, input):
x = self.enc(input)
mask = self.mask_net(x)
x = torch.stack([x] * self.num_spks)
sep_x = x * mask
# Decoding
est_source = torch.cat(
[self.dec(sep_x[i]).unsqueeze(-1) for i in range(self.num_spks)],
dim=-1,
)
T_origin = input.size(1)
T_est = est_source.size(1)
if T_origin > T_est:
est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
else:
est_source = est_source[:, :T_origin, :]
out = []
for spk in range(self.num_spks):
out.append(est_source[:, :, spk])
return out
|