File size: 9,150 Bytes
98e2ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7f75f6
98e2ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import torch
from model.memory import BaseMemory
from pytorch_utils.modules import MLP
import torch.nn as nn

from omegaconf import DictConfig
from typing import Dict, Tuple, List
from torch import Tensor
from tqdm import tqdm
import math


class EntityMemory(BaseMemory):
    """Module for clustering proposed mention spans using Entity-Ranking paradigm."""

    def __init__(
        self, config: DictConfig, span_emb_size: int, drop_module: nn.Module
    ) -> None:
        super(EntityMemory, self).__init__(config, span_emb_size, drop_module)
        self.mem_type: DictConfig = config.mem_type

    def forward_training(
        self,
        ment_boundaries: Tensor,
        mention_emb_list: List[Tensor],
        rep_emb_list: List[Tensor],
        gt_actions: List[Tuple[int, str]],
        metadata: Dict,
    ) -> List[Tensor]:
        """
        Forward pass during coreference model training where we use teacher-forcing.

        Args:
                ment_boundaries: Mention boundaries of proposed mentions
                mention_emb_list: Embedding list of proposed mentions
                gt_actions: Ground truth clustering actions
                metadata: Metadata such as document genre

        Returns:
                coref_new_list: Logit scores for ground truth actions.
        """
        assert (
            len(rep_emb_list) != 0
        ), "There are no entity representations, should not happen."

        # Initialize memory
        coref_new_list = []

        mem_vectors, mem_vectors_init, ent_counter, last_mention_start = (
            self.initialize_memory(rep=rep_emb_list)
        )

        for ment_idx, (ment_emb, (gt_cell_idx, gt_action_str)) in enumerate(
            zip(mention_emb_list, gt_actions)
        ):

            ment_start, ment_end = ment_boundaries[ment_idx]

            if self.config.num_feats != 0:
                feature_embs = self.get_feature_embs(
                    ment_start, last_mention_start, ent_counter, metadata
                )
            else:
                feature_embs = torch.empty(mem_vectors.shape[0], 0, device=self.device)

            coref_new_scores = self.get_coref_new_scores(
                ment_emb, mem_vectors, mem_vectors_init, ent_counter, feature_embs
            )

            coref_new_list.append(coref_new_scores)

            # Teacher forcing
            action_str, cell_idx = gt_action_str, gt_cell_idx

            num_ents: int = int(torch.sum((ent_counter > 0).long()).item())
            cell_mask: Tensor = (
                torch.arange(start=0, end=num_ents, device=self.device)
                == torch.tensor(cell_idx)
            ).float()

            mask = torch.unsqueeze(cell_mask, dim=1)
            mask = mask.repeat(1, self.mem_size)

            ## Update memory if action is cluster and memory is not static
            if action_str == "c" and self.config.type != "static":
                coref_vec = self.coref_update(
                    ment_emb, mem_vectors, cell_idx, ent_counter
                )
                mem_vectors = mem_vectors * (1 - mask) + mask * coref_vec
                ent_counter[cell_idx] = ent_counter[cell_idx] + 1
                last_mention_start[cell_idx] = ment_start

        return coref_new_list

    def forward(
        self,
        ment_boundaries: Tensor,
        mention_emb_list: List[Tensor],
        rep_emb_list: List[Tensor],
        gt_actions: List[Tuple[int, str]],
        metadata: Dict,
        teacher_force: False,
        memory_init=None,
    ):
        """Forward pass for clustering entity mentions during inference/evaluation.

        Args:
         ment_boundaries: Start and end token indices for the proposed mentions.
         mention_emb_list: Embedding list of proposed mentions
         metadata: Metadata features such as document genre embedding
         memory_init: Initializer for memory. For streaming coreference, we can pass the previous
                  memory state via this dictionary

        Returns:
                pred_actions: List of predicted clustering actions.
                mem_state: Current memory state.
        """

        ## Check length of mention_emb_list == gt_action
        assert len(mention_emb_list) == len(gt_actions)

        # Initialize memory
        if memory_init is not None:
            mem_vectors, mem_vectors_init, ent_counter, last_mention_start = (
                self.initialize_memory(**memory_init, rep=rep_emb_list)
            )
        else:
            mem_vectors, mem_vectors_init, ent_counter, last_mention_start = (
                self.initialize_memory(rep=rep_emb_list)
            )

        pred_actions = []  # argmax actions
        coref_scores_list = []

        ## Tensorized approach for static method
        if self.config.type == "static":
            batch_size = self.config.batch_size
            ### Mention Emb list gets batched in batch size
            num_batches = len(mention_emb_list) // batch_size + int(
                len(mention_emb_list) % batch_size != 0
            )
            for i in range(num_batches):
                print("Batch Number: ", i)
                start_idx = i * batch_size
                end_idx = min((i + 1) * batch_size, len(mention_emb_list))

                num_elements = end_idx - start_idx

                if ent_counter.size() == 0:
                    next_cell_idx, next_action_str = 0, "o"
                    pred_actions.extend(
                        [(next_cell_idx, next_action_str) * num_elements]
                    )
                    continue

                ment_emb_tensor = torch.stack(
                    mention_emb_list[start_idx:end_idx], dim=0
                )
                ment_start, ment_end = (
                    ment_boundaries[start_idx:end_idx, 0],
                    ment_boundaries[start_idx:end_idx, 1],
                )
                if self.config.num_feats != 0:
                    feature_embs = self.get_feature_embs_tensorized(
                        ment_start, last_mention_start, ent_counter, metadata
                    )  ## [B,D,20]
                else:
                    feature_embs = torch.empty(
                        ment_start.shape[0], mem_vectors.shape[0], 0, device=self.device
                    )  ## [B,D,20]
                coref_new_scores = self.get_coref_new_scores_tensorized(
                    ment_emb_tensor,
                    mem_vectors,
                    mem_vectors_init,
                    ent_counter,
                    feature_embs,
                )
                coref_copy = coref_new_scores.clone().detach().cpu()
                coref_scores_list.extend(coref_copy)
                assigned_cluster = self.assign_cluster_tensorized(coref_new_scores)
                gt_actions_batch = gt_actions[start_idx:end_idx]
                if teacher_force:
                    pred_actions.extend(gt_actions_batch)
                else:
                    pred_actions.extend(assigned_cluster)

        else:
            for ment_idx, ment_emb in enumerate(mention_emb_list):

                if ent_counter.size() == 0:
                    next_cell_idx, next_action_str = 0, "o"
                    pred_actions.append((next_cell_idx, next_action_str))
                    continue

                ment_start, ment_end = ment_boundaries[ment_idx]

                if self.config.num_feats != 0:
                    feature_embs = self.get_feature_embs(
                        ment_start, last_mention_start, ent_counter, metadata
                    )
                else:
                    feature_embs = torch.empty(
                        mem_vectors.shape[0], 0, device=self.device
                    )

                coref_new_scores = self.get_coref_new_scores(
                    ment_emb, mem_vectors, mem_vectors_init, ent_counter, feature_embs
                )
                coref_copy = coref_new_scores.clone().detach().cpu()
                coref_scores_list.append(coref_copy)
                pred_cell_idx, pred_action_str = self.assign_cluster(coref_new_scores)

                if teacher_force:
                    next_cell_idx, next_action_str = gt_actions[ment_idx]
                    pred_actions.append(gt_actions[ment_idx])
                else:
                    next_cell_idx, next_action_str = pred_cell_idx, pred_action_str
                    pred_actions.append((pred_cell_idx, pred_action_str))

                if next_action_str == "c":
                    coref_vec = self.coref_update(
                        ment_emb, mem_vectors, next_cell_idx, ent_counter
                    )
                    mem_vectors[next_cell_idx] = coref_vec
                    ent_counter[next_cell_idx] = ent_counter[next_cell_idx] + 1
                    last_mention_start[next_cell_idx] = ment_start

        mem_state = {
            "mem": mem_vectors,
            "mem_init": mem_vectors_init,
            "ent_counter": ent_counter,
            "last_mention_start": last_mention_start,
        }
        return pred_actions, mem_state, coref_scores_list