File size: 1,255 Bytes
1fd614c
 
 
 
aaeb391
1fd614c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import enumerate_spans
from torch.nn.utils.rnn import pad_sequence

class SpanEnumerationLayer(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)


    def compute_embeddings(self,embeddings, enumerations, operation = 'sum'):
    
      computed_embeddings = []

      for enumeration, embedding in zip(enumerations, embeddings):
          
          output_embeddings = []
          dim_size = embedding.shape[-1]
          
          for idx in range(len(enumeration)):
            
            x1,x2 = enumeration[idx]
            output_tensor = embedding[x1:x2+1].sum(dim=0)

            if(operation == 'average'):
              divisor = abs((x2+1)-x1)
              output_tensor=torch.div(output_tensor, divisor)  

            output_embeddings.append(output_tensor)
          computed_embeddings.append(torch.stack(output_embeddings))

      return computed_embeddings
              
    def forward(self, embeddings, lengths):
      
      enumerations = enumerate_spans(lengths)  
      computed_embeddings = self.compute_embeddings(embeddings, enumerations)

      return computed_embeddings, enumerations