File size: 4,310 Bytes
186b145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr

from transformers import pipeline
from typing import Dict, Union
from gliner import GLiNER

model = GLiNER.from_pretrained("numind/NuNER_Zero")

classifier = pipeline("zero-shot-classification", model="MoritzLaurer/deberta-v3-base-zeroshot-v1")

#define a function to process your input and output
def zero_shot(doc, candidates):
    given_labels = candidates.split(", ")
    dictionary = classifier(doc, given_labels)
    labels = dictionary['labels']
    scores = dictionary['scores']
    return dict(zip(labels, scores))

examples = [
    [
        "The Moon is Earth's only natural satellite. It orbits at an average distance of 384,400 km (238,900 mi), about 30 times the diameter of Earth. Over time Earth's gravity has caused tidal locking, causing the same side of the Moon to always face Earth. Because of this, the lunar day and the lunar month are the same length, at 29.5 Earth days. The Moon's gravitational pull – and to a lesser extent, the Sun's – are the main drivers of Earth's tides.",
        "celestial body,quantity,physical concept",
        0.3,
        False
    ],
]

def merge_entities(entities):
    if not entities:
        return []
    merged = []
    current = entities[0]
    for next_entity in entities[1:]:
        if next_entity['entity'] == current['entity'] and (next_entity['start'] == current['end'] + 1 or next_entity['start'] == current['end']):
            current['word'] += ' ' + next_entity['word']
            current['end'] = next_entity['end']
        else:
            merged.append(current)
            current = next_entity
    merged.append(current)
    return merged

def ner(
    text, labels: str, threshold: float, nested_ner: bool
) -> Dict[str, Union[str, int, float]]:
    labels = labels.split(",")
    r = {
        "text": text,
        "entities": [
            {
                "entity": entity["label"],
                "word": entity["text"],
                "start": entity["start"],
                "end": entity["end"],
                "score": 0,
            }
            for entity in model.predict_entities(
                text, labels, flat_ner=not nested_ner, threshold=threshold
            )
        ],
    }
    r["entities"] =  merge_entities(r["entities"])
    return r

with gr.Blocks(title="Zero-Shot Demo") as demo: #, theme=gr.themes.Soft()

    #create input and output objects
    with gr.Tab("Zero-Shot Text Classification"):
        #input object1
        input1 = gr.Textbox(label="Text")
        #input object 2
        input2 = gr.Textbox(label="Labels")
        #output object
        output = gr.Label(label="Output")
        #create interface
        gui = gr.Interface(
            title="Zero-Shot Text Classification",
            fn=zero_shot,
            inputs=[input1, input2],
            outputs=[output]
        )

    with gr.Tab("Zero-Shot NER"):
        gr.Markdown(
            """
            # Zero-Shot Named Entity Recognition (NER)
            """
        )

        input_text = gr.Textbox(
            value=examples[0][0], label="Text input", placeholder="Enter your text here", lines=3
        )
        with gr.Row() as row:
            labels = gr.Textbox(
                value=examples[0][1],
                label="Labels",
                placeholder="Enter your labels here (comma separated)",
                scale=2,
            )
            threshold = gr.Slider(
                0,
                1,
                value=0.3,
                step=0.01,
                label="Threshold",
                info="Lower the threshold to increase how many entities get predicted.",
                scale=1,
            )
        
        output = gr.HighlightedText(label="Predicted Entities")

        submit_btn = gr.Button("Submit")

        # Submitting
        # input_text.submit(
        #     fn=ner, inputs=[input_text, labels, threshold], outputs=output
        # )
        # labels.submit(
        #     fn=ner, inputs=[input_text, labels, threshold], outputs=output
        # )
        # threshold.release(
        #     fn=ner, inputs=[input_text, labels, threshold], outputs=output
        # )
        submit_btn.click(
            fn=ner, inputs=[input_text, labels, threshold], outputs=output
        )

demo.queue()
demo.launch(debug=True)