File size: 2,805 Bytes
44d9aba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
"""This file contains the custom code needed to make the causal_classification models
compatible with huggingface Auto classes.
tokenizer = AutoTokenizer.from_pretrained("my/repo")
model = AutoModelForSequenceClassification.from_pretrained("my/repo", trust_remote_code=True)
classifier = pipeline("text-classification", "my/repo", trust_remote_code=True)
"""
from typing import Optional, Union
from transformers import PreTrainedModel, AutoModelForCausalLM, PretrainedConfig
import torch
import os
def load_head(
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
config: PretrainedConfig,
device,
) -> torch.nn.Linear:
head = torch.nn.Linear(config.vocab_size, config.num_labels, bias=False)
classification_head = os.path.join(
pretrained_model_name_or_path, "classification_head.pth"
)
head.weight.data = torch.load(classification_head, map_location=device)
return head
class CustomModelForSequenceClassification(PreTrainedModel):
# Suppress the warning "Some weights were not initialized...You should probably TRAIN this model..."
_keys_to_ignore_on_load_missing = ["model.*", "head.*"]
def __init__(self, config, backbone: torch.nn.Module, head: torch.nn.Linear):
super().__init__(config)
self.model_backbone = backbone
self.head = head
def forward(self, **kwargs):
r = self.model_backbone(**kwargs).logits
out_last = r[:, -1].float()
logits = self.head(out_last)
return {"logits": logits}
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs,
):
model_backbone: torch.nn.Module = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
trust_remote_code=True,
**kwargs,
)
device = next(model_backbone.parameters()).device
head = load_head(pretrained_model_name_or_path, config, device=device)
return cls(config, model_backbone, head)
|