query-parser / app.py
avsolatorio's picture
Use the predict_entities method
0864cfe
import spaces
import os
import json
import gradio as gr
import pycountry
import torch
from datetime import datetime
from typing import Dict, Union
from gliner import GLiNER
_MODEL = {}
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
THRESHOLD = 0.3
LABELS = ["country", "year", "statistical indicator", "geographic region"]
QUERY = "gdp, co2 emissions, and mortality rate of the philippines vs. south asia in 2024"
MODELS = ["urchade/gliner_base", "urchade/gliner_medium-v2.1"]
print(f"Cache directory: {_CACHE_DIR}")
def get_model(model_name: str = None):
start = datetime.now()
if model_name is None:
model_name = "urchade/gliner_base"
global _MODEL
if _MODEL.get(model_name) is None:
_MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)
if torch.cuda.is_available() and not next(_MODEL[model_name].parameters()).device.type.startswith("cuda"):
_MODEL[model_name] = _MODEL[model_name].to("cuda")
print(f"{datetime.now()} :: get_model :: {datetime.now() - start}")
return _MODEL[model_name]
def get_country(country_name: str):
try:
return pycountry.countries.search_fuzzy(country_name)
except LookupError:
return None
@spaces.GPU(enable_queue=True, duration=5)
def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
start = datetime.now()
model = get_model(model_name)
if isinstance(labels, str):
labels = [i.strip() for i in labels.split(",")]
entities = model.predict_entities(query, labels, threshold=threshold, flat_ner=not nested_ner)
print(f"{datetime.now()} :: predict_entities :: {datetime.now() - start}")
return entities
def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
entities = []
_entities = predict_entities(model_name=model_name, query=query, labels=labels, threshold=threshold, nested_ner=nested_ner)
for entity in _entities:
if entity["label"] == "country":
country = get_country(entity["text"])
if country:
entity["normalized"] = [dict(c) for c in country]
entities.append(entity)
else:
entities.append(entity)
payload = {"query": query, "entities": entities}
print(f"{datetime.now()} :: parse_query :: {json.dumps(payload)}\n")
return payload
# Initialize model here.
print("Initializing models...")
for model_name in MODELS:
predict_entities(model_name, QUERY, LABELS, threshold=THRESHOLD)
with gr.Blocks(title="GLiNER-query-parser") as demo:
gr.Markdown(
"""
# GLiNER-based Query Parser (a zero-shot NER model)
This space demonstrates the GLiNER model's ability to predict entities in a given text query. Given a set of entities to track, the model can then identify instances of these entities in the query. The parsed entities are then displayed in the output. A special case is the "country" entity, which is normalized to the ISO 3166-1 alpha-2 code using the `pycountry` library. This GLiNER mode is licensed under the Apache 2.0 license.
## Links
* Model: https://huggingface.co/urchade/gliner_medium-v2.1, https://huggingface.co/urchade/gliner_base
* All GLiNER models: https://huggingface.co/models?library=gliner
* Paper: https://arxiv.org/abs/2311.08526
* Repository: https://github.com/urchade/GLiNER
"""
)
query = gr.Textbox(
value=QUERY, label="query", placeholder="Enter your query here"
)
with gr.Row() as row:
model_name = gr.Radio(
choices=MODELS,
value="urchade/gliner_base",
label="Model",
)
entities = gr.Textbox(
value=", ".join(LABELS),
label="entities",
placeholder="Enter the entities to detect here (comma separated)",
scale=2,
)
threshold = gr.Slider(
0,
1,
value=THRESHOLD,
step=0.01,
label="Threshold",
info="Lower threshold may extract more false-positive entities from the query.",
scale=1,
)
is_nested = gr.Checkbox(
value=False,
label="Nested NER",
info="Setting to True extracts nested entities",
scale=0,
)
output = gr.JSON(label="Extracted entities")
submit_btn = gr.Button("Submit")
# Submitting
query.submit(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
entities.submit(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
threshold.release(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
submit_btn.click(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
is_nested.change(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
model_name.change(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
demo.queue(default_concurrency_limit=5)
demo.launch(debug=True)
"""
from gradio_client import Client
client = Client("avsolatorio/query-parser")
result = client.predict(
query="gdp, m3, and child mortality of india and southeast asia 2024",
labels="country, year, statistical indicator, region",
threshold=0.3,
nested_ner=False,
api_name="/parse_query"
)
print(result)
"""