File size: 1,616 Bytes
98e2ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
from pytorch_utils.utils import get_sequence_mask
from model.document_encoder.base_encoder import BaseDocEncoder

from omegaconf import DictConfig
from typing import Dict, List
from torch import Tensor


class IndependentDocEncoder(BaseDocEncoder):
    def __init__(self, config: DictConfig):
        super(IndependentDocEncoder, self).__init__(config)

    def forward(self, document: Dict) -> Tensor:
        doc_tens = document["tensorized_sent"]
        if isinstance(doc_tens, list):
            doc_tens = torch.tensor(doc_tens, device=self.device)
        else:
            doc_tens = doc_tens.to(self.device)

        sent_len_list: List[int] = document["sent_len_list"]
        num_chunks = len(sent_len_list)
        if num_chunks == 1:
            attn_mask = None
        else:
            attn_mask = get_sequence_mask(
                torch.tensor(sent_len_list, device=self.device)
            )

        if not self.config.finetune:
            with torch.no_grad():
                outputs = self.lm_encoder(
                    doc_tens, attention_mask=attn_mask
                )  # C x L x E
        else:
            outputs = self.lm_encoder(doc_tens, attention_mask=attn_mask)  # C x L x E

        encoded_repr = outputs[0]

        unpadded_encoded_output = []
        for idx, sent_len in enumerate(sent_len_list):
            unpadded_encoded_output.append(encoded_repr[idx, 1 : sent_len + 1, :])

        encoded_output = torch.cat(
            unpadded_encoded_output, dim=0
        )  # .float() LLaMA :) Not sure how much space this eats up

        return encoded_output