import torch import torch.nn as nn from model.document_encoder import IndependentDocEncoder from pytorch_utils.modules import MLP import torch.nn.functional as F from model.mention_proposal.utils import sort_mentions from typing import List, Dict, Tuple from torch import Tensor class MentionProposalModule(nn.Module): """Module to propose candidate mention spans. This module performs the first two steps of the coreference pipeline. (1) Encode Document (2) Score candidate spans and filter through the high-scoring ones. """ def __init__(self, config, train_config, drop_module=None): super(MentionProposalModule, self).__init__() self.config = config self.train_config = train_config self.drop_module = drop_module # Encoder self.doc_encoder = IndependentDocEncoder(config.doc_encoder) self._build_model(hidden_size=self.doc_encoder.hidden_size) self.loss_fn = nn.BCEWithLogitsLoss(reduction="sum") @property def device(self) -> torch.device: return next(self.doc_encoder.parameters()).device def _build_model(self, hidden_size: int) -> None: mention_params = self.config.mention_params self.span_width_embeddings = nn.Embedding( mention_params.max_span_width, mention_params.emb_size ) self.span_width_prior_embeddings = nn.Embedding( mention_params.max_span_width, mention_params.emb_size ) ment_emb_type = mention_params.ment_emb ment_emb_to_size_factor = mention_params.ment_emb_to_size_factor[ment_emb_type] if ment_emb_type == "attn": self.mention_attn = nn.Linear(hidden_size, 1).to(self.device) self.span_emb_size = ( ment_emb_to_size_factor * hidden_size + mention_params.emb_size ) self.mention_mlp = MLP( input_size=self.span_emb_size, hidden_size=mention_params.mlp_size, output_size=1, bias=True, drop_module=self.drop_module, num_hidden_layers=mention_params.mlp_depth, ) self.span_width_mlp = MLP( input_size=mention_params.emb_size, hidden_size=mention_params.mlp_size, output_size=1, num_hidden_layers=mention_params.mlp_depth, bias=True, drop_module=self.drop_module, ) def get_span_embeddings( self, encoded_doc: Tensor, ment_starts: Tensor, ment_ends: Tensor ) -> Tensor: """Span embedding for the candidate mentions given the end points. Args: encoded_doc (Tensor): T x d where T is the number of tokens ment_starts (Tensor): C where C is the number of candidate spans proposed. Represents the starting token idx of proposed mentions. ment_ends (Tensor): The endpoint equivalent of ment_starts Returns: span_embs (Tensor): C x d' where d represents the span embedding dimensionality. where d' is typically a multiple of d + some constant (width emebddding). """ span_emb_list = [encoded_doc[ment_starts, :], encoded_doc[ment_ends, :]] # Add span width embeddings span_width_indices = torch.clamp( ment_ends - ment_starts, max=self.config.mention_params.max_span_width - 1 ) span_width_embs = self.drop_module( self.span_width_embeddings(span_width_indices) ) span_emb_list.append(span_width_embs) if self.config.mention_params.ment_emb == "attn": num_words = encoded_doc.shape[0] # num_tokens (T) num_c = ment_starts.shape[0] # num_candidates (C) doc_range = torch.unsqueeze( torch.arange(num_words, device=self.device), 0 ).repeat( num_c, 1 ) # [C x T] ment_masks = (doc_range >= torch.unsqueeze(ment_starts, dim=1)) & ( doc_range <= torch.unsqueeze(ment_ends, dim=1) ) # [C x T] word_attn = torch.squeeze(self.mention_attn(encoded_doc), dim=1) # [T] mention_word_attn = nn.functional.softmax( (1 - ment_masks.float()) * (-1e10) + torch.unsqueeze(word_attn, dim=0), dim=1, ) # [C x T] attention_term = torch.matmul(mention_word_attn, encoded_doc) # K x H span_emb_list.append(attention_term) span_embs = torch.cat(span_emb_list, dim=1) return span_embs def get_mention_width_scores( self, cand_starts: Tensor, cand_ends: Tensor ) -> Tensor: """Scores for candidate mention based solely on their length. This prior score is necessary because most mention spans tend to be shorter in width. """ span_width_idx = torch.clamp( cand_ends - cand_starts, max=self.config.mention_params.max_span_width - 1 ) span_width_embs = self.span_width_prior_embeddings(span_width_idx) width_scores = torch.squeeze(self.span_width_mlp(span_width_embs), dim=-1) return width_scores def get_flat_gold_mentions( self, clusters: List, num_tokens: int, flat_cand_mask: Tensor ) -> Tensor: """Represent the gold mentions in a binary flattened tensor. This flat representation of gold mentions is useful for calculating the mention prediction loss. Note that we filter out gold mentions longer than the max_span_width. """ gold_ments = torch.zeros( num_tokens, self.config.mention_params.max_span_width, device=self.device ) for cluster in clusters: for mention in cluster: span_start, span_end = mention[:2] span_width = span_end - span_start + 1 if span_width <= self.config.mention_params.max_span_width: span_width_idx = span_width - 1 gold_ments[span_start, span_width_idx] = 1 filt_gold_ments = gold_ments.reshape(-1)[flat_cand_mask].float() return filt_gold_ments def get_candidate_endpoints( self, encoded_doc: Tensor, document: Dict ) -> Tuple[Tensor, Tensor, Tensor]: """Propose the candidate endpoints given the max span width constraints. This method proposes the candidate spans while filtering out spans that cross sentence boundaries. This method could also use a constraint on not starting or ending in the middle of a word. """ num_words: int = encoded_doc.shape[0] sent_map: Tensor = document["sentence_map"].to(self.device) # num_words x max_span_width cand_starts = torch.unsqueeze( torch.arange(num_words, device=self.device), dim=1 ).repeat(1, self.config.mention_params.max_span_width) cand_ends = cand_starts + torch.unsqueeze( torch.arange(self.config.mention_params.max_span_width, device=self.device), dim=0, ) cand_start_sent_indices: Tensor = sent_map[cand_starts] # Avoid getting sentence indices for cand_ends >= num_words corr_cand_ends: Tensor = torch.min( cand_ends, torch.ones_like(cand_ends, device=self.device) * (num_words - 1) ) cand_end_sent_indices: Tensor = sent_map[corr_cand_ends] # End before document ends & same sentence constraint1: Tensor = cand_ends < num_words constraint2: Tensor = cand_start_sent_indices == cand_end_sent_indices # Follows word_boundary # Padding the subtoken_map because it will be useful for end of span check. subtoken_map: Tensor = torch.tensor( document["subtoken_map"] + [-1] * (self.config.mention_params.max_span_width + 1), device=self.device, ) # Check that the word corresponding to the previous subword is not the same at span start constraint3 = subtoken_map[cand_starts] != subtoken_map[cand_starts - 1] # Check that the word corresponding to the next subword is not the same at span end constraint4 = subtoken_map[cand_ends] != subtoken_map[cand_ends + 1] cand_mask: Tensor = constraint1 & constraint2 & constraint3 & constraint4 flat_cand_mask = cand_mask.reshape(-1) # Filter and flatten the candidate end points filt_cand_starts = cand_starts.reshape(-1)[flat_cand_mask] # (num_candidates,) filt_cand_ends = cand_ends.reshape(-1)[flat_cand_mask] # (num_candidates,) return filt_cand_starts, filt_cand_ends, flat_cand_mask def pred_mentions( self, document: Dict, encoded_doc: Tensor, eval_loss=False, ment_threshold=0.0 ) -> Dict: """ Predict mentions for the encoded document. Args: document: Dictionary with the processed document attributes encoded_doc: Encoded document outputted by the document encoder. ment_threshold: Score threshold beyond which mention spans are filtered through. Returns: output_dict: Output dictionary with endpoints of proposed mentions, scores, and loss. """ mention_params = self.config.mention_params num_tokens = encoded_doc.shape[0] num_words = document["subtoken_map"][-1] - document["subtoken_map"][0] + 1 cand_starts, cand_ends, cand_mask = self.get_candidate_endpoints( encoded_doc, document ) span_embs = self.get_span_embeddings(encoded_doc, cand_starts, cand_ends) mention_logits = torch.squeeze(self.mention_mlp(span_embs), dim=-1) mention_logits += self.get_mention_width_scores(cand_starts, cand_ends) del span_embs # Span embeddings not required anymore output_dict = {} if self.training or eval_loss: k = int(mention_params.top_span_ratio * num_words) topk_indices = torch.topk(mention_logits, k)[1] filt_gold_mentions = self.get_flat_gold_mentions( document["clusters"], num_tokens, cand_mask ) if self.train_config.ment_loss_mode == "all": mention_loss = self.loss_fn(mention_logits, filt_gold_mentions) else: mention_loss = self.loss_fn( mention_logits[topk_indices], filt_gold_mentions[topk_indices] ) if not mention_params.use_topk: mentions_proposed = mention_logits >= ment_threshold # Calculate accuracy correct = (mentions_proposed == filt_gold_mentions).sum().item() total = filt_gold_mentions.size(0) # Calculate true positives, predicted positives, and precision true_positives = ( ((mentions_proposed == 1) & (filt_gold_mentions == 1)).sum().item() ) predicted_positives = (mentions_proposed == 1).sum().item() # Calculate true positives, actual positives, and recall actual_positives = (filt_gold_mentions == 1).sum().item() output_dict["ment_correct"] = correct output_dict["ment_total"] = total output_dict["ment_tp"] = true_positives output_dict["ment_pp"] = predicted_positives output_dict["ment_ap"] = actual_positives # Add mention loss to output output_dict["ment_loss"] = mention_loss ignore_non_gold = mention_params.get("ignore_non_gold", True) if not mention_params.use_topk and ignore_non_gold: # Ignore invalid mentions even during training topk_indices = topk_indices[ torch.nonzero(filt_gold_mentions[topk_indices], as_tuple=True)[0] ] elif not ignore_non_gold: # print("Not ignoring non-gold mentions. Adding an additional 'check'. If an invalid mention it should be mapped to others") topk_indices = torch.squeeze( (mention_logits >= ment_threshold).nonzero(as_tuple=False), dim=1 ) else: if mention_params.use_topk: k = int(mention_params.top_span_ratio * num_words) topk_indices = torch.topk(mention_logits, k)[1] else: topk_indices = torch.squeeze( (mention_logits >= ment_threshold).nonzero(as_tuple=False), dim=1 ) topk_starts = cand_starts[topk_indices] topk_ends = cand_ends[topk_indices] topk_scores = mention_logits[topk_indices] ( output_dict["ment_starts"], output_dict["ment_ends"], sorted_indices, ) = sort_mentions(topk_starts, topk_ends, return_sorted_indices=True) output_dict["ment_scores"] = topk_scores[sorted_indices] return output_dict def transform_gold_mentions(self, document: Dict) -> Dict: """Transform gold mentions to a format similar to predicted mentions. This method is useful for running ablation experiments where we experiment with using the gold mentions i.e. skipping any errors of the mention proposal module. """ mentions = [] # print(document) for cluster in document["clusters"]: for ment_start, ment_end in cluster: mentions.append((ment_start, ment_end)) if len(mentions): topk_starts, topk_ends = zip(*mentions) else: raise ValueError topk_starts = torch.tensor(topk_starts, device=self.device) topk_ends = torch.tensor(topk_ends, device=self.device) topk_starts, topk_ends = sort_mentions(topk_starts, topk_ends) output_dict = { "ment_starts": topk_starts, "ment_ends": topk_ends, # Fake mention score "ment_scores": torch.tensor([1.0] * len(mentions), device=self.device), } return output_dict def get_specific_reps(self, document: Dict) -> List: pass def forward(self, document: Dict, eval_loss=False, gold_mentions=False) -> Dict: """Given the document return proposed mentions and their embeddings.""" encoded_doc: Tensor = self.doc_encoder(document) # .float() LLAMA if self.config.mention_params.use_gold_ments or gold_mentions: # Process gold mentions to a format similar to mentions obtained after prediction output_dict: Dict = self.transform_gold_mentions(document) else: if len(document.get("ext_predicted_mentions", [])) != 0: output_dict = {} ment_starts, ment_ends = zip(*document["ext_predicted_mentions"]) output_dict["ment_starts"] = torch.tensor( ment_starts, device=self.device ) output_dict["ment_ends"] = torch.tensor(ment_ends, device=self.device) output_dict["ment_scores"] = torch.tensor( [1.0] * len(ment_starts), device=self.device ) else: # print("Predicting mentions") output_dict = self.pred_mentions(document, encoded_doc, eval_loss) pred_starts: Tensor = output_dict["ment_starts"] pred_ends: Tensor = output_dict["ment_ends"] # Stack the starts and ends to get the mention tuple output_dict["ments"] = torch.stack((pred_starts, pred_ends), dim=1) # Get mention embeddings mention_embs: Tensor = self.get_span_embeddings( encoded_doc, pred_starts, pred_ends ) ## Representative Processing Code if document["representatives"]: rep_start, rep_end = zip(*document["representatives"]) rep_embs = self.get_span_embeddings( encoded_doc, torch.tensor(rep_start, device=self.device), torch.tensor(rep_end, device=self.device), ) output_dict["rep_emb_list"] = torch.unbind(rep_embs, dim=0) else: output_dict["rep_emb_list"] = () output_dict["ment_emb_list"] = torch.unbind(mention_embs, dim=0) return output_dict