NursNurs commited on
Commit
d23f61d
Β·
1 Parent(s): 4e18247

2 modes added

Browse files
pages/1_Descriptive_chatbot.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from tqdm import tqdm
4
+ from peft import PeftModel, PeftConfig
5
+ from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM
6
+ from transformers import AutoTokenizer
7
+ import numpy as np
8
+ import time
9
+ import string
10
+
11
+
12
+ # JS
13
+ import nltk
14
+ nltk.download('wordnet')
15
+ from nltk.corpus import wordnet as wn
16
+ from nltk.tokenize import word_tokenize
17
+
18
+ @st.cache_resource
19
+ def get_models(llama=False):
20
+ st.write('Loading the model...')
21
+ # config = PeftConfig.from_pretrained("NursNurs/T5ForReverseDictionary")
22
+ # model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
23
+ # model = PeftModel.from_pretrained(model, "NursNurs/T5ForReverseDictionary")
24
+
25
+ config = PeftConfig.from_pretrained("YouNameIt/T5ForReverseDictionary_prefix_tuned")
26
+ model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
27
+ model = PeftModel.from_pretrained(model, "YouNameIt/T5ForReverseDictionary_prefix_tuned")
28
+
29
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
30
+
31
+ # JS
32
+ if llama:
33
+ model_name = 'meta-llama/Llama-2-7b-chat-hf'
34
+ access_token = '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'
35
+ llama_tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token, use_fast=True)#, use_fast=True)
36
+ llama_model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=access_token, device_map={'':0})#, load_in_4bit=True)
37
+ st.write("The assistant is loaded and ready to use!")
38
+ return model, tokenizer, llama_model, llama_tokenizer
39
+
40
+ else:
41
+ st.write("_The assistant is loaded and ready to use! :tada:_")
42
+ return model, tokenizer
43
+
44
+ model, tokenizer = get_models()
45
+
46
+ def remove_punctuation(word):
47
+ # Create a translation table that maps all punctuation characters to None
48
+ translator = str.maketrans('', '', string.punctuation)
49
+
50
+ # Use the translate method to remove punctuation from the word
51
+ word_without_punctuation = word.translate(translator)
52
+
53
+ return word_without_punctuation
54
+
55
+ def return_top_k(sentence, k=10, word=None, rels=False):
56
+
57
+ if sentence[-1] != ".":
58
+ sentence = sentence + "."
59
+
60
+ if rels:
61
+ inputs = [f"Description : It is related to '{word}' but not '{word}'. Word : "]
62
+ else:
63
+ inputs = [f"Description : {sentence} Word : "]
64
+
65
+ inputs = tokenizer(
66
+ inputs,
67
+ padding=True, truncation=True,
68
+ return_tensors="pt",
69
+ )
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+ model.to(device)
72
+
73
+ with torch.no_grad():
74
+ inputs = {k: v.to(device) for k, v in inputs.items()}
75
+ output_sequences = model.generate(input_ids=inputs["input_ids"], max_new_tokens=10, num_beams=k+5, num_return_sequences=k+5, #max_length=3,
76
+ top_p = 50, output_scores=True, return_dict_in_generate=True) #repetition_penalty=10000.0
77
+
78
+ logits = output_sequences['sequences_scores'].clone().detach()
79
+ decoded_probabilities = torch.softmax(logits, dim=0)
80
+
81
+
82
+ #all word predictions
83
+ predictions = [tokenizer.decode(tokens, skip_special_tokens=True) for tokens in output_sequences['sequences']]
84
+ probabilities = [round(float(prob), 2) for prob in decoded_probabilities]
85
+
86
+ stripped_sent = [remove_punctuation(word.lower()) for word in sentence.split()]
87
+ for pred in predictions:
88
+ if (len(pred) < 2) | (pred in stripped_sent):
89
+ predictions.pop(predictions.index(pred))
90
+
91
+ return predictions[:10]
92
+
93
+ # JS
94
+ def get_related_words(word, num=5):
95
+ model.eval()
96
+ with torch.no_grad():
97
+ sentence = [f"Descripton : It is related to {word} but not {word}. Word : "]
98
+ #inputs = ["Description: It is something to cut stuff with. Word: "]
99
+ print(sentence)
100
+ inputs = tokenizer(sentence, padding=True, truncation=True, return_tensors="pt",)
101
+
102
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
+ model.to(device)
104
+
105
+ batch = {k: v.to(device) for k, v in inputs.items()}
106
+ beam_outputs = model.generate(
107
+ input_ids=batch['input_ids'], max_new_tokens=10, num_beams=num+2, num_return_sequences=num+2, early_stopping=True
108
+ )
109
+
110
+ #beam_preds = [tokenizer.decode(beam_output.detach().cpu().numpy(), skip_special_tokens=True) for beam_output in beam_outputs if ]
111
+ beam_preds = []
112
+ for beam_output in beam_outputs:
113
+ prediction = tokenizer.decode(beam_output.detach().cpu().numpy(), skip_special_tokens=True).strip()
114
+ if prediction not in " ".join(sentence):
115
+ beam_preds.append(prediction)
116
+
117
+ return ", ".join(beam_preds[:num])
118
+
119
+ #if 'messages' not in st.session_state:
120
+
121
+ def get_text():
122
+ input_text = st.chat_input()
123
+ return input_text
124
+
125
+ def write_bot(input, remember=True, blink=True):
126
+ with st.chat_message('assistant'):
127
+ message_placeholder = st.empty()
128
+ full_response = input
129
+ if blink == True:
130
+ response = ''
131
+ for chunk in full_response.split():
132
+ response += chunk + " "
133
+ time.sleep(0.05)
134
+ # Add a blinking cursor to simulate typing
135
+ message_placeholder.markdown(response + "β–Œ")
136
+ time.sleep(0.5)
137
+ message_placeholder.markdown(full_response)
138
+ if remember == True:
139
+ st.session_state.messages.append({'role': 'assistant', 'content': full_response})
140
+
141
+ def ask_if_helped():
142
+ y = st.button('Yes!', key=60)
143
+ n = st.button('No...', key=61)
144
+ new = st.button('I have a new word', key=62)
145
+ if y:
146
+ write_bot("I am happy to help!")
147
+ again = st.button('Play again')
148
+ if again:
149
+ write_bot("Please describe your word!")
150
+ st.session_state.is_helpful['ask'] = False
151
+ elif n:
152
+ st.session_state.actions.append('cue')
153
+ st.session_state.is_helpful['ask'] = False
154
+ #cue_generation()
155
+ elif new:
156
+ write_bot("Please describe your word!")
157
+ st.session_state.is_helpful['ask'] = False
158
+
159
+ ## removed: if st.session_state.actions[-1] == "result":
160
+
161
+ # JS
162
+ def get_related_words_llama(relation, target, device, num=5):
163
+ prompt = f"Provide {num} {relation}s for the word '{target}'. Your answer consists of these {num} words only. Do not include the word '{target}' itself in your answer"
164
+
165
+ inputs = tokenizer([prompt], return_tensors='pt').to(device)
166
+ output = model.generate(
167
+ **inputs, max_new_tokens=40, temperature=.75, early_stopping=True,
168
+ )
169
+ chatbot_response = tokenizer.decode(output[:, inputs['input_ids'].shape[-1]:][0], skip_special_tokens=True).strip()
170
+
171
+ postproc = [word for word in word_tokenize(chatbot_response) if len(word)>=3]
172
+
173
+ return postproc[-num:] if len(postproc)>=num else postproc
174
+
175
+
176
+ def postproc_wn(related_words, syns=False):
177
+ if syns:
178
+ related_words = [word.split('.')[0] if word[0] != "." else word.split('.')[1] for word in related_words]
179
+ else:
180
+ related_words = [word.name().split('.')[0] if word.name()[0] != "." else word.name().split('.')[1] for word in related_words]
181
+ related_words = [word.replace("_", " ") for word in related_words]
182
+
183
+ return related_words
184
+
185
+ # JS
186
+ def get_available_cues(target):
187
+ wn_nouns = [word.name() for word in wn.all_synsets(pos='n')]
188
+ wn_nouns = [word.split('.')[0] if word[0] != "." else word.split('.')[1] for word in wn_nouns]
189
+
190
+ if target in wn_nouns:
191
+ available_cues = {}
192
+ synset_target = wn.synsets(target, pos=wn.NOUN)[0]
193
+
194
+ #if wn.synonyms(target)[0]:
195
+ # available_cues['Synonyms'] = postproc_wn(wn.synonyms(target)[0], syns=True)
196
+
197
+ #if synset_target.hypernyms():
198
+ # available_cues['Hypernyms'] = postproc_wn(synset_target.hypernyms())
199
+
200
+
201
+ #if synset_target.hyponyms():
202
+ # available_cues['Hyponyms'] = postproc_wn(synset_target.hyponyms())
203
+
204
+ if synset_target.examples():
205
+ examples = []
206
+
207
+ for example in synset_target.examples():
208
+ examples.append(example.replace(target, "..."))
209
+
210
+ available_cues['Examples'] = examples
211
+
212
+ return available_cues
213
+
214
+ else:
215
+ return None
216
+
217
+ # JS: moved the cue generation further down
218
+ #def cue_generation():
219
+ # if st.session_state.actions[-1] == 'cue':
220
+
221
+ if 'messages' not in st.session_state:
222
+ st.session_state.messages = []
223
+
224
+ if 'results' not in st.session_state:
225
+ st.session_state.results = {'results': False, 'results_print': False}
226
+
227
+ if 'actions' not in st.session_state:
228
+ st.session_state.actions = [""]
229
+
230
+ if 'counters' not in st.session_state:
231
+ st.session_state.counters = {"letter_count": 0, "word_count": 0}
232
+
233
+ if 'is_helpful' not in st.session_state:
234
+ st.session_state.is_helpful = {'ask':False}
235
+
236
+ if 'descriptions' not in st.session_state:
237
+ st.session_state.descriptions = []
238
+
239
+ st.title("You name it! πŸ—£")
240
+
241
+ # JS: would remove Simon by some neutral avatar
242
+ with st.chat_message('user'):
243
+ st.write("Hey assistant!")
244
+
245
+ bot = st.chat_message('assistant')
246
+ bot.write("Hello human! Wanna practice naming some words?")
247
+
248
+ #for showing history of messages
249
+ for message in st.session_state.messages:
250
+ if message['role'] == 'user':
251
+ with st.chat_message(message['role']):
252
+ st.markdown(message['content'])
253
+ else:
254
+ with st.chat_message(message['role']):
255
+ st.markdown(message['content'])
256
+
257
+ #display user message in chat message container
258
+ prompt = get_text()
259
+ if prompt:
260
+ #JS: would replace Simon by some neutral character
261
+ with st.chat_message('user'):
262
+ st.markdown(prompt)
263
+ #add to history
264
+ st.session_state.messages.append({'role': 'user', 'content': prompt})
265
+ #TODO: replace it with zero-shot classifier
266
+ yes = ['yes', 'again', 'Yes', 'sure', 'new word', 'yes!', 'yep', 'yeah']
267
+ if prompt in yes:
268
+ write_bot("Please describe your word!")
269
+ elif prompt == 'it is similar to the best place on earth':
270
+ write_bot("Great! Let me think what it could be...")
271
+ time.sleep(3)
272
+ write_bot("Do you mean Saarland?")
273
+ #if previously we asked to give a prompt
274
+ elif (st.session_state.messages[-2]['content'] == "Please describe your word!") & (st.session_state.messages[-1]['content'] != "no"):
275
+ write_bot("Great! Let me think what it could be...")
276
+ st.session_state.descriptions.append(prompt)
277
+ st.session_state.results['results'] = return_top_k(st.session_state.descriptions[-1])
278
+ st.session_state.results['results_print'] = dict(zip(range(1, 11), st.session_state.results['results']))
279
+ write_bot("I think I have some ideas. Do you want to see my guesses or do you want a cue?")
280
+ st.session_state.actions.append("result")
281
+
282
+ if st.session_state.actions[-1] == "result":
283
+ col1, col2, col3, col4, col5 = st.columns(5)
284
+ with col1:
285
+ a1 = st.button('Results', key=10)
286
+ with col2:
287
+ a2 = st.button('Cue', key=11)
288
+ if a1:
289
+ write_bot("Here are my guesses about your word:")
290
+ st.write(st.session_state.results['results_print'])
291
+ time.sleep(1)
292
+ write_bot('Does it help you remember the word?', remember=False)
293
+ st.session_state.is_helpful['ask'] = True
294
+ elif a2:
295
+ #write_bot(f'The first letter is {st.session_state.results["results"][0][0]}.')
296
+ #time.sleep(1)
297
+ st.session_state.actions.append('cue')
298
+ #cue_generation()
299
+ #write_bot('Does it help you remember the word?', remember=False)
300
+ #st.session_state.is_helpful['ask'] = True
301
+
302
+ if st.session_state.is_helpful['ask']:
303
+ ask_if_helped()
304
+
305
+ if st.session_state.actions[-1] == 'cue':
306
+ guessed = False
307
+ write_bot('What do you want to see?', remember=False, blink=False)
308
+
309
+ while guessed == False:
310
+ # JS
311
+ word_count = st.session_state.counters["word_count"]
312
+ target = st.session_state.results["results"][word_count]
313
+
314
+ col1, col2, col3, col4, col5 = st.columns(5)
315
+
316
+
317
+ with col1:
318
+ b1 = st.button("Next letter", key="1")
319
+ with col2:
320
+ b2 = st.button("Related words")
321
+ with col3:
322
+ b3 = st.button("Next word", key="2")
323
+ with col4:
324
+ b4 = st.button("All words", key="3")
325
+
326
+ # JS
327
+ #if get_available_cues(target):
328
+ # avail_cues = get_available_cues(target)
329
+ #cues_buttons = {cue_type: st.button(cue_type) for cue_type in avail_cues}
330
+
331
+ b5 = st.button("I remembered the word!", key="4", type='primary')
332
+ b6 = st.button("Exit", key="5", type='primary')
333
+ new = st.button('Play again', key=64, type='primary')
334
+
335
+ if b1:
336
+ st.session_state.counters["letter_count"] += 1
337
+ #word_count = st.session_state.counters["word_count"]
338
+ letter_count = st.session_state.counters["letter_count"]
339
+ if letter_count < len(target):
340
+ write_bot(f'The word starts with {st.session_state.results["results"][word_count][:letter_count]}. \n Does this help you remember the word?', remember=False)
341
+ #ask_if_helped()
342
+ st.session_state.is_helpful['ask'] = True
343
+ else:
344
+ write_bot(f'This is my predicted word: "{target}". Does this match your query?')
345
+ #ask_if_helped()
346
+ st.session_state.is_helpful['ask'] = True
347
+
348
+ elif b2:
349
+ rels = return_top_k(st.session_state.descriptions[-1], word=target, rels=True)
350
+ write_bot(f'Here are words that are related to your word: {", ".join(rels)}. \n Does this help you remember the word?', remember=False)
351
+ #ask_if_helped()
352
+ st.session_state.is_helpful['ask'] = True
353
+
354
+ elif b3:
355
+ st.session_state.counters["letter_count"] = 1
356
+ letter_count = st.session_state.counters["letter_count"]
357
+ st.session_state.counters["word_count"] += 1
358
+ word_count = st.session_state.counters["word_count"]
359
+ #write_bot(f'The next word starts with {st.session_state.results["results"][word_count][:letter_count]}', remember=False)
360
+ if letter_count < len(target):
361
+ write_bot(f'The next word starts with {st.session_state.results["results"][word_count][:letter_count]}. \n Does this help you remember the word?', remember=False)
362
+ #ask_if_helped()
363
+ st.session_state.is_helpful['ask'] = True
364
+ else:
365
+ write_bot(f'This is my predicted word: "{target}". Does this match your query?')
366
+ #ask_if_helped()
367
+ st.session_state.is_helpful['ask'] = True
368
+
369
+ #elif get_available_cues(target) and "Synonyms" in cues_buttons and cues_buttons['Synonyms']:
370
+ #write_bot(f'Here are synonyms for the current word: {", ".join(avail_cues["Synonyms"])}', remember=False)
371
+
372
+ #elif get_available_cues(target) and "Hypernyms" in cues_buttons and cues_buttons['Hypernyms']:
373
+ #write_bot(f'Here are hypernyms for the current word: {", ".join(avail_cues["Hypernyms"])}', remember=False)
374
+
375
+ #elif get_available_cues(target) and "Hyponyms" in cues_buttons and cues_buttons['Hyponyms']:
376
+ #write_bot(f'Here are hyponyms for the current word: {", ".join(avail_cues["Hyponyms"])}', remember=False)
377
+
378
+ #elif get_available_cues(target) and "Examples" in cues_buttons and cues_buttons['Examples']:
379
+ #write_bot(f'Here are example contexts for the current word: {", ".join(avail_cues["Examples"])}', remember=False)
380
+
381
+ elif b4:
382
+ write_bot(f"Here are all my guesses about your word: {st.session_state.results['results_print']}")
383
+
384
+ elif b5:
385
+ write_bot("Yay! I am happy I could be of help!")
386
+ st.session_state.counters["word_count"] = 0
387
+ st.session_state.counters["letter_count"] = 0
388
+ new = st.button('Play again', key=63)
389
+ if new:
390
+ write_bot("Please describe your word!")
391
+ guessed = True
392
+
393
+ break
394
+
395
+ elif b6:
396
+ write_bot("I am sorry I couldn't help you this time. See you soon!")
397
+ st.session_state.counters["word_count"] = 0
398
+ st.session_state.counters["letter_count"] = 0
399
+ st.session_state.actions.append('cue')
400
+
401
+ if new:
402
+ write_bot("Please describe your word!")
403
+ st.session_state.counters["word_count"] = 0
404
+ st.session_state.counters["letter_count"] = 0
405
+
406
+ break
407
+
pages/2_Context-based_chatbot.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from tqdm import tqdm
4
+ from transformers import pipeline
5
+ import numpy as np
6
+ import time
7
+ import string
8
+
9
+
10
+ # JS
11
+ import nltk
12
+ nltk.download('wordnet')
13
+ from nltk.corpus import wordnet as wn
14
+ from nltk.tokenize import word_tokenize
15
+
16
+ @st.cache_resource
17
+ def get_models(llama=False):
18
+ st.write('Loading the model...')
19
+ model = pipeline("fill-mask")
20
+ st.write("_The assistant is loaded and ready to use! :tada:_")
21
+ return model
22
+
23
+ model = get_models()
24
+
25
+ def remove_punctuation(word):
26
+ # Create a translation table that maps all punctuation characters to None
27
+ translator = str.maketrans('', '', string.punctuation)
28
+
29
+ # Use the translate method to remove punctuation from the word
30
+ word_without_punctuation = word.translate(translator)
31
+
32
+ return word_without_punctuation
33
+
34
+ def return_top_k(sentence, word=None, rels=False):
35
+
36
+ if sentence[-1] != ".":
37
+ sentence = sentence + "."
38
+
39
+ # if rels:
40
+ # inputs = [f"Description : It is related to '{word}' but not '{word}'. Word : "]
41
+ # else:
42
+ # inputs = [f"Description : {sentence} Word : "]
43
+
44
+ output = model(sentence)
45
+ output = [output[i]['token_str'] for i in output.keys()]
46
+ return output
47
+
48
+
49
+ # JS
50
+ # def get_related_words(word, num=5):
51
+ # model.eval()
52
+ # with torch.no_grad():
53
+ # sentence = [f"Descripton : It is related to {word} but not {word}. Word : "]
54
+ # #inputs = ["Description: It is something to cut stuff with. Word: "]
55
+ # print(sentence)
56
+ # inputs = tokenizer(sentence, padding=True, truncation=True, return_tensors="pt",)
57
+
58
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
+ # model.to(device)
60
+
61
+ # batch = {k: v.to(device) for k, v in inputs.items()}
62
+ # beam_outputs = model.generate(
63
+ # input_ids=batch['input_ids'], max_new_tokens=10, num_beams=num+2, num_return_sequences=num+2, early_stopping=True
64
+ # )
65
+
66
+ # #beam_preds = [tokenizer.decode(beam_output.detach().cpu().numpy(), skip_special_tokens=True) for beam_output in beam_outputs if ]
67
+ # beam_preds = []
68
+ # for beam_output in beam_outputs:
69
+ # prediction = tokenizer.decode(beam_output.detach().cpu().numpy(), skip_special_tokens=True).strip()
70
+ # if prediction not in " ".join(sentence):
71
+ # beam_preds.append(prediction)
72
+
73
+ # return ", ".join(beam_preds[:num])
74
+
75
+ #if 'messages' not in st.session_state:
76
+
77
+ def get_text():
78
+ input_text = st.chat_input()
79
+ return input_text
80
+
81
+ def write_bot(input, remember=True, blink=True):
82
+ with st.chat_message('assistant'):
83
+ message_placeholder = st.empty()
84
+ full_response = input
85
+ if blink == True:
86
+ response = ''
87
+ for chunk in full_response.split():
88
+ response += chunk + " "
89
+ time.sleep(0.05)
90
+ # Add a blinking cursor to simulate typing
91
+ message_placeholder.markdown(response + "β–Œ")
92
+ time.sleep(0.5)
93
+ message_placeholder.markdown(full_response)
94
+ if remember == True:
95
+ st.session_state.messages.append({'role': 'assistant', 'content': full_response})
96
+
97
+ def ask_if_helped():
98
+ y = st.button('Yes!', key=60)
99
+ n = st.button('No...', key=61)
100
+ new = st.button('I have a new word', key=62)
101
+ if y:
102
+ write_bot("I am happy to help!")
103
+ again = st.button('Play again')
104
+ if again:
105
+ write_bot("Please describe your word!")
106
+ st.session_state.is_helpful['ask'] = False
107
+ elif n:
108
+ st.session_state.actions.append('cue')
109
+ st.session_state.is_helpful['ask'] = False
110
+ #cue_generation()
111
+ elif new:
112
+ write_bot("Please describe your word!")
113
+ st.session_state.is_helpful['ask'] = False
114
+
115
+ ## removed: if st.session_state.actions[-1] == "result":
116
+
117
+ # JS
118
+ # def get_related_words_llama(relation, target, device, num=5):
119
+ # prompt = f"Provide {num} {relation}s for the word '{target}'. Your answer consists of these {num} words only. Do not include the word '{target}' itself in your answer"
120
+
121
+ # inputs = tokenizer([prompt], return_tensors='pt').to(device)
122
+ # output = model.generate(
123
+ # **inputs, max_new_tokens=40, temperature=.75, early_stopping=True,
124
+ # )
125
+ # chatbot_response = tokenizer.decode(output[:, inputs['input_ids'].shape[-1]:][0], skip_special_tokens=True).strip()
126
+
127
+ # postproc = [word for word in word_tokenize(chatbot_response) if len(word)>=3]
128
+
129
+ # return postproc[-num:] if len(postproc)>=num else postproc
130
+
131
+
132
+ def postproc_wn(related_words, syns=False):
133
+ if syns:
134
+ related_words = [word.split('.')[0] if word[0] != "." else word.split('.')[1] for word in related_words]
135
+ else:
136
+ related_words = [word.name().split('.')[0] if word.name()[0] != "." else word.name().split('.')[1] for word in related_words]
137
+ related_words = [word.replace("_", " ") for word in related_words]
138
+
139
+ return related_words
140
+
141
+ # JS
142
+ def get_available_cues(target):
143
+ wn_nouns = [word.name() for word in wn.all_synsets(pos='n')]
144
+ wn_nouns = [word.split('.')[0] if word[0] != "." else word.split('.')[1] for word in wn_nouns]
145
+
146
+ if target in wn_nouns:
147
+ available_cues = {}
148
+ synset_target = wn.synsets(target, pos=wn.NOUN)[0]
149
+
150
+ #if wn.synonyms(target)[0]:
151
+ # available_cues['Synonyms'] = postproc_wn(wn.synonyms(target)[0], syns=True)
152
+
153
+ #if synset_target.hypernyms():
154
+ # available_cues['Hypernyms'] = postproc_wn(synset_target.hypernyms())
155
+
156
+
157
+ #if synset_target.hyponyms():
158
+ # available_cues['Hyponyms'] = postproc_wn(synset_target.hyponyms())
159
+
160
+ if synset_target.examples():
161
+ examples = []
162
+
163
+ for example in synset_target.examples():
164
+ examples.append(example.replace(target, "..."))
165
+
166
+ available_cues['Examples'] = examples
167
+
168
+ return available_cues
169
+
170
+ else:
171
+ return None
172
+
173
+ # JS: moved the cue generation further down
174
+ #def cue_generation():
175
+ # if st.session_state.actions[-1] == 'cue':
176
+
177
+ if 'messages' not in st.session_state:
178
+ st.session_state.messages = []
179
+
180
+ if 'results' not in st.session_state:
181
+ st.session_state.results = {'results': False, 'results_print': False}
182
+
183
+ if 'actions' not in st.session_state:
184
+ st.session_state.actions = [""]
185
+
186
+ if 'counters' not in st.session_state:
187
+ st.session_state.counters = {"letter_count": 0, "word_count": 0}
188
+
189
+ if 'is_helpful' not in st.session_state:
190
+ st.session_state.is_helpful = {'ask':False}
191
+
192
+ if 'descriptions' not in st.session_state:
193
+ st.session_state.descriptions = []
194
+
195
+ st.title("You name it! πŸ—£")
196
+
197
+ # JS: would remove Simon by some neutral avatar
198
+ with st.chat_message('user'):
199
+ st.write("Hey assistant!")
200
+
201
+ bot = st.chat_message('assistant')
202
+ bot.write("Hello human! Wanna practice naming some words?")
203
+
204
+ #for showing history of messages
205
+ for message in st.session_state.messages:
206
+ if message['role'] == 'user':
207
+ with st.chat_message(message['role']):
208
+ st.markdown(message['content'])
209
+ else:
210
+ with st.chat_message(message['role']):
211
+ st.markdown(message['content'])
212
+
213
+ #display user message in chat message container
214
+ prompt = get_text()
215
+ if prompt:
216
+ #JS: would replace Simon by some neutral character
217
+ with st.chat_message('user'):
218
+ st.markdown(prompt)
219
+ #add to history
220
+ st.session_state.messages.append({'role': 'user', 'content': prompt})
221
+ #TODO: replace it with zero-shot classifier
222
+ yes = ['yes', 'again', 'Yes', 'sure', 'new word', 'yes!', 'yep', 'yeah']
223
+ if prompt in yes:
224
+ write_bot("Please give a sentence using a <mask> instead of the word you have in mind!")
225
+ elif prompt == 'it is similar to the best place on earth':
226
+ write_bot("Great! Let me think what it could be...")
227
+ time.sleep(3)
228
+ write_bot("Do you mean Saarland?")
229
+ #if previously we asked to give a prompt
230
+ elif (st.session_state.messages[-2]['content'] == "Please give a sentence using a <mask> instead of the word you have in mind!") & (st.session_state.messages[-1]['content'] != "no"):
231
+ write_bot("Great! Let me think what it could be...")
232
+ st.session_state.descriptions.append(prompt)
233
+ st.session_state.results['results'] = return_top_k(st.session_state.descriptions[-1])
234
+ st.session_state.results['results_print'] = dict(zip(range(1, 11), st.session_state.results['results']))
235
+ write_bot("I think I have some ideas. Do you want to see my guesses or do you want a cue?")
236
+ st.session_state.actions.append("result")
237
+
238
+ if st.session_state.actions[-1] == "result":
239
+ col1, col2, col3, col4, col5 = st.columns(5)
240
+ with col1:
241
+ a1 = st.button('Results', key=10)
242
+ with col2:
243
+ a2 = st.button('Cue', key=11)
244
+ if a1:
245
+ write_bot("Here are my guesses about your word:")
246
+ st.write(st.session_state.results['results_print'])
247
+ time.sleep(1)
248
+ write_bot('Does it help you remember the word?', remember=False)
249
+ st.session_state.is_helpful['ask'] = True
250
+ elif a2:
251
+ #write_bot(f'The first letter is {st.session_state.results["results"][0][0]}.')
252
+ #time.sleep(1)
253
+ st.session_state.actions.append('cue')
254
+ #cue_generation()
255
+ #write_bot('Does it help you remember the word?', remember=False)
256
+ #st.session_state.is_helpful['ask'] = True
257
+
258
+ if st.session_state.is_helpful['ask']:
259
+ ask_if_helped()
260
+
261
+ if st.session_state.actions[-1] == 'cue':
262
+ guessed = False
263
+ write_bot('What do you want to see?', remember=False, blink=False)
264
+
265
+ while guessed == False:
266
+ # JS
267
+ word_count = st.session_state.counters["word_count"]
268
+ target = st.session_state.results["results"][word_count]
269
+
270
+ col1, col2, col3, col4, col5 = st.columns(5)
271
+
272
+
273
+ with col1:
274
+ b1 = st.button("Next letter", key="1")
275
+ with col2:
276
+ b2 = st.button("Related words")
277
+ with col3:
278
+ b3 = st.button("Next word", key="2")
279
+ with col4:
280
+ b4 = st.button("All words", key="3")
281
+
282
+ # JS
283
+ #if get_available_cues(target):
284
+ # avail_cues = get_available_cues(target)
285
+ #cues_buttons = {cue_type: st.button(cue_type) for cue_type in avail_cues}
286
+
287
+ b5 = st.button("I remembered the word!", key="4", type='primary')
288
+ b6 = st.button("Exit", key="5", type='primary')
289
+ new = st.button('Play again', key=64, type='primary')
290
+
291
+ if b1:
292
+ st.session_state.counters["letter_count"] += 1
293
+ #word_count = st.session_state.counters["word_count"]
294
+ letter_count = st.session_state.counters["letter_count"]
295
+ if letter_count < len(target):
296
+ write_bot(f'The word starts with {st.session_state.results["results"][word_count][:letter_count]}. \n Does this help you remember the word?', remember=False)
297
+ #ask_if_helped()
298
+ st.session_state.is_helpful['ask'] = True
299
+ else:
300
+ write_bot(f'This is my predicted word: "{target}". Does this match your query?')
301
+ #ask_if_helped()
302
+ st.session_state.is_helpful['ask'] = True
303
+
304
+ elif b2:
305
+ rels = return_top_k(st.session_state.descriptions[-1], word=target, rels=True)
306
+ write_bot(f'Here are words that are related to your word: {", ".join(rels)}. \n Does this help you remember the word?', remember=False)
307
+ #ask_if_helped()
308
+ st.session_state.is_helpful['ask'] = True
309
+
310
+ elif b3:
311
+ st.session_state.counters["letter_count"] = 1
312
+ letter_count = st.session_state.counters["letter_count"]
313
+ st.session_state.counters["word_count"] += 1
314
+ word_count = st.session_state.counters["word_count"]
315
+ #write_bot(f'The next word starts with {st.session_state.results["results"][word_count][:letter_count]}', remember=False)
316
+ if letter_count < len(target):
317
+ write_bot(f'The next word starts with {st.session_state.results["results"][word_count][:letter_count]}. \n Does this help you remember the word?', remember=False)
318
+ #ask_if_helped()
319
+ st.session_state.is_helpful['ask'] = True
320
+ else:
321
+ write_bot(f'This is my predicted word: "{target}". Does this match your query?')
322
+ #ask_if_helped()
323
+ st.session_state.is_helpful['ask'] = True
324
+
325
+ #elif get_available_cues(target) and "Synonyms" in cues_buttons and cues_buttons['Synonyms']:
326
+ #write_bot(f'Here are synonyms for the current word: {", ".join(avail_cues["Synonyms"])}', remember=False)
327
+
328
+ #elif get_available_cues(target) and "Hypernyms" in cues_buttons and cues_buttons['Hypernyms']:
329
+ #write_bot(f'Here are hypernyms for the current word: {", ".join(avail_cues["Hypernyms"])}', remember=False)
330
+
331
+ #elif get_available_cues(target) and "Hyponyms" in cues_buttons and cues_buttons['Hyponyms']:
332
+ #write_bot(f'Here are hyponyms for the current word: {", ".join(avail_cues["Hyponyms"])}', remember=False)
333
+
334
+ #elif get_available_cues(target) and "Examples" in cues_buttons and cues_buttons['Examples']:
335
+ #write_bot(f'Here are example contexts for the current word: {", ".join(avail_cues["Examples"])}', remember=False)
336
+
337
+ elif b4:
338
+ write_bot(f"Here are all my guesses about your word: {st.session_state.results['results_print']}")
339
+
340
+ elif b5:
341
+ write_bot("Yay! I am happy I could be of help!")
342
+ st.session_state.counters["word_count"] = 0
343
+ st.session_state.counters["letter_count"] = 0
344
+ new = st.button('Play again', key=63)
345
+ if new:
346
+ write_bot("Please describe your word!")
347
+ guessed = True
348
+
349
+ break
350
+
351
+ elif b6:
352
+ write_bot("I am sorry I couldn't help you this time. See you soon!")
353
+ st.session_state.counters["word_count"] = 0
354
+ st.session_state.counters["letter_count"] = 0
355
+ st.session_state.actions.append('cue')
356
+
357
+ if new:
358
+ write_bot("Please describe your word!")
359
+ st.session_state.counters["word_count"] = 0
360
+ st.session_state.counters["letter_count"] = 0
361
+
362
+ break
363
+
pages/App.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ st.set_page_config(
4
+ page_title="You Name It!",
5
+ page_icon="πŸ‘‹",
6
+ )
7
+
8
+ st.write("# Welcome to YouNameIt chatbot! πŸ‘‹")
9
+
10
+ st.sidebar.success("Select a chatbot mode above.")
11
+
12
+ st.markdown(
13
+ """
14
+ YouNameIt is a project helping people with aphasia practice their word retrieval skill and assisting them to remember words on a daily basis.
15
+ **πŸ‘ˆ Select a chatbot mode from the sidebar** to test our app!
16
+ ### What new features are planned?
17
+ - Adaptation to German language and more;
18
+ - Speech-to-text suppport;
19
+ - Android & IOS mobile apps.
20
+ ### For any suggestions or ideas please contact us.
21
+ - Julian []()
22
+ - Nursulu []()
23
+ """
24
+ )
25
+