Spaces:
Sleeping
Sleeping
import torch | |
from model.memory import BaseMemory | |
from pytorch_utils.modules import MLP | |
import torch.nn as nn | |
from omegaconf import DictConfig | |
from typing import Dict, Tuple, List | |
from torch import Tensor | |
from tqdm import tqdm | |
import math | |
class EntityMemory(BaseMemory): | |
"""Module for clustering proposed mention spans using Entity-Ranking paradigm.""" | |
def __init__( | |
self, config: DictConfig, span_emb_size: int, drop_module: nn.Module | |
) -> None: | |
super(EntityMemory, self).__init__(config, span_emb_size, drop_module) | |
self.mem_type: DictConfig = config.mem_type | |
def forward_training( | |
self, | |
ment_boundaries: Tensor, | |
mention_emb_list: List[Tensor], | |
rep_emb_list: List[Tensor], | |
gt_actions: List[Tuple[int, str]], | |
metadata: Dict, | |
) -> List[Tensor]: | |
""" | |
Forward pass during coreference model training where we use teacher-forcing. | |
Args: | |
ment_boundaries: Mention boundaries of proposed mentions | |
mention_emb_list: Embedding list of proposed mentions | |
gt_actions: Ground truth clustering actions | |
metadata: Metadata such as document genre | |
Returns: | |
coref_new_list: Logit scores for ground truth actions. | |
""" | |
assert ( | |
len(rep_emb_list) != 0 | |
), "There are no entity representations, should not happen." | |
# Initialize memory | |
coref_new_list = [] | |
mem_vectors, mem_vectors_init, ent_counter, last_mention_start = ( | |
self.initialize_memory(rep=rep_emb_list) | |
) | |
for ment_idx, (ment_emb, (gt_cell_idx, gt_action_str)) in enumerate( | |
zip(mention_emb_list, gt_actions) | |
): | |
ment_start, ment_end = ment_boundaries[ment_idx] | |
if self.config.num_feats != 0: | |
feature_embs = self.get_feature_embs( | |
ment_start, last_mention_start, ent_counter, metadata | |
) | |
else: | |
feature_embs = torch.empty(mem_vectors.shape[0], 0, device=self.device) | |
coref_new_scores = self.get_coref_new_scores( | |
ment_emb, mem_vectors, mem_vectors_init, ent_counter, feature_embs | |
) | |
coref_new_list.append(coref_new_scores) | |
# Teacher forcing | |
action_str, cell_idx = gt_action_str, gt_cell_idx | |
num_ents: int = int(torch.sum((ent_counter > 0).long()).item()) | |
cell_mask: Tensor = ( | |
torch.arange(start=0, end=num_ents, device=self.device) | |
== torch.tensor(cell_idx) | |
).float() | |
mask = torch.unsqueeze(cell_mask, dim=1) | |
mask = mask.repeat(1, self.mem_size) | |
## Update memory if action is cluster and memory is not static | |
if action_str == "c" and self.config.type != "static": | |
coref_vec = self.coref_update( | |
ment_emb, mem_vectors, cell_idx, ent_counter | |
) | |
mem_vectors = mem_vectors * (1 - mask) + mask * coref_vec | |
ent_counter[cell_idx] = ent_counter[cell_idx] + 1 | |
last_mention_start[cell_idx] = ment_start | |
return coref_new_list | |
def forward( | |
self, | |
ment_boundaries: Tensor, | |
mention_emb_list: List[Tensor], | |
rep_emb_list: List[Tensor], | |
gt_actions: List[Tuple[int, str]], | |
metadata: Dict, | |
teacher_force: False, | |
memory_init=None, | |
): | |
"""Forward pass for clustering entity mentions during inference/evaluation. | |
Args: | |
ment_boundaries: Start and end token indices for the proposed mentions. | |
mention_emb_list: Embedding list of proposed mentions | |
metadata: Metadata features such as document genre embedding | |
memory_init: Initializer for memory. For streaming coreference, we can pass the previous | |
memory state via this dictionary | |
Returns: | |
pred_actions: List of predicted clustering actions. | |
mem_state: Current memory state. | |
""" | |
## Check length of mention_emb_list == gt_action | |
assert len(mention_emb_list) == len(gt_actions) | |
# Initialize memory | |
if memory_init is not None: | |
mem_vectors, mem_vectors_init, ent_counter, last_mention_start = ( | |
self.initialize_memory(**memory_init, rep=rep_emb_list) | |
) | |
else: | |
mem_vectors, mem_vectors_init, ent_counter, last_mention_start = ( | |
self.initialize_memory(rep=rep_emb_list) | |
) | |
pred_actions = [] # argmax actions | |
coref_scores_list = [] | |
## Tensorized approach for static method | |
if self.config.type == "static": | |
batch_size = self.config.batch_size | |
### Mention Emb list gets batched in batch size | |
num_batches = len(mention_emb_list) // batch_size + int( | |
len(mention_emb_list) % batch_size != 0 | |
) | |
for i in range(num_batches): | |
print("Batch Number: ", i) | |
start_idx = i * batch_size | |
end_idx = min((i + 1) * batch_size, len(mention_emb_list)) | |
num_elements = end_idx - start_idx | |
if ent_counter.size() == 0: | |
next_cell_idx, next_action_str = 0, "o" | |
pred_actions.extend( | |
[(next_cell_idx, next_action_str) * num_elements] | |
) | |
continue | |
ment_emb_tensor = torch.stack( | |
mention_emb_list[start_idx:end_idx], dim=0 | |
) | |
ment_start, ment_end = ( | |
ment_boundaries[start_idx:end_idx, 0], | |
ment_boundaries[start_idx:end_idx, 1], | |
) | |
if self.config.num_feats != 0: | |
feature_embs = self.get_feature_embs_tensorized( | |
ment_start, last_mention_start, ent_counter, metadata | |
) ## [B,D,20] | |
else: | |
feature_embs = torch.empty( | |
ment_start.shape[0], mem_vectors.shape[0], 0, device=self.device | |
) ## [B,D,20] | |
coref_new_scores = self.get_coref_new_scores_tensorized( | |
ment_emb_tensor, | |
mem_vectors, | |
mem_vectors_init, | |
ent_counter, | |
feature_embs, | |
) | |
coref_copy = coref_new_scores.clone().detach().cpu() | |
coref_scores_list.extend(coref_copy) | |
assigned_cluster = self.assign_cluster_tensorized(coref_new_scores) | |
gt_actions_batch = gt_actions[start_idx:end_idx] | |
if teacher_force: | |
pred_actions.extend(gt_actions_batch) | |
else: | |
pred_actions.extend(assigned_cluster) | |
else: | |
for ment_idx, ment_emb in enumerate(mention_emb_list): | |
if ent_counter.size() == 0: | |
next_cell_idx, next_action_str = 0, "o" | |
pred_actions.append((next_cell_idx, next_action_str)) | |
continue | |
ment_start, ment_end = ment_boundaries[ment_idx] | |
if self.config.num_feats != 0: | |
feature_embs = self.get_feature_embs( | |
ment_start, last_mention_start, ent_counter, metadata | |
) | |
else: | |
feature_embs = torch.empty( | |
mem_vectors.shape[0], 0, device=self.device | |
) | |
coref_new_scores = self.get_coref_new_scores( | |
ment_emb, mem_vectors, mem_vectors_init, ent_counter, feature_embs | |
) | |
coref_copy = coref_new_scores.clone().detach().cpu() | |
coref_scores_list.append(coref_copy) | |
pred_cell_idx, pred_action_str = self.assign_cluster(coref_new_scores) | |
if teacher_force: | |
next_cell_idx, next_action_str = gt_actions[ment_idx] | |
pred_actions.append(gt_actions[ment_idx]) | |
else: | |
next_cell_idx, next_action_str = pred_cell_idx, pred_action_str | |
pred_actions.append((pred_cell_idx, pred_action_str)) | |
if next_action_str == "c": | |
coref_vec = self.coref_update( | |
ment_emb, mem_vectors, next_cell_idx, ent_counter | |
) | |
mem_vectors[next_cell_idx] = coref_vec | |
ent_counter[next_cell_idx] = ent_counter[next_cell_idx] + 1 | |
last_mention_start[next_cell_idx] = ment_start | |
mem_state = { | |
"mem": mem_vectors, | |
"mem_init": mem_vectors_init, | |
"ent_counter": ent_counter, | |
"last_mention_start": last_mention_start, | |
} | |
return pred_actions, mem_state, coref_scores_list | |