import torch.nn as nn from s3prl import Output class PredictorIdentity(nn.Module): """ This nn module is used as a predictor placeholder for certain SSL problems. """ def __init__(self, **kwargs): super(PredictorIdentity, self).__init__() def forward(self, output: Output): """ Args: output (s3prl.Output): An Output module Return: output (s3prl.Output): exactly the same as input, an Output module """ return output