"""This file contains the custom code needed to make the causal_classification models compatible with huggingface Auto classes. tokenizer = AutoTokenizer.from_pretrained("my/repo") model = AutoModelForSequenceClassification.from_pretrained("my/repo", trust_remote_code=True) classifier = pipeline("text-classification", "my/repo", trust_remote_code=True) """ from typing import Optional, Union from transformers import PreTrainedModel, AutoModelForCausalLM, PretrainedConfig import torch import os def load_head( pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], config: PretrainedConfig, device, ) -> torch.nn.Linear: head = torch.nn.Linear(config.vocab_size, config.num_labels, bias=False) classification_head = os.path.join( pretrained_model_name_or_path, "classification_head.pth" ) head.weight.data = torch.load(classification_head, map_location=device) return head class CustomModelForSequenceClassification(PreTrainedModel): # Suppress the warning "Some weights were not initialized...You should probably TRAIN this model..." _keys_to_ignore_on_load_missing = ["model.*", "head.*"] def __init__(self, config, backbone: torch.nn.Module, head: torch.nn.Linear): super().__init__(config) self.model_backbone = backbone self.head = head def forward(self, **kwargs): r = self.model_backbone(**kwargs).logits out_last = r[:, -1].float() logits = self.head(out_last) return {"logits": logits} @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None, ignore_mismatched_sizes: bool = False, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = "main", use_safetensors: bool = None, **kwargs, ): model_backbone: torch.nn.Module = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path, *model_args, config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes, force_download=force_download, local_files_only=local_files_only, token=token, revision=revision, use_safetensors=use_safetensors, trust_remote_code=True, **kwargs, ) device = next(model_backbone.parameters()).device head = load_head(pretrained_model_name_or_path, config, device=device) return cls(config, model_backbone, head)