Spaces:
Sleeping
Sleeping
add related words
Browse files
app.py
CHANGED
@@ -38,7 +38,7 @@ def get_models(llama=False):
|
|
38 |
model, tokenizer = get_models()
|
39 |
|
40 |
def return_top_k(sentence, k=10):
|
41 |
-
|
42 |
if sentence[-1] != ".":
|
43 |
sentence = sentence + "."
|
44 |
|
@@ -69,9 +69,34 @@ def return_top_k(sentence, k=10):
|
|
69 |
if (len(pred) < 2) | (pred in sentence.split()):
|
70 |
predictions.pop(predictions.index(pred))
|
71 |
|
72 |
-
return predictions[:10]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
-
|
75 |
if 'messages' not in st.session_state:
|
76 |
st.session_state.messages = []
|
77 |
|
@@ -267,10 +292,9 @@ if st.session_state.actions[-1] == 'cue':
|
|
267 |
target = st.session_state.results["results"][word_count]
|
268 |
if get_available_cues(target):
|
269 |
avail_cues = get_available_cues(target)
|
270 |
-
cues_buttons = {}
|
271 |
-
for cue_type
|
272 |
-
|
273 |
-
cues_buttons[cue_type] = st.button(cue_type)
|
274 |
|
275 |
with col3:
|
276 |
b3 = st.button("All words", key="3")
|
@@ -279,6 +303,8 @@ if st.session_state.actions[-1] == 'cue':
|
|
279 |
with col5:
|
280 |
b5 = st.button("Exit", key="5", type='primary')
|
281 |
|
|
|
|
|
282 |
if b1:
|
283 |
st.session_state.counters["letter_count"] += 1
|
284 |
word_count = st.session_state.counters["word_count"]
|
@@ -307,6 +333,12 @@ if st.session_state.actions[-1] == 'cue':
|
|
307 |
elif b3:
|
308 |
write_bot(f"Here are all my guesses about your word: {st.session_state.results['results_print']}")
|
309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
elif b4:
|
311 |
write_bot("Yay! I am happy I could be of help!")
|
312 |
new = st.button('Play again', key=63)
|
|
|
38 |
model, tokenizer = get_models()
|
39 |
|
40 |
def return_top_k(sentence, k=10):
|
41 |
+
|
42 |
if sentence[-1] != ".":
|
43 |
sentence = sentence + "."
|
44 |
|
|
|
69 |
if (len(pred) < 2) | (pred in sentence.split()):
|
70 |
predictions.pop(predictions.index(pred))
|
71 |
|
72 |
+
return predictions[:10]
|
73 |
+
|
74 |
+
# JS
|
75 |
+
def get_related_words(word, num=5):
|
76 |
+
model.eval()
|
77 |
+
with torch.no_grad():
|
78 |
+
sentence = [f"Descripton : It is related to {word} but not {word}. Word : "]
|
79 |
+
#inputs = ["Description: It is something to cut stuff with. Word: "]
|
80 |
+
print(sentence)
|
81 |
+
inputs = tokenizer(sentence, padding=True, truncation=True, return_tensors="pt",)
|
82 |
+
|
83 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
84 |
+
model.to(device)
|
85 |
+
|
86 |
+
batch = {k: v.to(device) for k, v in inputs.items()}
|
87 |
+
beam_outputs = model.generate(
|
88 |
+
input_ids=batch['input_ids'], max_new_tokens=10, num_beams=num+2, num_return_sequences=num+2, early_stopping=True
|
89 |
+
)
|
90 |
+
|
91 |
+
#beam_preds = [tokenizer.decode(beam_output.detach().cpu().numpy(), skip_special_tokens=True) for beam_output in beam_outputs if ]
|
92 |
+
beam_preds = []
|
93 |
+
for beam_output in beam_outputs:
|
94 |
+
prediction = tokenizer.decode(beam_output.detach().cpu().numpy(), skip_special_tokens=True).strip()
|
95 |
+
if prediction not in " ".join(sentence):
|
96 |
+
beam_preds.append(prediction)
|
97 |
+
|
98 |
+
return ", ".join(beam_preds[:num])
|
99 |
|
|
|
100 |
if 'messages' not in st.session_state:
|
101 |
st.session_state.messages = []
|
102 |
|
|
|
292 |
target = st.session_state.results["results"][word_count]
|
293 |
if get_available_cues(target):
|
294 |
avail_cues = get_available_cues(target)
|
295 |
+
cues_buttons = {cue_type: st.button(cue_type) for cue_type in avail_cues}
|
296 |
+
#for cue_type in avail_cues:
|
297 |
+
# cues_buttons[cue_type] = st.button(cue_type)
|
|
|
298 |
|
299 |
with col3:
|
300 |
b3 = st.button("All words", key="3")
|
|
|
303 |
with col5:
|
304 |
b5 = st.button("Exit", key="5", type='primary')
|
305 |
|
306 |
+
b6 = st.button("Related words")
|
307 |
+
|
308 |
if b1:
|
309 |
st.session_state.counters["letter_count"] += 1
|
310 |
word_count = st.session_state.counters["word_count"]
|
|
|
333 |
elif b3:
|
334 |
write_bot(f"Here are all my guesses about your word: {st.session_state.results['results_print']}")
|
335 |
|
336 |
+
elif b6:
|
337 |
+
sent = f"It is related to '{target}' but not '{target}'."
|
338 |
+
rels = return_top_k(target)
|
339 |
+
|
340 |
+
write_bot(f'Here are words that are related to your word: {", ".join(rels)}')
|
341 |
+
|
342 |
elif b4:
|
343 |
write_bot("Yay! I am happy I could be of help!")
|
344 |
new = st.button('Play again', key=63)
|