BigSalmon commited on
Commit
08112a4
·
1 Parent(s): fbe6573

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -2
app.py CHANGED
@@ -27,10 +27,11 @@ def get_model():
27
  #model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
28
  tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln63Paraphrase")
29
  model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln68Paraphrase")
 
30
  #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln63Paraphrase")
31
- return model, tokenizer
32
 
33
- model, tokenizer = get_model()
34
 
35
  st.text('''For Prompt Templates: https://huggingface.co/BigSalmon/InformalToFormalLincoln63Paraphrase''')
36
 
@@ -65,10 +66,41 @@ def run_generate(text, bad_words):
65
  e = e.replace(text, "")
66
  yo.append(e)
67
  return yo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  with st.form(key='my_form'):
69
  text = st.text_area(label='Enter sentence', value=first)
70
  submit_button = st.form_submit_button(label='Submit')
71
  submit_button2 = st.form_submit_button(label='Submit Log Probs')
 
 
 
 
72
  if submit_button:
73
  translated_text = run_generate(text, bad_words)
74
  st.write(translated_text if translated_text else "No translation found")
@@ -84,4 +116,19 @@ with st.form(key='my_form'):
84
  probabilities = torch.nn.functional.softmax(logits)
85
  best_logits, best_indices = logits.topk(logs_outputs)
86
  best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  st.write(best_words)
 
27
  #model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
28
  tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln63Paraphrase")
29
  model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln68Paraphrase")
30
+ model2 = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln63Paraphrase")
31
  #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln63Paraphrase")
32
+ return model, model2, tokenizer
33
 
34
+ model, model2, tokenizer = get_model()
35
 
36
  st.text('''For Prompt Templates: https://huggingface.co/BigSalmon/InformalToFormalLincoln63Paraphrase''')
37
 
 
66
  e = e.replace(text, "")
67
  yo.append(e)
68
  return yo
69
+
70
+ def run_generate2(text, bad_words):
71
+ yo = []
72
+ input_ids = tokenizer.encode(text, return_tensors='pt')
73
+ res = len(tokenizer.encode(text))
74
+ bad_words = bad_words.split()
75
+ bad_word_ids = []
76
+ for bad_word in bad_words:
77
+ bad_word = " " + bad_word
78
+ ids = tokenizer(bad_word).input_ids
79
+ bad_word_ids.append(ids)
80
+ sample_outputs = model2.generate(
81
+ input_ids,
82
+ do_sample=True,
83
+ max_length= res + lengths,
84
+ min_length = res + lengths,
85
+ top_k=50,
86
+ temperature=temp,
87
+ num_return_sequences=number_of_outputs,
88
+ bad_words_ids=bad_word_ids
89
+ )
90
+ for i in range(number_of_outputs):
91
+ e = tokenizer.decode(sample_outputs[i])
92
+ e = e.replace(text, "")
93
+ yo.append(e)
94
+ return yo
95
+
96
  with st.form(key='my_form'):
97
  text = st.text_area(label='Enter sentence', value=first)
98
  submit_button = st.form_submit_button(label='Submit')
99
  submit_button2 = st.form_submit_button(label='Submit Log Probs')
100
+
101
+ submit_button3 = st.form_submit_button(label='Submit Other Model')
102
+ submit_button4 = st.form_submit_button(label='Submit Log Probs Other Model')
103
+
104
  if submit_button:
105
  translated_text = run_generate(text, bad_words)
106
  st.write(translated_text if translated_text else "No translation found")
 
116
  probabilities = torch.nn.functional.softmax(logits)
117
  best_logits, best_indices = logits.topk(logs_outputs)
118
  best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
119
+ st.write(best_words)
120
+ if submit_button3:
121
+ translated_text = run_generate2(text, bad_words)
122
+ st.write(translated_text if translated_text else "No translation found")
123
+ if submit_button4:
124
+ text2 = str(text)
125
+ print(text2)
126
+ text3 = tokenizer.encode(text2)
127
+ myinput, past_key_values = torch.tensor([text3]), None
128
+ myinput = myinput
129
+ logits, past_key_values = model2(myinput, past_key_values = past_key_values, return_dict=False)
130
+ logits = logits[0,-1]
131
+ probabilities = torch.nn.functional.softmax(logits)
132
+ best_logits, best_indices = logits.topk(logs_outputs)
133
+ best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
134
  st.write(best_words)