File size: 4,291 Bytes
04e7b78
9604b3c
 
2adecad
 
 
9604b3c
2adecad
 
 
 
 
 
 
816c523
 
 
 
 
 
 
2adecad
 
fb5842d
 
 
 
816c523
fb5842d
2adecad
 
d3061d0
2adecad
 
 
 
 
d3061d0
2adecad
 
 
 
 
 
 
 
fb5842d
 
2adecad
 
 
 
 
 
 
 
 
068f0da
fb5842d
 
 
 
 
 
 
 
 
 
 
 
2adecad
fb5842d
 
 
 
 
 
 
 
 
 
 
 
b38e092
fb5842d
 
 
 
 
2adecad
b38e092
fb5842d
 
 
 
 
 
 
b38e092
fb5842d
b38e092
d3061d0
2adecad
 
 
 
 
d3061d0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
from transformers import pipeline

# Initialize the classifiers
zero_shot_classifier = pipeline("zero-shot-classification", model="tasksource/ModernBERT-base-nli")
nli_classifier = pipeline("text-classification", model="tasksource/ModernBERT-base-nli")

def process_input(text_input, labels_or_premise, mode):
    if mode == "Zero-Shot Classification":
        labels = [label.strip() for label in labels_or_premise.split(',')]
        prediction = zero_shot_classifier(text_input, labels)
        results = {label: score for label, score in zip(prediction['labels'], prediction['scores'])}
        return results, ''
    else:  # NLI mode
        prediction = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}])[0]
        # Force showing all three labels
        results = {
            "entailment": prediction.get("score", 0) if prediction.get("label") == "entailment" else 0,
            "contradiction": prediction.get("score", 0) if prediction.get("label") == "contradiction" else 0,
            "neutral": prediction.get("score", 0) if prediction.get("label") == "neutral" else 0
        }
        return results, ''

def update_interface(mode):
    if mode == "Zero-Shot Classification":
        return gr.update(label="🏷️ Categories", placeholder="Enter comma-separated categories...")
    else:
        return gr.update(label="πŸ”Ž Hypothesis", placeholder="Enter a hypothesis to compare with the premise...")

with gr.Blocks() as demo:
    gr.Markdown("# πŸ€– ModernBERT Text Analysis")
    
    mode = gr.Radio(
        ["Zero-Shot Classification", "Natural Language Inference"],
        label="Select Mode",
        value="Zero-Shot Classification"
    )
    
    with gr.Column():
        text_input = gr.Textbox(
            label="✍️ Input Text",
            placeholder="Enter your text...",
            lines=3
        )
        
        labels_or_premise = gr.Textbox(
            label="🏷️ Categories",
            placeholder="Enter comma-separated categories...",
            lines=2
        )
        
        submit_btn = gr.Button("Submit")
        
        outputs = [
            gr.Label(label="πŸ“Š Results"),
            gr.Markdown(label="πŸ“ˆ Analysis", visible=False)
        ]

        with gr.Column(variant="panel") as zero_shot_examples_panel:
            gr.Examples(
                examples=[
                    ["I need to buy groceries", "shopping, urgent tasks, leisure, philosophy"],
                    ["The sun is very bright today", "weather, astronomy, complaints, poetry"],
                    ["I love playing video games", "entertainment, sports, education, business"],
                    ["The car won't start", "transportation, art, cooking, literature"],
                    ["She wrote a beautiful poem", "creativity, finance, exercise, technology"]
                ],
                inputs=[text_input, labels_or_premise],
                label="Zero-Shot Classification Examples"
            )

        with gr.Column(variant="panel") as nli_examples_panel:
            gr.Examples(
                examples=[
                    ["A man is sleeping on a couch", "The man is awake"],
                    ["The restaurant is full of people", "The place is empty"],
                    ["The child is playing with toys", "The kid is having fun"],
                    ["It's raining outside", "The weather is wet"],
                    ["The dog is barking at the mailman", "There is a cat"]
                ],
                inputs=[text_input, labels_or_premise],
                label="Natural Language Inference Examples"
            )

    def update_visibility(mode):
        return (
            gr.update(visible=(mode == "Zero-Shot Classification")),
            gr.update(visible=(mode == "Natural Language Inference"))
        )

    mode.change(
        fn=update_interface,
        inputs=[mode],
        outputs=[labels_or_premise]
    )
    
    mode.change(
        fn=update_visibility,
        inputs=[mode],
        outputs=[zero_shot_examples_panel, nli_examples_panel]
    )
    
    submit_btn.click(
        fn=process_input,
        inputs=[text_input, labels_or_premise, mode],
        outputs=outputs
    )

if __name__ == "__main__":
    demo.launch()