File size: 1,476 Bytes
3ddbd57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import gradio as gr
from dataclasses import dataclass
from pytorch_ie import AnnotationList, LabeledSpan, Pipeline, TextDocument, annotation_field
from pytorch_ie.models import TransformerSpanClassificationModel
from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule
from spacy import displacy
@dataclass
class ExampleDocument(TextDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
model_name_or_path = "pie/example-ner-spanclf-conll03"
ner_taskmodule = TransformerSpanClassificationTaskModule.from_pretrained(model_name_or_path)
ner_model = TransformerSpanClassificationModel.from_pretrained(model_name_or_path)
ner_pipeline = Pipeline(model=ner_model, taskmodule=ner_taskmodule, device=-1, num_workers=0)
def predict(text):
document = ExampleDocument(text)
ner_pipeline(document, predict_field="entities")
doc = {
"text": document.text,
"ents": [{
"start": entity.start,
"end": entity.end,
"label": entity.label
} for entity in sorted(document.entities.predictions, key=lambda e: e.start)],
"title": None
}
html = displacy.render(doc, style="ent", page=True, manual=True, minify=True)
html = (
"<div style='max-width:100%; max-height:360px; overflow:auto'>"
+ html
+ "</div>"
)
return html
iface = gr.Interface(fn=predict, inputs="textbox", outputs="html")
iface.launch()
|