File size: 2,857 Bytes
6304a81 da7535b 6304a81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import numpy as np
from typing import List
class BertEmbedder:
def __init__(self, model_path:str, cut_head:bool=False):
"""
cut_head = True if the model have classifier head
"""
self.embedder = BertForSequenceClassification.from_pretrained(model_path)
self.max_length = self.embedder.config.max_position_embeddings
self.tokenizer = AutoTokenizer.from_pretrained(model_path, max_length=self.max_length)
if cut_head:
self.embedder = self.embedder.bert
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.embedder.to(self.device)
def __call__(self, text: str):
encoded_input = self.tokenizer(text,
return_tensors='pt',
max_length=self.max_length,
padding=True,
truncation=True).to(self.device)
model_output = self.embedder(**encoded_input)
text_embed = model_output.pooler_output[0].cpu()
return text_embed.tolist()
def batch_predict(self, texts: List[str]):
encoded_input = self.tokenizer(texts,
return_tensors='pt',
max_length=self.max_length,
padding=True,
truncation=True).to(self.device)
model_output = self.embedder(**encoded_input)
texts_embeds = model_output.pooler_output.cpu()
return texts_embeds
class PredictModel:
def __init__(self, embedder, classifier, batch_size=8):
self.batch_size = batch_size
self.embedder = embedder
self.classifier = classifier
def _texts2vecs(self, texts, log=False):
embeds = []
batches_texts = np.array_split(texts, len(texts) // self.batch_size)
if log:
iterator = tqdm(batches_texts)
else:
iterator = batches_texts
for batch_texts in iterator:
batch_texts = batch_texts.tolist()
embeds += self.embedder.batch_predict(batch_texts).tolist()
embeds = np.array(embeds)
return embeds
def fit(self, texts: List[str], labels: List[str], log: bool=False):
if log:
print('Start text2vec transform')
embeds = self._texts2vecs(texts, log)
if log:
print('Start classifier fitting')
self.classifier.fit(embeds, labels)
def predict(self, texts: List[str], log: bool=False):
if log:
print('Start text2vec transform')
embeds = self._texts2vecs(texts, log)
if log:
print('Start classifier prediction')
prediction = self.classifier.predict(embeds)
return prediction |