File size: 2,628 Bytes
b03b2db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5187656
 
b03b2db
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from gliner import GLiNER
import gradio as gr

model = GLiNER.from_pretrained("knowledgator/gliner-multitask-v1.0").to("cpu")

PROMPT_TEMPLATE = """Classify the given text having the following classes: {}"""
classification_examples = [
    [
        "The sun is shining and the weather is warm today.",
        "Weather, Food, Technology",
        0.5
    ],
    [
        "I really enjoyed the pizza we had for dinner last night.",
        "Food, Weather, Sports",
        0.5
    ],
    [
        "Das Kind spielt im Park und genießt die frische Luft.",
        "Nature, Technology, Politics",
        0.5
    ]
]

def prepare_prompts(text, labels):
    labels_str = ', '.join(labels)
    return PROMPT_TEMPLATE.format(labels_str) + "\n" + text

def process(text, labels, threshold):
    if not text.strip() or not labels.strip():
        return {"text": text, "entities": []}
    
    labels = [label.strip() for label in labels.split(",")]
    prompt = prepare_prompts(text, labels)
    
    predictions = model.run([prompt], ["match"], threshold=threshold)
    entities = []

    if predictions and predictions[0]:
        for pred in predictions[0]:
            entities.append({
                "entity": "match",
                "word": pred["text"],
                "start": pred["start"],
                "end": pred["end"],
                "score": pred["score"]
            })
    
    return {"text": prompt, "entities": entities}

with gr.Blocks(title="Text Classification with Highlighted Labels") as classification_interface:
    gr.Markdown("# Text Classification with Highlighted Labels")
    
    input_text = gr.Textbox(label="Input Text", placeholder="Enter text for classification")
    input_labels = gr.Textbox(label="Labels (Comma-Separated)", placeholder="Enter labels separated by commas (e.g., Positive, Negative, Neutral)")
    threshold = gr.Slider(0, 1, value=0.5, step=0.01, label="Threshold")
    
    output = gr.HighlightedText(label="Classification Results")
    
    submit_btn = gr.Button("Classify")
    
    examples = gr.Examples(
        examples=classification_examples,
        inputs=[input_text, input_labels, threshold],
        outputs=output,
        fn=process,
        cache_examples=True
    )
    theme=gr.themes.Base()


    input_text.submit(fn=process, inputs=[input_text, input_labels, threshold], outputs=output)
    threshold.release(fn=process, inputs=[input_text, input_labels, threshold], outputs=output)
    submit_btn.click(fn=process, inputs=[input_text, input_labels, threshold], outputs=output)

if __name__ == "__main__":
    classification_interface.launch()