File size: 2,805 Bytes
44d9aba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""This file contains the custom code needed to make the causal_classification models
compatible with huggingface Auto classes.

tokenizer = AutoTokenizer.from_pretrained("my/repo")
model = AutoModelForSequenceClassification.from_pretrained("my/repo", trust_remote_code=True)

classifier = pipeline("text-classification", "my/repo", trust_remote_code=True)
"""

from typing import Optional, Union

from transformers import PreTrainedModel, AutoModelForCausalLM, PretrainedConfig
import torch

import os


def load_head(
    pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
    config: PretrainedConfig,
    device,
) -> torch.nn.Linear:
    head = torch.nn.Linear(config.vocab_size, config.num_labels, bias=False)
    classification_head = os.path.join(
        pretrained_model_name_or_path, "classification_head.pth"
    )
    head.weight.data = torch.load(classification_head, map_location=device)
    return head


class CustomModelForSequenceClassification(PreTrainedModel):
    # Suppress the warning "Some weights were not initialized...You should probably TRAIN this model..."
    _keys_to_ignore_on_load_missing = ["model.*", "head.*"]

    def __init__(self, config, backbone: torch.nn.Module, head: torch.nn.Linear):
        super().__init__(config)
        self.model_backbone = backbone
        self.head = head

    def forward(self, **kwargs):
        r = self.model_backbone(**kwargs).logits
        out_last = r[:, -1].float()
        logits = self.head(out_last)
        return {"logits": logits}

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
        *model_args,
        config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
        cache_dir: Optional[Union[str, os.PathLike]] = None,
        ignore_mismatched_sizes: bool = False,
        force_download: bool = False,
        local_files_only: bool = False,
        token: Optional[Union[str, bool]] = None,
        revision: str = "main",
        use_safetensors: bool = None,
        **kwargs,
    ):
        model_backbone: torch.nn.Module = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path,
            *model_args,
            config=config,
            cache_dir=cache_dir,
            ignore_mismatched_sizes=ignore_mismatched_sizes,
            force_download=force_download,
            local_files_only=local_files_only,
            token=token,
            revision=revision,
            use_safetensors=use_safetensors,
            trust_remote_code=True,
            **kwargs,
        )
        device = next(model_backbone.parameters()).device
        head = load_head(pretrained_model_name_or_path, config, device=device)

        return cls(config, model_backbone, head)