|
import gradio as gr |
|
from transformers import pipeline |
|
import torch |
|
|
|
|
|
try: |
|
classifier = pipeline( |
|
"zero-shot-classification", |
|
model="models/tasksource/ModernBERT-nli", |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
classifier = None |
|
|
|
def classify_text(text, candidate_labels): |
|
""" |
|
Perform zero-shot classification on input text. |
|
|
|
Args: |
|
text (str): Input text to classify |
|
candidate_labels (str): Comma-separated string of possible labels |
|
|
|
Returns: |
|
dict: Dictionary containing labels and their corresponding scores |
|
""" |
|
if classifier is None: |
|
return {"Error": "Model failed to load"} |
|
|
|
try: |
|
|
|
labels = [label.strip() for label in candidate_labels.split(",")] |
|
|
|
|
|
result = classifier(text, labels) |
|
|
|
|
|
output = {} |
|
for label, score in zip(result["labels"], result["scores"]): |
|
output[label] = f"{score:.4f}" |
|
|
|
return output |
|
|
|
except Exception as e: |
|
return {"Error": str(e)} |
|
|
|
|
|
iface = gr.Interface( |
|
fn=classify_text, |
|
inputs=[ |
|
gr.Textbox( |
|
label="Text to classify", |
|
placeholder="Enter text here...", |
|
value="all cats are blue" |
|
), |
|
gr.Textbox( |
|
label="Possible labels (comma-separated)", |
|
placeholder="Enter labels...", |
|
value="true,false" |
|
) |
|
], |
|
outputs=gr.Label(label="Classification Results"), |
|
title="Zero-Shot Text Classification", |
|
description="Classify text into given categories without any training examples.", |
|
examples=[ |
|
["all cats are blue", "true,false"], |
|
["the sky is above us", "true,false"], |
|
["birds can fly", "true,false,unknown"] |
|
] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch(share=True) |