File size: 4,602 Bytes
1938c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78762bd
 
 
 
 
1938c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78762bd
1938c90
 
78762bd
1938c90
 
78762bd
1938c90
 
78762bd
1938c90
 
78762bd
 
 
 
1938c90
 
78762bd
1938c90
78762bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import gradio as gr
import pycountry
from typing import Dict, Union
from gliner import GLiNER


_MODEL = {}
_CACHE_DIR = os.environ.get("CACHE_DIR", None)

print(f"Cache directory: {_CACHE_DIR}")


def get_model(model_name: str = None):
    if model_name is None:
        # model_name = "urchade/gliner_base"
        model_name = "urchade/gliner_medium-v2.1"

    global _MODEL

    if _MODEL.get(model_name) is None:
        _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)

    return _MODEL[model_name]


def get_country(country_name: str):
    try:
        return pycountry.countries.search_fuzzy(country_name)
    except LookupError:
        return None


def parse_query(query: str, labels: Union[str, list], threshold: float = 0.5, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
    model = get_model(model_name)

    if isinstance(labels, str):
        labels = [i.strip() for i in labels.split(",")]

    _entities = model.predict_entities(query, labels, threshold=threshold)

    entities = []

    for entity in _entities:
        if entity["label"] == "country":
            country = get_country(entity["text"])
            if country:
                entity["normalized"] = [dict(c) for c in country]
                entities.append(entity)
        else:
            entities.append(entity)

    return {"query": query, "entities": entities}



with gr.Blocks(title="GLiNER-query-parser") as demo:
    gr.Markdown(
        """
        # GLiNER-based Query Parser (a zero-shot NER model)

        This space demonstrates the GLiNER model's ability to predict entities in a given text query. Given a set of entities to track, the model can then identify instances of these entities in the query. The parsed entities are then displayed in the output. A special case is the "country" entity, which is normalized to the ISO 3166-1 alpha-2 code using the `pycountry` library. This GLiNER mode is licensed under the Apache 2.0 license.

        ## Links
        * Model: https://huggingface.co/urchade/gliner_medium-v2.1, https://huggingface.co/urchade/gliner_base
        * All GLiNER models: https://huggingface.co/models?library=gliner
        * Paper: https://arxiv.org/abs/2311.08526
        * Repository: https://github.com/urchade/GLiNER
        """
    )

    query = gr.Textbox(
        value="gdp of the philippines in 2024", label="query", placeholder="Enter your query here"
    )
    with gr.Row() as row:
        model_name = gr.Radio(
            choices=["urchade/gliner_medium-v2.1", "urchade/gliner_base"],
            value="urchade/gliner_medium-v2.1",
            label="Model",
        )
        entities = gr.Textbox(
            value="country, year, indicator",
            label="entities",
            placeholder="Enter the entities to detect here (comma separated)",
            scale=2,
        )
        threshold = gr.Slider(
            0,
            1,
            value=0.5,
            step=0.01,
            label="Threshold",
            info="Lower threshold may extract more false-positive entities from the query.",
            scale=1,
        )
        is_nested = gr.Checkbox(
            value=False,
            label="Nested NER",
            info="Setting to True extracts nested entities",
            scale=0,
        )

    output = gr.JSON(label="Extracted entities")
    submit_btn = gr.Button("Submit")

    # Submitting
    query.submit(
        fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
    )
    entities.submit(
        fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
    )
    threshold.release(
        fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
    )
    submit_btn.click(
        fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
    )
    is_nested.change(
        fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
    )
    model_name.change(
        fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
    )

demo.queue(default_concurrency_limit=5)
demo.launch(debug=True)


"""
from gradio_client import Client

client = Client("avsolatorio/query-parser")
result = client.predict(
		query="gdp, m3, and child mortality of india and southeast asia 2024",
		labels="country, year, statistical indicator, region",
		threshold=0.3,
		nested_ner=False,
		api_name="/parse_query"
)
print(result)
"""