import torch import torch.nn as nn import torch.nn.functional as F from transformers import pipeline from sklearn.base import BaseEstimator, ClassifierMixin import numpy as np from typing import List, Tuple from sklearn.model_selection import train_test_split from sklearn.utils.class_weight import compute_class_weight from transformers import AutoTokenizer from transformers import DataCollatorWithPadding from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer from torch.utils.data import Dataset from pathlib import Path import json from numpy.typing import NDArray class BertClassifier(BaseEstimator, ClassifierMixin): def __init__(self, seed=42, epochs=5, device="cpu"): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") self.seed = seed self.epochs = epochs self.model = None self.labels = None self.device=device def _get_classes(self, y: List[str]) -> Tuple[NDArray, List[str]]: labels = sorted(set(y)) ids = [i for i in range(len(labels))] return ids, labels def _compute_metrics(self,eval_pairs): logits, labels = eval_pairs n = 3 ordered_choices = (-logits).argsort(-1)[:,:n] metrics = {} metrics["top_n_accuracy"] = np.mean([label in choices for label, choices in zip(labels, ordered_choices)]) metrics["accuracy"] = np.mean(labels == ordered_choices[:,0]) return metrics def load_weights(self, path:str): self.model = AutoModelForSequenceClassification.from_pretrained( path).to(self.device) self.labels = list(self.model.config.label2id.keys()) def _tokenize(self, texts:List[str]) -> torch.Tensor: return self.tokenizer(texts, padding=True, truncation=True, max_length=100, return_tensors="pt").to(self.device) def fit(self, X:List[str], y:List[str]): ids, labels = self._get_classes(y) self.labels = labels id2label = dict(zip(ids,labels)) label2id = dict(zip(labels,ids)) X = self._tokenize(X) dataset = [{"input_ids": text, "label": label2id[label]} for text, label in zip(X["input_ids"],y)] train_ds, test_ds = train_test_split(dataset, shuffle=True, random_state=self.seed, train_size=0.85) batch_size = 64 model = AutoModelForSequenceClassification.from_pretrained( "distilbert-base-uncased", num_labels=len(labels), id2label=id2label, label2id=label2id ).to(self.device) weights_path="weights/bert_classifier" training_args = TrainingArguments( output_dir=weights_path, learning_rate=2e-5, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, num_train_epochs=self.epochs, weight_decay=0.01, evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, push_to_hub=False, use_mps_device=self.device=="mps" ) class_weights = torch.Tensor() trainer = WeightedTrainer( class_ids=ids, model=model, args=training_args, train_dataset=train_ds, eval_dataset=test_ds, tokenizer=self.tokenizer, compute_metrics=self._compute_metrics ) trainer.train() model.eval() self.model = model def predict_proba(self, X:List[str]) -> NDArray: if self.model is None: raise Exception("Fit the model before inference.") tokens = self._tokenize(X) with torch.no_grad(): logits = self.model(**tokens).logits return F.softmax(logits, -1).cpu().numpy() def predict(self, X:List[str])-> List[str]: preds = self.predict_proba(X) return [self.labels[i] for i in preds.argmax(-1)] class WeightedTrainer(Trainer): def __init__(self,class_ids, train_dataset, *args, **kwargs): super().__init__(train_dataset=train_dataset, *args,**kwargs) y_train = [y["label"] for y in train_dataset] class_weights = compute_class_weight("balanced", classes=class_ids, y=y_train).astype("float32") class_weights = torch.from_numpy(class_weights).to(self.args.device.type) self.criteria = nn.CrossEntropyLoss(weight=class_weights) def compute_loss(self, model, inputs, return_outputs=False): labels = inputs.get("labels") # forward pass outputs = model(**inputs) logits = outputs.get("logits") loss = self.criteria(logits.view(-1, self.model.config.num_labels), labels.view(-1)) return (loss, outputs) if return_outputs else loss