File size: 2,389 Bytes
32b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import copy
import numpy as np
import weakref
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
from abc import ABCMeta, abstractmethod

from uniperceiver.config import configurable
from uniperceiver.config import CfgNode as CN
from uniperceiver.functional import pad_tensor, dict_to_cuda, flat_list_of_lists
from ..embedding import build_embeddings
from ..encoder import build_encoder, add_encoder_config
# from ..decoder import build_decoder, add_decoder_config
from ..predictor import build_predictor, add_predictor_config
from ..decode_strategy import build_beam_searcher, build_greedy_decoder

class BaseEncoderDecoder(nn.Module, metaclass=ABCMeta):
    @configurable
    def __init__(
        self,
        *,
        vocab_size,
        max_seq_len,
        token_embed,
        fused_encoder,
        decoder,
        greedy_decoder,
        beam_searcher,
        **kwargs,
    ):
        super(BaseEncoderDecoder, self).__init__()
        self.fused_encoder = fused_encoder
        self.decoder = decoder

        self.token_embed = token_embed
        self.greedy_decoder = greedy_decoder
        self.beam_searcher = beam_searcher
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len


    @classmethod
    def add_config(cls, cfg, tmp_cfg):
        add_encoder_config(cfg, tmp_cfg)
        add_predictor_config(cfg, tmp_cfg)

    def forward(self, batched_inputs, use_beam_search=None, output_sents=False):
        if use_beam_search is None:
            return self._forward(batched_inputs)
        # elif use_beam_search == False or self.beam_searcher.beam_size == 1:
        elif use_beam_search == False:
            return self.greedy_decode(batched_inputs, output_sents)
        else:
            return self.decode_beam_search(batched_inputs, output_sents)

    @abstractmethod
    def _forward(self, batched_inputs):
        pass

    def bind_or_init_weights(self):
        pass


    def greedy_decode(self, batched_inputs, output_sents=False):
        return self.greedy_decoder(
            batched_inputs,
            output_sents,
            model=weakref.proxy(self)
        )

    def decode_beam_search(self, batched_inputs, output_sents=False):
        return self.beam_searcher(
            batched_inputs,
            output_sents,
            model=weakref.proxy(self)
        )