Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from gliner import GLiNER | |
import gradio as gr | |
model = GLiNER.from_pretrained("knowledgator/gliner-multitask-v1.0").to("cpu") | |
PROMPT_TEMPLATE = """Classify the given text having the following classes: {}""" | |
classification_examples = [ | |
[ | |
"The sun is shining and the weather is warm today.", | |
"Weather, Food, Technology", | |
0.5 | |
], | |
[ | |
"I really enjoyed the pizza we had for dinner last night.", | |
"Food, Weather, Sports", | |
0.5 | |
], | |
[ | |
"Das Kind spielt im Park und genießt die frische Luft.", | |
"Nature, Technology, Politics", | |
0.5 | |
] | |
] | |
def prepare_prompts(text, labels): | |
labels_str = ', '.join(labels) | |
return PROMPT_TEMPLATE.format(labels_str) + "\n" + text | |
def process(text, labels, threshold): | |
if not text.strip() or not labels.strip(): | |
return {"text": text, "entities": []} | |
labels = [label.strip() for label in labels.split(",")] | |
prompt = prepare_prompts(text, labels) | |
predictions = model.run([prompt], ["match"], threshold=threshold) | |
entities = [] | |
if predictions and predictions[0]: | |
for pred in predictions[0]: | |
entities.append({ | |
"entity": "match", | |
"word": pred["text"], | |
"start": pred["start"], | |
"end": pred["end"], | |
"score": pred["score"] | |
}) | |
return {"text": prompt, "entities": entities} | |
with gr.Blocks(title="Text Classification with Highlighted Labels") as classification_interface: | |
gr.Markdown("# Text Classification with Highlighted Labels") | |
input_text = gr.Textbox(label="Input Text", placeholder="Enter text for classification") | |
input_labels = gr.Textbox(label="Labels (Comma-Separated)", placeholder="Enter labels separated by commas (e.g., Positive, Negative, Neutral)") | |
threshold = gr.Slider(0, 1, value=0.5, step=0.01, label="Threshold") | |
output = gr.HighlightedText(label="Classification Results") | |
submit_btn = gr.Button("Classify") | |
examples = gr.Examples( | |
examples=classification_examples, | |
inputs=[input_text, input_labels, threshold], | |
outputs=output, | |
fn=process, | |
cache_examples=True | |
) | |
theme=gr.themes.Base() | |
input_text.submit(fn=process, inputs=[input_text, input_labels, threshold], outputs=output) | |
threshold.release(fn=process, inputs=[input_text, input_labels, threshold], outputs=output) | |
submit_btn.click(fn=process, inputs=[input_text, input_labels, threshold], outputs=output) | |
if __name__ == "__main__": | |
classification_interface.launch() |