Spaces:
Running
Running
import torch.nn as nn | |
from .nrtr_decoder import NRTRDecoder | |
class CAMDecoder(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
nhead=None, | |
num_encoder_layers=6, | |
beam_size=0, | |
num_decoder_layers=6, | |
max_len=25, | |
attention_dropout_rate=0.0, | |
residual_dropout_rate=0.1, | |
scale_embedding=True, | |
): | |
super().__init__() | |
self.decoder = NRTRDecoder( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
nhead=nhead, | |
num_encoder_layers=num_encoder_layers, | |
beam_size=beam_size, | |
num_decoder_layers=num_decoder_layers, | |
max_len=max_len, | |
attention_dropout_rate=attention_dropout_rate, | |
residual_dropout_rate=residual_dropout_rate, | |
scale_embedding=scale_embedding, | |
) | |
def forward(self, x, data=None): | |
dec_in = x['refined_feat'] | |
dec_output = self.decoder(dec_in, data=data) | |
x['rec_output'] = dec_output | |
if self.training: | |
return x | |
else: | |
return dec_output | |