Spaces:
Sleeping
Sleeping
Fix of repeating inputs
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM
|
|
6 |
from transformers import AutoTokenizer
|
7 |
import numpy as np
|
8 |
import time
|
|
|
9 |
|
10 |
# JS
|
11 |
import nltk
|
@@ -37,6 +38,15 @@ def get_models(llama=False):
|
|
37 |
|
38 |
model, tokenizer = get_models()
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
def return_top_k(sentence, k=10):
|
41 |
|
42 |
if sentence[-1] != ".":
|
@@ -64,9 +74,10 @@ def return_top_k(sentence, k=10):
|
|
64 |
#all word predictions
|
65 |
predictions = [tokenizer.decode(tokens, skip_special_tokens=True) for tokens in output_sequences['sequences']]
|
66 |
probabilities = [round(float(prob), 2) for prob in decoded_probabilities]
|
67 |
-
|
|
|
68 |
for pred in predictions:
|
69 |
-
if (len(pred) < 2) | (pred in
|
70 |
predictions.pop(predictions.index(pred))
|
71 |
|
72 |
return predictions[:10]
|
|
|
6 |
from transformers import AutoTokenizer
|
7 |
import numpy as np
|
8 |
import time
|
9 |
+
import string
|
10 |
|
11 |
# JS
|
12 |
import nltk
|
|
|
38 |
|
39 |
model, tokenizer = get_models()
|
40 |
|
41 |
+
def remove_punctuation(word):
|
42 |
+
# Create a translation table that maps all punctuation characters to None
|
43 |
+
translator = str.maketrans('', '', string.punctuation)
|
44 |
+
|
45 |
+
# Use the translate method to remove punctuation from the word
|
46 |
+
word_without_punctuation = word.translate(translator)
|
47 |
+
|
48 |
+
return word_without_punctuation
|
49 |
+
|
50 |
def return_top_k(sentence, k=10):
|
51 |
|
52 |
if sentence[-1] != ".":
|
|
|
74 |
#all word predictions
|
75 |
predictions = [tokenizer.decode(tokens, skip_special_tokens=True) for tokens in output_sequences['sequences']]
|
76 |
probabilities = [round(float(prob), 2) for prob in decoded_probabilities]
|
77 |
+
|
78 |
+
stripped_sent = [remove_punctuation(word) for word in sentence.split()]
|
79 |
for pred in predictions:
|
80 |
+
if (len(pred) < 2) | (pred in stripped_sent):
|
81 |
predictions.pop(predictions.index(pred))
|
82 |
|
83 |
return predictions[:10]
|