Spaces:
Sleeping
Sleeping
import torch | |
from pytorch_utils.utils import get_sequence_mask | |
from model.document_encoder.base_encoder import BaseDocEncoder | |
from omegaconf import DictConfig | |
from typing import Dict, List | |
from torch import Tensor | |
class IndependentDocEncoder(BaseDocEncoder): | |
def __init__(self, config: DictConfig): | |
super(IndependentDocEncoder, self).__init__(config) | |
def forward(self, document: Dict) -> Tensor: | |
doc_tens = document["tensorized_sent"] | |
if isinstance(doc_tens, list): | |
doc_tens = torch.tensor(doc_tens, device=self.device) | |
else: | |
doc_tens = doc_tens.to(self.device) | |
sent_len_list: List[int] = document["sent_len_list"] | |
num_chunks = len(sent_len_list) | |
if num_chunks == 1: | |
attn_mask = None | |
else: | |
attn_mask = get_sequence_mask( | |
torch.tensor(sent_len_list, device=self.device) | |
) | |
if not self.config.finetune: | |
with torch.no_grad(): | |
outputs = self.lm_encoder( | |
doc_tens, attention_mask=attn_mask | |
) # C x L x E | |
else: | |
outputs = self.lm_encoder(doc_tens, attention_mask=attn_mask) # C x L x E | |
encoded_repr = outputs[0] | |
unpadded_encoded_output = [] | |
for idx, sent_len in enumerate(sent_len_list): | |
unpadded_encoded_output.append(encoded_repr[idx, 1 : sent_len + 1, :]) | |
encoded_output = torch.cat( | |
unpadded_encoded_output, dim=0 | |
) # .float() LLaMA :) Not sure how much space this eats up | |
return encoded_output | |