Spaces:
Build error
Build error
File size: 1,255 Bytes
1fd614c aaeb391 1fd614c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import enumerate_spans
from torch.nn.utils.rnn import pad_sequence
class SpanEnumerationLayer(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def compute_embeddings(self,embeddings, enumerations, operation = 'sum'):
computed_embeddings = []
for enumeration, embedding in zip(enumerations, embeddings):
output_embeddings = []
dim_size = embedding.shape[-1]
for idx in range(len(enumeration)):
x1,x2 = enumeration[idx]
output_tensor = embedding[x1:x2+1].sum(dim=0)
if(operation == 'average'):
divisor = abs((x2+1)-x1)
output_tensor=torch.div(output_tensor, divisor)
output_embeddings.append(output_tensor)
computed_embeddings.append(torch.stack(output_embeddings))
return computed_embeddings
def forward(self, embeddings, lengths):
enumerations = enumerate_spans(lengths)
computed_embeddings = self.compute_embeddings(embeddings, enumerations)
return computed_embeddings, enumerations
|