Spaces:
Running
Running
File size: 642 Bytes
224a33f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch.nn as nn
def RegressionHead(
d_model: int,
output_dim: int,
hidden_dim: int | None = None,
) -> nn.Module:
"""Single-hidden layer MLP for supervised output.
Args:
d_model: input dimension
output_dim: dimensionality of the output.
hidden_dim: optional dimension of hidden layer, defaults to d_model.
Returns:
output MLP module.
"""
hidden_dim = hidden_dim if hidden_dim is not None else d_model
return nn.Sequential(
nn.Linear(d_model, hidden_dim),
nn.GELU(),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, output_dim),
)
|