ZoniaChatbot commited on
Commit
9280c25
·
verified ·
1 Parent(s): ed8a94b

Update chatpdf.py

Browse files
Files changed (1) hide show
  1. chatpdf.py +122 -60
chatpdf.py CHANGED
@@ -1,8 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- @author:XuMing([email protected])
4
- @description:
5
- """
6
  import argparse
7
  import hashlib
8
  import os
@@ -18,6 +13,7 @@ from similarities import (
18
  EnsembleSimilarity,
19
  BertSimilarity,
20
  BM25Similarity,
 
21
  )
22
  from similarities.similarity import SimilarityABC
23
  from transformers import (
@@ -43,10 +39,9 @@ MODEL_CLASSES = {
43
  "auto": (AutoModelForCausalLM, AutoTokenizer),
44
  }
45
 
46
- PROMPT_TEMPLATE = """Basándose en la siguiente información conocida, responda a la pregunta del usuario de forma
47
- concisa y profesional. Si no puede obtener una respuesta, diga «No se puede responder a la pregunta basándose en la
48
- información conocida» o «No se proporciona suficiente información relevante», no está permitido añadir elementos
49
- inventados en la respuesta, y ésta debe estar en Español.
50
 
51
  Contenido conocido:
52
  {context_str}
@@ -55,8 +50,6 @@ Pregunta:
55
  {query_str}
56
  """
57
 
58
-
59
-
60
  class SentenceSplitter:
61
  def __init__(self, chunk_size: int = 250, chunk_overlap: int = 50):
62
  self.chunk_size = chunk_size
@@ -121,8 +114,7 @@ class SentenceSplitter:
121
  return overlapped_chunks
122
 
123
 
124
-
125
- class Rag:
126
  def __init__(
127
  self,
128
  similarity_model: SimilarityABC = None,
@@ -139,8 +131,8 @@ class Rag:
139
  rerank_model_name_or_path: str = None,
140
  enable_history: bool = False,
141
  num_expand_context_chunk: int = 2,
142
- similarity_top_k: int = 10,
143
- rerank_top_k: int = 3,
144
  ):
145
  """
146
  Init RAG model.
@@ -176,9 +168,11 @@ class Rag:
176
  if similarity_model is not None:
177
  self.sim_model = similarity_model
178
  else:
179
- m1 = BertSimilarity(model_name_or_path="hiiamsid/sentence_similarity_spanish_es", device=self.device)
180
  m2 = BM25Similarity()
181
- default_sim_model = EnsembleSimilarity(similarities=[m1, m2], weights=[0.5, 0.5], c=2)
 
 
182
  self.sim_model = default_sim_model
