Spaces:
Sleeping
Sleeping
File size: 717 Bytes
14d1720 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
# from mtts.models.fs2_model import ENCODERS
import torch.nn as nn
from mtts.utils.logging import get_logger
from .fs2_transformer_encoder import FS2TransformerEncoder
logger = get_logger(__file__)
ENCODERS = [FS2TransformerEncoder]
class Encoder(nn.Module):
''' Encoder '''
def __init__(self, encoder_type: str = 'FS2TransformerEncoder', **kwargs):
super(Encoder, self).__init__()
logger.info(f'building encoder with type:{encoder_type}')
encoder_class = eval(encoder_type)
assert encoder_class in ENCODERS
self.config = kwargs
self.encoder = encoder_class(**kwargs)
def forward(self, *args, **kwargs):
return self.encoder(*args, **kwargs)
|