|
"""This file contains the custom code needed to make the causal_classification models |
|
compatible with huggingface Auto classes. |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("my/repo") |
|
model = AutoModelForSequenceClassification.from_pretrained("my/repo", trust_remote_code=True) |
|
|
|
classifier = pipeline("text-classification", "my/repo", trust_remote_code=True) |
|
""" |
|
|
|
from typing import Optional, Union |
|
|
|
from transformers import PreTrainedModel, AutoModelForCausalLM, PretrainedConfig |
|
import torch |
|
|
|
import os |
|
|
|
|
|
def load_head( |
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], |
|
config: PretrainedConfig, |
|
device, |
|
) -> torch.nn.Linear: |
|
head = torch.nn.Linear(config.vocab_size, config.num_labels, bias=False) |
|
classification_head = os.path.join( |
|
pretrained_model_name_or_path, "classification_head.pth" |
|
) |
|
head.weight.data = torch.load(classification_head, map_location=device) |
|
return head |
|
|
|
|
|
class CustomModelForSequenceClassification(PreTrainedModel): |
|
|
|
_keys_to_ignore_on_load_missing = ["model.*", "head.*"] |
|
|
|
def __init__(self, config, backbone: torch.nn.Module, head: torch.nn.Linear): |
|
super().__init__(config) |
|
self.model_backbone = backbone |
|
self.head = head |
|
|
|
def forward(self, **kwargs): |
|
r = self.model_backbone(**kwargs).logits |
|
out_last = r[:, -1].float() |
|
logits = self.head(out_last) |
|
return {"logits": logits} |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], |
|
*model_args, |
|
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, |
|
cache_dir: Optional[Union[str, os.PathLike]] = None, |
|
ignore_mismatched_sizes: bool = False, |
|
force_download: bool = False, |
|
local_files_only: bool = False, |
|
token: Optional[Union[str, bool]] = None, |
|
revision: str = "main", |
|
use_safetensors: bool = None, |
|
**kwargs, |
|
): |
|
model_backbone: torch.nn.Module = AutoModelForCausalLM.from_pretrained( |
|
pretrained_model_name_or_path, |
|
*model_args, |
|
config=config, |
|
cache_dir=cache_dir, |
|
ignore_mismatched_sizes=ignore_mismatched_sizes, |
|
force_download=force_download, |
|
local_files_only=local_files_only, |
|
token=token, |
|
revision=revision, |
|
use_safetensors=use_safetensors, |
|
trust_remote_code=True, |
|
**kwargs, |
|
) |
|
device = next(model_backbone.parameters()).device |
|
head = load_head(pretrained_model_name_or_path, config, device=device) |
|
|
|
return cls(config, model_backbone, head) |
|
|