import os import gradio as gr import pycountry from typing import Dict, Union from gliner import GLiNER _MODEL = {} _CACHE_DIR = os.environ.get("CACHE_DIR", None) print(f"Cache directory: {_CACHE_DIR}") def get_model(model_name: str = None): if model_name is None: # model_name = "urchade/gliner_base" model_name = "urchade/gliner_medium-v2.1" global _MODEL if _MODEL.get(model_name) is None: _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR) return _MODEL[model_name] def get_country(country_name: str): try: return pycountry.countries.search_fuzzy(country_name) except LookupError: return None def parse_query(query: str, labels: Union[str, list], threshold: float = 0.5, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]: model = get_model(model_name) if isinstance(labels, str): labels = [i.strip() for i in labels.split(",")] _entities = model.predict_entities(query, labels, threshold=threshold) entities = [] for entity in _entities: if entity["label"] == "country": country = get_country(entity["text"]) if country: entity["normalized"] = [dict(c) for c in country] entities.append(entity) else: entities.append(entity) return {"query": query, "entities": entities} with gr.Blocks(title="GLiNER-query-parser") as demo: gr.Markdown( """ # GLiNER-based Query Parser (a zero-shot NER model) This space demonstrates the GLiNER model's ability to predict entities in a given text query. Given a set of entities to track, the model can then identify instances of these entities in the query. The parsed entities are then displayed in the output. A special case is the "country" entity, which is normalized to the ISO 3166-1 alpha-2 code using the `pycountry` library. This GLiNER mode is licensed under the Apache 2.0 license. ## Links * Model: https://huggingface.co/urchade/gliner_medium-v2.1, https://huggingface.co/urchade/gliner_base * All GLiNER models: https://huggingface.co/models?library=gliner * Paper: https://arxiv.org/abs/2311.08526 * Repository: https://github.com/urchade/GLiNER """ ) query = gr.Textbox( value="gdp of the philippines in 2024", label="query", placeholder="Enter your query here" ) with gr.Row() as row: model_name = gr.Radio( choices=["urchade/gliner_medium-v2.1", "urchade/gliner_base"], value="urchade/gliner_medium-v2.1", label="Model", ) entities = gr.Textbox( value="country, year, indicator", label="entities", placeholder="Enter the entities to detect here (comma separated)", scale=2, ) threshold = gr.Slider( 0, 1, value=0.5, step=0.01, label="Threshold", info="Lower threshold may extract more false-positive entities from the query.", scale=1, ) is_nested = gr.Checkbox( value=False, label="Nested NER", info="Setting to True extracts nested entities", scale=0, ) output = gr.JSON(label="Extracted entities") submit_btn = gr.Button("Submit") # Submitting query.submit( fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output ) entities.submit( fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output ) threshold.release( fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output ) submit_btn.click( fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output ) is_nested.change( fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output ) model_name.change( fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output ) demo.queue(default_concurrency_limit=5) demo.launch(debug=True) """ from gradio_client import Client client = Client("avsolatorio/query-parser") result = client.predict( query="gdp, m3, and child mortality of india and southeast asia 2024", labels="country, year, statistical indicator, region", threshold=0.3, nested_ner=False, api_name="/parse_query" ) print(result) """