# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from fairseq.models.roberta.hub_interface import RobertaHubInterface import torch import torch.nn.functional as F class XMODHubInterface(RobertaHubInterface): def extract_features( self, tokens: torch.LongTensor, return_all_hiddens: bool = False, lang_id=None, ) -> torch.Tensor: if tokens.dim() == 1: tokens = tokens.unsqueeze(0) if tokens.size(-1) > self.model.max_positions(): raise ValueError( "tokens exceeds maximum length: {} > {}".format( tokens.size(-1), self.model.max_positions() ) ) features, extra = self.model( tokens.to(device=self.device), features_only=True, return_all_hiddens=return_all_hiddens, lang_id=lang_id, ) if return_all_hiddens: # convert from T x B x C -> B x T x C inner_states = extra["inner_states"] return [inner_state.transpose(0, 1) for inner_state in inner_states] else: return features # just the last layer's features def predict( self, head: str, tokens: torch.LongTensor, return_logits: bool = False, lang_id=None, ): features = self.extract_features(tokens.to(device=self.device), lang_id=lang_id) logits = self.model.classification_heads[head](features) if return_logits: return logits return F.log_softmax(logits, dim=-1)