Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,875 Bytes
1938c90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import os
import gradio as gr
import pycountry
from typing import Dict, Union
from gliner import GLiNER
_MODEL = {}
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
print(f"Cache directory: {_CACHE_DIR}")
def get_model(model_name: str = None):
if model_name is None:
# model_name = "urchade/gliner_base"
model_name = "urchade/gliner_medium-v2.1"
global _MODEL
if _MODEL.get(model_name) is None:
_MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)
return _MODEL[model_name]
def get_country(country_name: str):
try:
return pycountry.countries.search_fuzzy(country_name)
except LookupError:
return None
def parse_query(query: str, labels: Union[str, list], threshold: float = 0.5, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
model = get_model(model_name)
if isinstance(labels, str):
labels = [i.strip() for i in labels.split(",")]
_entities = model.predict_entities(query, labels, threshold=threshold)
entities = []
for entity in _entities:
if entity["label"] == "country":
country = get_country(entity["text"])
if country:
entity["normalized"] = [dict(c) for c in country]
entities.append(entity)
else:
entities.append(entity)
return {"query": query, "entities": entities}
with gr.Blocks(title="GLiNER-query-parser") as demo:
gr.Markdown(
"""
# GLiNER-based Query Parser (a zero-shot NER model)
This space demonstrates the GLiNER model's ability to predict entities in a given text query. Given a set of entities to track, the model can then identify instances of these entities in the query. The parsed entities are then displayed in the output. A special case is the "country" entity, which is normalized to the ISO 3166-1 alpha-2 code using the `pycountry` library. This GLiNER mode is licensed under the Apache 2.0 license.
## Links
* Model: https://huggingface.co/urchade/gliner_medium-v2.1, https://huggingface.co/urchade/gliner_base
* All GLiNER models: https://huggingface.co/models?library=gliner
* Paper: https://arxiv.org/abs/2311.08526
* Repository: https://github.com/urchade/GLiNER
"""
)
query = gr.Textbox(
value="gdp of the philippines in 2024", label="query", placeholder="Enter your query here"
)
with gr.Row() as row:
entities = gr.Textbox(
value="country, year, indicator",
label="entities",
placeholder="Enter the entities to detect here (comma separated)",
scale=2,
)
threshold = gr.Slider(
0,
1,
value=0.5,
step=0.01,
label="Threshold",
info="Lower threshold may extract more false-positive entities from the query.",
scale=1,
)
is_nested = gr.Checkbox(
value=False,
label="Nested NER",
info="Setting to True extracts nested entities",
scale=0,
)
output = gr.JSON(label="Extracted entities")
submit_btn = gr.Button("Submit")
# Submitting
query.submit(
fn=parse_query, inputs=[query, entities, threshold, is_nested], outputs=output
)
entities.submit(
fn=parse_query, inputs=[query, entities, threshold, is_nested], outputs=output
)
threshold.release(
fn=parse_query, inputs=[query, entities, threshold, is_nested], outputs=output
)
submit_btn.click(
fn=parse_query, inputs=[query, entities, threshold, is_nested], outputs=output
)
is_nested.change(
fn=parse_query, inputs=[query, entities, threshold, is_nested], outputs=output
)
demo.queue()
demo.launch(debug=True)
|