OFA / fairseq /examples /discriminative_reranking_nmt /models /discriminative_reranking_model.py
logicwong's picture
init
c9bb3f2
from dataclasses import dataclass, field
import os
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import (
BaseFairseqModel,
register_model,
)
from fairseq.models.roberta.model import RobertaClassificationHead
from fairseq.modules import (
LayerNorm,
TransformerSentenceEncoder,
TransformerSentenceEncoderLayer,
)
ACTIVATION_FN_CHOICES = ChoiceEnum(utils.get_available_activation_fns())
JOINT_CLASSIFICATION_CHOICES = ChoiceEnum(["none", "sent"])
SENTENCE_REP_CHOICES = ChoiceEnum(["head", "meanpool", "maxpool"])
def update_init_roberta_model_state(state):
"""
update the state_dict of a Roberta model for initializing
weights of the BertRanker
"""
for k in list(state.keys()):
if ".lm_head." in k or "version" in k:
del state[k]
continue
# remove 'encoder/decoder.sentence_encoder.' from the key
assert k.startswith("encoder.sentence_encoder.") or k.startswith(
"decoder.sentence_encoder."
), f"Cannot recognize parameter name {k}"
if "layernorm_embedding" in k:
new_k = k.replace(".layernorm_embedding.", ".emb_layer_norm.")
state[new_k[25:]] = state[k]
else:
state[k[25:]] = state[k]
del state[k]
class BaseRanker(nn.Module):
def __init__(self, args, task):
super().__init__()
self.separator_token = task.dictionary.eos()
self.padding_idx = task.dictionary.pad()
def forward(self, src_tokens):
raise NotImplementedError
def get_segment_labels(self, src_tokens):
segment_boundary = (src_tokens == self.separator_token).long()
segment_labels = (
segment_boundary.cumsum(dim=1)
- segment_boundary
- (src_tokens == self.padding_idx).long()
)
return segment_labels
def get_positions(self, src_tokens, segment_labels):
segment_positions = (
torch.arange(src_tokens.shape[1])
.to(src_tokens.device)
.repeat(src_tokens.shape[0], 1)
)
segment_boundary = (src_tokens == self.separator_token).long()
_, col_idx = (segment_positions * segment_boundary).nonzero(as_tuple=True)
col_idx = torch.cat([torch.zeros(1).type_as(col_idx), col_idx])
offset = torch.cat(
[
torch.zeros(1).type_as(segment_boundary),
segment_boundary.sum(dim=1).cumsum(dim=0)[:-1],
]
)
segment_positions -= col_idx[segment_labels + offset.unsqueeze(1)] * (
segment_labels != 0
)
padding_mask = src_tokens.ne(self.padding_idx)
segment_positions = (segment_positions + 1) * padding_mask.type_as(
segment_positions
) + self.padding_idx
return segment_positions
class BertRanker(BaseRanker):
def __init__(self, args, task):
super(BertRanker, self).__init__(args, task)
init_model = getattr(args, "pretrained_model", "")
self.joint_layers = nn.ModuleList()
if os.path.isfile(init_model):
print(f"initialize weight from {init_model}")
from fairseq import hub_utils
x = hub_utils.from_pretrained(
os.path.dirname(init_model),
checkpoint_file=os.path.basename(init_model),
)
in_state_dict = x["models"][0].state_dict()
init_args = x["args"].model
num_positional_emb = init_args.max_positions + task.dictionary.pad() + 1
# follow the setup in roberta
self.model = TransformerSentenceEncoder(
padding_idx=task.dictionary.pad(),
vocab_size=len(task.dictionary),
num_encoder_layers=getattr(
args, "encoder_layers", init_args.encoder_layers
),
embedding_dim=init_args.encoder_embed_dim,
ffn_embedding_dim=init_args.encoder_ffn_embed_dim,
num_attention_heads=init_args.encoder_attention_heads,
dropout=init_args.dropout,
attention_dropout=init_args.attention_dropout,
activation_dropout=init_args.activation_dropout,
num_segments=2, # add language embeddings
max_seq_len=num_positional_emb,
offset_positions_by_padding=False,
encoder_normalize_before=True,
apply_bert_init=True,
activation_fn=init_args.activation_fn,
freeze_embeddings=args.freeze_embeddings,
n_trans_layers_to_freeze=args.n_trans_layers_to_freeze,
)
# still need to learn segment embeddings as we added a second language embedding
if args.freeze_embeddings:
for p in self.model.segment_embeddings.parameters():
p.requires_grad = False
update_init_roberta_model_state(in_state_dict)
print("loading weights from the pretrained model")
self.model.load_state_dict(
in_state_dict, strict=False
) # ignore mismatch in language embeddings
ffn_embedding_dim = init_args.encoder_ffn_embed_dim
num_attention_heads = init_args.encoder_attention_heads
dropout = init_args.dropout
attention_dropout = init_args.attention_dropout
activation_dropout = init_args.activation_dropout
activation_fn = init_args.activation_fn
classifier_embed_dim = getattr(
args, "embed_dim", init_args.encoder_embed_dim
)
if classifier_embed_dim != init_args.encoder_embed_dim:
self.transform_layer = nn.Linear(
init_args.encoder_embed_dim, classifier_embed_dim
)
else:
self.model = TransformerSentenceEncoder(
padding_idx=task.dictionary.pad(),
vocab_size=len(task.dictionary),
num_encoder_layers=args.encoder_layers,
embedding_dim=args.embed_dim,
ffn_embedding_dim=args.ffn_embed_dim,
num_attention_heads=args.attention_heads,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
max_seq_len=task.max_positions()
if task.max_positions()
else args.tokens_per_sample,
num_segments=2,
offset_positions_by_padding=False,
encoder_normalize_before=args.encoder_normalize_before,
apply_bert_init=args.apply_bert_init,
activation_fn=args.activation_fn,
)
classifier_embed_dim = args.embed_dim
ffn_embedding_dim = args.ffn_embed_dim
num_attention_heads = args.attention_heads
dropout = args.dropout
attention_dropout = args.attention_dropout
activation_dropout = args.activation_dropout
activation_fn = args.activation_fn
self.joint_classification = args.joint_classification
if args.joint_classification == "sent":
if args.joint_normalize_before:
self.joint_layer_norm = LayerNorm(classifier_embed_dim)
else:
self.joint_layer_norm = None
self.joint_layers = nn.ModuleList(
[
TransformerSentenceEncoderLayer(
embedding_dim=classifier_embed_dim,
ffn_embedding_dim=ffn_embedding_dim,
num_attention_heads=num_attention_heads,
dropout=dropout,
attention_dropout=attention_dropout,
activation_dropout=activation_dropout,
activation_fn=activation_fn,
)
for _ in range(args.num_joint_layers)
]
)
self.classifier = RobertaClassificationHead(
classifier_embed_dim,
classifier_embed_dim,
1, # num_classes
"tanh",
args.classifier_dropout,
)
def forward(self, src_tokens, src_lengths):
segment_labels = self.get_segment_labels(src_tokens)
positions = self.get_positions(src_tokens, segment_labels)
inner_states, _ = self.model(
tokens=src_tokens,
segment_labels=segment_labels,
last_state_only=True,
positions=positions,
)
return inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C
def sentence_forward(self, encoder_out, src_tokens=None, sentence_rep="head"):
# encoder_out: B x T x C
if sentence_rep == "head":
x = encoder_out[:, :1, :]
else: # 'meanpool', 'maxpool'
assert src_tokens is not None, "meanpool requires src_tokens input"
segment_labels = self.get_segment_labels(src_tokens)
padding_mask = src_tokens.ne(self.padding_idx)
encoder_mask = segment_labels * padding_mask.type_as(segment_labels)
if sentence_rep == "meanpool":
ntokens = torch.sum(encoder_mask, dim=1, keepdim=True)
x = torch.sum(
encoder_out * encoder_mask.unsqueeze(2), dim=1, keepdim=True
) / ntokens.unsqueeze(2).type_as(encoder_out)
else: # 'maxpool'
encoder_out[
(encoder_mask == 0).unsqueeze(2).repeat(1, 1, encoder_out.shape[-1])
] = -float("inf")
x, _ = torch.max(encoder_out, dim=1, keepdim=True)
if hasattr(self, "transform_layer"):
x = self.transform_layer(x)
return x # B x 1 x C
def joint_forward(self, x):
# x: T x B x C
if self.joint_layer_norm:
x = self.joint_layer_norm(x.transpose(0, 1))
x = x.transpose(0, 1)
for layer in self.joint_layers:
x, _ = layer(x, self_attn_padding_mask=None)
return x
def classification_forward(self, x):
# x: B x T x C
return self.classifier(x)
@dataclass
class DiscriminativeNMTRerankerConfig(FairseqDataclass):
pretrained_model: str = field(
default="", metadata={"help": "pretrained model to load"}
)
sentence_rep: SENTENCE_REP_CHOICES = field(
default="head",
metadata={
"help": "method to transform the output of the transformer stack to a sentence-level representation"
},
)
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
attention_dropout: float = field(
default=0.0, metadata={"help": "dropout probability for attention weights"}
)
activation_dropout: float = field(
default=0.0, metadata={"help": "dropout probability after activation in FFN"}
)
classifier_dropout: float = field(
default=0.0, metadata={"help": "classifier dropout probability"}
)
embed_dim: int = field(default=768, metadata={"help": "embedding dimension"})
ffn_embed_dim: int = field(
default=2048, metadata={"help": "embedding dimension for FFN"}
)
encoder_layers: int = field(default=12, metadata={"help": "num encoder layers"})
attention_heads: int = field(default=8, metadata={"help": "num attention heads"})
encoder_normalize_before: bool = field(
default=False, metadata={"help": "apply layernorm before each encoder block"}
)
apply_bert_init: bool = field(
default=False, metadata={"help": "use custom param initialization for BERT"}
)
activation_fn: ACTIVATION_FN_CHOICES = field(
default="relu", metadata={"help": "activation function to use"}
)
freeze_embeddings: bool = field(
default=False, metadata={"help": "freeze embeddings in the pretrained model"}
)
n_trans_layers_to_freeze: int = field(
default=0,
metadata={
"help": "number of layers to freeze in the pretrained transformer model"
},
)
# joint classfication
joint_classification: JOINT_CLASSIFICATION_CHOICES = field(
default="none",
metadata={"help": "method to compute joint features for classification"},
)
num_joint_layers: int = field(
default=1, metadata={"help": "number of joint layers"}
)
joint_normalize_before: bool = field(
default=False,
metadata={"help": "apply layer norm on the input to the joint layer"},
)
@register_model(
"discriminative_nmt_reranker", dataclass=DiscriminativeNMTRerankerConfig
)
class DiscriminativeNMTReranker(BaseFairseqModel):
@classmethod
def build_model(cls, args, task):
model = BertRanker(args, task)
return DiscriminativeNMTReranker(args, model)
def __init__(self, args, model):
super().__init__()
self.model = model
self.sentence_rep = args.sentence_rep
self.joint_classification = args.joint_classification
def forward(self, src_tokens, src_lengths, **kwargs):
return self.model(src_tokens, src_lengths)
def sentence_forward(self, encoder_out, src_tokens):
return self.model.sentence_forward(encoder_out, src_tokens, self.sentence_rep)
def joint_forward(self, x):
return self.model.joint_forward(x)
def classification_forward(self, x):
return self.model.classification_forward(x)