Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,602 Bytes
1938c90 78762bd 1938c90 78762bd 1938c90 78762bd 1938c90 78762bd 1938c90 78762bd 1938c90 78762bd 1938c90 78762bd 1938c90 78762bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import os
import gradio as gr
import pycountry
from typing import Dict, Union
from gliner import GLiNER
_MODEL = {}
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
print(f"Cache directory: {_CACHE_DIR}")
def get_model(model_name: str = None):
if model_name is None:
# model_name = "urchade/gliner_base"
model_name = "urchade/gliner_medium-v2.1"
global _MODEL
if _MODEL.get(model_name) is None:
_MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)
return _MODEL[model_name]
def get_country(country_name: str):
try:
return pycountry.countries.search_fuzzy(country_name)
except LookupError:
return None
def parse_query(query: str, labels: Union[str, list], threshold: float = 0.5, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
model = get_model(model_name)
if isinstance(labels, str):
labels = [i.strip() for i in labels.split(",")]
_entities = model.predict_entities(query, labels, threshold=threshold)
entities = []
for entity in _entities:
if entity["label"] == "country":
country = get_country(entity["text"])
if country:
entity["normalized"] = [dict(c) for c in country]
entities.append(entity)
else:
entities.append(entity)
return {"query": query, "entities": entities}
with gr.Blocks(title="GLiNER-query-parser") as demo:
gr.Markdown(
"""
# GLiNER-based Query Parser (a zero-shot NER model)
This space demonstrates the GLiNER model's ability to predict entities in a given text query. Given a set of entities to track, the model can then identify instances of these entities in the query. The parsed entities are then displayed in the output. A special case is the "country" entity, which is normalized to the ISO 3166-1 alpha-2 code using the `pycountry` library. This GLiNER mode is licensed under the Apache 2.0 license.
## Links
* Model: https://huggingface.co/urchade/gliner_medium-v2.1, https://huggingface.co/urchade/gliner_base
* All GLiNER models: https://huggingface.co/models?library=gliner
* Paper: https://arxiv.org/abs/2311.08526
* Repository: https://github.com/urchade/GLiNER
"""
)
query = gr.Textbox(
value="gdp of the philippines in 2024", label="query", placeholder="Enter your query here"
)
with gr.Row() as row:
model_name = gr.Radio(
choices=["urchade/gliner_medium-v2.1", "urchade/gliner_base"],
value="urchade/gliner_medium-v2.1",
label="Model",
)
entities = gr.Textbox(
value="country, year, indicator",
label="entities",
placeholder="Enter the entities to detect here (comma separated)",
scale=2,
)
threshold = gr.Slider(
0,
1,
value=0.5,
step=0.01,
label="Threshold",
info="Lower threshold may extract more false-positive entities from the query.",
scale=1,
)
is_nested = gr.Checkbox(
value=False,
label="Nested NER",
info="Setting to True extracts nested entities",
scale=0,
)
output = gr.JSON(label="Extracted entities")
submit_btn = gr.Button("Submit")
# Submitting
query.submit(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
entities.submit(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
threshold.release(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
submit_btn.click(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
is_nested.change(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
model_name.change(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
demo.queue(default_concurrency_limit=5)
demo.launch(debug=True)
"""
from gradio_client import Client
client = Client("avsolatorio/query-parser")
result = client.predict(
query="gdp, m3, and child mortality of india and southeast asia 2024",
labels="country, year, statistical indicator, region",
threshold=0.3,
nested_ner=False,
api_name="/parse_query"
)
print(result)
""" |