File size: 1,185 Bytes
b745835
 
 
695c722
b745835
 
 
9de1f50
b745835
 
 
 
 
 
 
 
 
 
 
 
e859a81
b745835
18057d4
 
 
 
f03f842
18057d4
17b108d
18057d4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from transformers import DataCollatorWithPadding
from transformers import Trainer
import gradio as gr

tokenizer = AutoTokenizer.from_pretrained("smallbenchnlp/roberta-small")
model = AutoModelForSequenceClassification.from_pretrained("frostymelonade/roberta-small-pun-detector-v2", num_labels=2)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

def classify_pun(text):
    inputs = [tokenizer(text, truncation=True)]
    predictions = trainer.predict(inputs)
    label = "Pun" if predictions[0][0][0] < predictions[0][0][1] else "Not a pun"
    return label, str(predictions[0][0])

#gr.Interface(fn=classify_pun, inputs=["text"], outputs=["text", "text"]).launch()
with gr.Blocks() as demo:
    text = gr.Textbox(label="Text")
    output = gr.Textbox(label="Classification")
    output2 = gr.Textbox(label="Raw Results")
    greet_btn = gr.Button("Submit")
    greet_btn.click(fn=classify_pun, inputs=text, outputs=[output, output2], api_name="classify_pun")


demo.launch()