File size: 2,194 Bytes
98e2ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch.nn as nn
from transformers import (
    AutoModel,
    AutoTokenizer,
    PreTrainedTokenizerFast,
    PreTrainedModel,
)
import torch

from omegaconf import DictConfig
from typing import Dict
import transformers

transformers.logging.set_verbosity_error()


class BaseDocEncoder(nn.Module):

    def __init__(self, config: DictConfig):
        super(BaseDocEncoder, self).__init__()
        self.config = config

        gradient_checkpointing = False
        if config.finetune:
            gradient_checkpointing = True

        model_str: str = config.transformer.model_str

        self.lm_encoder: PreTrainedModel = AutoModel.from_pretrained(
            pretrained_model_name_or_path=model_str,
            output_hidden_states=False,
            add_pooling_layer=False,  ## Comment it out for LLAMA
        )

        if gradient_checkpointing:
            self.lm_encoder.gradient_checkpointing_enable()  ####### This is the line that we need to put in the code when we enable.

        self.tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path=model_str,
            use_fast=True,
            clean_up_tokenization_spaces=True,
        )
        if config.add_speaker_tokens:
            self.tokenizer.add_special_tokens(
                {
                    "additional_special_tokens": [
                        config.speaker_start,
                        config.speaker_end,
                    ]
                }
            )

            self.lm_encoder.resize_token_embeddings(len(self.tokenizer))

        if not config.finetune:
            for param in self.lm_encoder.parameters():
                # Don't update encoder params
                param.requires_grad = False

        self.hidden_size: int = self.lm_encoder.config.hidden_size

    @property
    def device(self) -> torch.device:
        return next(self.lm_encoder.parameters()).device

    def get_tokenizer(self) -> PreTrainedTokenizerFast:
        return self.tokenizer

    def to_add_speaker_tokens(self) -> bool:
        return self.add_speaker_tokens

    def forward(self, document: Dict):
        raise NotImplementedError