Spaces:
Runtime error
Runtime error
from typing import Tuple | |
import string | |
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification | |
import spacy | |
import torch | |
class NER: | |
prompt: str = """ | |
Identify entities in the text having the following classes: | |
{} | |
Text: | |
""" | |
def __init__(self, model_name: str, sents_batch: int=10): | |
self.sents_batch = sents_batch | |
self.nlp: spacy.Language = spacy.load( | |
'en_core_web_sm', | |
disable = ['lemmatizer', 'parser', 'tagger', 'ner'] | |
) | |
self.nlp.add_pipe('sentencizer') | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForTokenClassification.from_pretrained(model_name) | |
self.pipeline = pipeline( | |
"ner", | |
model=model, | |
tokenizer=tokenizer, | |
aggregation_strategy='first', | |
batch_size=12, | |
device=device | |
) | |
def get_last_sentence_id(self, i: int, sentences_len: int) -> int: | |
return min(i + self.sents_batch, sentences_len) - 1 | |
def chunkanize(self, text: str) -> Tuple[list[str], list[int]]: | |
doc = self.nlp(text) | |
chunks = [] | |
starts = [] | |
sentences = list(doc.sents) | |
for i in range(0, len(sentences), self.sents_batch): | |
start = sentences[i].start_char | |
starts.append(start) | |
last_sentence = self.get_last_sentence_id(i, len(sentences)) | |
end = sentences[last_sentence].end_char | |
chunks.append(text[start:end]) | |
return chunks, starts | |
def get_inputs( | |
self, chunks: list[str], labels: list[str] | |
) -> Tuple[list[str], list[int]]: | |
inputs = [] | |
prompts_lens = [] | |
for label in labels: | |
prompt = self.prompt.format(label) | |
prompts_lens.append(len(prompt)) | |
for chunk in chunks: | |
inputs.append(prompt + chunk) | |
return inputs, prompts_lens | |
def clean_span( | |
cls, start: int, end: int, span: str | |
) -> Tuple[int, int, str]: | |
if len(span) >= 1: | |
if span[0] in string.punctuation: | |
return cls.clean_span(start+1, end, span[1:]) | |
if span[-1] in string.punctuation: | |
return cls.clean_span(start, end-1, span[:-1]) | |
return start, end, span.strip() | |
def predict( | |
self, | |
text: str, | |
inputs: list[str], | |
labels: list[str], | |
chunks_starts: list[int], | |
prompts_lens: list[int], | |
threshold: float | |
) -> list[dict[str, any]]: | |
outputs = [] | |
for id, output in enumerate(self.pipeline(inputs)): | |
label = labels[id//len(chunks_starts)] | |
shift = chunks_starts[id%len(chunks_starts)] - prompts_lens[id//len(chunks_starts)] | |
for ent in output: | |
start = ent['start'] + shift + 1 | |
end = ent['end'] + shift | |
start, end, span = self.clean_span(start, end, text[start:end]) | |
if not span: | |
continue | |
if ent['score'] >= threshold: | |
outputs.append({ | |
'span': span, | |
'start': start, | |
'end': end, | |
'entity': label | |
}) | |
return outputs | |
def __call__( | |
self, labels: str, text: str, threshold: float=0. | |
) -> dict[str, any]: | |
labels_list = [label.strip() for label in labels.split(',')] | |
chunks, chunks_starts = self.chunkanize(text) | |
inputs, prompts_lens = self.get_inputs(chunks, labels_list) | |
outputs = self.predict( | |
text, inputs, labels_list, chunks_starts, prompts_lens, threshold | |
) | |
print(outputs) | |
return {"text": text, "entities": outputs} |