ZoniaChatbot commited on
Commit
56fb6ea
·
verified ·
1 Parent(s): 7948ebe

Update chatpdf.py

Browse files
Files changed (1) hide show
  1. chatpdf.py +36 -105
chatpdf.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  import argparse
2
  import hashlib
3
  import os
@@ -38,24 +43,6 @@ MODEL_CLASSES = {
38
  "auto": (AutoModelForCausalLM, AutoTokenizer),
39
  }
40
 
41
- PROMPT_TEMPLATE1 = """基于以下已知信息,简洁和专业的来回答用户的问题。
42
- 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
43
-
44
- 已知内容:
45
- {context_str}
46
-
47
- 问题:
48
- {query_str}
49
- """
50
- PROMPT_TEMPLATE1 = """Utiliza la siguiente información para responder a la pregunta del usuario.
51
- Si no sabes la respuesta, di simplemente que no la sabes, no intentes inventarte una respuesta.
52
-
53
- Contexto: {context_str}
54
- Pregunta: {query_str}
55
-
56
- Devuelve sólo la respuesta útil que aparece a continuación y nada más, y ésta debe estar en Español.
57
- Respuesta útil:
58
- """
59
  PROMPT_TEMPLATE = """Basándose en la siguiente información conocida, responda a la pregunta del usuario de forma
60
  concisa y profesional. Si no puede obtener una respuesta, diga «No se puede responder a la pregunta basándose en la
61
  información conocida» o «No se proporciona suficiente información relevante», no está permitido añadir elementos
@@ -69,6 +56,7 @@ Pregunta:
69
  """
70
 
71
 
 
72
  class SentenceSplitter:
73
  def __init__(self, chunk_size: int = 250, chunk_overlap: int = 50):
74
  self.chunk_size = chunk_size
@@ -134,7 +122,7 @@ class SentenceSplitter:
134
 
135
 
136
 
137
- class ChatPDF:
138
  def __init__(
139
  self,
140
  similarity_model: SimilarityABC = None,
@@ -151,8 +139,8 @@ class ChatPDF:
151
  rerank_model_name_or_path: str = None,
152
  enable_history: bool = False,
153
  num_expand_context_chunk: int = 2,
154
- similarity_top_k: int = 5,
155
- rerank_top_k: int =3,
156
  ):
