File size: 3,397 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Model interfaces

Authors:
  * Leo 2022
"""

from typing import List, Tuple

import torch
import torch.nn as nn

__all__ = [
    "AbsUpstream",
    "AbsFeaturizer",
    "AbsFrameModel",
    "AbsUtteranceModel",
]


class AbsUpstream(nn.Module):
    """
    The upstream model should follow this interface. Please subclass it.
    """

    @property
    def num_layer(self) -> int:
        """
        number of hidden states
        """
        raise NotImplementedError

    @property
    def hidden_sizes(self) -> List[int]:
        """
        hidden size of each hidden state
        """
        raise NotImplementedError

    @property
    def downsample_rates(self) -> List[int]:
        """
        downsample rate from 16 KHz waveforms for each hidden state
        """
        raise NotImplementedError

    def forward(
        self, wavs: torch.FloatTensor, wavs_len: torch.LongTensor
    ) -> Tuple[List[torch.FloatTensor], List[torch.LongTensor]]:
        """
        Args:
            wavs (torch.FloatTensor): (batch_size, seq_len, 1)
            wavs_len (torch.LongTensor): (batch_size, )

        Returns:
            tuple:

            1. all_hs (List[torch.FloatTensor]): all the hidden states
            2. all_hs_len (List[torch.LongTensor]): the lengths for all the hidden states
        """
        raise NotImplementedError


class AbsFeaturizer(nn.Module):
    """
    The featurizer should follow this interface. Please subclass it.
    The featurizer's mission is to reduce (standardize) the multiple hidden
    states from :obj:`AbsUpstream` into a single hidden state, so that
    the downstream model can use it as a conventional representation.
    """

    @property
    def output_size(self) -> int:
        """
        The output size after hidden states reduction
        """
        raise NotImplementedError

    @property
    def downsample_rate(self) -> int:
        """
        The downsample rate from 16 KHz waveform of the reduced single hidden state
        """
        raise NotImplementedError

    def forward(
        self, all_hs: List[torch.FloatTensor], all_hs_len: List[torch.LongTensor]
    ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
        """
        Args:
            all_hs (List[torch.FloatTensor]): all the hidden states
            all_hs_len (List[torch.LongTensor]): the lengths for all the hidden states

        Returns:
            tuple:

            1. hs (torch.FloatTensor)
            2. hs_len (torch.LongTensor)
        """
        raise NotImplementedError


class AbsFrameModel(nn.Module):
    """
    The frame-level model interface.
    """

    @property
    def input_size(self) -> int:
        raise NotImplementedError

    @property
    def output_size(self) -> int:
        raise NotImplementedError

    def forward(
        self, x: torch.FloatTensor, x_len: torch.LongTensor
    ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
        raise NotImplementedError


class AbsUtteranceModel(nn.Module):
    """
    The utterance-level model interface, which pools the temporal dimension.
    """

    @property
    def input_size(self) -> int:
        raise NotImplementedError

    @property
    def output_size(self) -> int:
        raise NotImplementedError

    def forward(
        self, x: torch.FloatTensor, x_len: torch.LongTensor
    ) -> torch.FloatTensor:
        raise NotImplementedError