nehalelkaref commited on
Commit
1fd614c
·
1 Parent(s): 572b22d

Upload layers.py

Browse files
Files changed (1) hide show
  1. layers.py +42 -0
layers.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .utils import enumerate_spans
6
+ from torch.nn.utils.rnn import pad_sequence
7
+
8
+ class SpanEnumerationLayer(nn.Module):
9
+ def __init__(self, *args, **kwargs) -> None:
10
+ super().__init__(*args, **kwargs)
11
+
12
+
13
+ def compute_embeddings(self,embeddings, enumerations, operation = 'sum'):
14
+
15
+ computed_embeddings = []
16
+
17
+ for enumeration, embedding in zip(enumerations, embeddings):
18
+
19
+ output_embeddings = []
20
+ dim_size = embedding.shape[-1]
21
+
22
+ for idx in range(len(enumeration)):
23
+
24
+ x1,x2 = enumeration[idx]
25
+ output_tensor = embedding[x1:x2+1].sum(dim=0)
26
+
27
+ if(operation == 'average'):
28
+ divisor = abs((x2+1)-x1)
29
+ output_tensor=torch.div(output_tensor, divisor)
30
+
31
+ output_embeddings.append(output_tensor)
32
+ computed_embeddings.append(torch.stack(output_embeddings))
33
+
34
+ return computed_embeddings
35
+
36
+ def forward(self, embeddings, lengths):
37
+
38
+ enumerations = enumerate_spans(lengths)
39
+ computed_embeddings = self.compute_embeddings(embeddings, enumerations)
40
+
41
+ return computed_embeddings, enumerations
42
+