sileod's picture
Update app.py
362e959 verified
raw
history blame
5.09 kB
import gradio as gr
from transformers import pipeline
# Initialize the classifiers
zero_shot_classifier = pipeline("zero-shot-classification", model="tasksource/ModernBERT-base-nli")
nli_classifier = pipeline("text-classification", model="tasksource/ModernBERT-base-nli")
if False:
gr.load("models/answerdotai/ModernBERT-base").launch()
# Define examples
zero_shot_examples = [
["I absolutely love this product, it's amazing!", "positive, negative, neutral"],
["I need to buy groceries", "shopping, urgent tasks, leisure, philosophy"],
["The sun is very bright today", "weather, astronomy, complaints, poetry"],
["I love playing video games", "entertainment, sports, education, business"],
["The car won't start", "transportation, art, cooking, literature"]
]
nli_examples = [
["A man is sleeping on a couch", "The man is awake"],
["The restaurant is full of people", "The place is empty"],
["The child is playing with toys", "The kid is having fun"],
["It's raining outside", "The weather is wet"],
["The dog is barking at the mailman", "There is a cat"]
]
def process_input(text_input, labels_or_premise, mode):
if mode == "Zero-Shot Classification":
labels = [label.strip() for label in labels_or_premise.split(',')]
prediction = zero_shot_classifier(text_input, labels)
results = {label: score for label, score in zip(prediction['labels'], prediction['scores'])}
return results, ''
else: # NLI mode
prediction = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}])[0]
results = {
"entailment": prediction.get("score", 0) if prediction.get("label") == "entailment" else 0,
"contradiction": prediction.get("score", 0) if prediction.get("label") == "contradiction" else 0,
"neutral": prediction.get("score", 0) if prediction.get("label") == "neutral" else 0
}
return results, ''
def update_interface(mode):
if mode == "Zero-Shot Classification":
return (
gr.update(
label="🏷️ Categories",
placeholder="Enter comma-separated categories...",
value=zero_shot_examples[0][1]
),
gr.update(value=zero_shot_examples[0][0])
)
else:
return (
gr.update(
label="πŸ”Ž Hypothesis",
placeholder="Enter a hypothesis to compare with the premise...",
value=nli_examples[0][1]
),
gr.update(value=nli_examples[0][0])
)
with gr.Blocks() as demo:
gr.Markdown("""
# tasksource/ModernBERT-nli demonstration
Using [tasksource/ModernBERT-base-nli](https://huggingface.co/tasksource/ModernBERT-base-nli),
fine-tuned from [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)
on large scale tasksource classification tasks. The tuned model achieves high accuracy on reasoning and long-context NLI, outperforming Llama 3 8B on ConTRoL and FOLIO.
""")
mode = gr.Radio(
["Zero-Shot Classification", "Natural Language Inference"],
label="Select Mode",
value="Zero-Shot Classification"
)
with gr.Column():
text_input = gr.Textbox(
label="✍️ Input Text",
placeholder="Enter your text...",
lines=3,
value=zero_shot_examples[0][0] # Initial value
)
labels_or_premise = gr.Textbox(
label="🏷️ Categories",
placeholder="Enter comma-separated categories...",
lines=2,
value=zero_shot_examples[0][1] # Initial value
)
submit_btn = gr.Button("Submit")
outputs = [
gr.Label(label="πŸ“Š Results"),
gr.Markdown(label="πŸ“ˆ Analysis", visible=False)
]
with gr.Column(variant="panel") as zero_shot_examples_panel:
gr.Examples(
examples=zero_shot_examples,
inputs=[text_input, labels_or_premise],
label="Zero-Shot Classification Examples"
)
with gr.Column(variant="panel") as nli_examples_panel:
gr.Examples(
examples=nli_examples,
inputs=[text_input, labels_or_premise],
label="Natural Language Inference Examples"
)
def update_visibility(mode):
return (
gr.update(visible=(mode == "Zero-Shot Classification")),
gr.update(visible=(mode == "Natural Language Inference"))
)
mode.change(
fn=update_interface,
inputs=[mode],
outputs=[labels_or_premise, text_input]
)
mode.change(
fn=update_visibility,
inputs=[mode],
outputs=[zero_shot_examples_panel, nli_examples_panel]
)
submit_btn.click(
fn=process_input,
inputs=[text_input, labels_or_premise, mode],
outputs=outputs
)
if __name__ == "__main__":
demo.launch()