Spaces:
Sleeping
Sleeping
File size: 1,616 Bytes
98e2ea5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import torch
from pytorch_utils.utils import get_sequence_mask
from model.document_encoder.base_encoder import BaseDocEncoder
from omegaconf import DictConfig
from typing import Dict, List
from torch import Tensor
class IndependentDocEncoder(BaseDocEncoder):
def __init__(self, config: DictConfig):
super(IndependentDocEncoder, self).__init__(config)
def forward(self, document: Dict) -> Tensor:
doc_tens = document["tensorized_sent"]
if isinstance(doc_tens, list):
doc_tens = torch.tensor(doc_tens, device=self.device)
else:
doc_tens = doc_tens.to(self.device)
sent_len_list: List[int] = document["sent_len_list"]
num_chunks = len(sent_len_list)
if num_chunks == 1:
attn_mask = None
else:
attn_mask = get_sequence_mask(
torch.tensor(sent_len_list, device=self.device)
)
if not self.config.finetune:
with torch.no_grad():
outputs = self.lm_encoder(
doc_tens, attention_mask=attn_mask
) # C x L x E
else:
outputs = self.lm_encoder(doc_tens, attention_mask=attn_mask) # C x L x E
encoded_repr = outputs[0]
unpadded_encoded_output = []
for idx, sent_len in enumerate(sent_len_list):
unpadded_encoded_output.append(encoded_repr[idx, 1 : sent_len + 1, :])
encoded_output = torch.cat(
unpadded_encoded_output, dim=0
) # .float() LLaMA :) Not sure how much space this eats up
return encoded_output
|