Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
from transformers import ( | |
AutoModel, | |
AutoTokenizer, | |
PreTrainedTokenizerFast, | |
PreTrainedModel, | |
) | |
import torch | |
from omegaconf import DictConfig | |
from typing import Dict | |
import transformers | |
transformers.logging.set_verbosity_error() | |
class BaseDocEncoder(nn.Module): | |
def __init__(self, config: DictConfig): | |
super(BaseDocEncoder, self).__init__() | |
self.config = config | |
gradient_checkpointing = False | |
if config.finetune: | |
gradient_checkpointing = True | |
model_str: str = config.transformer.model_str | |
self.lm_encoder: PreTrainedModel = AutoModel.from_pretrained( | |
pretrained_model_name_or_path=model_str, | |
output_hidden_states=False, | |
add_pooling_layer=False, ## Comment it out for LLAMA | |
) | |
if gradient_checkpointing: | |
self.lm_encoder.gradient_checkpointing_enable() ####### This is the line that we need to put in the code when we enable. | |
self.tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained( | |
pretrained_model_name_or_path=model_str, | |
use_fast=True, | |
clean_up_tokenization_spaces=True, | |
) | |
if config.add_speaker_tokens: | |
self.tokenizer.add_special_tokens( | |
{ | |
"additional_special_tokens": [ | |
config.speaker_start, | |
config.speaker_end, | |
] | |
} | |
) | |
self.lm_encoder.resize_token_embeddings(len(self.tokenizer)) | |
if not config.finetune: | |
for param in self.lm_encoder.parameters(): | |
# Don't update encoder params | |
param.requires_grad = False | |
self.hidden_size: int = self.lm_encoder.config.hidden_size | |
def device(self) -> torch.device: | |
return next(self.lm_encoder.parameters()).device | |
def get_tokenizer(self) -> PreTrainedTokenizerFast: | |
return self.tokenizer | |
def to_add_speaker_tokens(self) -> bool: | |
return self.add_speaker_tokens | |
def forward(self, document: Dict): | |
raise NotImplementedError | |