|
import gradio as gr |
|
from dataclasses import dataclass |
|
|
|
from pytorch_ie import AnnotationList, LabeledSpan, Pipeline, TextDocument, annotation_field |
|
from pytorch_ie.models import TransformerSpanClassificationModel |
|
from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule |
|
from spacy import displacy |
|
|
|
|
|
@dataclass |
|
class ExampleDocument(TextDocument): |
|
entities: AnnotationList[LabeledSpan] = annotation_field(target="text") |
|
|
|
|
|
model_name_or_path = "pie/example-ner-spanclf-conll03" |
|
ner_taskmodule = TransformerSpanClassificationTaskModule.from_pretrained(model_name_or_path) |
|
ner_model = TransformerSpanClassificationModel.from_pretrained(model_name_or_path) |
|
|
|
ner_pipeline = Pipeline(model=ner_model, taskmodule=ner_taskmodule, device=-1, num_workers=0) |
|
|
|
|
|
def predict(text): |
|
document = ExampleDocument(text) |
|
|
|
ner_pipeline(document, predict_field="entities") |
|
|
|
doc = { |
|
"text": document.text, |
|
"ents": [{ |
|
"start": entity.start, |
|
"end": entity.end, |
|
"label": entity.label |
|
} for entity in sorted(document.entities.predictions, key=lambda e: e.start)], |
|
"title": None |
|
} |
|
|
|
html = displacy.render(doc, style="ent", page=True, manual=True, minify=True) |
|
html = ( |
|
"<div style='max-width:100%; max-height:360px; overflow:auto'>" |
|
+ html |
|
+ "</div>" |
|
) |
|
|
|
return html |
|
|
|
|
|
iface = gr.Interface(fn=predict, inputs="textbox", outputs="html") |
|
iface.launch() |
|
|