File size: 2,524 Bytes
c1db962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
import spacy
import torch

nlp = spacy.load('en_core_web_sm', disable = ['lemmatizer', 'parser', 'tagger', 'ner'])
nlp.add_pipe('sentencizer')

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


class NER:
    model_name = 'knowledgator/UTC-DeBERTa-small'
    prompt="""
Identify entities in the text having the following classes:
{}

Text:
"""
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForTokenClassification.from_pretrained(model_name)
    ner_pipeline = pipeline(
        "ner", 
        model=model, 
        tokenizer=tokenizer,
        aggregation_strategy='first', 
        batch_size=12,
        device=device
    )
    
    @classmethod
    def chunkanize(cls, text, prompt_ = '', n_sents = 10):
        doc = nlp(text)
        chunks = []
        starts = []
        start = 0
        end = 0
        proc = False
        for id, sent in enumerate(doc.sents, start=1):
            if not proc:
                start = sent[0].idx
                starts.append(start)
            proc = True
            end = sent[-1].idx+len(sent[-1].text)
            if id%n_sents==0:
                chunk_text = prompt_+text[start:end]
                chunks.append(chunk_text)
                proc = False
        if proc:
            chunk_text = prompt_+text[start:end]
            chunks.append(chunk_text)
        return chunks, starts


    @classmethod
    def ner(cls, labels, text, treshold = 0.):
        chunks, starts, classes = [], [], []
        label2prompt_len = {}
        for label in labels.split(', '):
            prompt_ = cls.prompt.format(label)
            prompt_len = len(prompt_)
            label2prompt_len[label] = prompt_len
            curr_chunks, curr_starts = cls.chunkanize(text, prompt_)
            curr_labels = [label for _ in range(len(curr_chunks))]
            chunks+=curr_chunks
            starts+=curr_starts
            classes+=curr_labels
        outputs = []
        for id, output in enumerate(cls.ner_pipeline(chunks)):
            label = classes[id]
            prompt_len = label2prompt_len[label]
            start = starts[id]-prompt_len
            for ent in output:
                if ent['score']>treshold:
                    ent['start'] += start
                    ent['end'] += start
                    ent['entity'] = label
                    outputs.append(ent)
        return {"text": text, "entities": outputs}