schlenker commited on
Commit
a4def93
·
1 Parent(s): baef832

add related words

Browse files
Files changed (1) hide show
  1. app.py +39 -7
app.py CHANGED
@@ -38,7 +38,7 @@ def get_models(llama=False):
38
  model, tokenizer = get_models()
39
 
40
  def return_top_k(sentence, k=10):
41
-
42
  if sentence[-1] != ".":
43
  sentence = sentence + "."
44
 
@@ -69,9 +69,34 @@ def return_top_k(sentence, k=10):
69
  if (len(pred) < 2) | (pred in sentence.split()):
70
  predictions.pop(predictions.index(pred))
71
 
72
- return predictions[:10]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
-
75
  if 'messages' not in st.session_state:
76
  st.session_state.messages = []
77
 
@@ -267,10 +292,9 @@ if st.session_state.actions[-1] == 'cue':
267
  target = st.session_state.results["results"][word_count]
268
  if get_available_cues(target):
269
  avail_cues = get_available_cues(target)
270
- cues_buttons = {}
271
- for cue_type, cues in avail_cues.items():
272
- #st.button(cue_type)
273
- cues_buttons[cue_type] = st.button(cue_type)
274
 
275
  with col3:
276
  b3 = st.button("All words", key="3")
@@ -279,6 +303,8 @@ if st.session_state.actions[-1] == 'cue':
279
  with col5:
280
  b5 = st.button("Exit", key="5", type='primary')
281
 
 
 
282
  if b1:
283
  st.session_state.counters["letter_count"] += 1
284
  word_count = st.session_state.counters["word_count"]
@@ -307,6 +333,12 @@ if st.session_state.actions[-1] == 'cue':
307
  elif b3:
308
  write_bot(f"Here are all my guesses about your word: {st.session_state.results['results_print']}")
309
 
 
 
 
 
 
 
310
  elif b4:
311
  write_bot("Yay! I am happy I could be of help!")
312
  new = st.button('Play again', key=63)
 
38
  model, tokenizer = get_models()
39
 
40
  def return_top_k(sentence, k=10):
41
+
42
  if sentence[-1] != ".":
43
  sentence = sentence + "."
44
 
 
69
  if (len(pred) < 2) | (pred in sentence.split()):
70
  predictions.pop(predictions.index(pred))
71
 
72
+ return predictions[:10]
73
+
74
+ # JS
75
+ def get_related_words(word, num=5):
76
+ model.eval()
77
+ with torch.no_grad():
78
+ sentence = [f"Descripton : It is related to {word} but not {word}. Word : "]
79
+ #inputs = ["Description: It is something to cut stuff with. Word: "]
80
+ print(sentence)
81
+ inputs = tokenizer(sentence, padding=True, truncation=True, return_tensors="pt",)
82
+
83
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
+ model.to(device)
85
+
86
+ batch = {k: v.to(device) for k, v in inputs.items()}
87
+ beam_outputs = model.generate(
88
+ input_ids=batch['input_ids'], max_new_tokens=10, num_beams=num+2, num_return_sequences=num+2, early_stopping=True
89
+ )
90
+
91
+ #beam_preds = [tokenizer.decode(beam_output.detach().cpu().numpy(), skip_special_tokens=True) for beam_output in beam_outputs if ]
92
+ beam_preds = []
93
+ for beam_output in beam_outputs:
94
+ prediction = tokenizer.decode(beam_output.detach().cpu().numpy(), skip_special_tokens=True).strip()
95
+ if prediction not in " ".join(sentence):
96
+ beam_preds.append(prediction)
97
+
98
+ return ", ".join(beam_preds[:num])
99
 
 
100
  if 'messages' not in st.session_state:
101
  st.session_state.messages = []
102
 
 
292
  target = st.session_state.results["results"][word_count]
293
  if get_available_cues(target):
294
  avail_cues = get_available_cues(target)
295
+ cues_buttons = {cue_type: st.button(cue_type) for cue_type in avail_cues}
296
+ #for cue_type in avail_cues:
297
+ # cues_buttons[cue_type] = st.button(cue_type)
 
298
 
299
  with col3:
300
  b3 = st.button("All words", key="3")
 
303
  with col5:
304
  b5 = st.button("Exit", key="5", type='primary')
305
 
306
+ b6 = st.button("Related words")
307
+
308
  if b1:
309
  st.session_state.counters["letter_count"] += 1
310
  word_count = st.session_state.counters["word_count"]
 
333
  elif b3:
334
  write_bot(f"Here are all my guesses about your word: {st.session_state.results['results_print']}")
335
 
336
+ elif b6:
337
+ sent = f"It is related to '{target}' but not '{target}'."
338
+ rels = return_top_k(target)
339
+
340
+ write_bot(f'Here are words that are related to your word: {", ".join(rels)}')
341
+
342
  elif b4:
343
  write_bot("Yay! I am happy I could be of help!")
344
  new = st.button('Play again', key=63)