savasy commited on
Commit
142690b
·
1 Parent(s): 568a3e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -1,18 +1,21 @@
1
  from transformers import AutoModelForMaskedLM , AutoTokenizer
2
  import torch
3
- model_path="bert-large-uncased"
4
  tokenizer = AutoTokenizer.from_pretrained(model_path)
5
  # load Prompting class
6
  from prompt import Prompting
7
  prompting= Prompting(model=model_path)
8
  prompt= ". Because it was "+ prompting.tokenizer.mask_token +"."
9
 
 
10
  def predict(text):
11
  THRESHOLD = prompting.compute_tokens_prob(prompt, token_list1=["good"], token_list2= ["bad"])[0].item()
12
  res=prompting.compute_tokens_prob(text+prompt, token_list1=["good"], token_list2= ["bad"])
13
  if res[0] > THRESHOLD:
14
- return {"POSITIVE":res[0]}
15
- return {"NEGATIVE":res[0]}
16
-
17
-
18
-
 
 
 
1
  from transformers import AutoModelForMaskedLM , AutoTokenizer
2
  import torch
3
+ model_path="bert-base-multilingual-uncased"
4
  tokenizer = AutoTokenizer.from_pretrained(model_path)
5
  # load Prompting class
6
  from prompt import Prompting
7
  prompting= Prompting(model=model_path)
8
  prompt= ". Because it was "+ prompting.tokenizer.mask_token +"."
9
 
10
+
11
  def predict(text):
12
  THRESHOLD = prompting.compute_tokens_prob(prompt, token_list1=["good"], token_list2= ["bad"])[0].item()
13
  res=prompting.compute_tokens_prob(text+prompt, token_list1=["good"], token_list2= ["bad"])
14
  if res[0] > THRESHOLD:
15
+ return {"POSITIVE":res[0].item()-THRESHOLD}
16
+ return {"NEGATIVE":THRESHOLD-res[1].item()}
17
+
18
+
19
+
20
+ import gradio as gr
21
+ iface = gr.Interface(fn=predict, inputs=["text"], outputs=["label"]).launch()