nehalelkaref commited on
Commit
40d48bd
·
1 Parent(s): 7916f53

Delete layers.py

Browse files
Files changed (1) hide show
  1. layers.py +0 -42
layers.py DELETED
@@ -1,42 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from utils import enumerate_spans
6
- from torch.nn.utils.rnn import pad_sequence
7
-
8
- class SpanEnumerationLayer(nn.Module):
9
- def __init__(self, *args, **kwargs) -> None:
10
- super().__init__(*args, **kwargs)
11
-
12
-
13
- def compute_embeddings(self,embeddings, enumerations, operation = 'sum'):
14
-
15
- computed_embeddings = []
16
-
17
- for enumeration, embedding in zip(enumerations, embeddings):
18
-
19
- output_embeddings = []
20
- dim_size = embedding.shape[-1]
21
-
22
- for idx in range(len(enumeration)):
23
-
24
- x1,x2 = enumeration[idx]
25
- output_tensor = embedding[x1:x2+1].sum(dim=0)
26
-
27
- if(operation == 'average'):
28
- divisor = abs((x2+1)-x1)
29
- output_tensor=torch.div(output_tensor, divisor)
30
-
31
- output_embeddings.append(output_tensor)
32
- computed_embeddings.append(torch.stack(output_embeddings))
33
-
34
- return computed_embeddings
35
-
36
- def forward(self, embeddings, lengths):
37
-
38
- enumerations = enumerate_spans(lengths)
39
- computed_embeddings = self.compute_embeddings(embeddings, enumerations)
40
-
41
- return computed_embeddings, enumerations
42
-