demo_icd10 / required_classes.py
lyangas
first demo code added
1bb21bd
import numpy as np
from typing import List
class BertEmbedder:
def __init__(self, model_path:str, cut_head:bool=False):
"""
cut_head = True if the model have classifier head
"""
self.embedder = BertForSequenceClassification.from_pretrained(model_path)
self.max_length = self.embedder.config.max_position_embeddings
self.tokenizer = AutoTokenizer.from_pretrained(model_path, max_length=self.max_length)
if cut_head:
self.embedder = self.embedder.bert
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.embedder.to(self.device)
def __call__(self, text: str):
encoded_input = self.tokenizer(text,
return_tensors='pt',
max_length=self.max_length,
padding=True,
truncation=True).to(self.device)
model_output = self.embedder(**encoded_input)
text_embed = model_output.pooler_output[0].cpu()
return text_embed
def batch_predict(self, texts: List[str]):
encoded_input = self.tokenizer(texts,
return_tensors='pt',
max_length=self.max_length,
padding=True,
truncation=True).to(self.device)
model_output = self.embedder(**encoded_input)
texts_embeds = model_output.pooler_output.cpu()
return texts_embeds
class PredictModel:
def __init__(self, embedder, classifier, batch_size=8):
self.batch_size = batch_size
self.embedder = embedder
self.classifier = classifier
def _texts2vecs(self, texts, log=False):
embeds = []
batches_texts = np.array_split(texts, len(texts) // self.batch_size)
if log:
iterator = tqdm(batches_texts)
else:
iterator = batches_texts
for batch_texts in iterator:
batch_texts = batch_texts.tolist()
embeds += self.embedder.batch_predict(batch_texts).tolist()
embeds = np.array(embeds)
return embeds
def fit(self, texts: List[str], labels: List[str], log: bool=False):
if log:
print('Start text2vec transform')
embeds = self._texts2vecs(texts, log)
if log:
print('Start classifier fitting')
self.classifier.fit(embeds, labels)
def predict(self, texts: List[str], log: bool=False):
if log:
print('Start text2vec transform')
embeds = self._texts2vecs(texts, log)
if log:
print('Start classifier prediction')
prediction = self.classifier.predict(embeds)
return prediction