File size: 2,848 Bytes
6304a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
from typing import List


class BertEmbedder:
    def __init__(self, model_path:str, cut_head:bool=False):
        """
            cut_head = True if the model have classifier head
        """
        self.embedder = BertForSequenceClassification.from_pretrained(model_path)
        self.max_length = self.embedder.config.max_position_embeddings
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, max_length=self.max_length)

        if cut_head:
            self.embedder = self.embedder.bert

        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.embedder.to(self.device)

    def __call__(self, text: str):
        encoded_input = self.tokenizer(text, 
                                       return_tensors='pt', 
                                       max_length=self.max_length,
                                       padding=True,
                                       truncation=True).to(self.device)
        model_output = self.embedder(**encoded_input)
        text_embed = model_output.pooler_output[0].cpu()
        return text_embed

    def batch_predict(self, texts: List[str]):
        encoded_input = self.tokenizer(texts, 
                                       return_tensors='pt', 
                                       max_length=self.max_length,
                                       padding=True,
                                       truncation=True).to(self.device)
        model_output = self.embedder(**encoded_input)
        texts_embeds = model_output.pooler_output.cpu()
        return texts_embeds

class PredictModel:
    def __init__(self, embedder, classifier, batch_size=8):
        self.batch_size = batch_size
        self.embedder = embedder
        self.classifier = classifier

    def _texts2vecs(self, texts, log=False):
        embeds = []
        batches_texts = np.array_split(texts, len(texts) // self.batch_size)
        if log:
            iterator = tqdm(batches_texts)
        else:
            iterator = batches_texts
        for batch_texts in iterator:
            batch_texts = batch_texts.tolist()
            embeds += self.embedder.batch_predict(batch_texts).tolist()
        embeds = np.array(embeds)
        return embeds

    def fit(self, texts: List[str], labels: List[str], log: bool=False):
        if log:
            print('Start text2vec transform')
        embeds = self._texts2vecs(texts, log)
        if log:
            print('Start classifier fitting')
        self.classifier.fit(embeds, labels)

    def predict(self, texts: List[str], log: bool=False):
        if log:
            print('Start text2vec transform')
        embeds = self._texts2vecs(texts, log)
        if log:
            print('Start classifier prediction')
        prediction = self.classifier.predict(embeds)
        return prediction