File size: 4,913 Bytes
aadb779
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import pipeline
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np
from typing import List, Tuple
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from torch.utils.data import Dataset
from pathlib import Path
import json
from numpy.typing import NDArray

class BertClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, seed=42, epochs=5, device="cpu"):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
        self.seed = seed
        self.epochs = epochs
        self.model = None
        self.labels = None
        self.device=device

    def _get_classes(self, y: List[str]) -> Tuple[NDArray, List[str]]:
        labels = sorted(set(y))
        ids = [i for i in range(len(labels))]
        return ids, labels

    def _compute_metrics(self,eval_pairs):
            logits, labels = eval_pairs
            n = 3
            ordered_choices = (-logits).argsort(-1)[:,:n]
            metrics = {}
            metrics["top_n_accuracy"] = np.mean([label in choices for label, choices in zip(labels, ordered_choices)])
            metrics["accuracy"] = np.mean(labels == ordered_choices[:,0])
            return metrics



    def load_weights(self, path:str):
        self.model = AutoModelForSequenceClassification.from_pretrained(
            path).to(self.device)
        self.labels = list(self.model.config.label2id.keys())
        
    def _tokenize(self, texts:List[str]) -> torch.Tensor:
        return self.tokenizer(texts, padding=True,
            truncation=True,
            max_length=100,
            return_tensors="pt").to(self.device)


    
    def fit(self, X:List[str], y:List[str]):
        ids, labels = self._get_classes(y)
        self.labels = labels
        id2label = dict(zip(ids,labels))
        label2id = dict(zip(labels,ids))
        X = self._tokenize(X)
        dataset = [{"input_ids": text, "label": label2id[label]} for text, label in zip(X["input_ids"],y)]
        train_ds, test_ds = train_test_split(dataset, shuffle=True, random_state=self.seed, train_size=0.85)
        batch_size = 64

        model = AutoModelForSequenceClassification.from_pretrained(
            "distilbert-base-uncased", num_labels=len(labels), id2label=id2label, label2id=label2id
        ).to(self.device)
        weights_path="weights/bert_classifier"
        training_args = TrainingArguments(
            output_dir=weights_path,
            learning_rate=2e-5,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            num_train_epochs=self.epochs,
            weight_decay=0.01,
            evaluation_strategy="epoch",
            save_strategy="epoch",
            load_best_model_at_end=True,
            push_to_hub=False,
            use_mps_device=self.device=="mps"
        )

        class_weights = torch.Tensor()

        trainer = WeightedTrainer(
            class_ids=ids,
            model=model,
            args=training_args,
            train_dataset=train_ds,
            eval_dataset=test_ds,
            tokenizer=self.tokenizer,
            compute_metrics=self._compute_metrics
        )

        trainer.train()
        model.eval()
        self.model = model

        
    
    def predict_proba(self, X:List[str]) -> NDArray:
        if self.model is None:
            raise Exception("Fit the model before inference.")
        tokens = self._tokenize(X)
        with torch.no_grad():
            logits = self.model(**tokens).logits
            return F.softmax(logits, -1).cpu().numpy()
        

    def predict(self, X:List[str])-> List[str]:
        preds = self.predict_proba(X)
        return [self.labels[i] for i in preds.argmax(-1)]




class WeightedTrainer(Trainer):

    def __init__(self,class_ids, train_dataset, *args, **kwargs):
        super().__init__(train_dataset=train_dataset, *args,**kwargs)
        y_train = [y["label"] for y in train_dataset]
        class_weights = compute_class_weight("balanced", classes=class_ids, y=y_train).astype("float32")
        class_weights = torch.from_numpy(class_weights).to(self.args.device.type)
        self.criteria = nn.CrossEntropyLoss(weight=class_weights)
        
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss = self.criteria(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss