File size: 2,584 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from torch import nn

from s3prl import Output
from s3prl.nn.transformer_mockingjay import (
    ACT2FN,
    TransformerConfig,
    TransformerLayerNorm,
)


class PredictorMockingjay(nn.Module):
    """
    The predictor model for SSL pre-training tasks.
    Currently supporting SSL problems of Mockingjay, Tera, and Audio Albert.
    """

    def __init__(self, config, output_dim, input_dim=None, **kwargs):
        """
        Args:
            config (TransformerConfig):
                A `TransformerConfig` class instance with the configuration to build a new model,
                can also be a `dict` that initializes the TransformerConfig class
            output_dim (int):
                The output dimension of predictor
            input_dim (int):
                The input dimension of predictor, if `None` is given, then use the `hidden_size` defined in `config`.
                Default: None
        """

        super(PredictorMockingjay, self).__init__()
        if type(config) is dict:
            config = TransformerConfig(**config)
        self.output_size = output_dim
        if input_dim is None:
            self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        else:
            self.dense = nn.Linear(input_dim, config.hidden_size)
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        self.LayerNorm = TransformerLayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
        self.output = nn.Linear(config.hidden_size, self.output_size)

    def forward(self, inputs, output_states=False):
        """
        Args:
            inputs (torch.LongTensor):
                A torch.LongTensor of shape [batch_size, sequence_length, input_dim]
            output_states (bool):
                A boolean which controls whether to return the `hidden_states` of the predictor.
                Default: False
        Return:
            Output (s3prl.Output):
                An Output module that contains `prediction` and/or `hidden_states`.
        """
        hidden_states = inputs.hidden_states
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        prediction = self.output(hidden_states)
        if output_states:
            return Output(hidden_states=hidden_states, prediction=prediction)
        else:
            return Output(prediction=prediction)