import torch from pytorch_utils.utils import get_sequence_mask from model.document_encoder.base_encoder import BaseDocEncoder from omegaconf import DictConfig from typing import Dict, List from torch import Tensor class IndependentDocEncoder(BaseDocEncoder): def __init__(self, config: DictConfig): super(IndependentDocEncoder, self).__init__(config) def forward(self, document: Dict) -> Tensor: doc_tens = document["tensorized_sent"] if isinstance(doc_tens, list): doc_tens = torch.tensor(doc_tens, device=self.device) else: doc_tens = doc_tens.to(self.device) sent_len_list: List[int] = document["sent_len_list"] num_chunks = len(sent_len_list) if num_chunks == 1: attn_mask = None else: attn_mask = get_sequence_mask( torch.tensor(sent_len_list, device=self.device) ) if not self.config.finetune: with torch.no_grad(): outputs = self.lm_encoder( doc_tens, attention_mask=attn_mask ) # C x L x E else: outputs = self.lm_encoder(doc_tens, attention_mask=attn_mask) # C x L x E encoded_repr = outputs[0] unpadded_encoded_output = [] for idx, sent_len in enumerate(sent_len_list): unpadded_encoded_output.append(encoded_repr[idx, 1 : sent_len + 1, :]) encoded_output = torch.cat( unpadded_encoded_output, dim=0 ) # .float() LLaMA :) Not sure how much space this eats up return encoded_output