ZoniaChatbot commited on
Commit
f45d702
verified
1 Parent(s): 219396b

Update chatpdf.py

Browse files
Files changed (1) hide show
  1. chatpdf.py +34 -34
chatpdf.py CHANGED
@@ -373,44 +373,44 @@ class ChatPDF:
373
  return scores
374
 
375
  def get_reference_results(self, query: str):
376
- reference_results = []
377
- sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
378
-
379
- # Ajustar seg煤n el tipo de retorno de sim_contents
380
- if isinstance(sim_contents, dict): # Si es un diccionario
 
 
 
 
 
 
 
381
  for query_id, id_score_dict in sim_contents.items():
382
  for corpus_id, s in id_score_dict.items():
383
  hit_chunk = self.sim_model.corpus[corpus_id]
384
  reference_results.append(hit_chunk)
385
- elif isinstance(sim_contents, list): # Si es una lista
386
- for item in sim_contents:
387
- # Ajusta esto dependiendo de la estructura de los elementos de la lista
388
- # Ejemplo: si es una lista de (corpus_id, score) tuplas
389
- corpus_id, _ = item
390
- hit_chunk = self.sim_model.corpus[corpus_id]
391
- reference_results.append(hit_chunk)
392
-
393
- # Resto del c贸digo...
394
- if reference_results:
395
- if self.rerank_model is not None:
396
- # Rerank reference results
397
- rerank_scores = self._get_reranker_score(query, reference_results)
398
- logger.debug(f"rerank_scores: {rerank_scores}")
399
- # Get rerank top k chunks
400
- reference_results = [reference for reference, score in sorted(
401
- zip(reference_results, rerank_scores), key=lambda x: x[1], reverse=True)][:self.rerank_top_k]
402
- hit_chunk_dict = {corpus_id: hit_chunk for corpus_id, hit_chunk in hit_chunk_dict.items() if
403
- hit_chunk in reference_results}
404
- # Expand reference context chunk
405
- if self.num_expand_context_chunk > 0:
406
- new_reference_results = []
407
- for corpus_id, hit_chunk in hit_chunk_dict.items():
408
- expanded_reference = self.sim_model.corpus.get(corpus_id - 1, '') + hit_chunk
409
- for i in range(self.num_expand_context_chunk):
410
- expanded_reference += self.sim_model.corpus.get(corpus_id + i + 1, '')
411
- new_reference_results.append(expanded_reference)
412
- reference_results = new_reference_results
413
- return reference_results
414
 
415
  def predict_stream(
416
  self,
 
373
  return scores
374
 
375
  def get_reference_results(self, query: str):
376
+ """
377
+ Get reference results.
378
+ 1. Similarity model get similar chunks
379
+ 2. Rerank similar chunks
380
+ 3. Expand reference context chunk
381
+ :param query:
382
+ :return:
383
+ """
384
+ reference_results = []
385
+ sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
386
+ # Get reference results from corpus
387
+ hit_chunk_dict = dict()
388
  for query_id, id_score_dict in sim_contents.items():
389
  for corpus_id, s in id_score_dict.items():
390
  hit_chunk = self.sim_model.corpus[corpus_id]
391
  reference_results.append(hit_chunk)
392
+ hit_chunk_dict[corpus_id] = hit_chunk
393
+
394
+ if reference_results:
395
+ if self.rerank_model is not None:
396
+ # Rerank reference results
397
+ rerank_scores = self._get_reranker_score(query, reference_results)
398
+ logger.debug(f"rerank_scores: {rerank_scores}")
399
+ # Get rerank top k chunks
400
+ reference_results = [reference for reference, score in sorted(
401
+ zip(reference_results, rerank_scores), key=lambda x: x[1], reverse=True)][:self.rerank_top_k]
402
+ hit_chunk_dict = {corpus_id: hit_chunk for corpus_id, hit_chunk in hit_chunk_dict.items() if
403
+ hit_chunk in reference_results}
404
+ # Expand reference context chunk
405
+ if self.num_expand_context_chunk > 0:
406
+ new_reference_results = []
407
+ for corpus_id, hit_chunk in hit_chunk_dict.items():
408
+ expanded_reference = self.sim_model.corpus.get(corpus_id - 1, '') + hit_chunk
409
+ for i in range(self.num_expand_context_chunk):
410
+ expanded_reference += self.sim_model.corpus.get(corpus_id + i + 1, '')
411
+ new_reference_results.append(expanded_reference)
412
+ reference_results = new_reference_results
413
+ return reference_results
 
 
 
 
 
 
 
414
 
415
  def predict_stream(
416
  self,