|
from dataclasses import dataclass, field |
|
import os |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from fairseq import utils |
|
from fairseq.dataclass import ChoiceEnum, FairseqDataclass |
|
from fairseq.models import ( |
|
BaseFairseqModel, |
|
register_model, |
|
) |
|
|
|
from fairseq.models.roberta.model import RobertaClassificationHead |
|
|
|
from fairseq.modules import ( |
|
LayerNorm, |
|
TransformerSentenceEncoder, |
|
TransformerSentenceEncoderLayer, |
|
) |
|
|
|
|
|
ACTIVATION_FN_CHOICES = ChoiceEnum(utils.get_available_activation_fns()) |
|
JOINT_CLASSIFICATION_CHOICES = ChoiceEnum(["none", "sent"]) |
|
SENTENCE_REP_CHOICES = ChoiceEnum(["head", "meanpool", "maxpool"]) |
|
|
|
|
|
def update_init_roberta_model_state(state): |
|
""" |
|
update the state_dict of a Roberta model for initializing |
|
weights of the BertRanker |
|
""" |
|
for k in list(state.keys()): |
|
if ".lm_head." in k or "version" in k: |
|
del state[k] |
|
continue |
|
|
|
assert k.startswith("encoder.sentence_encoder.") or k.startswith( |
|
"decoder.sentence_encoder." |
|
), f"Cannot recognize parameter name {k}" |
|
if "layernorm_embedding" in k: |
|
new_k = k.replace(".layernorm_embedding.", ".emb_layer_norm.") |
|
state[new_k[25:]] = state[k] |
|
else: |
|
state[k[25:]] = state[k] |
|
del state[k] |
|
|
|
|
|
class BaseRanker(nn.Module): |
|
def __init__(self, args, task): |
|
super().__init__() |
|
|
|
self.separator_token = task.dictionary.eos() |
|
self.padding_idx = task.dictionary.pad() |
|
|
|
def forward(self, src_tokens): |
|
raise NotImplementedError |
|
|
|
def get_segment_labels(self, src_tokens): |
|
segment_boundary = (src_tokens == self.separator_token).long() |
|
segment_labels = ( |
|
segment_boundary.cumsum(dim=1) |
|
- segment_boundary |
|
- (src_tokens == self.padding_idx).long() |
|
) |
|
|
|
return segment_labels |
|
|
|
def get_positions(self, src_tokens, segment_labels): |
|
segment_positions = ( |
|
torch.arange(src_tokens.shape[1]) |
|
.to(src_tokens.device) |
|
.repeat(src_tokens.shape[0], 1) |
|
) |
|
segment_boundary = (src_tokens == self.separator_token).long() |
|
_, col_idx = (segment_positions * segment_boundary).nonzero(as_tuple=True) |
|
col_idx = torch.cat([torch.zeros(1).type_as(col_idx), col_idx]) |
|
offset = torch.cat( |
|
[ |
|
torch.zeros(1).type_as(segment_boundary), |
|
segment_boundary.sum(dim=1).cumsum(dim=0)[:-1], |
|
] |
|
) |
|
segment_positions -= col_idx[segment_labels + offset.unsqueeze(1)] * ( |
|
segment_labels != 0 |
|
) |
|
|
|
padding_mask = src_tokens.ne(self.padding_idx) |
|
segment_positions = (segment_positions + 1) * padding_mask.type_as( |
|
segment_positions |
|
) + self.padding_idx |
|
|
|
return segment_positions |
|
|
|
|
|
class BertRanker(BaseRanker): |
|
def __init__(self, args, task): |
|
super(BertRanker, self).__init__(args, task) |
|
|
|
init_model = getattr(args, "pretrained_model", "") |
|
self.joint_layers = nn.ModuleList() |
|
if os.path.isfile(init_model): |
|
print(f"initialize weight from {init_model}") |
|
|
|
from fairseq import hub_utils |
|
|
|
x = hub_utils.from_pretrained( |
|
os.path.dirname(init_model), |
|
checkpoint_file=os.path.basename(init_model), |
|
) |
|
|
|
in_state_dict = x["models"][0].state_dict() |
|
init_args = x["args"].model |
|
|
|
num_positional_emb = init_args.max_positions + task.dictionary.pad() + 1 |
|
|
|
|
|
self.model = TransformerSentenceEncoder( |
|
padding_idx=task.dictionary.pad(), |
|
vocab_size=len(task.dictionary), |
|
num_encoder_layers=getattr( |
|
args, "encoder_layers", init_args.encoder_layers |
|
), |
|
embedding_dim=init_args.encoder_embed_dim, |
|
ffn_embedding_dim=init_args.encoder_ffn_embed_dim, |
|
num_attention_heads=init_args.encoder_attention_heads, |
|
dropout=init_args.dropout, |
|
attention_dropout=init_args.attention_dropout, |
|
activation_dropout=init_args.activation_dropout, |
|
num_segments=2, |
|
max_seq_len=num_positional_emb, |
|
offset_positions_by_padding=False, |
|
encoder_normalize_before=True, |
|
apply_bert_init=True, |
|
activation_fn=init_args.activation_fn, |
|
freeze_embeddings=args.freeze_embeddings, |
|
n_trans_layers_to_freeze=args.n_trans_layers_to_freeze, |
|
) |
|
|
|
|
|
if args.freeze_embeddings: |
|
for p in self.model.segment_embeddings.parameters(): |
|
p.requires_grad = False |
|
|
|
update_init_roberta_model_state(in_state_dict) |
|
print("loading weights from the pretrained model") |
|
self.model.load_state_dict( |
|
in_state_dict, strict=False |
|
) |
|
|
|
ffn_embedding_dim = init_args.encoder_ffn_embed_dim |
|
num_attention_heads = init_args.encoder_attention_heads |
|
dropout = init_args.dropout |
|
attention_dropout = init_args.attention_dropout |
|
activation_dropout = init_args.activation_dropout |
|
activation_fn = init_args.activation_fn |
|
|
|
classifier_embed_dim = getattr( |
|
args, "embed_dim", init_args.encoder_embed_dim |
|
) |
|
if classifier_embed_dim != init_args.encoder_embed_dim: |
|
self.transform_layer = nn.Linear( |
|
init_args.encoder_embed_dim, classifier_embed_dim |
|
) |
|
else: |
|
self.model = TransformerSentenceEncoder( |
|
padding_idx=task.dictionary.pad(), |
|
vocab_size=len(task.dictionary), |
|
num_encoder_layers=args.encoder_layers, |
|
embedding_dim=args.embed_dim, |
|
ffn_embedding_dim=args.ffn_embed_dim, |
|
num_attention_heads=args.attention_heads, |
|
dropout=args.dropout, |
|
attention_dropout=args.attention_dropout, |
|
activation_dropout=args.activation_dropout, |
|
max_seq_len=task.max_positions() |
|
if task.max_positions() |
|
else args.tokens_per_sample, |
|
num_segments=2, |
|
offset_positions_by_padding=False, |
|
encoder_normalize_before=args.encoder_normalize_before, |
|
apply_bert_init=args.apply_bert_init, |
|
activation_fn=args.activation_fn, |
|
) |
|
|
|
classifier_embed_dim = args.embed_dim |
|
ffn_embedding_dim = args.ffn_embed_dim |
|
num_attention_heads = args.attention_heads |
|
dropout = args.dropout |
|
attention_dropout = args.attention_dropout |
|
activation_dropout = args.activation_dropout |
|
activation_fn = args.activation_fn |
|
|
|
self.joint_classification = args.joint_classification |
|
if args.joint_classification == "sent": |
|
if args.joint_normalize_before: |
|
self.joint_layer_norm = LayerNorm(classifier_embed_dim) |
|
else: |
|
self.joint_layer_norm = None |
|
|
|
self.joint_layers = nn.ModuleList( |
|
[ |
|
TransformerSentenceEncoderLayer( |
|
embedding_dim=classifier_embed_dim, |
|
ffn_embedding_dim=ffn_embedding_dim, |
|
num_attention_heads=num_attention_heads, |
|
dropout=dropout, |
|
attention_dropout=attention_dropout, |
|
activation_dropout=activation_dropout, |
|
activation_fn=activation_fn, |
|
) |
|
for _ in range(args.num_joint_layers) |
|
] |
|
) |
|
|
|
self.classifier = RobertaClassificationHead( |
|
classifier_embed_dim, |
|
classifier_embed_dim, |
|
1, |
|
"tanh", |
|
args.classifier_dropout, |
|
) |
|
|
|
def forward(self, src_tokens, src_lengths): |
|
segment_labels = self.get_segment_labels(src_tokens) |
|
positions = self.get_positions(src_tokens, segment_labels) |
|
|
|
inner_states, _ = self.model( |
|
tokens=src_tokens, |
|
segment_labels=segment_labels, |
|
last_state_only=True, |
|
positions=positions, |
|
) |
|
|
|
return inner_states[-1].transpose(0, 1) |
|
|
|
def sentence_forward(self, encoder_out, src_tokens=None, sentence_rep="head"): |
|
|
|
if sentence_rep == "head": |
|
x = encoder_out[:, :1, :] |
|
else: |
|
assert src_tokens is not None, "meanpool requires src_tokens input" |
|
segment_labels = self.get_segment_labels(src_tokens) |
|
padding_mask = src_tokens.ne(self.padding_idx) |
|
encoder_mask = segment_labels * padding_mask.type_as(segment_labels) |
|
|
|
if sentence_rep == "meanpool": |
|
ntokens = torch.sum(encoder_mask, dim=1, keepdim=True) |
|
x = torch.sum( |
|
encoder_out * encoder_mask.unsqueeze(2), dim=1, keepdim=True |
|
) / ntokens.unsqueeze(2).type_as(encoder_out) |
|
else: |
|
encoder_out[ |
|
(encoder_mask == 0).unsqueeze(2).repeat(1, 1, encoder_out.shape[-1]) |
|
] = -float("inf") |
|
x, _ = torch.max(encoder_out, dim=1, keepdim=True) |
|
|
|
if hasattr(self, "transform_layer"): |
|
x = self.transform_layer(x) |
|
|
|
return x |
|
|
|
def joint_forward(self, x): |
|
|
|
if self.joint_layer_norm: |
|
x = self.joint_layer_norm(x.transpose(0, 1)) |
|
x = x.transpose(0, 1) |
|
|
|
for layer in self.joint_layers: |
|
x, _ = layer(x, self_attn_padding_mask=None) |
|
return x |
|
|
|
def classification_forward(self, x): |
|
|
|
return self.classifier(x) |
|
|
|
|
|
@dataclass |
|
class DiscriminativeNMTRerankerConfig(FairseqDataclass): |
|
pretrained_model: str = field( |
|
default="", metadata={"help": "pretrained model to load"} |
|
) |
|
sentence_rep: SENTENCE_REP_CHOICES = field( |
|
default="head", |
|
metadata={ |
|
"help": "method to transform the output of the transformer stack to a sentence-level representation" |
|
}, |
|
) |
|
|
|
dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) |
|
attention_dropout: float = field( |
|
default=0.0, metadata={"help": "dropout probability for attention weights"} |
|
) |
|
activation_dropout: float = field( |
|
default=0.0, metadata={"help": "dropout probability after activation in FFN"} |
|
) |
|
classifier_dropout: float = field( |
|
default=0.0, metadata={"help": "classifier dropout probability"} |
|
) |
|
embed_dim: int = field(default=768, metadata={"help": "embedding dimension"}) |
|
ffn_embed_dim: int = field( |
|
default=2048, metadata={"help": "embedding dimension for FFN"} |
|
) |
|
encoder_layers: int = field(default=12, metadata={"help": "num encoder layers"}) |
|
attention_heads: int = field(default=8, metadata={"help": "num attention heads"}) |
|
encoder_normalize_before: bool = field( |
|
default=False, metadata={"help": "apply layernorm before each encoder block"} |
|
) |
|
apply_bert_init: bool = field( |
|
default=False, metadata={"help": "use custom param initialization for BERT"} |
|
) |
|
activation_fn: ACTIVATION_FN_CHOICES = field( |
|
default="relu", metadata={"help": "activation function to use"} |
|
) |
|
freeze_embeddings: bool = field( |
|
default=False, metadata={"help": "freeze embeddings in the pretrained model"} |
|
) |
|
n_trans_layers_to_freeze: int = field( |
|
default=0, |
|
metadata={ |
|
"help": "number of layers to freeze in the pretrained transformer model" |
|
}, |
|
) |
|
|
|
|
|
joint_classification: JOINT_CLASSIFICATION_CHOICES = field( |
|
default="none", |
|
metadata={"help": "method to compute joint features for classification"}, |
|
) |
|
num_joint_layers: int = field( |
|
default=1, metadata={"help": "number of joint layers"} |
|
) |
|
joint_normalize_before: bool = field( |
|
default=False, |
|
metadata={"help": "apply layer norm on the input to the joint layer"}, |
|
) |
|
|
|
|
|
@register_model( |
|
"discriminative_nmt_reranker", dataclass=DiscriminativeNMTRerankerConfig |
|
) |
|
class DiscriminativeNMTReranker(BaseFairseqModel): |
|
@classmethod |
|
def build_model(cls, args, task): |
|
model = BertRanker(args, task) |
|
return DiscriminativeNMTReranker(args, model) |
|
|
|
def __init__(self, args, model): |
|
super().__init__() |
|
|
|
self.model = model |
|
self.sentence_rep = args.sentence_rep |
|
self.joint_classification = args.joint_classification |
|
|
|
def forward(self, src_tokens, src_lengths, **kwargs): |
|
return self.model(src_tokens, src_lengths) |
|
|
|
def sentence_forward(self, encoder_out, src_tokens): |
|
return self.model.sentence_forward(encoder_out, src_tokens, self.sentence_rep) |
|
|
|
def joint_forward(self, x): |
|
return self.model.joint_forward(x) |
|
|
|
def classification_forward(self, x): |
|
return self.model.classification_forward(x) |
|
|