Tolga commited on
Commit
39ef911
·
1 Parent(s): 1cb41a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -1
app.py CHANGED
@@ -1,3 +1,42 @@
1
  import gradio as gr
 
2
 
3
- gr.Interface.load("models/tkurtulus/rottentomato-classifier").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
 
4
+ def classify_sentence(sent:str):
5
+ toksentence = tokenizer(sent,truncation=True,return_tensors="pt")
6
+ model.eval()
7
+ with torch.no_grad():
8
+ toksentence.to(device)
9
+ output = model(**toksentence)
10
+
11
+ return F.softmax(output.logits,dim=1).argmax(dim=1)
12
+
13
+
14
+ def classify_text(text:str):
15
+ sentences = sent_tokenize(text)
16
+ annotations = np.array(list(map(classify_sentence,sentences)),dtype=object)
17
+ result = list(zip(sentences,[mapping[val] for val in annotations]))
18
+ return (annotations,result)
19
+
20
+ def classify_text_wrapper(text:str):
21
+ preds,result = classify_text(text)
22
+ n = len(preds)
23
+ non_biased = np.where(preds==0)[0].shape[0]
24
+ biased = np.where(preds==1)[0].shape[0]
25
+
26
+ return (result,{'bias ratio':biased/n})
27
+
28
+ examples=[["[Newsoms's] obsession with masks has created an almost hostile environment in our neighborhoods and streets.\n“He won because the Election was Rigged,” Trump wrote, not referring to Biden by name, adding a list of complaints about vote counting"]]
29
+
30
+ model = AutoModelForSequenceClassification.from_pretrained("tkurtulus/autotrain-rottentomato-2981285985")
31
+ tokenizer = AutoTokenizer.from_pretrained("tkurtulus/autotrain-rottentomato-2981285985");
32
+ model.eval();
33
+
34
+ label = gr.outputs.Label(num_top_classes=None,label='')
35
+ text_h = gr.outputs.HighlightedText(color_map={'Unbiased':'#9ad1A1','Biased':'#db8a8a'},label='Classification')
36
+ inputs = gr.inputs.Textbox(placeholder=None, default="", label=None)
37
+
38
+ app = gr.Interface(fn=classify_text_wrapper,title='Bias classifier',theme='default',
39
+ inputs="textbox",layout='unaligned', outputs=[text_h,label], capture_session=True
40
+ ,examples=examples)
41
+
42
+ app.launch(inbrowser=True)