Spaces:
Runtime error
Runtime error
File size: 1,185 Bytes
b745835 695c722 b745835 9de1f50 b745835 e859a81 b745835 18057d4 f03f842 18057d4 17b108d 18057d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from transformers import DataCollatorWithPadding
from transformers import Trainer
import gradio as gr
tokenizer = AutoTokenizer.from_pretrained("smallbenchnlp/roberta-small")
model = AutoModelForSequenceClassification.from_pretrained("frostymelonade/roberta-small-pun-detector-v2", num_labels=2)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainer = Trainer(
model=model,
data_collator=data_collator,
tokenizer=tokenizer,
)
def classify_pun(text):
inputs = [tokenizer(text, truncation=True)]
predictions = trainer.predict(inputs)
label = "Pun" if predictions[0][0][0] < predictions[0][0][1] else "Not a pun"
return label, str(predictions[0][0])
#gr.Interface(fn=classify_pun, inputs=["text"], outputs=["text", "text"]).launch()
with gr.Blocks() as demo:
text = gr.Textbox(label="Text")
output = gr.Textbox(label="Classification")
output2 = gr.Textbox(label="Raw Results")
greet_btn = gr.Button("Submit")
greet_btn.click(fn=classify_pun, inputs=text, outputs=[output, output2], api_name="classify_pun")
demo.launch() |