import torch from torch import nn from uniperceiver.config import configurable from uniperceiver.config import kfg from .build import PREDICTOR_REGISTRY import math import torch.nn.functional as F __all__ = ["BasePredictor", "RobertaLMHead","TwoLayerPredictor", "RobertaRegressionHead"] @PREDICTOR_REGISTRY.register() class BasePredictor(nn.Module): @configurable def __init__( self, *, hidden_size: int, vocab_size: int, # include / dropout: float ): super(BasePredictor, self).__init__() self.logits = nn.Linear(hidden_size, vocab_size) self.dropout = nn.Dropout(dropout) if dropout > 0.0 else None @classmethod def from_config(cls, cfg): return { "hidden_size": cfg.MODEL.DECODER_DIM, "vocab_size": cfg.MODEL.VOCAB_SIZE, "dropout": cfg.MODEL.PRED_DROPOUT } @classmethod def add_config(cls, cfg): pass def forward(self, batched_inputs): hidden_states = batched_inputs[kfg.G_HIDDEN_STATES] if isinstance(hidden_states, list): hidden_states = hidden_states[-1] if self.dropout: hidden_states = self.dropout(hidden_states) logits = self.logits(hidden_states) return { kfg.G_LOGITS: logits } def gelu_accurate(x): if not hasattr(gelu_accurate, "_a"): gelu_accurate._a = math.sqrt(2 / math.pi) return ( 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) ) def gelu(x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.gelu(x.float()).type_as(x) @PREDICTOR_REGISTRY.register() class TwoLayerPredictor(nn.Module): @configurable def __init__( self, *, hidden_size: int, vocab_size: int, # include / dropout: float ): super(TwoLayerPredictor, self).__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.activation_fn = gelu self.layer_norm = nn.LayerNorm(hidden_size) self.logits = nn.Linear(hidden_size, vocab_size) self.dropout = nn.Dropout(dropout) if dropout > 0.0 else None def replace_logits(self, shared_weights): self.logits.weight = shared_weights @classmethod def from_config(cls, cfg): return { "hidden_size": cfg.MODEL.DECODER_DIM, "vocab_size": cfg.MODEL.VOCAB_SIZE, "dropout": cfg.MODEL.PRED_DROPOUT } @classmethod def add_config(cls, cfg): pass def forward(self, batched_inputs): hidden_states = batched_inputs[kfg.G_HIDDEN_STATES] if isinstance(hidden_states, list): hidden_states = hidden_states[-1] x = self.dense(hidden_states) x = self.activation_fn(x) x = self.layer_norm(x) logits = self.logits(x) return { kfg.G_LOGITS: logits } @PREDICTOR_REGISTRY.register() class RobertaLMHead(nn.Module): @configurable def __init__( self, *, hidden_size: int, vocab_size: int, # include / dropout: float, untie_weight_embedding: bool, use_bias: bool, share_hidden: bool, ): super(RobertaLMHead, self).__init__() self.activation_fn = gelu if untie_weight_embedding is True: self.weight = nn.Linear(hidden_size, vocab_size, bias=False).weight else: self.weight = None if share_hidden: self.dense = None self.layer_norm = None else: self.dense = nn.Linear(hidden_size, hidden_size) self.layer_norm = nn.LayerNorm(hidden_size) if use_bias: self.bias = nn.Parameter(torch.zeros(vocab_size)) else: self.bias = None self.dropout = nn.Dropout(dropout) if dropout > 0.0 else None # print("dropout: {}".format(self.dropout)) def replace_weight(self, weight): if self.weight is None: self.weight = weight else: print('already has weight, please set UNTIE_WEIGHT_EMBEDDING to False') def replace_module_hidden(self, dense, layer_norm): if (self.dense is None) and (self.layer_norm is None): self.dense = dense self.layer_norm = layer_norm else: print('already has hidden layers!') raise ValueError @classmethod def from_config(cls, cfg): return { "hidden_size": cfg.MODEL.DECODER_DIM, "vocab_size": cfg.MODEL.VOCAB_SIZE, "dropout": cfg.MODEL.PRED_DROPOUT, "untie_weight_embedding": cfg.MODEL.UNTIE_WEIGHT_EMBEDDING, "use_bias": cfg.MODEL.USE_PREDICTOR_BIAS, "share_hidden": cfg.MODEL.SHARE_PREDICTOR_HIDDEN, } @classmethod def add_config(cls, cfg): pass def forward(self, batched_inputs): if kfg.G_HIDDEN_STATES in batched_inputs: hidden_states = batched_inputs[kfg.G_HIDDEN_STATES] if isinstance(hidden_states, list): hidden_states = hidden_states[-1] if kfg.G_TARGET_IDS in batched_inputs: mask_tokens = batched_inputs[kfg.G_TARGET_IDS].ne(-1) hidden_states = hidden_states[mask_tokens] batched_inputs[kfg.G_TARGET_IDS] = batched_inputs[kfg.G_TARGET_IDS][mask_tokens] logits = self._forward(hidden_states) return { kfg.G_LOGITS: logits } elif kfg.U_HIDDEN_STATES in batched_inputs: hidden_states = batched_inputs[kfg.U_HIDDEN_STATES] if isinstance(hidden_states, list): hidden_states = hidden_states[-1] mask_tokens = batched_inputs[kfg.U_TARGET_IDS].ne(-1) hidden_states = hidden_states[mask_tokens] batched_inputs[kfg.U_TARGET_IDS] = batched_inputs[kfg.U_TARGET_IDS][mask_tokens] logits = self._forward(hidden_states) return { kfg.U_LOGITS: logits } def _forward(self, x): x = self.dense(x) x = self.activation_fn(x) x = self.layer_norm(x) if self.dropout: x = self.dropout(x) if self.bias is not None: logits = F.linear(x, self.weight) + self.bias else: logits = F.linear(x, self.weight) return logits @PREDICTOR_REGISTRY.register() class RobertaRegressionHead(nn.Module): @configurable def __init__( self, *, hidden_size, feat_dim, transform, sigmoid ): super(RobertaRegressionHead, self).__init__() self.transform = transform self.decoder = nn.Linear(hidden_size, feat_dim) self.output_sigmoid = sigmoid @classmethod def from_config(cls, cfg): return { "hidden_size": cfg.MODEL.DECODER_DIM, "feat_dim": cfg.MODEL.LABELS_NUM, "sigmoid": cfg.MODEL.SIGMOID, "transform": BertPooler(cfg) } @classmethod def add_config(cls, cfg): pass def test_forward(self, u_logits): # for Single stream similarity return { kfg.OUTPUT: u_logits } def forward(self, batched_inputs): ret = {} if kfg.G_HIDDEN_STATES in batched_inputs: hidden_states = batched_inputs[kfg.G_HIDDEN_STATES] if isinstance(hidden_states, list): hidden_states = hidden_states[-1] hidden_states = self.transform(hidden_states) logits = self.decoder(hidden_states) if self.output_sigmoid: logits = torch.sigmoid(logits) ret.update({ kfg.G_LOGITS: logits }) if not self.training: ret_test = self.test_forward(logits) ret.update(ret_test) return ret elif kfg.U_HIDDEN_STATES in batched_inputs: hidden_states = batched_inputs[kfg.U_HIDDEN_STATES] if isinstance(hidden_states, list): hidden_states = hidden_states[-1] hidden_states = self.transform(hidden_states) logits = self.decoder(hidden_states) if self.output_sigmoid: logits = torch.sigmoid(logits) ret.update({ kfg.U_LOGITS: logits }) if not self.training: ret_test = self.test_forward(logits) ret.update(ret_test) return ret