avsolatorio commited on
Commit
1938c90
·
1 Parent(s): 33a406c

Add query parser app

Browse files

Signed-off-by: Aivin V. Solatorio <[email protected]>

Files changed (2) hide show
  1. app.py +119 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import pycountry
4
+ from typing import Dict, Union
5
+ from gliner import GLiNER
6
+
7
+
8
+ _MODEL = {}
9
+ _CACHE_DIR = os.environ.get("CACHE_DIR", None)
10
+
11
+ print(f"Cache directory: {_CACHE_DIR}")
12
+
13
+
14
+ def get_model(model_name: str = None):
15
+ if model_name is None:
16
+ # model_name = "urchade/gliner_base"
17
+ model_name = "urchade/gliner_medium-v2.1"
18
+
19
+ global _MODEL
20
+
21
+ if _MODEL.get(model_name) is None:
22
+ _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)
23
+
24
+ return _MODEL[model_name]
25
+
26
+
27
+ def get_country(country_name: str):
28
+ try:
29
+ return pycountry.countries.search_fuzzy(country_name)
30
+ except LookupError:
31
+ return None
32
+
33
+
34
+ def parse_query(query: str, labels: Union[str, list], threshold: float = 0.5, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
35
+ model = get_model(model_name)
36
+
37
+ if isinstance(labels, str):
38
+ labels = [i.strip() for i in labels.split(",")]
39
+
40
+ _entities = model.predict_entities(query, labels, threshold=threshold)
41
+
42
+ entities = []
43
+
44
+ for entity in _entities:
45
+ if entity["label"] == "country":
46
+ country = get_country(entity["text"])
47
+ if country:
48
+ entity["normalized"] = [dict(c) for c in country]
49
+ entities.append(entity)
50
+ else:
51
+ entities.append(entity)
52
+
53
+ return {"query": query, "entities": entities}
54
+
55
+
56
+
57
+ with gr.Blocks(title="GLiNER-query-parser") as demo:
58
+ gr.Markdown(
59
+ """
60
+ # GLiNER-based Query Parser (a zero-shot NER model)
61
+
62
+ This space demonstrates the GLiNER model's ability to predict entities in a given text query. Given a set of entities to track, the model can then identify instances of these entities in the query. The parsed entities are then displayed in the output. A special case is the "country" entity, which is normalized to the ISO 3166-1 alpha-2 code using the `pycountry` library. This GLiNER mode is licensed under the Apache 2.0 license.
63
+
64
+ ## Links
65
+ * Model: https://huggingface.co/urchade/gliner_medium-v2.1, https://huggingface.co/urchade/gliner_base
66
+ * All GLiNER models: https://huggingface.co/models?library=gliner
67
+ * Paper: https://arxiv.org/abs/2311.08526
68
+ * Repository: https://github.com/urchade/GLiNER
69
+ """
70
+ )
71
+
72
+ query = gr.Textbox(
73
+ value="gdp of the philippines in 2024", label="query", placeholder="Enter your query here"
74
+ )
75
+ with gr.Row() as row:
76
+ entities = gr.Textbox(
77
+ value="country, year, indicator",
78
+ label="entities",
79
+ placeholder="Enter the entities to detect here (comma separated)",
80
+ scale=2,
81
+ )
82
+ threshold = gr.Slider(
83
+ 0,
84
+ 1,
85
+ value=0.5,
86
+ step=0.01,
87
+ label="Threshold",
88
+ info="Lower threshold may extract more false-positive entities from the query.",
89
+ scale=1,
90
+ )
91
+ is_nested = gr.Checkbox(
92
+ value=False,
93
+ label="Nested NER",
94
+ info="Setting to True extracts nested entities",
95
+ scale=0,
96
+ )
97
+
98
+ output = gr.JSON(label="Extracted entities")
99
+ submit_btn = gr.Button("Submit")
100
+
101
+ # Submitting
102
+ query.submit(
103
+ fn=parse_query, inputs=[query, entities, threshold, is_nested], outputs=output
104
+ )
105
+ entities.submit(
106
+ fn=parse_query, inputs=[query, entities, threshold, is_nested], outputs=output
107
+ )
108
+ threshold.release(
109
+ fn=parse_query, inputs=[query, entities, threshold, is_nested], outputs=output
110
+ )
111
+ submit_btn.click(
112
+ fn=parse_query, inputs=[query, entities, threshold, is_nested], outputs=output
113
+ )
114
+ is_nested.change(
115
+ fn=parse_query, inputs=[query, entities, threshold, is_nested], outputs=output
116
+ )
117
+
118
+ demo.queue()
119
+ demo.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gliner
2
+ pycountry
3
+ scipy==1.12
4
+ gradio