File size: 1,159 Bytes
6680682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from typing import *

import torch
from allennlp.modules.span_extractors import SpanExtractor


@SpanExtractor.register('combo')
class ComboSpanExtractor(SpanExtractor):
    def __init__(self, input_dim: int, sub_extractors: List[SpanExtractor]):
        super().__init__()
        self.sub_extractors = sub_extractors
        for i, sub in enumerate(sub_extractors):
            self.add_module(f'SpanExtractor-{i+1}', sub)
        self.input_dim = input_dim

    def get_input_dim(self) -> int:
        return self.input_dim

    def get_output_dim(self) -> int:
        return sum([sub.get_output_dim() for sub in self.sub_extractors])

    def forward(
            self,
            sequence_tensor: torch.FloatTensor,
            span_indices: torch.LongTensor,
            sequence_mask: torch.BoolTensor = None,
            span_indices_mask: torch.BoolTensor = None,
    ):
        outputs = [
            sub(
                sequence_tensor=sequence_tensor,
                span_indices=span_indices,
                span_indices_mask=span_indices_mask
            ) for sub in self.sub_extractors
        ]
        return torch.cat(outputs, dim=2)