File size: 1,476 Bytes
3ddbd57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import gradio as gr
from dataclasses import dataclass

from pytorch_ie import AnnotationList, LabeledSpan, Pipeline, TextDocument, annotation_field
from pytorch_ie.models import TransformerSpanClassificationModel
from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule
from spacy import displacy


@dataclass
class ExampleDocument(TextDocument):
    entities: AnnotationList[LabeledSpan] = annotation_field(target="text")


model_name_or_path = "pie/example-ner-spanclf-conll03"
ner_taskmodule = TransformerSpanClassificationTaskModule.from_pretrained(model_name_or_path)
ner_model = TransformerSpanClassificationModel.from_pretrained(model_name_or_path)

ner_pipeline = Pipeline(model=ner_model, taskmodule=ner_taskmodule, device=-1, num_workers=0)


def predict(text):
    document = ExampleDocument(text)

    ner_pipeline(document, predict_field="entities")
    
    doc = {
        "text": document.text,
        "ents": [{
            "start": entity.start,
            "end": entity.end,
            "label": entity.label
        } for entity in sorted(document.entities.predictions, key=lambda e: e.start)],
        "title": None
    }

    html = displacy.render(doc, style="ent", page=True, manual=True, minify=True)
    html = (
        "<div style='max-width:100%; max-height:360px; overflow:auto'>"
        + html
        + "</div>"
    )
    
    return html


iface = gr.Interface(fn=predict, inputs="textbox", outputs="html")
iface.launch()