Saripudin commited on
Commit
186b145
·
verified ·
1 Parent(s): 12bfd93

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -0
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from transformers import pipeline
4
+ from typing import Dict, Union
5
+ from gliner import GLiNER
6
+
7
+ model = GLiNER.from_pretrained("numind/NuNER_Zero")
8
+
9
+ classifier = pipeline("zero-shot-classification", model="MoritzLaurer/deberta-v3-base-zeroshot-v1")
10
+
11
+ #define a function to process your input and output
12
+ def zero_shot(doc, candidates):
13
+ given_labels = candidates.split(", ")
14
+ dictionary = classifier(doc, given_labels)
15
+ labels = dictionary['labels']
16
+ scores = dictionary['scores']
17
+ return dict(zip(labels, scores))
18
+
19
+ examples = [
20
+ [
21
+ "The Moon is Earth's only natural satellite. It orbits at an average distance of 384,400 km (238,900 mi), about 30 times the diameter of Earth. Over time Earth's gravity has caused tidal locking, causing the same side of the Moon to always face Earth. Because of this, the lunar day and the lunar month are the same length, at 29.5 Earth days. The Moon's gravitational pull – and to a lesser extent, the Sun's – are the main drivers of Earth's tides.",
22
+ "celestial body,quantity,physical concept",
23
+ 0.3,
24
+ False
25
+ ],
26
+ ]
27
+
28
+ def merge_entities(entities):
29
+ if not entities:
30
+ return []
31
+ merged = []
32
+ current = entities[0]
33
+ for next_entity in entities[1:]:
34
+ if next_entity['entity'] == current['entity'] and (next_entity['start'] == current['end'] + 1 or next_entity['start'] == current['end']):
35
+ current['word'] += ' ' + next_entity['word']
36
+ current['end'] = next_entity['end']
37
+ else:
38
+ merged.append(current)
39
+ current = next_entity
40
+ merged.append(current)
41
+ return merged
42
+
43
+ def ner(
44
+ text, labels: str, threshold: float, nested_ner: bool
45
+ ) -> Dict[str, Union[str, int, float]]:
46
+ labels = labels.split(",")
47
+ r = {
48
+ "text": text,
49
+ "entities": [
50
+ {
51
+ "entity": entity["label"],
52
+ "word": entity["text"],
53
+ "start": entity["start"],
54
+ "end": entity["end"],
55
+ "score": 0,
56
+ }
57
+ for entity in model.predict_entities(
58
+ text, labels, flat_ner=not nested_ner, threshold=threshold
59
+ )
60
+ ],
61
+ }
62
+ r["entities"] = merge_entities(r["entities"])
63
+ return r
64
+
65
+ with gr.Blocks(title="Zero-Shot Demo") as demo: #, theme=gr.themes.Soft()
66
+
67
+ #create input and output objects
68
+ with gr.Tab("Zero-Shot Text Classification"):
69
+ #input object1
70
+ input1 = gr.Textbox(label="Text")
71
+ #input object 2
72
+ input2 = gr.Textbox(label="Labels")
73
+ #output object
74
+ output = gr.Label(label="Output")
75
+ #create interface
76
+ gui = gr.Interface(
77
+ title="Zero-Shot Text Classification",
78
+ fn=zero_shot,
79
+ inputs=[input1, input2],
80
+ outputs=[output]
81
+ )
82
+
83
+ with gr.Tab("Zero-Shot NER"):
84
+ gr.Markdown(
85
+ """
86
+ # Zero-Shot Named Entity Recognition (NER)
87
+ """
88
+ )
89
+
90
+ input_text = gr.Textbox(
91
+ value=examples[0][0], label="Text input", placeholder="Enter your text here", lines=3
92
+ )
93
+ with gr.Row() as row:
94
+ labels = gr.Textbox(
95
+ value=examples[0][1],
96
+ label="Labels",
97
+ placeholder="Enter your labels here (comma separated)",
98
+ scale=2,
99
+ )
100
+ threshold = gr.Slider(
101
+ 0,
102
+ 1,
103
+ value=0.3,
104
+ step=0.01,
105
+ label="Threshold",
106
+ info="Lower the threshold to increase how many entities get predicted.",
107
+ scale=1,
108
+ )
109
+
110
+ output = gr.HighlightedText(label="Predicted Entities")
111
+
112
+ submit_btn = gr.Button("Submit")
113
+
114
+ # Submitting
115
+ # input_text.submit(
116
+ # fn=ner, inputs=[input_text, labels, threshold], outputs=output
117
+ # )
118
+ # labels.submit(
119
+ # fn=ner, inputs=[input_text, labels, threshold], outputs=output
120
+ # )
121
+ # threshold.release(
122
+ # fn=ner, inputs=[input_text, labels, threshold], outputs=output
123
+ # )
124
+ submit_btn.click(
125
+ fn=ner, inputs=[input_text, labels, threshold], outputs=output
126
+ )
127
+
128
+ demo.queue()
129
+ demo.launch(debug=True)