Spaces:
Runtime error
Runtime error
Update rpc.py
Browse files
rpc.py
CHANGED
@@ -201,6 +201,7 @@ def load_index(index_path="/dev/shm/rpc-vecdb/index"):
|
|
201 |
# Generate Function
|
202 |
def generate(text, use_rpc=True, max_tokens=128):
|
203 |
enc_text = tokenizer.encode(text, add_special_tokens=False)
|
|
|
204 |
tok = None
|
205 |
i = 0
|
206 |
while i < max_tokens and tok != vocab_size - 2:
|
@@ -208,7 +209,7 @@ def generate(text, use_rpc=True, max_tokens=128):
|
|
208 |
enc_text = enc_text[-input_size:]
|
209 |
if use_rpc:
|
210 |
xq = vectorize_texts([enc_text])[-1]
|
211 |
-
_id = index.search(xq, size=1, epsilon=
|
212 |
#_id = index.search(xq.reshape((1, -1)), 1)[1][0][0]
|
213 |
if all_toks[_id] in carry_toks:
|
214 |
tmp = tf.argmax(tf.matmul(xq.reshape((1, -1)), encoder.layers[1].shared_weights, transpose_b=True), axis=-1).numpy()[0]
|
|
|
201 |
# Generate Function
|
202 |
def generate(text, use_rpc=True, max_tokens=128):
|
203 |
enc_text = tokenizer.encode(text, add_special_tokens=False)
|
204 |
+
text = tokenizer.decode(enc_text)
|
205 |
tok = None
|
206 |
i = 0
|
207 |
while i < max_tokens and tok != vocab_size - 2:
|
|
|
209 |
enc_text = enc_text[-input_size:]
|
210 |
if use_rpc:
|
211 |
xq = vectorize_texts([enc_text])[-1]
|
212 |
+
_id = index.search(xq, size=1, epsilon=1)[0][0]
|
213 |
#_id = index.search(xq.reshape((1, -1)), 1)[1][0][0]
|
214 |
if all_toks[_id] in carry_toks:
|
215 |
tmp = tf.argmax(tf.matmul(xq.reshape((1, -1)), encoder.layers[1].shared_weights, transpose_b=True), axis=-1).numpy()[0]
|