Spaces:
Running
Running
File size: 5,793 Bytes
d4e7f01 1938c90 aadf93b 1938c90 3f9fe25 aadf93b 1938c90 49fed2a 78d1da2 49fed2a 1938c90 3f9fe25 1938c90 78d1da2 1938c90 3f9fe25 df8632c 3f9fe25 1938c90 49fed2a 1938c90 49fed2a 3f9fe25 81fbf66 3f9fe25 1938c90 81fbf66 d4e7f01 1938c90 3f9fe25 1938c90 aadf93b 81fbf66 aadf93b 1938c90 78d1da2 1938c90 78762bd 49fed2a 78d1da2 78762bd 1938c90 78d1da2 1938c90 49fed2a 1938c90 78762bd 1938c90 78762bd 1938c90 78762bd 1938c90 78762bd 1938c90 78762bd 1938c90 78762bd 1938c90 78762bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import spaces
import os
import json
import gradio as gr
import pycountry
import torch
from datetime import datetime
from typing import Dict, Union
from gliner import GLiNER
_MODEL = {}
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
THRESHOLD = 0.3
LABELS = ["country", "year", "statistical indicator", "geographic region"]
QUERY = "gdp, co2 emissions, and mortality rate of the philippines vs. south asia in 2024"
MODELS = ["urchade/gliner_base", "urchade/gliner_medium-v2.1"]
print(f"Cache directory: {_CACHE_DIR}")
def get_model(model_name: str = None):
start = datetime.now()
if model_name is None:
model_name = "urchade/gliner_base"
global _MODEL
if _MODEL.get(model_name) is None:
_MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)
if torch.cuda.is_available() and not next(_MODEL[model_name].parameters()).device.type.startswith("cuda"):
_MODEL[model_name] = _MODEL[model_name].to("cuda")
print(f"{datetime.now()} :: get_model :: {datetime.now() - start}")
return _MODEL[model_name]
# Initialize model here.
print("Initializing models...")
for model_name in MODELS:
model = get_model(model_name=model_name)
model.predict_entities(QUERY, LABELS, threshold=THRESHOLD)
def get_country(country_name: str):
try:
return pycountry.countries.search_fuzzy(country_name)
except LookupError:
return None
@spaces.GPU(enable_queue=True, duration=5)
def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
start = datetime.now()
model = get_model(model_name)
if isinstance(labels, str):
labels = [i.strip() for i in labels.split(",")]
entities = model.predict_entities(query, labels, threshold=threshold, flat_ner=not nested_ner)
print(f"{datetime.now()} :: predict_entities :: {datetime.now() - start}")
return entities
def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
entities = []
_entities = predict_entities(model_name=model_name, query=query, labels=labels, threshold=threshold, nested_ner=nested_ner)
for entity in _entities:
if entity["label"] == "country":
country = get_country(entity["text"])
if country:
entity["normalized"] = [dict(c) for c in country]
entities.append(entity)
else:
entities.append(entity)
payload = {"query": query, "entities": entities}
print(f"{datetime.now()} :: parse_query :: {json.dumps(payload)}\n")
return payload
with gr.Blocks(title="GLiNER-query-parser") as demo:
gr.Markdown(
"""
# GLiNER-based Query Parser (a zero-shot NER model)
This space demonstrates the GLiNER model's ability to predict entities in a given text query. Given a set of entities to track, the model can then identify instances of these entities in the query. The parsed entities are then displayed in the output. A special case is the "country" entity, which is normalized to the ISO 3166-1 alpha-2 code using the `pycountry` library. This GLiNER mode is licensed under the Apache 2.0 license.
## Links
* Model: https://huggingface.co/urchade/gliner_medium-v2.1, https://huggingface.co/urchade/gliner_base
* All GLiNER models: https://huggingface.co/models?library=gliner
* Paper: https://arxiv.org/abs/2311.08526
* Repository: https://github.com/urchade/GLiNER
"""
)
query = gr.Textbox(
value=QUERY, label="query", placeholder="Enter your query here"
)
with gr.Row() as row:
model_name = gr.Radio(
choices=MODELS,
value="urchade/gliner_base",
label="Model",
)
entities = gr.Textbox(
value=", ".join(LABELS),
label="entities",
placeholder="Enter the entities to detect here (comma separated)",
scale=2,
)
threshold = gr.Slider(
0,
1,
value=THRESHOLD,
step=0.01,
label="Threshold",
info="Lower threshold may extract more false-positive entities from the query.",
scale=1,
)
is_nested = gr.Checkbox(
value=False,
label="Nested NER",
info="Setting to True extracts nested entities",
scale=0,
)
output = gr.JSON(label="Extracted entities")
submit_btn = gr.Button("Submit")
# Submitting
query.submit(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
entities.submit(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
threshold.release(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
submit_btn.click(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
is_nested.change(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
model_name.change(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
demo.queue(default_concurrency_limit=5)
demo.launch(debug=True)
"""
from gradio_client import Client
client = Client("avsolatorio/query-parser")
result = client.predict(
query="gdp, m3, and child mortality of india and southeast asia 2024",
labels="country, year, statistical indicator, region",
threshold=0.3,
nested_ner=False,
api_name="/parse_query"
)
print(result)
""" |