File size: 515 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch.nn as nn

from s3prl import Output


class PredictorIdentity(nn.Module):
    """
    This nn module is used as a predictor placeholder for certain SSL problems.
    """

    def __init__(self, **kwargs):
        super(PredictorIdentity, self).__init__()

    def forward(self, output: Output):
        """
        Args:
            output (s3prl.Output): An Output module

        Return:
            output (s3prl.Output): exactly the same as input, an Output module
        """
        return output