frostymelonade commited on
Commit
b745835
Β·
1 Parent(s): a9a5a1e

Initial application file

Browse files
Files changed (1) hide show
  1. app.py +25 -0
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+ from transformers import AutoModelForSequenceClassification
3
+ from transformers import DataCollatorWithPadding
4
+ from transformers import TrainingArguments, Trainer
5
+ import gradio as gr
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained("smallbenchnlp/roberta-small")
8
+ training_args = TrainingArguments(output_dir="roberta-small-pun-detector", evaluation_strategy="epoch")
9
+ model = AutoModelForSequenceClassification.from_pretrained("frostymelonade/roberta-small-pun-detector", num_labels=2)
10
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
11
+
12
+ trainer = Trainer(
13
+ model=model,
14
+ args=training_args,
15
+ data_collator=data_collator,
16
+ tokenizer=tokenizer,
17
+ )
18
+
19
+ def classify_pun(text):
20
+ inputs = [tokenizer(text, truncation=True)]
21
+ predictions = trainer.predict(inputs)
22
+ label = "Pun" if predictions[0][0][0] < predictions[0][0][1] else "Not a pun"
23
+ return label
24
+
25
+ gr.Interface(fn=classify_pun, inputs=["text"], outputs=["text"]).launch()