File size: 4,210 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""Transducer decoder interface module."""

from dataclasses import dataclass
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import torch


@dataclass
class Hypothesis:
    """Default hypothesis definition for beam search."""

    score: float
    yseq: List[int]
    dec_state: Union[
        Tuple[torch.Tensor, Optional[torch.Tensor]], List[torch.Tensor], torch.Tensor
    ]
    lm_state: Union[Dict[str, Any], List[Any]] = None


@dataclass
class NSCHypothesis(Hypothesis):
    """Extended hypothesis definition for NSC beam search."""

    y: List[torch.Tensor] = None
    lm_scores: torch.Tensor = None


class TransducerDecoderInterface:
    """Decoder interface for transducer models."""

    def init_state(
        self,
        batch_size: int,
        device: torch.device,
    ) -> Union[
        Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]]
    ]:
        """Initialize decoder states.

        Args:
            batch_size: Batch size for initial state
            device: Device for initial state

        Returns:
            state: Initialized state

        """
        raise NotImplementedError("init_state method is not implemented")

    def score(
        self,
        hyp: Union[Hypothesis, NSCHypothesis],
        cache: Dict[str, Any],
    ) -> Union[
        Tuple[torch.Tensor, Optional[torch.Tensor]],
        torch.Tensor,
        List[Optional[torch.Tensor]],
    ]:
        """Forward one hypothesis.

        Args:
            hyp: Hypothesis.
            cache: Pairs of (y, state) for each token sequence (key)

        Returns:
            y: Decoder outputs
            new_state: New decoder state
            lm_tokens: Token id for LM

        """
        raise NotImplementedError("score method is not implemented")

    def batch_score(
        self,
        hyps: Union[List[Hypothesis], List[NSCHypothesis]],
        batch_states: Union[
            Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]]
        ],
        cache: Dict[str, Any],
    ) -> Union[
        Tuple[torch.Tensor, Optional[torch.Tensor]],
        torch.Tensor,
        List[Optional[torch.Tensor]],
    ]:
        """Forward batch of hypotheses.

        Args:
            hyps: Batch of hypotheses
            batch_states: Batch of decoder states
            cache: pairs of (y, state) for each token sequence (key)

        Returns:
            batch_y: Decoder outputs
            batch_states: Batch of decoder states
            lm_tokens: Batch of token ids for LM

        """
        raise NotImplementedError("batch_score method is not implemented")

    def select_state(
        self,
        batch_states: Union[
            Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]]
        ],
        idx: int,
    ) -> Union[
        Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]]
    ]:
        """Get decoder state from batch for given id.

        Args:
            batch_states: Batch of decoder states
            idx: Index to extract state from batch

        Returns:
            state_idx: Decoder state for given id

        """
        raise NotImplementedError("select_state method is not implemented")

    def create_batch_states(
        self,
        batch_states: Union[
            Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]]
        ],
        l_states: List[
            Union[
                Tuple[torch.Tensor, Optional[torch.Tensor]],
                List[Optional[torch.Tensor]],
            ]
        ],
        l_tokens: List[List[int]],
    ) -> Union[
        Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]]
    ]:
        """Create batch of decoder states.

        Args:
            batch_states: Batch of decoder states
            l_states: List of decoder states
            l_tokens: List of token sequences for input batch

        Returns:
            batch_states: Batch of decoder states

        """
        raise NotImplementedError("create_batch_states method is not implemented")