Spaces:
Build error
Build error
import numpy as np | |
from typing import List | |
class BertEmbedder: | |
def __init__(self, model_path:str, cut_head:bool=False): | |
""" | |
cut_head = True if the model have classifier head | |
""" | |
self.embedder = BertForSequenceClassification.from_pretrained(model_path) | |
self.max_length = self.embedder.config.max_position_embeddings | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path, max_length=self.max_length) | |
if cut_head: | |
self.embedder = self.embedder.bert | |
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
self.embedder.to(self.device) | |
def __call__(self, text: str): | |
encoded_input = self.tokenizer(text, | |
return_tensors='pt', | |
max_length=self.max_length, | |
padding=True, | |
truncation=True).to(self.device) | |
model_output = self.embedder(**encoded_input) | |
text_embed = model_output.pooler_output[0].cpu() | |
return text_embed | |
def batch_predict(self, texts: List[str]): | |
encoded_input = self.tokenizer(texts, | |
return_tensors='pt', | |
max_length=self.max_length, | |
padding=True, | |
truncation=True).to(self.device) | |
model_output = self.embedder(**encoded_input) | |
texts_embeds = model_output.pooler_output.cpu() | |
return texts_embeds | |
class PredictModel: | |
def __init__(self, embedder, classifier, batch_size=8): | |
self.batch_size = batch_size | |
self.embedder = embedder | |
self.classifier = classifier | |
def _texts2vecs(self, texts, log=False): | |
embeds = [] | |
batches_texts = np.array_split(texts, len(texts) // self.batch_size) | |
if log: | |
iterator = tqdm(batches_texts) | |
else: | |
iterator = batches_texts | |
for batch_texts in iterator: | |
batch_texts = batch_texts.tolist() | |
embeds += self.embedder.batch_predict(batch_texts).tolist() | |
embeds = np.array(embeds) | |
return embeds | |
def fit(self, texts: List[str], labels: List[str], log: bool=False): | |
if log: | |
print('Start text2vec transform') | |
embeds = self._texts2vecs(texts, log) | |
if log: | |
print('Start classifier fitting') | |
self.classifier.fit(embeds, labels) | |
def predict(self, texts: List[str], log: bool=False): | |
if log: | |
print('Start text2vec transform') | |
embeds = self._texts2vecs(texts, log) | |
if log: | |
print('Start classifier prediction') | |
prediction = self.classifier.predict(embeds) | |
return prediction |