File size: 1,151 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch.nn as nn

from .nrtr_decoder import NRTRDecoder


class CAMDecoder(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        nhead=None,
        num_encoder_layers=6,
        beam_size=0,
        num_decoder_layers=6,
        max_len=25,
        attention_dropout_rate=0.0,
        residual_dropout_rate=0.1,
        scale_embedding=True,
    ):
        super().__init__()

        self.decoder = NRTRDecoder(
            in_channels=in_channels,
            out_channels=out_channels,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            beam_size=beam_size,
            num_decoder_layers=num_decoder_layers,
            max_len=max_len,
            attention_dropout_rate=attention_dropout_rate,
            residual_dropout_rate=residual_dropout_rate,
            scale_embedding=scale_embedding,
        )

    def forward(self, x, data=None):
        dec_in = x['refined_feat']
        dec_output = self.decoder(dec_in, data=data)
        x['rec_output'] = dec_output
        if self.training:
            return x
        else:
            return dec_output