Unggi commited on
Commit
4933ff8
·
1 Parent(s): d583822

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -3,6 +3,7 @@ pip.main(['install', 'torch'])
3
  pip.main(['install', 'transformers'])
4
 
5
  import torch
 
6
  import gradio as gr
7
  import transformers
8
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
@@ -32,10 +33,18 @@ def inference(prompt):
32
  with torch.no_grad():
33
  logits = model(**inputs).logits
34
 
35
- predicted_class_id = logits.argmax().item()
 
 
 
 
36
  class_id = model.config.id2label[predicted_class_id]
37
 
38
- return class_id
 
 
 
 
39
 
40
  demo = gr.Interface(
41
  fn=inference,
 
3
  pip.main(['install', 'transformers'])
4
 
5
  import torch
6
+ import torch.nn as nn
7
  import gradio as gr
8
  import transformers
9
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
33
  with torch.no_grad():
34
  logits = model(**inputs).logits
35
 
36
+ # for binary classification
37
+ sigmoid = nn.Sigmoid()
38
+ bi_prob = sigmoid(logits)
39
+
40
+ predicted_class_id = bi_prob.argmax().item()
41
  class_id = model.config.id2label[predicted_class_id]
42
 
43
+ return {
44
+ "class_id": class_id,
45
+ "clean_prob": bi_prob[0][0].item(),
46
+ "unclean_prob": bi_prob[0][1].item()
47
+ }
48
 
49
  demo = gr.Interface(
50
  fn=inference,