GLiNER_HandyLab / interfaces /classification.py
BioMike's picture
Update interfaces/classification.py
5187656 verified
raw
history blame
2.63 kB
from gliner import GLiNER
import gradio as gr
model = GLiNER.from_pretrained("knowledgator/gliner-multitask-v1.0").to("cpu")
PROMPT_TEMPLATE = """Classify the given text having the following classes: {}"""
classification_examples = [
[
"The sun is shining and the weather is warm today.",
"Weather, Food, Technology",
0.5
],
[
"I really enjoyed the pizza we had for dinner last night.",
"Food, Weather, Sports",
0.5
],
[
"Das Kind spielt im Park und genießt die frische Luft.",
"Nature, Technology, Politics",
0.5
]
]
def prepare_prompts(text, labels):
labels_str = ', '.join(labels)
return PROMPT_TEMPLATE.format(labels_str) + "\n" + text
def process(text, labels, threshold):
if not text.strip() or not labels.strip():
return {"text": text, "entities": []}
labels = [label.strip() for label in labels.split(",")]
prompt = prepare_prompts(text, labels)
predictions = model.run([prompt], ["match"], threshold=threshold)
entities = []
if predictions and predictions[0]:
for pred in predictions[0]:
entities.append({
"entity": "match",
"word": pred["text"],
"start": pred["start"],
"end": pred["end"],
"score": pred["score"]
})
return {"text": prompt, "entities": entities}
with gr.Blocks(title="Text Classification with Highlighted Labels") as classification_interface:
gr.Markdown("# Text Classification with Highlighted Labels")
input_text = gr.Textbox(label="Input Text", placeholder="Enter text for classification")
input_labels = gr.Textbox(label="Labels (Comma-Separated)", placeholder="Enter labels separated by commas (e.g., Positive, Negative, Neutral)")
threshold = gr.Slider(0, 1, value=0.5, step=0.01, label="Threshold")
output = gr.HighlightedText(label="Classification Results")
submit_btn = gr.Button("Classify")
examples = gr.Examples(
examples=classification_examples,
inputs=[input_text, input_labels, threshold],
outputs=output,
fn=process,
cache_examples=True
)
theme=gr.themes.Base()
input_text.submit(fn=process, inputs=[input_text, input_labels, threshold], outputs=output)
threshold.release(fn=process, inputs=[input_text, input_labels, threshold], outputs=output)
submit_btn.click(fn=process, inputs=[input_text, input_labels, threshold], outputs=output)
if __name__ == "__main__":
classification_interface.launch()