topdu's picture
openocr demo
29f689c
raw
history blame
1.15 kB
import torch.nn as nn
from .nrtr_decoder import NRTRDecoder
class CAMDecoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
nhead=None,
num_encoder_layers=6,
beam_size=0,
num_decoder_layers=6,
max_len=25,
attention_dropout_rate=0.0,
residual_dropout_rate=0.1,
scale_embedding=True,
):
super().__init__()
self.decoder = NRTRDecoder(
in_channels=in_channels,
out_channels=out_channels,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
beam_size=beam_size,
num_decoder_layers=num_decoder_layers,
max_len=max_len,
attention_dropout_rate=attention_dropout_rate,
residual_dropout_rate=residual_dropout_rate,
scale_embedding=scale_embedding,
)
def forward(self, x, data=None):
dec_in = x['refined_feat']
dec_output = self.decoder(dec_in, data=data)
x['rec_output'] = dec_output
if self.training:
return x
else:
return dec_output