Spaces:
Running
Running
File size: 1,151 Bytes
29f689c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import torch.nn as nn
from .nrtr_decoder import NRTRDecoder
class CAMDecoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
nhead=None,
num_encoder_layers=6,
beam_size=0,
num_decoder_layers=6,
max_len=25,
attention_dropout_rate=0.0,
residual_dropout_rate=0.1,
scale_embedding=True,
):
super().__init__()
self.decoder = NRTRDecoder(
in_channels=in_channels,
out_channels=out_channels,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
beam_size=beam_size,
num_decoder_layers=num_decoder_layers,
max_len=max_len,
attention_dropout_rate=attention_dropout_rate,
residual_dropout_rate=residual_dropout_rate,
scale_embedding=scale_embedding,
)
def forward(self, x, data=None):
dec_in = x['refined_feat']
dec_output = self.decoder(dec_in, data=data)
x['rec_output'] = dec_output
if self.training:
return x
else:
return dec_output
|