183
  self.gen_model, self.tokenizer = self._init_gen_model(
184
  generate_model_type,
@@ -243,14 +237,14 @@ class Rag:
243
  try:
244
  model.generation_config = GenerationConfig.from_pretrained(gen_model_name_or_path, trust_remote_code=True)
245
  except Exception as e:
246
- logger.warning(f"Failed to load generation config from {gen_model_name_or_path}, {e}")
247
  if peft_name:
248
  model = PeftModel.from_pretrained(
249
  model,
250
  peft_name,
251
  torch_dtype="auto",
252
  )
253
- logger.info(f"Loaded peft model from {peft_name}")
254
  model.eval()
255
  return model, tokenizer
256
 
@@ -341,6 +335,7 @@ class Rag:
341
  raw_text = [text.strip() for text in page_text.splitlines() if text.strip()]
342
  new_text = ''
343
  for text in raw_text:
 
344
  if new_text:
345
  new_text += ' '
346
  new_text += text
@@ -397,25 +392,37 @@ class Rag:
397
  return scores
398
 
399
  def get_reference_results(self, query: str):
400
- """
401
- Get reference results.
402
- 1. Similarity model get similar chunks
403
- 2. Rerank similar chunks
404
- 3. Expand reference context chunk
405
- :param query:
406
- :return:
407
- """
408
- reference_results = []
 
 
 
 
 
 
 
 
409
  sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
410
- # Get reference results from corpus
411
- hit_chunk_dict = dict()
412
- for c in sim_contents:
413
- for id_score_dict in c:
414
- corpus_id = id_score_dict['corpus_id']
415
- hit_chunk = id_score_dict["corpus_doc"]
416
- reference_results.append(hit_chunk)
417
- hit_chunk_dict[corpus_id] = hit_chunk
418
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  if reference_results:
420
  if self.rerank_model is not None:
421
  # Rerank reference results
@@ -440,9 +447,9 @@ class Rag:
440
  def predict_stream(
441
  self,
442
  query: str,
443
- max_length: int = 512,
444
- context_len: int = 2048,
445
- temperature: float = 0.7,
446
  ):
447
  """Generate predictions stream."""
448
  stop_str = self.tokenizer.eos_token if self.tokenizer.eos_token else "</s>"
@@ -450,15 +457,16 @@ class Rag:
450
  self.history = []
451
  if self.sim_model.corpus:
452
  reference_results = self.get_reference_results(query)
453
- if reference_results:
454
- reference_results = self._add_source_numbers(reference_results)
455
- context_str = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
456
- else:
457
- context_str = ''
458
  prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
 
459
  else:
460
  prompt = query
461
- logger.debug(f"prompt: {prompt}")
462
  self.history.append([prompt, ''])
463
  response = ""
464
  for new_text in self.stream_generate_answer(
@@ -473,9 +481,9 @@ class Rag:
473
  def predict(
474
  self,
475
  query: str,
476
- max_length: int = 512,
477
- context_len: int = 2048,
478
- temperature: float = 0.7,
479
  ):
480
  """Query from corpus."""
481
  reference_results = []
@@ -483,15 +491,20 @@ class Rag:
483
  self.history = []
484
  if self.sim_model.corpus:
485
  reference_results = self.get_reference_results(query)
486
- if reference_results:
487
- reference_results = self._add_source_numbers(reference_results)
488
- context_str = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
489
- else:
490
- context_str = ''
 
 
 
 
 
491
  prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
 
492
  else:
493
  prompt = query
494
- logger.debug(f"prompt: {prompt}")
495
  self.history.append([prompt, ''])
496
  response = ""
497
  for new_text in self.stream_generate_answer(
@@ -504,8 +517,29 @@ class Rag:
504
  self.history[-1][1] = response
505
  return response, reference_results
506
 
507
- def query(self, query: str, **kwargs):
508
- return self.predict(query, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
 
510
  def save_corpus_emb(self):
511
  dir_name = self.get_file_hash(self.corpus_files)
@@ -517,13 +551,15 @@ class Rag:
517
 
518
  def load_corpus_emb(self, emb_dir: str):
519
  if hasattr(self.sim_model, 'load_corpus_embeddings'):
520
- logger.debug(f"Loading corpus embeddings from {emb_dir}")
521
  self.sim_model.load_corpus_embeddings(emb_dir)
 
 
522
 
523
 
524
  if __name__ == "__main__":
525
  parser = argparse.ArgumentParser()
526
- parser.add_argument("--sim_model_name", type=str, default="hiiamsid/sentence_similarity_spanish_es")
527
  parser.add_argument("--gen_model_type", type=str, default="auto")
528
  parser.add_argument("--gen_model_name", type=str, default="Qwen/Qwen2-0.5B-Instruct")
529
  parser.add_argument("--lora_model", type=str, default=None)
@@ -538,7 +574,7 @@ if __name__ == "__main__":
538
  args = parser.parse_args()
539
  print(args)
540
  sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
541
- m = Rag(
542
  similarity_model=sim_model,
543
  generate_model_type=args.gen_model_type,
544
  generate_model_name_or_path=args.gen_model_name,
@@ -551,4 +587,30 @@ if __name__ == "__main__":
551
  corpus_files=args.corpus_files.split(','),
552
  num_expand_context_chunk=args.num_expand_context_chunk,
553
  rerank_model_name_or_path=args.rerank_model_name,
554
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import argparse
2
  import hashlib
3
  import os
 
13
  EnsembleSimilarity,
14
  BertSimilarity,
15
  BM25Similarity,
16
+ TfidfSimilarity
17
  )
18
  from similarities.similarity import SimilarityABC
19
  from transformers import (
 
39
  "auto": (AutoModelForCausalLM, AutoTokenizer),
40
  }
41
 
42
+ PROMPT_TEMPLATE = """Basándose únicamente en la información proporcionada a continuación, responda a las preguntas del usuario de manera concisa y profesional.
43
+ No se debe responder a preguntas relacionadas con sentimientos, emociones, temas personales o cualquier información que no esté explícitamente presente en el contenido proporcionado.
44
+ Si la pregunta se refiere a un artículo específico y no se encuentra en el contenido proporcionado, diga: "No se puede encontrar el artículo solicitado en la información conocida".
 
45
 
46
  Contenido conocido:
47
  {context_str}
 
50
  {query_str}
51
  """
52
 
 
 
53
  class SentenceSplitter:
54
  def __init__(self, chunk_size: int = 250, chunk_overlap: int = 50):
55
  self.chunk_size = chunk_size
 
114
  return overlapped_chunks
115
 
116
 
117
+ class ChatPDF:
 
118
  def __init__(
119
  self,
120
  similarity_model: SimilarityABC = None,
 
131
  rerank_model_name_or_path: str = None,
132
  enable_history: bool = False,
133
  num_expand_context_chunk: int = 2,
134
+ similarity_top_k: int = 15,
135
+ rerank_top_k: int = 5,
136
  ):
137
  """
138
  Init RAG model.
 
168
  if similarity_model is not None:
169
  self.sim_model = similarity_model
170
  else:
171
+ m1 = BertSimilarity(model_name_or_path="sentence-transformers/all-mpnet-base-v2", device=self.device)
172
  m2 = BM25Similarity()
173
+ m3 = TfidfSimilarity()
174
+ default_sim_model = EnsembleSimilarity(similarities=[m1, m2, m3], weights=[0.5, 0.5, 0.5],
175
+ c=2) # Ajuste los pesos según los resultados
176
  self.sim_model = default_sim_model
177
  self.gen_model, self.tokenizer = self._init_gen_model(
178
  generate_model_type,
 
237
  try:
238
  model.generation_config = GenerationConfig.from_pretrained(gen_model_name_or_path, trust_remote_code=True)
239
  except Exception as e:
240
+ logger.warning(f"No se pudo cargar la configuración de generación desde {gen_model_name_or_path}, {e}")
241
  if peft_name:
242
  model = PeftModel.from_pretrained(
243
  model,
244
  peft_name,
245
  torch_dtype="auto",
246
  )
247
+ logger.info(f"Modelo peft cargado desde {peft_name}")
248
  model.eval()
249
  return model, tokenizer
250
 
 
335
  raw_text = [text.strip() for text in page_text.splitlines() if text.strip()]
336
  new_text = ''
337
  for text in raw_text:
338
+ # Añadir un espacio antes de concatenar si new_text no está vacío
339
  if new_text:
340
  new_text += ' '
341
  new_text += text
 
392
  return scores
393
 
394
  def get_reference_results(self, query: str):
395
+ # Verificar si la consulta incluye un "Artículo X"
396
+ exact_match = None
397
+ if re.search(r'Artículo\s*\d+', query, re.IGNORECASE):
398
+ # Buscar el término específico "Artículo X" en el corpus de manera más precisa
399
+ term = re.search(r'Artículo\s*\d+', query, re.IGNORECASE).group()
400
+ # Buscar coincidencias exactas en el corpus
401
+ for corpus_id, content in self.sim_model.corpus.items():
402
+ # Agregar espacio o signo de puntuación alrededor de "term" para evitar coincidencias parciales
403
+ if re.search(r'\b' + re.escape(term) + r'\b', content, re.IGNORECASE):
404
+ exact_match = content
405
+ break
406
+
407
+ if exact_match:
408
+ # Si se encuentra una coincidencia exacta, devolverla como contexto
409
+ return [exact_match]
410
+
411
+ # Si no se encuentra una coincidencia exacta, continuar con la búsqueda general
412
  sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
 
 
 
 
 
 
 
 
413
 
414
+ # Procesar los resultados de similitud
415
+ reference_results = []
416
+
417
+ hit_chunk_dict = dict()
418
+ threshold_score = 0.5 # Establece un umbral para filtrar fragmentos irrelevantes
419
+
420
+ for query_id, id_score_dict in sim_contents.items():
421
+ for corpus_id, s in id_score_dict.items():
422
+ if s > threshold_score: # Filtrar por puntuación de similitud
423
+ hit_chunk = self.sim_model.corpus[corpus_id]
424
+ reference_results.append(hit_chunk)
425
+ hit_chunk_dict[corpus_id] = hit_chunk
426
  if reference_results:
427
  if self.rerank_model is not None:
428
  # Rerank reference results
 
447
  def predict_stream(
448
  self,
449
  query: str,
450
+ max_length: int = 256,
451
+ context_len: int = 1024,
452
+ temperature: float = 0.5,
453
  ):
454
  """Generate predictions stream."""
455
  stop_str = self.tokenizer.eos_token if self.tokenizer.eos_token else "</s>"
 
457
  self.history = []
458
  if self.sim_model.corpus:
459
  reference_results = self.get_reference_results(query)
460
+ if not reference_results:
461
+ yield 'No se ha proporcionado suficiente información relevante', reference_results
462
+ reference_results = self._add_source_numbers(reference_results)
463
+ context_str = '\n'.join(reference_results)[:]
464
+ print("gggggg: ", (context_len - len(PROMPT_TEMPLATE)))
465
  prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
466
+ logger.debug(f"prompt: {prompt}")
467
  else:
468
  prompt = query
469
+ logger.debug(prompt)
470
  self.history.append([prompt, ''])
471
  response = ""
472
  for new_text in self.stream_generate_answer(
 
481
  def predict(
482
  self,
483
  query: str,
484
+ max_length: int = 256,
485
+ context_len: int = 1024,
486
+ temperature: float = 0.5,
487
  ):
488
  """Query from corpus."""
489
  reference_results = []
 
491
  self.history = []
492
  if self.sim_model.corpus:
493
  reference_results = self.get_reference_results(query)
494
+
495
+ if not reference_results:
496
+ return 'No se ha proporcionado suficiente información relevante', reference_results
497
+ reference_results = self._add_source_numbers(reference_results)
498
+ # context_str = '\n'.join(reference_results) # Usa todos los fragmentos
499
+ context_st = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
500
+ #print("Context: ", (context_len - len(PROMPT_TEMPLATE)))
501
+ print(".......................................................")
502
+ context_str = '\n'.join(reference_results)[:]
503
+ #print("context_str: ", context_str)
504
  prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
505
+ logger.debug(f"prompt: {prompt}")
506
  else:
507
  prompt = query
 
508
  self.history.append([prompt, ''])
509
  response = ""
510
  for new_text in self.stream_generate_answer(
 
517
  self.history[-1][1] = response
518
  return response, reference_results
519
 
520
+ def save_corpus_text(self):
521
+ if not self.corpus_files:
522
+ logger.warning("No hay archivos de corpus para guardar.")
523
+ return
524
+
525
+ corpus_text_file = os.path.join("corpus_embs/", "corpus_text.txt")
526
+
527
+ with open(corpus_text_file, 'w', encoding='utf-8') as f:
528
+ for chunk in self.sim_model.corpus.values():
529
+ f.write(chunk + "\n\n") # Añade dos saltos de línea entre chunks para mejor legibilidad
530
+
531
+ logger.info(f"Texto del corpus guardado en: {corpus_text_file}")
532
+ return corpus_text_file
533
+
534
+ def load_corpus_text(self, emb_dir: str):
535
+ corpus_text_file = os.path.join("corpus_embs/", "corpus_text.txt")
536
+ if os.path.exists(corpus_text_file):
537
+ with open(corpus_text_file, 'r', encoding='utf-8') as f:
538
+ corpus_text = f.read().split("\n\n") # Asumiendo que usamos dos saltos de línea como separador
539
+ self.sim_model.corpus = {i: chunk.strip() for i, chunk in enumerate(corpus_text) if chunk.strip()}
540
+ logger.info(f"Texto del corpus cargado desde: {corpus_text_file}")
541
+ else:
542
+ logger.warning(f"No se encontró el archivo de texto del corpus en: {corpus_text_file}")
543
 
544
  def save_corpus_emb(self):
545
  dir_name = self.get_file_hash(self.corpus_files)
 
551
 
552
  def load_corpus_emb(self, emb_dir: str):
553
  if hasattr(self.sim_model, 'load_corpus_embeddings'):
554
+ logger.debug(f"Cargando incrustaciones del corpus desde {emb_dir}")
555
  self.sim_model.load_corpus_embeddings(emb_dir)
556
+ # Cargar el texto del corpus
557
+ self.load_corpus_text(emb_dir)
558
 
559
 
560
  if __name__ == "__main__":
561
  parser = argparse.ArgumentParser()
562
+ parser.add_argument("--sim_model_name", type=str, default="sentence-transformers/all-mpnet-base-v2")
563
  parser.add_argument("--gen_model_type", type=str, default="auto")
564
  parser.add_argument("--gen_model_name", type=str, default="Qwen/Qwen2-0.5B-Instruct")
565
  parser.add_argument("--lora_model", type=str, default=None)
 
574
  args = parser.parse_args()
575
  print(args)
576
  sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
577
+ m = ChatPDF(
578
  similarity_model=sim_model,
579
  generate_model_type=args.gen_model_type,
580
  generate_model_name_or_path=args.gen_model_name,
 
587
  corpus_files=args.corpus_files.split(','),
588
  num_expand_context_chunk=args.num_expand_context_chunk,
589
  rerank_model_name_or_path=args.rerank_model_name,
590
+ )
591
+
592
+ # Comprobar si existen incrustaciones guardadas
593
+ dir_name = m.get_file_hash(args.corpus_files.split(','))
594
+ save_dir = os.path.join(m.save_corpus_emb_dir, dir_name)
595
+
596
+ if os.path.exists(save_dir):
597
+ # Cargar las incrustaciones guardadas
598
+ m.load_corpus_emb(save_dir)
599
+ print(f"Incrustaciones del corpus cargadas desde: {save_dir}")
600
+ else:
601
+ # Procesar el corpus y guardar las incrustaciones
602
+ m.add_corpus(args.corpus_files.split(','))
603
+ save_dir = m.save_corpus_emb()
604
+ # Guardar el texto del corpus
605
+ m.save_corpus_text()
606
+ print(f"Las incrustaciones del corpus se han guardado en: {save_dir}")
607
+
608
+ while True:
609
+ query = input("\nEnter a query: ")
610
+ if query == "exit":
611
+ break
612
+ if query.strip() == "":
613
+ continue
614
+ r, refs = m.predict(query)
615
+ print(r, refs)
616
+ print("\nRespuesta: ", r)