157
  """
158
  Init RAG model.
@@ -188,7 +176,7 @@ class ChatPDF:
188
  if similarity_model is not None:
189
  self.sim_model = similarity_model
190
  else:
191
- m1 = BertSimilarity(model_name_or_path="jaimevera1107/all-MiniLM-L6-v2-similarity-es", device=self.device)
192
  m2 = BM25Similarity()
193
  default_sim_model = EnsembleSimilarity(similarities=[m1, m2], weights=[0.5, 0.5], c=2)
194
  self.sim_model = default_sim_model
@@ -205,7 +193,7 @@ class ChatPDF:
205
  self.add_corpus(corpus_files)
206
  self.save_corpus_emb_dir = save_corpus_emb_dir
207
  if rerank_model_name_or_path is None:
208
- rerank_model_name_or_path = "BAAI/bge-reranker-large"
209
  if rerank_model_name_or_path:
210
  self.rerank_tokenizer = AutoTokenizer.from_pretrained(rerank_model_name_or_path)
211
  self.rerank_model = AutoModelForSequenceClassification.from_pretrained(rerank_model_name_or_path)
@@ -255,14 +243,14 @@ class ChatPDF:
255
  try:
256
  model.generation_config = GenerationConfig.from_pretrained(gen_model_name_or_path, trust_remote_code=True)
257
  except Exception as e:
258
- logger.warning(f"No se pudo cargar la configuración de generación desde {gen_model_name_or_path}, {e}")
259
  if peft_name:
260
  model = PeftModel.from_pretrained(
261
  model,
262
  peft_name,
263
  torch_dtype="auto",
264
  )
265
- logger.info(f"Modelo peft cargado desde {peft_name}")
266
  model.eval()
267
  return model, tokenizer
268
 
@@ -353,9 +341,6 @@ class ChatPDF:
353
  raw_text = [text.strip() for text in page_text.splitlines() if text.strip()]
354
  new_text = ''
355
  for text in raw_text:
356
- # Añadir un espacio antes de concatenar si new_text no está vacío
357
- if new_text:
358
- new_text += ' '
359
  new_text += text
360
  if text[-1] in ['.', '!', '?', '。', '!', '?', '…', ';', ';', ':', ':', '”', '’', ')', '】', '》', '」',
361
  '』', '〕', '〉', '》', '〗', '〞', '〟', '»', '"', "'", ')', ']', '}']:
@@ -422,9 +407,10 @@ class ChatPDF:
422
  sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
423
  # Get reference results from corpus
424
  hit_chunk_dict = dict()
425
- for query_id, id_score_dict in sim_contents.items():
426
- for corpus_id, s in id_score_dict.items():
427
- hit_chunk = self.sim_model.corpus[corpus_id]
 
428
  reference_results.append(hit_chunk)
429
  hit_chunk_dict[corpus_id] = hit_chunk
430
 
@@ -462,16 +448,15 @@ class ChatPDF:
462
  self.history = []
463
  if self.sim_model.corpus:
464
  reference_results = self.get_reference_results(query)
465
- if not reference_results:
466
- yield 'No se ha proporcionado suficiente información relevante', reference_results
467
- reference_results = self._add_source_numbers(reference_results)
468
- context_str = '\n'.join(reference_results)[:]
469
- print("gggggg: ", (context_len - len(PROMPT_TEMPLATE)))
470
  prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
471
- logger.debug(f"prompt: {prompt}")
472
  else:
473
  prompt = query
474
- logger.debug(prompt)
475
  self.history.append([prompt, ''])
476
  response = ""
477
  for new_text in self.stream_generate_answer(
@@ -496,20 +481,15 @@ class ChatPDF:
496
  self.history = []
497
  if self.sim_model.corpus:
498
  reference_results = self.get_reference_results(query)
499
-
500
- if not reference_results:
501
- return 'No se ha proporcionado suficiente información relevante', reference_results
502
- reference_results = self._add_source_numbers(reference_results)
503
- # context_str = '\n'.join(reference_results) # Usa todos los fragmentos
504
- context_st = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
505
- print("Context: ", (context_len - len(PROMPT_TEMPLATE)))
506
- print(".......................................................")
507
- context_str = '\n'.join(reference_results)[:]
508
- print("context_str: ", context_str)
509
  prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
510
- logger.debug(f"prompt: {prompt}")
511
  else:
512
  prompt = query
 
513
  self.history.append([prompt, ''])
514
  response = ""
515
  for new_text in self.stream_generate_answer(
@@ -522,29 +502,8 @@ class ChatPDF:
522
  self.history[-1][1] = response
523
  return response, reference_results
524
 
525
- def save_corpus_text(self):
526
- if not self.corpus_files:
527
- logger.warning("No hay archivos de corpus para guardar.")
528
- return
529
-
530
- corpus_text_file = os.path.join("corpus_embs/", "corpus_text.txt")
531
-
532
- with open(corpus_text_file, 'w', encoding='utf-8') as f:
533
- for chunk in self.sim_model.corpus.values():
534
- f.write(chunk + "\n\n") # Añade dos saltos de línea entre chunks para mejor legibilidad
535
-
536
- logger.info(f"Texto del corpus guardado en: {corpus_text_file}")
537
- return corpus_text_file
538
-
539
- def load_corpus_text(self, emb_dir: str):
540
- corpus_text_file = os.path.join("corpus_embs/", "corpus_text.txt")
541
- if os.path.exists(corpus_text_file):
542
- with open(corpus_text_file, 'r', encoding='utf-8') as f:
543
- corpus_text = f.read().split("\n\n") # Asumiendo que usamos dos saltos de línea como separador
544
- self.sim_model.corpus = {i: chunk.strip() for i, chunk in enumerate(corpus_text) if chunk.strip()}
545
- logger.info(f"Texto del corpus cargado desde: {corpus_text_file}")
546
- else:
547
- logger.warning(f"No se encontró el archivo de texto del corpus en: {corpus_text_file}")
548
 
549
  def save_corpus_emb(self):
550
  dir_name = self.get_file_hash(self.corpus_files)
@@ -556,20 +515,18 @@ class ChatPDF:
556
 
557
  def load_corpus_emb(self, emb_dir: str):
558
  if hasattr(self.sim_model, 'load_corpus_embeddings'):
559
- logger.debug(f"Cargando incrustaciones del corpus desde {emb_dir}")
560
  self.sim_model.load_corpus_embeddings(emb_dir)
561
- # Cargar el texto del corpus
562
- self.load_corpus_text(emb_dir)
563
 
564
 
565
  if __name__ == "__main__":
566
  parser = argparse.ArgumentParser()
567
- parser.add_argument("--sim_model_name", type=str, default="jaimevera1107/all-MiniLM-L6-v2-similarity-es")
568
  parser.add_argument("--gen_model_type", type=str, default="auto")
569
  parser.add_argument("--gen_model_name", type=str, default="LenguajeNaturalAI/leniachat-qwen2-1.5B-v0")
570
  parser.add_argument("--lora_model", type=str, default=None)
571
  parser.add_argument("--rerank_model_name", type=str, default="")
572
- parser.add_argument("--corpus_files", type=str, default="Acuerdo009.pdf")
573
  parser.add_argument("--device", type=str, default=None)
574
  parser.add_argument("--int4", action='store_true', help="use int4 quantization")
575
  parser.add_argument("--int8", action='store_true', help="use int8 quantization")
@@ -579,7 +536,7 @@ if __name__ == "__main__":
579
  args = parser.parse_args()
580
  print(args)
581
  sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
582
- m = ChatPDF(
583
  similarity_model=sim_model,
584
  generate_model_type=args.gen_model_type,
585
  generate_model_name_or_path=args.gen_model_name,
@@ -592,30 +549,4 @@ if __name__ == "__main__":
592
  corpus_files=args.corpus_files.split(','),
593
  num_expand_context_chunk=args.num_expand_context_chunk,
594
  rerank_model_name_or_path=args.rerank_model_name,
595
- )
596
-
597
- # Comprobar si existen incrustaciones guardadas
598
- dir_name = m.get_file_hash(args.corpus_files.split(','))
599
- save_dir = os.path.join(m.save_corpus_emb_dir, dir_name)
600
-
601
- if os.path.exists(save_dir):
602
- # Cargar las incrustaciones guardadas
603
- m.load_corpus_emb(save_dir)
604
- print(f"Incrustaciones del corpus cargadas desde: {save_dir}")
605
- else:
606
- # Procesar el corpus y guardar las incrustaciones
607
- m.add_corpus(args.corpus_files.split(','))
608
- save_dir = m.save_corpus_emb()
609
- # Guardar el texto del corpus
610
- m.save_corpus_text()
611
- print(f"Las incrustaciones del corpus se han guardado en: {save_dir}")
612
-
613
- while True:
614
- query = input("\nEnter a query: ")
615
- if query == "exit":
616
- break
617
- if query.strip() == "":
618
- continue
619
- r, refs = m.predict(query)
620
- print(r, refs)
621
- print("\nRespuesta: ", r)
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:XuMing([email protected])
4
+ @description:
5
+ """
6
  import argparse
