tcy6 commited on
Commit
becce76
·
1 Parent(s): f63f4f8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -135,6 +135,8 @@ def retrieve_gradio(knowledge_base: str, query: str, topk: int):
135
  doc_list = [f for f in os.listdir(target_cache_dir) if f.endswith('.npy')]
136
  doc_list = sorted(doc_list)
137
  doc_reps = [np.load(os.path.join(target_cache_dir, f)) for f in doc_list]
 
 
138
 
139
  query_with_instruction = "Represent this query for retrieving relevant document: " + query
140
  with torch.no_grad():
@@ -142,7 +144,6 @@ def retrieve_gradio(knowledge_base: str, query: str, topk: int):
142
 
143
  query_md5 = hashlib.md5(query.encode()).hexdigest()
144
 
145
- doc_reps_cat = torch.cat([torch.Tensor(i) for i in doc_reps], dim=0)
146
  print(f"query_rep_shape: {query_rep.shape}, doc_reps_cat_shape: {doc_reps_cat.shape}")
147
  similarities = torch.matmul(query_rep, doc_reps_cat.T)
148
 
 
135
  doc_list = [f for f in os.listdir(target_cache_dir) if f.endswith('.npy')]
136
  doc_list = sorted(doc_list)
137
  doc_reps = [np.load(os.path.join(target_cache_dir, f)) for f in doc_list]
138
+ doc_reps_cat = torch.cat([torch.Tensor(i) for i in doc_reps], dim=0)
139
+ doc_reps_cat = torch.cat([i for i in doc_reps_cat], dim=0)
140
 
141
  query_with_instruction = "Represent this query for retrieving relevant document: " + query
142
  with torch.no_grad():
 
144
 
145
  query_md5 = hashlib.md5(query.encode()).hexdigest()
146
 
 
147
  print(f"query_rep_shape: {query_rep.shape}, doc_reps_cat_shape: {doc_reps_cat.shape}")
148
  similarities = torch.matmul(query_rep, doc_reps_cat.T)
149