Spaces:
Build error
Build error
File size: 1,159 Bytes
6680682 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
from typing import *
import torch
from allennlp.modules.span_extractors import SpanExtractor
@SpanExtractor.register('combo')
class ComboSpanExtractor(SpanExtractor):
def __init__(self, input_dim: int, sub_extractors: List[SpanExtractor]):
super().__init__()
self.sub_extractors = sub_extractors
for i, sub in enumerate(sub_extractors):
self.add_module(f'SpanExtractor-{i+1}', sub)
self.input_dim = input_dim
def get_input_dim(self) -> int:
return self.input_dim
def get_output_dim(self) -> int:
return sum([sub.get_output_dim() for sub in self.sub_extractors])
def forward(
self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
sequence_mask: torch.BoolTensor = None,
span_indices_mask: torch.BoolTensor = None,
):
outputs = [
sub(
sequence_tensor=sequence_tensor,
span_indices=span_indices,
span_indices_mask=span_indices_mask
) for sub in self.sub_extractors
]
return torch.cat(outputs, dim=2)
|