7
  import hashlib
8
  import os
 
43
  "auto": (AutoModelForCausalLM, AutoTokenizer),
44
  }
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  PROMPT_TEMPLATE = """Basándose en la siguiente información conocida, responda a la pregunta del usuario de forma
47
  concisa y profesional. Si no puede obtener una respuesta, diga «No se puede responder a la pregunta basándose en la
48
  información conocida» o «No se proporciona suficiente información relevante», no está permitido añadir elementos
 
56
  """
57
 
58
 
59
+
60
  class SentenceSplitter:
61
  def __init__(self, chunk_size: int = 250, chunk_overlap: int = 50):
62
  self.chunk_size = chunk_size
 
122
 
123
 
124
 
125
+ class Rag:
126
  def __init__(
127
  self,
128
  similarity_model: SimilarityABC = None,
 
139
  rerank_model_name_or_path: str = None,
140
  enable_history: bool = False,
141
  num_expand_context_chunk: int = 2,
142
+ similarity_top_k: int = 10,
143
+ rerank_top_k: int = 3,
144
  ):
145
  """
146
  Init RAG model.
 
176
  if similarity_model is not None:
177
  self.sim_model = similarity_model
178
  else:
179
+ m1 = BertSimilarity(model_name_or_path="shibing624/text2vec-base-multilingual", device=self.device)
180
  m2 = BM25Similarity()
181
  default_sim_model = EnsembleSimilarity(similarities=[m1, m2], weights=[0.5, 0.5], c=2)
182
  self.sim_model = default_sim_model
 
193
  self.add_corpus(corpus_files)
194
  self.save_corpus_emb_dir = save_corpus_emb_dir
195
  if rerank_model_name_or_path is None:
196
+ rerank_model_name_or_path = "BAAI/bge-reranker-base"
197
  if rerank_model_name_or_path:
198
  self.rerank_tokenizer = AutoTokenizer.from_pretrained(rerank_model_name_or_path)
199
  self.rerank_model = AutoModelForSequenceClassification.from_pretrained(rerank_model_name_or_path)
 
243
  try:
244
  model.generation_config = GenerationConfig.from_pretrained(gen_model_name_or_path, trust_remote_code=True)
245
  except Exception as e:
246
+ logger.warning(f"Failed to load generation config from {gen_model_name_or_path}, {e}")
247
  if peft_name:
248
  model = PeftModel.from_pretrained(
249
  model,
250
  peft_name,
251
  torch_dtype="auto",
252
  )
253
+ logger.info(f"Loaded peft model from {peft_name}")
254
  model.eval()
255
  return model, tokenizer
256
 
 
341
  raw_text = [text.strip() for text in page_text.splitlines() if text.strip()]
342
  new_text = ''
343
  for text in raw_text:
 
 
 
344
  new_text += text
345
  if text[-1] in ['.', '!', '?', '。', '!', '?', '…', ';', ';', ':', ':', '”', '’', ')', '】', '》', '」',
346
  '』', '〕', '〉', '》', '〗', '〞', '〟', '»', '"', "'", ')', ']', '}']:
 
407
  sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
408
  # Get reference results from corpus
409
  hit_chunk_dict = dict()
410
+ for c in sim_contents:
411
+ for id_score_dict in c:
412
+ corpus_id = id_score_dict['corpus_id']
413
+ hit_chunk = id_score_dict["corpus_doc"]
414
  reference_results.append(hit_chunk)
415
  hit_chunk_dict[corpus_id] = hit_chunk
416
 
 
448
  self.history = []
449
  if self.sim_model.corpus:
450
  reference_results = self.get_reference_results(query)
451
+ if reference_results:
452
+ reference_results = self._add_source_numbers(reference_results)
453
+ context_str = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
454
+ else:
455
+ context_str = ''
456
  prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
 
457
  else:
458
  prompt = query
459
+ logger.debug(f"prompt: {prompt}")
460
  self.history.append([prompt, ''])
461
  response = ""
462
  for new_text in self.stream_generate_answer(
 
481
  self.history = []
482
  if self.sim_model.corpus:
483
  reference_results = self.get_reference_results(query)
484
+ if reference_results:
485
+ reference_results = self._add_source_numbers(reference_results)
486
+ context_str = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
487
+ else:
488
+ context_str = ''
 
 
 
 
 
489
  prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
 
490
  else:
491
  prompt = query
492
+ logger.debug(f"prompt: {prompt}")
493
  self.history.append([prompt, ''])
494
  response = ""
495
  for new_text in self.stream_generate_answer(
 
502
  self.history[-1][1] = response
503
  return response, reference_results
504
 
505
+ def query(self, query: str, **kwargs):
506
+ return self.predict(query, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
 
508
  def save_corpus_emb(self):
509
  dir_name = self.get_file_hash(self.corpus_files)
 
515
 
516
  def load_corpus_emb(self, emb_dir: str):
517
  if hasattr(self.sim_model, 'load_corpus_embeddings'):
518
+ logger.debug(f"Loading corpus embeddings from {emb_dir}")
519
  self.sim_model.load_corpus_embeddings(emb_dir)
 
 
520
 
521
 
522
  if __name__ == "__main__":
523
  parser = argparse.ArgumentParser()
524
+ parser.add_argument("--sim_model_name", type=str, default="shibing624/text2vec-base-multilingual")
525
  parser.add_argument("--gen_model_type", type=str, default="auto")
526
  parser.add_argument("--gen_model_name", type=str, default="LenguajeNaturalAI/leniachat-qwen2-1.5B-v0")
527
  parser.add_argument("--lora_model", type=str, default=None)
528
  parser.add_argument("--rerank_model_name", type=str, default="")
529
+ parser.add_argument("--corpus_files", type=str, default="data/sample.pdf")
530
  parser.add_argument("--device", type=str, default=None)
531
  parser.add_argument("--int4", action='store_true', help="use int4 quantization")
532
  parser.add_argument("--int8", action='store_true', help="use int8 quantization")
 
536
  args = parser.parse_args()
537
  print(args)
538
  sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
539
+ m = Rag(
540
  similarity_model=sim_model,
541
  generate_model_type=args.gen_model_type,
542
  generate_model_name_or_path=args.gen_model_name,
 
549
  corpus_files=args.corpus_files.split(','),
550
  num_expand_context_chunk=args.num_expand_context_chunk,
551
  rerank_model_name_or_path=args.rerank_model_name,
552
+ )