wuxulong19950206
First model version
14d1720
raw
history blame contribute delete
717 Bytes
# from mtts.models.fs2_model import ENCODERS
import torch.nn as nn
from mtts.utils.logging import get_logger
from .fs2_transformer_encoder import FS2TransformerEncoder
logger = get_logger(__file__)
ENCODERS = [FS2TransformerEncoder]
class Encoder(nn.Module):
''' Encoder '''
def __init__(self, encoder_type: str = 'FS2TransformerEncoder', **kwargs):
super(Encoder, self).__init__()
logger.info(f'building encoder with type:{encoder_type}')
encoder_class = eval(encoder_type)
assert encoder_class in ENCODERS
self.config = kwargs
self.encoder = encoder_class(**kwargs)
def forward(self, *args, **kwargs):
return self.encoder(*args, **kwargs)