Spaces:
Runtime error
Runtime error
File size: 5,104 Bytes
10b2302 c1db962 8e19b14 c1db962 10b2302 c1db962 10b2302 8e19b14 10b2302 8e19b14 10b2302 8e19b14 10b2302 8e19b14 10b2302 c1db962 10b2302 c1db962 10b2302 c1db962 10b2302 c1db962 10b2302 c1db962 10b2302 c1db962 10b2302 8e19b14 10b2302 8e19b14 10b2302 8e19b14 10b2302 c1db962 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from typing import Tuple
import string
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
import spacy
import torch
import gradio as gr
class NER:
prompt: str = """
Identify entities in the text having the following classes:
{}
Text:
"""
def __init__(
self,
model_name: str,
sents_batch: int=10,
tokens_limit: int=2048
):
self.sents_batch = sents_batch
self.tokens_limit = tokens_limit
self.nlp: spacy.Language = spacy.load(
'en_core_web_sm',
disable = ['lemmatizer', 'parser', 'tagger', 'ner']
)
self.nlp.add_pipe('sentencizer')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
self.pipeline = pipeline(
"ner",
model=model,
tokenizer=self.tokenizer,
aggregation_strategy='first',
batch_size=12,
device=device
)
def get_last_sentence_id(self, i: int, sentences_len: int) -> int:
return min(i + self.sents_batch, sentences_len) - 1
def chunkanize(self, text: str) -> Tuple[list[str], list[int]]:
doc = self.nlp(text)
chunks = []
starts = []
sentences = list(doc.sents)
for i in range(0, len(sentences), self.sents_batch):
start = sentences[i].start_char
starts.append(start)
last_sentence = self.get_last_sentence_id(i, len(sentences))
end = sentences[last_sentence].end_char
chunks.append(text[start:end])
return chunks, starts
def get_inputs(
self, chunks: list[str], labels: list[str]
) -> Tuple[list[str], list[int]]:
inputs = []
prompts_lens = []
for label in labels:
prompt = self.prompt.format(label)
prompts_lens.append(len(prompt))
for chunk in chunks:
inputs.append(prompt + chunk)
return inputs, prompts_lens
@classmethod
def clean_span(
cls, start: int, end: int, span: str
) -> Tuple[int, int, str]:
if len(span) >= 1:
if span[0] in string.punctuation:
return cls.clean_span(start+1, end, span[1:])
if span[-1] in string.punctuation:
return cls.clean_span(start, end-1, span[:-1])
return start, end, span.strip()
def predict(
self,
text: str,
inputs: list[str],
labels: list[str],
chunks_starts: list[int],
prompts_lens: list[int],
threshold: float
) -> list[dict[str, any]]:
outputs = []
for id, output in enumerate(self.pipeline(inputs)):
label = labels[id//len(chunks_starts)]
shift = chunks_starts[id%len(chunks_starts)] - prompts_lens[id//len(chunks_starts)]
for ent in output:
start = ent['start'] + shift + 1
end = ent['end'] + shift
start, end, span = self.clean_span(start, end, text[start:end])
if not span:
continue
if ent['score'] >= threshold:
outputs.append({
'span': span,
'start': start,
'end': end,
'entity': label
})
return outputs
def check_text(self, text: str) -> None:
if not text:
raise gr.Error('No text provided. Please provide text.')
def check_labels(self, labels: list[str]) -> None:
if not labels:
raise gr.Error(
'No labels provided. Please provide labels.'
' Multiple labels should be divided by commas.'
' See examples below.'
)
def check_tokens_limit(self, inputs: list[str]) -> None:
tokens = 0
for input_ in inputs:
tokens += len(self.tokenizer.encode(input_))
if tokens > self.tokens_limit:
raise gr.Error(
'Too many tokens! Please reduce size of text or amount of labels.'
f' Max tokens count is: {self.tokens_limit}.'
)
def process(
self, labels: str, text: str, threshold: float=0.
) -> dict[str, any]:
labels_list = list({
l for label in labels.split(',')
if (l:=label.strip())
})
self.check_labels(labels_list)
self.check_text(text)
chunks, chunks_starts = self.chunkanize(text)
inputs, prompts_lens = self.get_inputs(chunks, labels_list)
self.check_tokens_limit(inputs)
outputs = self.predict(
text, inputs, labels_list, chunks_starts, prompts_lens, threshold
)
return {"text": text, "entities": outputs} |