import numpy as np from typing import List class BertEmbedder: def __init__(self, model_path:str, cut_head:bool=False): """ cut_head = True if the model have classifier head """ self.embedder = BertForSequenceClassification.from_pretrained(model_path) self.max_length = self.embedder.config.max_position_embeddings self.tokenizer = AutoTokenizer.from_pretrained(model_path, max_length=self.max_length) if cut_head: self.embedder = self.embedder.bert self.device = "cuda:0" if torch.cuda.is_available() else "cpu" self.embedder.to(self.device) def __call__(self, text: str): encoded_input = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, padding=True, truncation=True).to(self.device) model_output = self.embedder(**encoded_input) text_embed = model_output.pooler_output[0].cpu() return text_embed def batch_predict(self, texts: List[str]): encoded_input = self.tokenizer(texts, return_tensors='pt', max_length=self.max_length, padding=True, truncation=True).to(self.device) model_output = self.embedder(**encoded_input) texts_embeds = model_output.pooler_output.cpu() return texts_embeds class PredictModel: def __init__(self, embedder, classifier, batch_size=8): self.batch_size = batch_size self.embedder = embedder self.classifier = classifier def _texts2vecs(self, texts, log=False): embeds = [] batches_texts = np.array_split(texts, len(texts) // self.batch_size) if log: iterator = tqdm(batches_texts) else: iterator = batches_texts for batch_texts in iterator: batch_texts = batch_texts.tolist() embeds += self.embedder.batch_predict(batch_texts).tolist() embeds = np.array(embeds) return embeds def fit(self, texts: List[str], labels: List[str], log: bool=False): if log: print('Start text2vec transform') embeds = self._texts2vecs(texts, log) if log: print('Start classifier fitting') self.classifier.fit(embeds, labels) def predict(self, texts: List[str], log: bool=False): if log: print('Start text2vec transform') embeds = self._texts2vecs(texts, log) if log: print('Start classifier prediction') prediction = self.classifier.predict(embeds) return prediction