falcon-finetuning-demo-10-epochs / modeling_custom.py
almersawi's picture
Upload custom model code
44d9aba verified
raw
history blame
2.81 kB
"""This file contains the custom code needed to make the causal_classification models
compatible with huggingface Auto classes.
tokenizer = AutoTokenizer.from_pretrained("my/repo")
model = AutoModelForSequenceClassification.from_pretrained("my/repo", trust_remote_code=True)
classifier = pipeline("text-classification", "my/repo", trust_remote_code=True)
"""
from typing import Optional, Union
from transformers import PreTrainedModel, AutoModelForCausalLM, PretrainedConfig
import torch
import os
def load_head(
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
config: PretrainedConfig,
device,
) -> torch.nn.Linear:
head = torch.nn.Linear(config.vocab_size, config.num_labels, bias=False)
classification_head = os.path.join(
pretrained_model_name_or_path, "classification_head.pth"
)
head.weight.data = torch.load(classification_head, map_location=device)
return head
class CustomModelForSequenceClassification(PreTrainedModel):
# Suppress the warning "Some weights were not initialized...You should probably TRAIN this model..."
_keys_to_ignore_on_load_missing = ["model.*", "head.*"]
def __init__(self, config, backbone: torch.nn.Module, head: torch.nn.Linear):
super().__init__(config)
self.model_backbone = backbone
self.head = head
def forward(self, **kwargs):
r = self.model_backbone(**kwargs).logits
out_last = r[:, -1].float()
logits = self.head(out_last)
return {"logits": logits}
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs,
):
model_backbone: torch.nn.Module = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
trust_remote_code=True,
**kwargs,
)
device = next(model_backbone.parameters()).device
head = load_head(pretrained_model_name_or_path, config, device=device)
return cls(config, model_backbone, head)