Spaces:
Runtime error
Runtime error
File size: 4,913 Bytes
aadb779 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import pipeline
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np
from typing import List, Tuple
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from torch.utils.data import Dataset
from pathlib import Path
import json
from numpy.typing import NDArray
class BertClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, seed=42, epochs=5, device="cpu"):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
self.seed = seed
self.epochs = epochs
self.model = None
self.labels = None
self.device=device
def _get_classes(self, y: List[str]) -> Tuple[NDArray, List[str]]:
labels = sorted(set(y))
ids = [i for i in range(len(labels))]
return ids, labels
def _compute_metrics(self,eval_pairs):
logits, labels = eval_pairs
n = 3
ordered_choices = (-logits).argsort(-1)[:,:n]
metrics = {}
metrics["top_n_accuracy"] = np.mean([label in choices for label, choices in zip(labels, ordered_choices)])
metrics["accuracy"] = np.mean(labels == ordered_choices[:,0])
return metrics
def load_weights(self, path:str):
self.model = AutoModelForSequenceClassification.from_pretrained(
path).to(self.device)
self.labels = list(self.model.config.label2id.keys())
def _tokenize(self, texts:List[str]) -> torch.Tensor:
return self.tokenizer(texts, padding=True,
truncation=True,
max_length=100,
return_tensors="pt").to(self.device)
def fit(self, X:List[str], y:List[str]):
ids, labels = self._get_classes(y)
self.labels = labels
id2label = dict(zip(ids,labels))
label2id = dict(zip(labels,ids))
X = self._tokenize(X)
dataset = [{"input_ids": text, "label": label2id[label]} for text, label in zip(X["input_ids"],y)]
train_ds, test_ds = train_test_split(dataset, shuffle=True, random_state=self.seed, train_size=0.85)
batch_size = 64
model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=len(labels), id2label=id2label, label2id=label2id
).to(self.device)
weights_path="weights/bert_classifier"
training_args = TrainingArguments(
output_dir=weights_path,
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=self.epochs,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=False,
use_mps_device=self.device=="mps"
)
class_weights = torch.Tensor()
trainer = WeightedTrainer(
class_ids=ids,
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=test_ds,
tokenizer=self.tokenizer,
compute_metrics=self._compute_metrics
)
trainer.train()
model.eval()
self.model = model
def predict_proba(self, X:List[str]) -> NDArray:
if self.model is None:
raise Exception("Fit the model before inference.")
tokens = self._tokenize(X)
with torch.no_grad():
logits = self.model(**tokens).logits
return F.softmax(logits, -1).cpu().numpy()
def predict(self, X:List[str])-> List[str]:
preds = self.predict_proba(X)
return [self.labels[i] for i in preds.argmax(-1)]
class WeightedTrainer(Trainer):
def __init__(self,class_ids, train_dataset, *args, **kwargs):
super().__init__(train_dataset=train_dataset, *args,**kwargs)
y_train = [y["label"] for y in train_dataset]
class_weights = compute_class_weight("balanced", classes=class_ids, y=y_train).astype("float32")
class_weights = torch.from_numpy(class_weights).to(self.args.device.type)
self.criteria = nn.CrossEntropyLoss(weight=class_weights)
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.get("labels")
# forward pass
outputs = model(**inputs)
logits = outputs.get("logits")
loss = self.criteria(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
|