File size: 2,449 Bytes
daf1a30
 
 
 
2268fdb
daf1a30
 
 
 
 
 
540c766
7a74ade
daf1a30
 
2268fdb
 
daf1a30
 
 
 
 
 
 
 
132fa59
daf1a30
 
 
 
c7db18a
 
 
 
 
 
daf1a30
 
 
 
 
 
 
c7db18a
 
 
 
 
 
2268fdb
 
 
daf1a30
 
 
c7db18a
daf1a30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9806d6
daf1a30
d9ee424
daf1a30
 
 
cdb9404
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import re
import gradio as gr
from dataclasses import dataclass
from prettytable import PrettyTable
import logging

from pytorch_ie.annotations import LabeledSpan, BinaryRelation
from pytorch_ie.auto import AutoPipeline
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.documents import TextDocument

import transformer_re_text_classification2

from typing import List

logger = logging.getLogger(__name__)


@dataclass
class ExampleDocument(TextDocument):
    entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
    relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")


ner_model_name_or_path = "pie/example-ner-spanclf-conll03"
re_model_name_or_path = "DFKI-SLT/relation_classification_tacred_revisited"

ner_pipeline = AutoPipeline.from_pretrained(ner_model_name_or_path, device=-1, num_workers=0)
re_pipeline = AutoPipeline.from_pretrained(re_model_name_or_path, device=-1, num_workers=0)

ner_tag_mapping = {
    'ORG': 'ORGANIZATION',
    'PER': 'PERSON',
    'LOC': 'LOCATION'
}


def predict(text):
    document = ExampleDocument(text)

    ner_pipeline(document)

    while len(document.entities.predictions) > 0:
        entity = document.entities.predictions.pop(0)
        if entity.label in ner_tag_mapping:
            entity = LabeledSpan(start=entity.start, end=entity.end, label=ner_tag_mapping[entity.label],
                                 score=entity.score)
        if entity.label in re_pipeline.taskmodule.entity_labels:
            document.entities.append(entity)
            logger.warning(f"detected entity: {entity} (added)")
        else:
            logger.warning(f"detected entity: {entity} (NOT added)")

    re_pipeline(document)


    t = PrettyTable()
    t.field_names = ["head", "tail", "relation"]
    t.align = "l"
    for relation in document.relations.predictions:
        t.add_row([str(relation.head), str(relation.tail), relation.label])

    html = t.get_html_string(format=True)
    html = (
        "<div style='max-width:100%; max-height:360px; overflow:auto'>"
        + html
        + "</div>"
    )
    
    return html


iface = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(
        lines=5,
        placeholder="There is still some uncertainty that Musk - also chief executive of electric car maker Tesla and rocket company SpaceX - will pull off his planned buyout.",
    ),
    outputs="html",
)
iface.launch()