NER / app.py
christophalt's picture
Create app.py
3ddbd57
raw
history blame
1.48 kB
import gradio as gr
from dataclasses import dataclass
from pytorch_ie import AnnotationList, LabeledSpan, Pipeline, TextDocument, annotation_field
from pytorch_ie.models import TransformerSpanClassificationModel
from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule
from spacy import displacy
@dataclass
class ExampleDocument(TextDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
model_name_or_path = "pie/example-ner-spanclf-conll03"
ner_taskmodule = TransformerSpanClassificationTaskModule.from_pretrained(model_name_or_path)
ner_model = TransformerSpanClassificationModel.from_pretrained(model_name_or_path)
ner_pipeline = Pipeline(model=ner_model, taskmodule=ner_taskmodule, device=-1, num_workers=0)
def predict(text):
document = ExampleDocument(text)
ner_pipeline(document, predict_field="entities")
doc = {
"text": document.text,
"ents": [{
"start": entity.start,
"end": entity.end,
"label": entity.label
} for entity in sorted(document.entities.predictions, key=lambda e: e.start)],
"title": None
}
html = displacy.render(doc, style="ent", page=True, manual=True, minify=True)
html = (
"<div style='max-width:100%; max-height:360px; overflow:auto'>"
+ html
+ "</div>"
)
return html
iface = gr.Interface(fn=predict, inputs="textbox", outputs="html")
iface.launch()