|
import gradio as gr |
|
from transformers import pipeline |
|
import re |
|
|
|
|
|
def sent_tokenize(text): |
|
sentence_endings = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|!)(\s|$)') |
|
sentences = sentence_endings.split(text) |
|
return [s.strip() for s in sentences if s.strip()] |
|
|
|
|
|
zero_shot_classifier = pipeline("zero-shot-classification", model="tasksource/ModernBERT-base-nli", device="cpu") |
|
nli_classifier = pipeline("text-classification", model="tasksource/ModernBERT-base-nli", device="cpu") |
|
|
|
|
|
zero_shot_examples = [ |
|
["I absolutely love this product, it's amazing!", "positive, negative, neutral"], |
|
["I need to buy groceries", "shopping, urgent tasks, leisure, philosophy"], |
|
["The sun is very bright today", "weather, astronomy, complaints, poetry"], |
|
["I love playing video games", "entertainment, sports, education, business"], |
|
["The car won't start", "transportation, art, cooking, literature"] |
|
] |
|
|
|
nli_examples = [ |
|
["A man is sleeping on a couch", "The man is awake"], |
|
["The restaurant's waiting area is bustling, but several tables remain vacant", "The establishment is at maximum capacity"], |
|
["The child is methodically arranging blocks while frowning in concentration", "The kid is experiencing joy"], |
|
["Dark clouds are gathering and the pavement shows scattered wet spots", "It's been raining heavily all day"], |
|
["A German Shepherd is exhibiting defensive behavior towards someone approaching the property", "The animal making noise is feline"] |
|
] |
|
|
|
long_context_examples = [ |
|
["""The small cafe on the corner has been bustling with activity all morning. The aroma of freshly baked pastries wafts through the air, drawing in passersby. The baristas work efficiently behind the counter, crafting intricate latte art. Several customers are seated at wooden tables, engaged in quiet conversations or working on laptops. Through the large windows, sunshine streams in, creating a warm and inviting atmosphere.""", |
|
"The cafe is experiencing a slow, quiet morning"] |
|
] |
|
|
|
def get_label_color(label): |
|
"""Return color based on NLI label.""" |
|
colors = { |
|
'ENTAILMENT': '#90EE90', |
|
'NEUTRAL': '#FFE5B4', |
|
'CONTRADICTION': '#FFB6C1' |
|
} |
|
return colors.get(label, '#FFFFFF') |
|
|
|
def create_analysis_html(sentence_results, global_label): |
|
"""Create HTML table for sentence analysis with color coding.""" |
|
html = """ |
|
<style> |
|
.analysis-table { |
|
width: 100%; |
|
border-collapse: collapse; |
|
margin: 20px 0; |
|
font-family: Arial, sans-serif; |
|
} |
|
.analysis-table th, .analysis-table td { |
|
padding: 12px; |
|
border: 1px solid #ddd; |
|
text-align: left; |
|
} |
|
.analysis-table th { |
|
background-color: #f5f5f5; |
|
} |
|
.global-prediction { |
|
padding: 15px; |
|
margin: 20px 0; |
|
border-radius: 5px; |
|
font-weight: bold; |
|
} |
|
</style> |
|
""" |
|
|
|
|
|
html += f""" |
|
<div class="global-prediction" style="background-color: {get_label_color(global_label)}"> |
|
Global Prediction: {global_label} |
|
</div> |
|
""" |
|
|
|
|
|
html += """ |
|
<table class="analysis-table"> |
|
<tr> |
|
<th>Sentence</th> |
|
<th>Prediction</th> |
|
</tr> |
|
""" |
|
|
|
|
|
for result in sentence_results: |
|
html += f""" |
|
<tr style="background-color: {get_label_color(result['prediction'])}"> |
|
<td>{result['sentence']}</td> |
|
<td>{result['prediction']}</td> |
|
</tr> |
|
""" |
|
|
|
html += "</table>" |
|
return html |
|
|
|
def process_input(text_input, labels_or_premise, mode): |
|
if mode == "Zero-Shot Classification": |
|
labels = [label.strip() for label in labels_or_premise.split(',')] |
|
prediction = zero_shot_classifier(text_input, labels) |
|
results = {label: score for label, score in zip(prediction['labels'], prediction['scores'])} |
|
return results, '' |
|
elif mode == "Natural Language Inference": |
|
pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0] |
|
results = {pred['label']: pred['score'] for pred in pred} |
|
return results, '' |
|
else: |
|
|
|
global_pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0] |
|
global_results = {pred['label']: pred['score'] for pred in global_pred} |
|
global_label = max(global_results.items(), key=lambda x: x[1])[0] |
|
|
|
|
|
sentences = sent_tokenize(text_input) |
|
sentence_results = [] |
|
|
|
for sentence in sentences: |
|
sent_pred = nli_classifier([{"text": sentence, "text_pair": labels_or_premise}], return_all_scores=True)[0] |
|
sent_scores = {pred['label']: pred['score'] for pred in sent_pred} |
|
max_label = max(sent_scores.items(), key=lambda x: x[1])[0] |
|
sentence_results.append({ |
|
'sentence': sentence, |
|
'prediction': max_label, |
|
'scores': sent_scores |
|
}) |
|
|
|
analysis_html = create_analysis_html(sentence_results, global_label) |
|
return global_results, analysis_html |
|
|
|
def update_interface(mode): |
|
if mode == "Zero-Shot Classification": |
|
return ( |
|
gr.update( |
|
label="π·οΈ Categories", |
|
placeholder="Enter comma-separated categories...", |
|
value=zero_shot_examples[0][1] |
|
), |
|
gr.update(value=zero_shot_examples[0][0]) |
|
) |
|
elif mode == "Natural Language Inference": |
|
return ( |
|
gr.update( |
|
label="π Hypothesis", |
|
placeholder="Enter a hypothesis to compare with the premise...", |
|
value=nli_examples[0][1] |
|
), |
|
gr.update(value=nli_examples[0][0]) |
|
) |
|
else: |
|
return ( |
|
gr.update( |
|
label="π Global Hypothesis", |
|
placeholder="Enter a hypothesis to test against the full context...", |
|
value=long_context_examples[0][1] |
|
), |
|
gr.update(value=long_context_examples[0][0]) |
|
) |
|
|
|
def update_visibility(mode): |
|
return ( |
|
gr.update(visible=(mode == "Zero-Shot Classification")), |
|
gr.update(visible=(mode == "Natural Language Inference")), |
|
gr.update(visible=(mode == "Long Context NLI")) |
|
) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(""" |
|
# tasksource/ModernBERT-nli demonstration |
|
|
|
This space uses [tasksource/ModernBERT-base-nli](https://huggingface.co/tasksource/ModernBERT-base-nli), |
|
fine-tuned from [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) |
|
on tasksource classification tasks. |
|
This NLI model achieves high accuracy on logical reasoning and long-context NLI, outperforming Llama 3 8B on ConTRoL and FOLIO. |
|
""") |
|
|
|
mode = gr.Radio( |
|
["Zero-Shot Classification", "Natural Language Inference", "Long Context NLI"], |
|
label="Select Mode", |
|
value="Zero-Shot Classification" |
|
) |
|
|
|
with gr.Column(): |
|
text_input = gr.Textbox( |
|
label="βοΈ Input Text", |
|
placeholder="Enter your text...", |
|
lines=3, |
|
value=zero_shot_examples[0][0] |
|
) |
|
|
|
labels_or_premise = gr.Textbox( |
|
label="π·οΈ Categories", |
|
placeholder="Enter comma-separated categories...", |
|
lines=2, |
|
value=zero_shot_examples[0][1] |
|
) |
|
|
|
submit_btn = gr.Button("Submit") |
|
|
|
outputs = [ |
|
gr.Label(label="π Results"), |
|
gr.HTML(label="π Sentence Analysis") |
|
] |
|
|
|
with gr.Column(variant="panel") as zero_shot_examples_panel: |
|
gr.Examples( |
|
examples=zero_shot_examples, |
|
inputs=[text_input, labels_or_premise], |
|
label="Zero-Shot Classification Examples", |
|
) |
|
|
|
with gr.Column(variant="panel") as nli_examples_panel: |
|
gr.Examples( |
|
examples=nli_examples, |
|
inputs=[text_input, labels_or_premise], |
|
label="Natural Language Inference Examples", |
|
) |
|
|
|
with gr.Column(variant="panel") as long_context_examples_panel: |
|
gr.Examples( |
|
examples=long_context_examples, |
|
inputs=[text_input, labels_or_premise], |
|
label="Long Context NLI Examples", |
|
) |
|
|
|
mode.change( |
|
fn=update_interface, |
|
inputs=[mode], |
|
outputs=[labels_or_premise, text_input] |
|
) |
|
|
|
mode.change( |
|
fn=update_visibility, |
|
inputs=[mode], |
|
outputs=[zero_shot_examples_panel, nli_examples_panel, long_context_examples_panel] |
|
) |
|
|
|
submit_btn.click( |
|
fn=process_input, |
|
inputs=[text_input, labels_or_premise, mode], |
|
outputs=outputs |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |