pedrocas15 commited on
Commit
7ff3f2d
·
verified ·
1 Parent(s): 2f324c8

Update rpc.py

Browse files
Files changed (1) hide show
  1. rpc.py +2 -1
rpc.py CHANGED
@@ -201,6 +201,7 @@ def load_index(index_path="/dev/shm/rpc-vecdb/index"):
201
  # Generate Function
202
  def generate(text, use_rpc=True, max_tokens=128):
203
  enc_text = tokenizer.encode(text, add_special_tokens=False)
 
204
  tok = None
205
  i = 0
206
  while i < max_tokens and tok != vocab_size - 2:
@@ -208,7 +209,7 @@ def generate(text, use_rpc=True, max_tokens=128):
208
  enc_text = enc_text[-input_size:]
209
  if use_rpc:
210
  xq = vectorize_texts([enc_text])[-1]
211
- _id = index.search(xq, size=1, epsilon=2)[0][0]
212
  #_id = index.search(xq.reshape((1, -1)), 1)[1][0][0]
213
  if all_toks[_id] in carry_toks:
214
  tmp = tf.argmax(tf.matmul(xq.reshape((1, -1)), encoder.layers[1].shared_weights, transpose_b=True), axis=-1).numpy()[0]
 
201
  # Generate Function
202
  def generate(text, use_rpc=True, max_tokens=128):
203
  enc_text = tokenizer.encode(text, add_special_tokens=False)
204
+ text = tokenizer.decode(enc_text)
205
  tok = None
206
  i = 0
207
  while i < max_tokens and tok != vocab_size - 2:
 
209
  enc_text = enc_text[-input_size:]
210
  if use_rpc:
211
  xq = vectorize_texts([enc_text])[-1]
212
+ _id = index.search(xq, size=1, epsilon=1)[0][0]
213
  #_id = index.search(xq.reshape((1, -1)), 1)[1][0][0]
214
  if all_toks[_id] in carry_toks:
215
  tmp = tf.argmax(tf.matmul(xq.reshape((1, -1)), encoder.layers[1].shared_weights, transpose_b=True), axis=-1).numpy()[0]