File size: 5,793 Bytes
d4e7f01
1938c90
aadf93b
1938c90
 
3f9fe25
aadf93b
1938c90
 
 
 
 
 
49fed2a
78d1da2
 
49fed2a
1938c90
 
 
 
 
3f9fe25
 
1938c90
78d1da2
1938c90
 
 
 
 
3f9fe25
df8632c
3f9fe25
 
 
1938c90
 
 
 
49fed2a
 
 
 
 
 
 
1938c90
 
 
 
 
 
 
49fed2a
3f9fe25
81fbf66
3f9fe25
1938c90
 
 
 
81fbf66
 
 
 
 
d4e7f01
 
 
1938c90
 
3f9fe25
1938c90
 
 
 
 
 
 
 
 
 
aadf93b
81fbf66
aadf93b
 
1938c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78d1da2
1938c90
 
78762bd
49fed2a
78d1da2
78762bd
 
1938c90
78d1da2
1938c90
 
 
 
 
 
 
49fed2a
1938c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78762bd
1938c90
 
78762bd
1938c90
 
78762bd
1938c90
 
78762bd
1938c90
 
78762bd
 
 
 
1938c90
 
78762bd
1938c90
78762bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import spaces
import os
import json
import gradio as gr
import pycountry
import torch
from datetime import datetime
from typing import Dict, Union
from gliner import GLiNER


_MODEL = {}
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
THRESHOLD = 0.3
LABELS = ["country", "year", "statistical indicator", "geographic region"]
QUERY = "gdp, co2 emissions, and mortality rate of the philippines vs. south asia in 2024"
MODELS = ["urchade/gliner_base", "urchade/gliner_medium-v2.1"]

print(f"Cache directory: {_CACHE_DIR}")


def get_model(model_name: str = None):
    start = datetime.now()

    if model_name is None:
        model_name = "urchade/gliner_base"

    global _MODEL

    if _MODEL.get(model_name) is None:
        _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)

    if torch.cuda.is_available() and not next(_MODEL[model_name].parameters()).device.type.startswith("cuda"):
        _MODEL[model_name] = _MODEL[model_name].to("cuda")

    print(f"{datetime.now()} :: get_model :: {datetime.now() - start}")

    return _MODEL[model_name]


# Initialize model here.
print("Initializing models...")
for model_name in MODELS:
    model = get_model(model_name=model_name)
    model.predict_entities(QUERY, LABELS, threshold=THRESHOLD)


def get_country(country_name: str):
    try:
        return pycountry.countries.search_fuzzy(country_name)
    except LookupError:
        return None


@spaces.GPU(enable_queue=True, duration=5)
def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
    start = datetime.now()
    model = get_model(model_name)

    if isinstance(labels, str):
        labels = [i.strip() for i in labels.split(",")]

    entities = model.predict_entities(query, labels, threshold=threshold, flat_ner=not nested_ner)

    print(f"{datetime.now()} :: predict_entities :: {datetime.now() - start}")

    return entities


def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:

    entities = []
    _entities = predict_entities(model_name=model_name, query=query, labels=labels, threshold=threshold, nested_ner=nested_ner)

    for entity in _entities:
        if entity["label"] == "country":
            country = get_country(entity["text"])
            if country:
                entity["normalized"] = [dict(c) for c in country]
                entities.append(entity)
        else:
            entities.append(entity)

    payload = {"query": query, "entities": entities}
    print(f"{datetime.now()} :: parse_query :: {json.dumps(payload)}\n")

    return payload



with gr.Blocks(title="GLiNER-query-parser") as demo:
    gr.Markdown(
        """
        # GLiNER-based Query Parser (a zero-shot NER model)

        This space demonstrates the GLiNER model's ability to predict entities in a given text query. Given a set of entities to track, the model can then identify instances of these entities in the query. The parsed entities are then displayed in the output. A special case is the "country" entity, which is normalized to the ISO 3166-1 alpha-2 code using the `pycountry` library. This GLiNER mode is licensed under the Apache 2.0 license.

        ## Links
        * Model: https://huggingface.co/urchade/gliner_medium-v2.1, https://huggingface.co/urchade/gliner_base
        * All GLiNER models: https://huggingface.co/models?library=gliner
        * Paper: https://arxiv.org/abs/2311.08526
        * Repository: https://github.com/urchade/GLiNER
        """
    )

    query = gr.Textbox(
        value=QUERY, label="query", placeholder="Enter your query here"
    )
    with gr.Row() as row:
        model_name = gr.Radio(
            choices=MODELS,
            value="urchade/gliner_base",
            label="Model",
        )
        entities = gr.Textbox(
            value=", ".join(LABELS),
            label="entities",
            placeholder="Enter the entities to detect here (comma separated)",
            scale=2,
        )
        threshold = gr.Slider(
            0,
            1,
            value=THRESHOLD,
            step=0.01,
            label="Threshold",
            info="Lower threshold may extract more false-positive entities from the query.",
            scale=1,
        )
        is_nested = gr.Checkbox(
            value=False,
            label="Nested NER",
            info="Setting to True extracts nested entities",
            scale=0,
        )

    output = gr.JSON(label="Extracted entities")
    submit_btn = gr.Button("Submit")

    # Submitting
    query.submit(
        fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
    )
    entities.submit(
        fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
    )
    threshold.release(
        fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
    )
    submit_btn.click(
        fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
    )
    is_nested.change(
        fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
    )
    model_name.change(
        fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
    )

demo.queue(default_concurrency_limit=5)
demo.launch(debug=True)


"""
from gradio_client import Client

client = Client("avsolatorio/query-parser")
result = client.predict(
		query="gdp, m3, and child mortality of india and southeast asia 2024",
		labels="country, year, statistical indicator, region",
		threshold=0.3,
		nested_ner=False,
		api_name="/parse_query"
)
print(result)
"""