sileod's picture
Update app.py
816c523 verified
raw
history blame
4.29 kB
import gradio as gr
from transformers import pipeline
# Initialize the classifiers
zero_shot_classifier = pipeline("zero-shot-classification", model="tasksource/ModernBERT-base-nli")
nli_classifier = pipeline("text-classification", model="tasksource/ModernBERT-base-nli")
def process_input(text_input, labels_or_premise, mode):
if mode == "Zero-Shot Classification":
labels = [label.strip() for label in labels_or_premise.split(',')]
prediction = zero_shot_classifier(text_input, labels)
results = {label: score for label, score in zip(prediction['labels'], prediction['scores'])}
return results, ''
else: # NLI mode
prediction = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}])[0]
# Force showing all three labels
results = {
"entailment": prediction.get("score", 0) if prediction.get("label") == "entailment" else 0,
"contradiction": prediction.get("score", 0) if prediction.get("label") == "contradiction" else 0,
"neutral": prediction.get("score", 0) if prediction.get("label") == "neutral" else 0
}
return results, ''
def update_interface(mode):
if mode == "Zero-Shot Classification":
return gr.update(label="🏷️ Categories", placeholder="Enter comma-separated categories...")
else:
return gr.update(label="πŸ”Ž Hypothesis", placeholder="Enter a hypothesis to compare with the premise...")
with gr.Blocks() as demo:
gr.Markdown("# πŸ€– ModernBERT Text Analysis")
mode = gr.Radio(
["Zero-Shot Classification", "Natural Language Inference"],
label="Select Mode",
value="Zero-Shot Classification"
)
with gr.Column():
text_input = gr.Textbox(
label="✍️ Input Text",
placeholder="Enter your text...",
lines=3
)
labels_or_premise = gr.Textbox(
label="🏷️ Categories",
placeholder="Enter comma-separated categories...",
lines=2
)
submit_btn = gr.Button("Submit")
outputs = [
gr.Label(label="πŸ“Š Results"),
gr.Markdown(label="πŸ“ˆ Analysis", visible=False)
]
with gr.Column(variant="panel") as zero_shot_examples_panel:
gr.Examples(
examples=[
["I need to buy groceries", "shopping, urgent tasks, leisure, philosophy"],
["The sun is very bright today", "weather, astronomy, complaints, poetry"],
["I love playing video games", "entertainment, sports, education, business"],
["The car won't start", "transportation, art, cooking, literature"],
["She wrote a beautiful poem", "creativity, finance, exercise, technology"]
],
inputs=[text_input, labels_or_premise],
label="Zero-Shot Classification Examples"
)
with gr.Column(variant="panel") as nli_examples_panel:
gr.Examples(
examples=[
["A man is sleeping on a couch", "The man is awake"],
["The restaurant is full of people", "The place is empty"],
["The child is playing with toys", "The kid is having fun"],
["It's raining outside", "The weather is wet"],
["The dog is barking at the mailman", "There is a cat"]
],
inputs=[text_input, labels_or_premise],
label="Natural Language Inference Examples"
)
def update_visibility(mode):
return (
gr.update(visible=(mode == "Zero-Shot Classification")),
gr.update(visible=(mode == "Natural Language Inference"))
)
mode.change(
fn=update_interface,
inputs=[mode],
outputs=[labels_or_premise]
)
mode.change(
fn=update_visibility,
inputs=[mode],
outputs=[zero_shot_examples_panel, nli_examples_panel]
)
submit_btn.click(
fn=process_input,
inputs=[text_input, labels_or_premise, mode],
outputs=outputs
)
if __name__ == "__main__":
demo.launch()