BigSalmon commited on
Commit
e7e1f57
·
1 Parent(s): cfc8fee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -31,6 +31,7 @@ def get_model():
31
  #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln63Paraphrase")
32
  tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln71Paraphrase")
33
  model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln71Paraphrase")
 
34
  model2 = AutoModelForCausalLM.from_pretrained("sberbank-ai/mGPT")
35
  return model, model2, tokenizer
36
 
@@ -72,13 +73,13 @@ def run_generate(text, bad_words):
72
 
73
  def run_generate2(text, bad_words):
74
  yo = []
75
- input_ids = tokenizer.encode(text, return_tensors='pt')
76
- res = len(tokenizer.encode(text))
77
  bad_words = bad_words.split()
78
  bad_word_ids = []
79
  for bad_word in bad_words:
80
  bad_word = " " + bad_word
81
- ids = tokenizer(bad_word).input_ids
82
  bad_word_ids.append(ids)
83
  sample_outputs = model2.generate(
84
  input_ids,
@@ -91,7 +92,7 @@ def run_generate2(text, bad_words):
91
  bad_words_ids=bad_word_ids
92
  )
93
  for i in range(number_of_outputs):
94
- e = tokenizer.decode(sample_outputs[i])
95
  e = e.replace(text, "")
96
  yo.append(e)
97
  return yo
@@ -126,12 +127,12 @@ with st.form(key='my_form'):
126
  if submit_button4:
127
  text2 = str(text)
128
  print(text2)
129
- text3 = tokenizer.encode(text2)
130
  myinput, past_key_values = torch.tensor([text3]), None
131
  myinput = myinput
132
  logits, past_key_values = model2(myinput, past_key_values = past_key_values, return_dict=False)
133
  logits = logits[0,-1]
134
  probabilities = torch.nn.functional.softmax(logits)
135
  best_logits, best_indices = logits.topk(logs_outputs)
136
- best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
137
  st.write(best_words)
 
31
  #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln63Paraphrase")
32
  tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln71Paraphrase")
33
  model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln71Paraphrase")
34
+ tokenizer2 = AutoTokenizer.from_pretrained("sberbank-ai/mGPT")
35
  model2 = AutoModelForCausalLM.from_pretrained("sberbank-ai/mGPT")
36
  return model, model2, tokenizer
37
 
 
73
 
74
  def run_generate2(text, bad_words):
75
  yo = []
76
+ input_ids = tokenizer2.encode(text, return_tensors='pt')
77
+ res = len(tokenizer2.encode(text))
78
  bad_words = bad_words.split()
79
  bad_word_ids = []
80
  for bad_word in bad_words:
81
  bad_word = " " + bad_word
82
+ ids = tokenizer2(bad_word).input_ids
83
  bad_word_ids.append(ids)
84
  sample_outputs = model2.generate(
85
  input_ids,
 
92
  bad_words_ids=bad_word_ids
93
  )
94
  for i in range(number_of_outputs):
95
+ e = tokenizer2.decode(sample_outputs[i])
96
  e = e.replace(text, "")
97
  yo.append(e)
98
  return yo
 
127
  if submit_button4:
128
  text2 = str(text)
129
  print(text2)
130
+ text3 = tokenizer2.encode(text2)
131
  myinput, past_key_values = torch.tensor([text3]), None
132
  myinput = myinput
133
  logits, past_key_values = model2(myinput, past_key_values = past_key_values, return_dict=False)
134
  logits = logits[0,-1]
135
  probabilities = torch.nn.functional.softmax(logits)
136
  best_logits, best_indices = logits.topk(logs_outputs)
137
+ best_words = [tokenizer2.decode([idx.item()]) for idx in best_indices]
138
  st.write(best_words)