query-parser / app.py
avsolatorio's picture
Add model_name to args
78762bd
raw
history blame
4.6 kB
import os
import gradio as gr
import pycountry
from typing import Dict, Union
from gliner import GLiNER
_MODEL = {}
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
print(f"Cache directory: {_CACHE_DIR}")
def get_model(model_name: str = None):
if model_name is None:
# model_name = "urchade/gliner_base"
model_name = "urchade/gliner_medium-v2.1"
global _MODEL
if _MODEL.get(model_name) is None:
_MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)
return _MODEL[model_name]
def get_country(country_name: str):
try:
return pycountry.countries.search_fuzzy(country_name)
except LookupError:
return None
def parse_query(query: str, labels: Union[str, list], threshold: float = 0.5, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
model = get_model(model_name)
if isinstance(labels, str):
labels = [i.strip() for i in labels.split(",")]
_entities = model.predict_entities(query, labels, threshold=threshold)
entities = []
for entity in _entities:
if entity["label"] == "country":
country = get_country(entity["text"])
if country:
entity["normalized"] = [dict(c) for c in country]
entities.append(entity)
else:
entities.append(entity)
return {"query": query, "entities": entities}
with gr.Blocks(title="GLiNER-query-parser") as demo:
gr.Markdown(
"""
# GLiNER-based Query Parser (a zero-shot NER model)
This space demonstrates the GLiNER model's ability to predict entities in a given text query. Given a set of entities to track, the model can then identify instances of these entities in the query. The parsed entities are then displayed in the output. A special case is the "country" entity, which is normalized to the ISO 3166-1 alpha-2 code using the `pycountry` library. This GLiNER mode is licensed under the Apache 2.0 license.
## Links
* Model: https://huggingface.co/urchade/gliner_medium-v2.1, https://huggingface.co/urchade/gliner_base
* All GLiNER models: https://huggingface.co/models?library=gliner
* Paper: https://arxiv.org/abs/2311.08526
* Repository: https://github.com/urchade/GLiNER
"""
)
query = gr.Textbox(
value="gdp of the philippines in 2024", label="query", placeholder="Enter your query here"
)
with gr.Row() as row:
model_name = gr.Radio(
choices=["urchade/gliner_medium-v2.1", "urchade/gliner_base"],
value="urchade/gliner_medium-v2.1",
label="Model",
)
entities = gr.Textbox(
value="country, year, indicator",
label="entities",
placeholder="Enter the entities to detect here (comma separated)",
scale=2,
)
threshold = gr.Slider(
0,
1,
value=0.5,
step=0.01,
label="Threshold",
info="Lower threshold may extract more false-positive entities from the query.",
scale=1,
)
is_nested = gr.Checkbox(
value=False,
label="Nested NER",
info="Setting to True extracts nested entities",
scale=0,
)
output = gr.JSON(label="Extracted entities")
submit_btn = gr.Button("Submit")
# Submitting
query.submit(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
entities.submit(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
threshold.release(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
submit_btn.click(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
is_nested.change(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
model_name.change(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
demo.queue(default_concurrency_limit=5)
demo.launch(debug=True)
"""
from gradio_client import Client
client = Client("avsolatorio/query-parser")
result = client.predict(
query="gdp, m3, and child mortality of india and southeast asia 2024",
labels="country, year, statistical indicator, region",
threshold=0.3,
nested_ner=False,
api_name="/parse_query"
)
print(result)
"""