schlenker commited on
Commit
971a4bf
·
2 Parent(s): a4def93 2a20515

Merge branch 'main' of https://huggingface.co/spaces/YouNameIt/YouNameIt_chatbot into main

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -6,6 +6,7 @@ from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM
6
  from transformers import AutoTokenizer
7
  import numpy as np
8
  import time
 
9
 
10
  # JS
11
  import nltk
@@ -37,6 +38,15 @@ def get_models(llama=False):
37
 
38
  model, tokenizer = get_models()
39
 
 
 
 
 
 
 
 
 
 
40
  def return_top_k(sentence, k=10):
41
 
42
  if sentence[-1] != ".":
@@ -64,9 +74,10 @@ def return_top_k(sentence, k=10):
64
  #all word predictions
65
  predictions = [tokenizer.decode(tokens, skip_special_tokens=True) for tokens in output_sequences['sequences']]
66
  probabilities = [round(float(prob), 2) for prob in decoded_probabilities]
67
-
 
68
  for pred in predictions:
69
- if (len(pred) < 2) | (pred in sentence.split()):
70
  predictions.pop(predictions.index(pred))
71
 
72
  return predictions[:10]
 
6
  from transformers import AutoTokenizer
7
  import numpy as np
8
  import time
9
+ import string
10
 
11
  # JS
12
  import nltk
 
38
 
39
  model, tokenizer = get_models()
40
 
41
+ def remove_punctuation(word):
42
+ # Create a translation table that maps all punctuation characters to None
43
+ translator = str.maketrans('', '', string.punctuation)
44
+
45
+ # Use the translate method to remove punctuation from the word
46
+ word_without_punctuation = word.translate(translator)
47
+
48
+ return word_without_punctuation
49
+
50
  def return_top_k(sentence, k=10):
51
 
52
  if sentence[-1] != ".":
 
74
  #all word predictions
75
  predictions = [tokenizer.decode(tokens, skip_special_tokens=True) for tokens in output_sequences['sequences']]
76
  probabilities = [round(float(prob), 2) for prob in decoded_probabilities]
77
+
78
+ stripped_sent = [remove_punctuation(word.lower()) for word in sentence.split()]
79
  for pred in predictions:
80
+ if (len(pred) < 2) | (pred in stripped_sent):
81
  predictions.pop(predictions.index(pred))
82
 
83
  return predictions[:10]