Update app.py
Browse files
app.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
import gradio as gr
|
2 |
from transformers import pipeline
|
|
|
|
|
|
|
3 |
|
4 |
# Initialize the classifiers
|
5 |
zero_shot_classifier = pipeline("zero-shot-classification", model="tasksource/ModernBERT-base-nli")
|
6 |
nli_classifier = pipeline("text-classification", model="tasksource/ModernBERT-base-nli")
|
7 |
|
8 |
-
|
9 |
-
gr.load("models/answerdotai/ModernBERT-base").launch()
|
10 |
-
|
11 |
-
# Define examples
|
12 |
zero_shot_examples = [
|
13 |
["I absolutely love this product, it's amazing!", "positive, negative, neutral"],
|
14 |
["I need to buy groceries", "shopping, urgent tasks, leisure, philosophy"],
|
@@ -25,17 +25,54 @@ nli_examples = [
|
|
25 |
["A German Shepherd is exhibiting defensive behavior towards someone approaching the property", "The animal making noise is feline"]
|
26 |
]
|
27 |
|
|
|
|
|
|
|
|
|
|
|
28 |
def process_input(text_input, labels_or_premise, mode):
|
29 |
if mode == "Zero-Shot Classification":
|
30 |
labels = [label.strip() for label in labels_or_premise.split(',')]
|
31 |
prediction = zero_shot_classifier(text_input, labels)
|
32 |
results = {label: score for label, score in zip(prediction['labels'], prediction['scores'])}
|
33 |
return results, ''
|
34 |
-
|
35 |
-
pred= nli_classifier([{"text": text_input, "text_pair": labels_or_premise}],return_all_scores=True)[0]
|
36 |
-
results= {pred['label']:pred['score'] for pred in pred}
|
37 |
-
|
38 |
return results, ''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
def update_interface(mode):
|
41 |
if mode == "Zero-Shot Classification":
|
@@ -47,7 +84,7 @@ def update_interface(mode):
|
|
47 |
),
|
48 |
gr.update(value=zero_shot_examples[0][0])
|
49 |
)
|
50 |
-
|
51 |
return (
|
52 |
gr.update(
|
53 |
label="π Hypothesis",
|
@@ -56,19 +93,28 @@ def update_interface(mode):
|
|
56 |
),
|
57 |
gr.update(value=nli_examples[0][0])
|
58 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
with gr.Blocks() as demo:
|
61 |
gr.Markdown("""
|
62 |
# tasksource/ModernBERT-nli demonstration
|
63 |
|
64 |
-
This
|
65 |
fine-tuned from [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)
|
66 |
on tasksource classification tasks.
|
67 |
This NLI model achieves high accuracy on logical reasoning and long-context NLI, outperforming Llama 3 8B on ConTRoL and FOLIO.
|
68 |
""")
|
69 |
|
70 |
mode = gr.Radio(
|
71 |
-
["Zero-Shot Classification", "Natural Language Inference"],
|
72 |
label="Select Mode",
|
73 |
value="Zero-Shot Classification"
|
74 |
)
|
@@ -78,21 +124,21 @@ with gr.Blocks() as demo:
|
|
78 |
label="βοΈ Input Text",
|
79 |
placeholder="Enter your text...",
|
80 |
lines=3,
|
81 |
-
value=zero_shot_examples[0][0]
|
82 |
)
|
83 |
|
84 |
labels_or_premise = gr.Textbox(
|
85 |
label="π·οΈ Categories",
|
86 |
placeholder="Enter comma-separated categories...",
|
87 |
lines=2,
|
88 |
-
value=zero_shot_examples[0][1]
|
89 |
)
|
90 |
|
91 |
submit_btn = gr.Button("Submit")
|
92 |
|
93 |
outputs = [
|
94 |
gr.Label(label="π Results"),
|
95 |
-
gr.Markdown(label="π Analysis", visible=
|
96 |
]
|
97 |
|
98 |
with gr.Column(variant="panel") as zero_shot_examples_panel:
|
@@ -107,12 +153,20 @@ with gr.Blocks() as demo:
|
|
107 |
examples=nli_examples,
|
108 |
inputs=[text_input, labels_or_premise],
|
109 |
label="Natural Language Inference Examples",
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
def update_visibility(mode):
|
113 |
return (
|
114 |
gr.update(visible=(mode == "Zero-Shot Classification")),
|
115 |
-
gr.update(visible=(mode == "Natural Language Inference"))
|
|
|
116 |
)
|
117 |
|
118 |
mode.change(
|
@@ -124,7 +178,7 @@ with gr.Blocks() as demo:
|
|
124 |
mode.change(
|
125 |
fn=update_visibility,
|
126 |
inputs=[mode],
|
127 |
-
outputs=[zero_shot_examples_panel, nli_examples_panel]
|
128 |
)
|
129 |
|
130 |
submit_btn.click(
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import pipeline
|
3 |
+
import nltk
|
4 |
+
nltk.download('punkt')
|
5 |
+
from nltk.tokenize import sent_tokenize
|
6 |
|
7 |
# Initialize the classifiers
|
8 |
zero_shot_classifier = pipeline("zero-shot-classification", model="tasksource/ModernBERT-base-nli")
|
9 |
nli_classifier = pipeline("text-classification", model="tasksource/ModernBERT-base-nli")
|
10 |
|
11 |
+
# Define examples (including new long context example)
|
|
|
|
|
|
|
12 |
zero_shot_examples = [
|
13 |
["I absolutely love this product, it's amazing!", "positive, negative, neutral"],
|
14 |
["I need to buy groceries", "shopping, urgent tasks, leisure, philosophy"],
|
|
|
25 |
["A German Shepherd is exhibiting defensive behavior towards someone approaching the property", "The animal making noise is feline"]
|
26 |
]
|
27 |
|
28 |
+
long_context_examples = [
|
29 |
+
["""The small cafe on the corner has been bustling with activity all morning. The aroma of freshly baked pastries wafts through the air, drawing in passersby. The baristas work efficiently behind the counter, crafting intricate latte art. Several customers are seated at wooden tables, engaged in quiet conversations or working on laptops. Through the large windows, sunshine streams in, creating a warm and inviting atmosphere.""",
|
30 |
+
"The cafe is experiencing a slow, quiet morning"]
|
31 |
+
]
|
32 |
+
|
33 |
def process_input(text_input, labels_or_premise, mode):
|
34 |
if mode == "Zero-Shot Classification":
|
35 |
labels = [label.strip() for label in labels_or_premise.split(',')]
|
36 |
prediction = zero_shot_classifier(text_input, labels)
|
37 |
results = {label: score for label, score in zip(prediction['labels'], prediction['scores'])}
|
38 |
return results, ''
|
39 |
+
elif mode == "Natural Language Inference":
|
40 |
+
pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0]
|
41 |
+
results = {pred['label']: pred['score'] for pred in pred}
|
|
|
42 |
return results, ''
|
43 |
+
else: # Long Context NLI
|
44 |
+
# Global prediction
|
45 |
+
global_pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0]
|
46 |
+
global_results = {pred['label']: pred['score'] for pred in global_pred}
|
47 |
+
|
48 |
+
# Sentence-level analysis
|
49 |
+
sentences = sent_tokenize(text_input)
|
50 |
+
sentence_results = []
|
51 |
+
|
52 |
+
for sentence in sentences:
|
53 |
+
sent_pred = nli_classifier([{"text": sentence, "text_pair": labels_or_premise}], return_all_scores=True)[0]
|
54 |
+
sent_scores = {pred['label']: pred['score'] for pred in sent_pred}
|
55 |
+
max_label = max(sent_scores.items(), key=lambda x: x[1])[0]
|
56 |
+
sentence_results.append({
|
57 |
+
'sentence': sentence,
|
58 |
+
'prediction': max_label,
|
59 |
+
'scores': sent_scores
|
60 |
+
})
|
61 |
+
|
62 |
+
# Create markdown analysis
|
63 |
+
analysis_md = "## Global Prediction\n"
|
64 |
+
max_global_label = max(global_results.items(), key=lambda x: x[1])[0]
|
65 |
+
analysis_md += f"Overall prediction: **{max_global_label}**\n\n"
|
66 |
+
analysis_md += "## Sentence-Level Analysis\n"
|
67 |
+
|
68 |
+
for i, result in enumerate(sentence_results, 1):
|
69 |
+
analysis_md += f"\n### Sentence {i}\n"
|
70 |
+
analysis_md += f"*{result['sentence']}*\n"
|
71 |
+
analysis_md += f"Prediction: **{result['prediction']}**\n"
|
72 |
+
scores_str = ", ".join([f"{label}: {score:.2f}" for label, score in result['scores'].items()])
|
73 |
+
analysis_md += f"Scores: {scores_str}\n"
|
74 |
+
|
75 |
+
return global_results, analysis_md
|
76 |
|
77 |
def update_interface(mode):
|
78 |
if mode == "Zero-Shot Classification":
|
|
|
84 |
),
|
85 |
gr.update(value=zero_shot_examples[0][0])
|
86 |
)
|
87 |
+
elif mode == "Natural Language Inference":
|
88 |
return (
|
89 |
gr.update(
|
90 |
label="π Hypothesis",
|
|
|
93 |
),
|
94 |
gr.update(value=nli_examples[0][0])
|
95 |
)
|
96 |
+
else: # Long Context NLI
|
97 |
+
return (
|
98 |
+
gr.update(
|
99 |
+
label="π Global Hypothesis",
|
100 |
+
placeholder="Enter a hypothesis to test against the full context...",
|
101 |
+
value=long_context_examples[0][1]
|
102 |
+
),
|
103 |
+
gr.update(value=long_context_examples[0][0])
|
104 |
+
)
|
105 |
|
106 |
with gr.Blocks() as demo:
|
107 |
gr.Markdown("""
|
108 |
# tasksource/ModernBERT-nli demonstration
|
109 |
|
110 |
+
This space uses [tasksource/ModernBERT-base-nli](https://huggingface.co/tasksource/ModernBERT-base-nli),
|
111 |
fine-tuned from [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)
|
112 |
on tasksource classification tasks.
|
113 |
This NLI model achieves high accuracy on logical reasoning and long-context NLI, outperforming Llama 3 8B on ConTRoL and FOLIO.
|
114 |
""")
|
115 |
|
116 |
mode = gr.Radio(
|
117 |
+
["Zero-Shot Classification", "Natural Language Inference", "Long Context NLI"],
|
118 |
label="Select Mode",
|
119 |
value="Zero-Shot Classification"
|
120 |
)
|
|
|
124 |
label="βοΈ Input Text",
|
125 |
placeholder="Enter your text...",
|
126 |
lines=3,
|
127 |
+
value=zero_shot_examples[0][0]
|
128 |
)
|
129 |
|
130 |
labels_or_premise = gr.Textbox(
|
131 |
label="π·οΈ Categories",
|
132 |
placeholder="Enter comma-separated categories...",
|
133 |
lines=2,
|
134 |
+
value=zero_shot_examples[0][1]
|
135 |
)
|
136 |
|
137 |
submit_btn = gr.Button("Submit")
|
138 |
|
139 |
outputs = [
|
140 |
gr.Label(label="π Results"),
|
141 |
+
gr.Markdown(label="π Sentence Analysis", visible=True)
|
142 |
]
|
143 |
|
144 |
with gr.Column(variant="panel") as zero_shot_examples_panel:
|
|
|
153 |
examples=nli_examples,
|
154 |
inputs=[text_input, labels_or_premise],
|
155 |
label="Natural Language Inference Examples",
|
156 |
+
)
|
157 |
+
|
158 |
+
with gr.Column(variant="panel") as long_context_examples_panel:
|
159 |
+
gr.Examples(
|
160 |
+
examples=long_context_examples,
|
161 |
+
inputs=[text_input, labels_or_premise],
|
162 |
+
label="Long Context NLI Examples",
|
163 |
+
)
|
164 |
|
165 |
def update_visibility(mode):
|
166 |
return (
|
167 |
gr.update(visible=(mode == "Zero-Shot Classification")),
|
168 |
+
gr.update(visible=(mode == "Natural Language Inference")),
|
169 |
+
gr.update(visible=(mode == "Long Context NLI"))
|
170 |
)
|
171 |
|
172 |
mode.change(
|
|
|
178 |
mode.change(
|
179 |
fn=update_visibility,
|
180 |
inputs=[mode],
|
181 |
+
outputs=[zero_shot_examples_panel, nli_examples_panel, long_context_examples_panel]
|
182 |
)
|
183 |
|
184 |
submit_btn.click(
|