File size: 515 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch.nn as nn
from s3prl import Output
class PredictorIdentity(nn.Module):
"""
This nn module is used as a predictor placeholder for certain SSL problems.
"""
def __init__(self, **kwargs):
super(PredictorIdentity, self).__init__()
def forward(self, output: Output):
"""
Args:
output (s3prl.Output): An Output module
Return:
output (s3prl.Output): exactly the same as input, an Output module
"""
return output
|