File size: 5,092 Bytes
04e7b78 9604b3c 2adecad 9604b3c 7400288 6d5fe23 362e959 6d5fe23 2adecad 816c523 2adecad fb5842d 6d5fe23 fb5842d 6d5fe23 fb5842d 2adecad 6d5fe23 362e959 d3061d0 6d5fe23 2adecad d3061d0 2adecad 6d5fe23 2adecad fb5842d 6d5fe23 2adecad 068f0da fb5842d 6d5fe23 fb5842d 2adecad fb5842d 6d5fe23 fb5842d b38e092 fb5842d 2adecad b38e092 fb5842d 6d5fe23 fb5842d b38e092 fb5842d b38e092 d3061d0 2adecad d3061d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import gradio as gr
from transformers import pipeline
# Initialize the classifiers
zero_shot_classifier = pipeline("zero-shot-classification", model="tasksource/ModernBERT-base-nli")
nli_classifier = pipeline("text-classification", model="tasksource/ModernBERT-base-nli")
if False:
gr.load("models/answerdotai/ModernBERT-base").launch()
# Define examples
zero_shot_examples = [
["I absolutely love this product, it's amazing!", "positive, negative, neutral"],
["I need to buy groceries", "shopping, urgent tasks, leisure, philosophy"],
["The sun is very bright today", "weather, astronomy, complaints, poetry"],
["I love playing video games", "entertainment, sports, education, business"],
["The car won't start", "transportation, art, cooking, literature"]
]
nli_examples = [
["A man is sleeping on a couch", "The man is awake"],
["The restaurant is full of people", "The place is empty"],
["The child is playing with toys", "The kid is having fun"],
["It's raining outside", "The weather is wet"],
["The dog is barking at the mailman", "There is a cat"]
]
def process_input(text_input, labels_or_premise, mode):
if mode == "Zero-Shot Classification":
labels = [label.strip() for label in labels_or_premise.split(',')]
prediction = zero_shot_classifier(text_input, labels)
results = {label: score for label, score in zip(prediction['labels'], prediction['scores'])}
return results, ''
else: # NLI mode
prediction = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}])[0]
results = {
"entailment": prediction.get("score", 0) if prediction.get("label") == "entailment" else 0,
"contradiction": prediction.get("score", 0) if prediction.get("label") == "contradiction" else 0,
"neutral": prediction.get("score", 0) if prediction.get("label") == "neutral" else 0
}
return results, ''
def update_interface(mode):
if mode == "Zero-Shot Classification":
return (
gr.update(
label="π·οΈ Categories",
placeholder="Enter comma-separated categories...",
value=zero_shot_examples[0][1]
),
gr.update(value=zero_shot_examples[0][0])
)
else:
return (
gr.update(
label="π Hypothesis",
placeholder="Enter a hypothesis to compare with the premise...",
value=nli_examples[0][1]
),
gr.update(value=nli_examples[0][0])
)
with gr.Blocks() as demo:
gr.Markdown("""
# tasksource/ModernBERT-nli demonstration
Using [tasksource/ModernBERT-base-nli](https://huggingface.co/tasksource/ModernBERT-base-nli),
fine-tuned from [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)
on large scale tasksource classification tasks. The tuned model achieves high accuracy on reasoning and long-context NLI, outperforming Llama 3 8B on ConTRoL and FOLIO.
""")
mode = gr.Radio(
["Zero-Shot Classification", "Natural Language Inference"],
label="Select Mode",
value="Zero-Shot Classification"
)
with gr.Column():
text_input = gr.Textbox(
label="βοΈ Input Text",
placeholder="Enter your text...",
lines=3,
value=zero_shot_examples[0][0] # Initial value
)
labels_or_premise = gr.Textbox(
label="π·οΈ Categories",
placeholder="Enter comma-separated categories...",
lines=2,
value=zero_shot_examples[0][1] # Initial value
)
submit_btn = gr.Button("Submit")
outputs = [
gr.Label(label="π Results"),
gr.Markdown(label="π Analysis", visible=False)
]
with gr.Column(variant="panel") as zero_shot_examples_panel:
gr.Examples(
examples=zero_shot_examples,
inputs=[text_input, labels_or_premise],
label="Zero-Shot Classification Examples"
)
with gr.Column(variant="panel") as nli_examples_panel:
gr.Examples(
examples=nli_examples,
inputs=[text_input, labels_or_premise],
label="Natural Language Inference Examples"
)
def update_visibility(mode):
return (
gr.update(visible=(mode == "Zero-Shot Classification")),
gr.update(visible=(mode == "Natural Language Inference"))
)
mode.change(
fn=update_interface,
inputs=[mode],
outputs=[labels_or_premise, text_input]
)
mode.change(
fn=update_visibility,
inputs=[mode],
outputs=[zero_shot_examples_panel, nli_examples_panel]
)
submit_btn.click(
fn=process_input,
inputs=[text_input, labels_or_premise, mode],
outputs=outputs
)
if __name__ == "__main__":
demo.launch() |