|
import gradio as gr |
|
from transformers import pipeline |
|
|
|
|
|
zero_shot_classifier = pipeline("zero-shot-classification", model="tasksource/ModernBERT-base-nli") |
|
nli_classifier = pipeline("text-classification", model="tasksource/ModernBERT-base-nli") |
|
|
|
if False: |
|
gr.load("models/answerdotai/ModernBERT-base").launch() |
|
|
|
|
|
zero_shot_examples = [ |
|
["I absolutely love this product, it's amazing!", "positive, negative, neutral"], |
|
["I need to buy groceries", "shopping, urgent tasks, leisure, philosophy"], |
|
["The sun is very bright today", "weather, astronomy, complaints, poetry"], |
|
["I love playing video games", "entertainment, sports, education, business"], |
|
["The car won't start", "transportation, art, cooking, literature"] |
|
] |
|
|
|
nli_examples = [ |
|
["A man is sleeping on a couch", "The man is awake"], |
|
["The restaurant is full of people", "The place is empty"], |
|
["The child is playing with toys", "The kid is having fun"], |
|
["It's raining outside", "The weather is wet"], |
|
["The dog is barking at the mailman", "There is a cat"] |
|
] |
|
|
|
def process_input(text_input, labels_or_premise, mode): |
|
if mode == "Zero-Shot Classification": |
|
labels = [label.strip() for label in labels_or_premise.split(',')] |
|
prediction = zero_shot_classifier(text_input, labels) |
|
results = {label: score for label, score in zip(prediction['labels'], prediction['scores'])} |
|
return results, '' |
|
else: |
|
prediction = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}])[0] |
|
results = { |
|
"entailment": prediction.get("score", 0) if prediction.get("label") == "entailment" else 0, |
|
"contradiction": prediction.get("score", 0) if prediction.get("label") == "contradiction" else 0, |
|
"neutral": prediction.get("score", 0) if prediction.get("label") == "neutral" else 0 |
|
} |
|
return results, '' |
|
|
|
def update_interface(mode): |
|
if mode == "Zero-Shot Classification": |
|
return ( |
|
gr.update( |
|
label="π·οΈ Categories", |
|
placeholder="Enter comma-separated categories...", |
|
value=zero_shot_examples[0][1] |
|
), |
|
gr.update(value=zero_shot_examples[0][0]) |
|
) |
|
else: |
|
return ( |
|
gr.update( |
|
label="π Hypothesis", |
|
placeholder="Enter a hypothesis to compare with the premise...", |
|
value=nli_examples[0][1] |
|
), |
|
gr.update(value=nli_examples[0][0]) |
|
) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(""" |
|
# tasksource/ModernBERT-nli demonstration |
|
|
|
Using [tasksource/ModernBERT-base-nli](https://huggingface.co/tasksource/ModernBERT-base-nli), |
|
fine-tuned from [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) |
|
on large scale tasksource classification tasks. The tuned model achieves high accuracy on reasoning and long-context NLI, outperforming Llama 3 8B on ConTRoL and FOLIO. |
|
""") |
|
|
|
mode = gr.Radio( |
|
["Zero-Shot Classification", "Natural Language Inference"], |
|
label="Select Mode", |
|
value="Zero-Shot Classification" |
|
) |
|
|
|
with gr.Column(): |
|
text_input = gr.Textbox( |
|
label="βοΈ Input Text", |
|
placeholder="Enter your text...", |
|
lines=3, |
|
value=zero_shot_examples[0][0] |
|
) |
|
|
|
labels_or_premise = gr.Textbox( |
|
label="π·οΈ Categories", |
|
placeholder="Enter comma-separated categories...", |
|
lines=2, |
|
value=zero_shot_examples[0][1] |
|
) |
|
|
|
submit_btn = gr.Button("Submit") |
|
|
|
outputs = [ |
|
gr.Label(label="π Results"), |
|
gr.Markdown(label="π Analysis", visible=False) |
|
] |
|
|
|
with gr.Column(variant="panel") as zero_shot_examples_panel: |
|
gr.Examples( |
|
examples=zero_shot_examples, |
|
inputs=[text_input, labels_or_premise], |
|
label="Zero-Shot Classification Examples" |
|
) |
|
|
|
with gr.Column(variant="panel") as nli_examples_panel: |
|
gr.Examples( |
|
examples=nli_examples, |
|
inputs=[text_input, labels_or_premise], |
|
label="Natural Language Inference Examples" |
|
) |
|
|
|
def update_visibility(mode): |
|
return ( |
|
gr.update(visible=(mode == "Zero-Shot Classification")), |
|
gr.update(visible=(mode == "Natural Language Inference")) |
|
) |
|
|
|
mode.change( |
|
fn=update_interface, |
|
inputs=[mode], |
|
outputs=[labels_or_premise, text_input] |
|
) |
|
|
|
mode.change( |
|
fn=update_visibility, |
|
inputs=[mode], |
|
outputs=[zero_shot_examples_panel, nli_examples_panel] |
|
) |
|
|
|
submit_btn.click( |
|
fn=process_input, |
|
inputs=[text_input, labels_or_premise, mode], |
|
outputs=outputs |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |