File size: 5,092 Bytes
04e7b78
9604b3c
 
2adecad
 
 
9604b3c
7400288
 
 
6d5fe23
 
362e959
6d5fe23
 
 
 
 
 
 
 
 
 
 
 
 
 
2adecad
 
 
 
 
 
 
816c523
 
 
 
 
 
2adecad
 
fb5842d
 
6d5fe23
 
 
 
 
 
 
 
fb5842d
6d5fe23
 
 
 
 
 
 
 
fb5842d
2adecad
6d5fe23
362e959
d3061d0
6d5fe23
 
 
 
 
2adecad
 
 
 
 
d3061d0
2adecad
 
 
 
6d5fe23
 
2adecad
 
 
fb5842d
 
6d5fe23
 
2adecad
 
 
 
 
 
 
 
068f0da
fb5842d
 
6d5fe23
fb5842d
 
 
2adecad
fb5842d
 
6d5fe23
fb5842d
 
 
b38e092
fb5842d
 
 
 
 
2adecad
b38e092
fb5842d
 
6d5fe23
fb5842d
 
 
 
b38e092
fb5842d
b38e092
d3061d0
2adecad
 
 
 
 
d3061d0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
from transformers import pipeline

# Initialize the classifiers
zero_shot_classifier = pipeline("zero-shot-classification", model="tasksource/ModernBERT-base-nli")
nli_classifier = pipeline("text-classification", model="tasksource/ModernBERT-base-nli")

if False:
    gr.load("models/answerdotai/ModernBERT-base").launch()

# Define examples
zero_shot_examples = [
    ["I absolutely love this product, it's amazing!", "positive, negative, neutral"],
    ["I need to buy groceries", "shopping, urgent tasks, leisure, philosophy"],
    ["The sun is very bright today", "weather, astronomy, complaints, poetry"],
    ["I love playing video games", "entertainment, sports, education, business"],
    ["The car won't start", "transportation, art, cooking, literature"]
]

nli_examples = [
    ["A man is sleeping on a couch", "The man is awake"],
    ["The restaurant is full of people", "The place is empty"],
    ["The child is playing with toys", "The kid is having fun"],
    ["It's raining outside", "The weather is wet"],
    ["The dog is barking at the mailman", "There is a cat"]
]

def process_input(text_input, labels_or_premise, mode):
    if mode == "Zero-Shot Classification":
        labels = [label.strip() for label in labels_or_premise.split(',')]
        prediction = zero_shot_classifier(text_input, labels)
        results = {label: score for label, score in zip(prediction['labels'], prediction['scores'])}
        return results, ''
    else:  # NLI mode
        prediction = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}])[0]
        results = {
            "entailment": prediction.get("score", 0) if prediction.get("label") == "entailment" else 0,
            "contradiction": prediction.get("score", 0) if prediction.get("label") == "contradiction" else 0,
            "neutral": prediction.get("score", 0) if prediction.get("label") == "neutral" else 0
        }
        return results, ''

def update_interface(mode):
    if mode == "Zero-Shot Classification":
        return (
            gr.update(
                label="🏷️ Categories", 
                placeholder="Enter comma-separated categories...",
                value=zero_shot_examples[0][1]
            ),
            gr.update(value=zero_shot_examples[0][0])
        )
    else:
        return (
            gr.update(
                label="πŸ”Ž Hypothesis", 
                placeholder="Enter a hypothesis to compare with the premise...",
                value=nli_examples[0][1]
            ),
            gr.update(value=nli_examples[0][0])
        )

with gr.Blocks() as demo:
    gr.Markdown("""
    # tasksource/ModernBERT-nli demonstration
    
    Using [tasksource/ModernBERT-base-nli](https://huggingface.co/tasksource/ModernBERT-base-nli), 
    fine-tuned from [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) 
    on large scale tasksource classification tasks. The tuned model achieves high accuracy on reasoning and long-context NLI, outperforming Llama 3 8B on ConTRoL and FOLIO.
    """)

    mode = gr.Radio(
        ["Zero-Shot Classification", "Natural Language Inference"],
        label="Select Mode",
        value="Zero-Shot Classification"
    )
    
    with gr.Column():
        text_input = gr.Textbox(
            label="✍️ Input Text",
            placeholder="Enter your text...",
            lines=3,
            value=zero_shot_examples[0][0]  # Initial value
        )
        
        labels_or_premise = gr.Textbox(
            label="🏷️ Categories",
            placeholder="Enter comma-separated categories...",
            lines=2,
            value=zero_shot_examples[0][1]  # Initial value
        )
        
        submit_btn = gr.Button("Submit")
        
        outputs = [
            gr.Label(label="πŸ“Š Results"),
            gr.Markdown(label="πŸ“ˆ Analysis", visible=False)
        ]

        with gr.Column(variant="panel") as zero_shot_examples_panel:
            gr.Examples(
                examples=zero_shot_examples,
                inputs=[text_input, labels_or_premise],
                label="Zero-Shot Classification Examples"
            )

        with gr.Column(variant="panel") as nli_examples_panel:
            gr.Examples(
                examples=nli_examples,
                inputs=[text_input, labels_or_premise],
                label="Natural Language Inference Examples"
            )

    def update_visibility(mode):
        return (
            gr.update(visible=(mode == "Zero-Shot Classification")),
            gr.update(visible=(mode == "Natural Language Inference"))
        )

    mode.change(
        fn=update_interface,
        inputs=[mode],
        outputs=[labels_or_premise, text_input]
    )
    
    mode.change(
        fn=update_visibility,
        inputs=[mode],
        outputs=[zero_shot_examples_panel, nli_examples_panel]
    )
    
    submit_btn.click(
        fn=process_input,
        inputs=[text_input, labels_or_premise, mode],
        outputs=outputs
    )

if __name__ == "__main__":
    demo.launch()