ZoniaChatbot commited on
Commit
82a25ec
·
verified ·
1 Parent(s): 2507581

Update chatpdf.py

Browse files
Files changed (1) hide show
  1. chatpdf.py +55 -101
chatpdf.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  import argparse
2
  import hashlib
3
  import os
@@ -13,7 +18,6 @@ from similarities import (
13
  EnsembleSimilarity,
14
  BertSimilarity,
15
  BM25Similarity,
16
- TfidfSimilarity
17
  )
18
  from similarities.similarity import SimilarityABC
19
  from transformers import (
@@ -50,6 +54,7 @@ Pregunta:
50
  {query_str}
51
  """
52
 
 
53
  class SentenceSplitter:
54
  def __init__(self, chunk_size: int = 250, chunk_overlap: int = 50):
55
  self.chunk_size = chunk_size
@@ -62,7 +67,7 @@ class SentenceSplitter:
62
  return self._split_english_text(text)
63
 
64
  def _split_chinese_text(self, text: str) -> List[str]:
65
- sentence_endings = {'\n', '。', '!', '?', ';', '…'} # puntuación al final de una frase
66
  chunks, current_chunk = [], ''
67
  for word in jieba.cut(text):
68
  if len(current_chunk) + len(word) > self.chunk_size:
@@ -80,16 +85,22 @@ class SentenceSplitter:
80
  return chunks
81
 
82
  def _split_english_text(self, text: str) -> List[str]:
83
- # Dividir el texto inglés por frases utilizando expresiones regulares
84
  sentences = re.split(r'(?<=[.!?])\s+', text.replace('\n', ' '))
85
- chunks, current_chunk = [], ''
 
86
  for sentence in sentences:
87
- if len(current_chunk) + len(sentence) <= self.chunk_size or not current_chunk:
88
  current_chunk += (' ' if current_chunk else '') + sentence
89
  else:
90
- chunks.append(current_chunk)
91
- current_chunk = sentence
92
- if current_chunk: # Añade el último trozo
 
 
 
 
 
93
  chunks.append(current_chunk)
94
 
95
  if self.chunk_overlap > 0 and len(chunks) > 1:
@@ -98,7 +109,7 @@ class SentenceSplitter:
98
  return chunks
99
 
100
  def _is_has_chinese(self, text: str) -> bool:
101
- # comprobar si contiene caracteres chinos
102
  if any("\u4e00" <= ch <= "\u9fff" for ch in text):
103
  return True
104
  else:
@@ -114,7 +125,7 @@ class SentenceSplitter:
114
  return overlapped_chunks
115
 
116
 
117
- class ChatPDF:
118
  def __init__(
119
  self,
120
  similarity_model: SimilarityABC = None,
@@ -122,7 +133,7 @@ class ChatPDF:
122
  generate_model_name_or_path: str = "Qwen/Qwen2-0.5B-Instruct",
123
  lora_model_name_or_path: str = None,
124
  corpus_files: Union[str, List[str]] = None,
125
- save_corpus_emb_dir: str = "corpus_embs/",
126
  device: str = None,
127
  int8: bool = False,
128
  int4: bool = False,
@@ -131,8 +142,8 @@ class ChatPDF:
131
  rerank_model_name_or_path: str = None,
132
  enable_history: bool = False,
133
  num_expand_context_chunk: int = 2,
134
- similarity_top_k: int = 15,
135
- rerank_top_k: int = 5,
136
  ):
137
  """
138
  Init RAG model.
@@ -171,8 +182,7 @@ class ChatPDF:
171
  m1 = BertSimilarity(model_name_or_path="sentence-transformers/all-mpnet-base-v2", device=self.device)
172
  m2 = BM25Similarity()
173
  m3 = TfidfSimilarity()
174
- default_sim_model = EnsembleSimilarity(similarities=[m1, m2, m3], weights=[0.5, 0.5, 0.5],
175
- c=2) # Ajuste los pesos según los resultados
176
  self.sim_model = default_sim_model
177
  self.gen_model, self.tokenizer = self._init_gen_model(
178
  generate_model_type,
@@ -237,14 +247,14 @@ class ChatPDF:
237
  try:
238
  model.generation_config = GenerationConfig.from_pretrained(gen_model_name_or_path, trust_remote_code=True)
239
  except Exception as e:
240
- logger.warning(f"No se pudo cargar la configuración de generación desde {gen_model_name_or_path}, {e}")
241
  if peft_name:
242
  model = PeftModel.from_pretrained(
243
  model,
244
  peft_name,
245
  torch_dtype="auto",
246
  )
247
- logger.info(f"Modelo peft cargado desde {peft_name}")
248
  model.eval()
249
  return model, tokenizer
250
 
@@ -335,7 +345,6 @@ class ChatPDF:
335
  raw_text = [text.strip() for text in page_text.splitlines() if text.strip()]
336
  new_text = ''
337
  for text in raw_text:
338
- # Añadir un espacio antes de concatenar si new_text no está vacío
339
  if new_text:
340
  new_text += ' '
341
  new_text += text
@@ -408,12 +417,9 @@ class ChatPDF:
408
  # Si se encuentra una coincidencia exacta, devolverla como contexto
409
  return [exact_match]
410
 
411
- # Si no se encuentra una coincidencia exacta, continuar con la búsqueda general
412
- sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
413
-
414
- # Procesar los resultados de similitud
415
  reference_results = []
416
-
 
417
  hit_chunk_dict = dict()
418
  threshold_score = 0.5 # Establece un umbral para filtrar fragmentos irrelevantes
419
 
@@ -423,6 +429,7 @@ class ChatPDF:
423
  hit_chunk = self.sim_model.corpus[corpus_id]
424
  reference_results.append(hit_chunk)
425
  hit_chunk_dict[corpus_id] = hit_chunk
 
426
  if reference_results:
427
  if self.rerank_model is not None:
428
  # Rerank reference results
@@ -447,9 +454,9 @@ class ChatPDF:
447
  def predict_stream(
448
  self,
449
  query: str,
450
- max_length: int = 256,
451
- context_len: int = 1024,
452
- temperature: float = 0.5,
453
  ):
454
  """Generate predictions stream."""
455
  stop_str = self.tokenizer.eos_token if self.tokenizer.eos_token else "</s>"
@@ -457,16 +464,15 @@ class ChatPDF:
457
  self.history = []
458
  if self.sim_model.corpus:
459
  reference_results = self.get_reference_results(query)
460
- if not reference_results:
461
- yield 'No se ha proporcionado suficiente información relevante', reference_results
462
- reference_results = self._add_source_numbers(reference_results)
463
- context_str = '\n'.join(reference_results)[:]
464
- print("gggggg: ", (context_len - len(PROMPT_TEMPLATE)))
465
  prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
466
- logger.debug(f"prompt: {prompt}")
467
  else:
468
  prompt = query
469
- logger.debug(prompt)
470
  self.history.append([prompt, ''])
471
  response = ""
472
  for new_text in self.stream_generate_answer(
@@ -481,9 +487,9 @@ class ChatPDF:
481
  def predict(
482
  self,
483
  query: str,
484
- max_length: int = 256,
485
- context_len: int = 1024,
486
- temperature: float = 0.5,
487
  ):
488
  """Query from corpus."""
489
  reference_results = []
@@ -491,20 +497,15 @@ class ChatPDF:
491
  self.history = []
492
  if self.sim_model.corpus:
493
  reference_results = self.get_reference_results(query)
494
-
495
- if not reference_results:
496
- return 'No se ha proporcionado suficiente información relevante', reference_results
497
- reference_results = self._add_source_numbers(reference_results)
498
- # context_str = '\n'.join(reference_results) # Usa todos los fragmentos
499
- context_st = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
500
- #print("Context: ", (context_len - len(PROMPT_TEMPLATE)))
501
- print(".......................................................")
502
- context_str = '\n'.join(reference_results)[:]
503
- #print("context_str: ", context_str)
504
  prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
505
- logger.debug(f"prompt: {prompt}")
506
  else:
507
  prompt = query
 
508
  self.history.append([prompt, ''])
509
  response = ""
510
  for new_text in self.stream_generate_answer(
@@ -517,29 +518,8 @@ class ChatPDF:
517
  self.history[-1][1] = response
518
  return response, reference_results
519
 
520
- def save_corpus_text(self):
521
- if not self.corpus_files:
522
- logger.warning("No hay archivos de corpus para guardar.")
523
- return
524
-
525
- corpus_text_file = os.path.join("corpus_embs/", "corpus_text.txt")
526
-
527
- with open(corpus_text_file, 'w', encoding='utf-8') as f:
528
- for chunk in self.sim_model.corpus.values():
529
- f.write(chunk + "\n\n") # Añade dos saltos de línea entre chunks para mejor legibilidad
530
-
531
- logger.info(f"Texto del corpus guardado en: {corpus_text_file}")
532
- return corpus_text_file
533
-
534
- def load_corpus_text(self, emb_dir: str):
535
- corpus_text_file = os.path.join("corpus_embs/", "corpus_text.txt")
536
- if os.path.exists(corpus_text_file):
537
- with open(corpus_text_file, 'r', encoding='utf-8') as f:
538
- corpus_text = f.read().split("\n\n") # Asumiendo que usamos dos saltos de línea como separador
539
- self.sim_model.corpus = {i: chunk.strip() for i, chunk in enumerate(corpus_text) if chunk.strip()}
540
- logger.info(f"Texto del corpus cargado desde: {corpus_text_file}")
541
- else:
542
- logger.warning(f"No se encontró el archivo de texto del corpus en: {corpus_text_file}")
543
 
544
  def save_corpus_emb(self):
545
  dir_name = self.get_file_hash(self.corpus_files)
@@ -551,10 +531,8 @@ class ChatPDF:
551
 
552
  def load_corpus_emb(self, emb_dir: str):
553
  if hasattr(self.sim_model, 'load_corpus_embeddings'):
554
- logger.debug(f"Cargando incrustaciones del corpus desde {emb_dir}")
555
  self.sim_model.load_corpus_embeddings(emb_dir)
556
- # Cargar el texto del corpus
557
- self.load_corpus_text(emb_dir)
558
 
559
 
560
  if __name__ == "__main__":
@@ -564,7 +542,7 @@ if __name__ == "__main__":
564
  parser.add_argument("--gen_model_name", type=str, default="Qwen/Qwen2-0.5B-Instruct")
565
  parser.add_argument("--lora_model", type=str, default=None)
566
  parser.add_argument("--rerank_model_name", type=str, default="")
567
- parser.add_argument("--corpus_files", type=str, default="Acuerdo009.pdf")
568
  parser.add_argument("--device", type=str, default=None)
569
  parser.add_argument("--int4", action='store_true', help="use int4 quantization")
570
  parser.add_argument("--int8", action='store_true', help="use int8 quantization")
@@ -574,7 +552,7 @@ if __name__ == "__main__":
574
  args = parser.parse_args()
575
  print(args)
576
  sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
577
- m = ChatPDF(
578
  similarity_model=sim_model,
579
  generate_model_type=args.gen_model_type,
580
  generate_model_name_or_path=args.gen_model_name,
@@ -588,29 +566,5 @@ if __name__ == "__main__":
588
  num_expand_context_chunk=args.num_expand_context_chunk,
589
  rerank_model_name_or_path=args.rerank_model_name,
590
  )
591
-
592
- # Comprobar si existen incrustaciones guardadas
593
- dir_name = m.get_file_hash(args.corpus_files.split(','))
594
- save_dir = os.path.join(m.save_corpus_emb_dir, dir_name)
595
-
596
- if os.path.exists(save_dir):
597
- # Cargar las incrustaciones guardadas
598
- m.load_corpus_emb(save_dir)
599
- print(f"Incrustaciones del corpus cargadas desde: {save_dir}")
600
- else:
601
- # Procesar el corpus y guardar las incrustaciones
602
- m.add_corpus(args.corpus_files.split(','))
603
- save_dir = m.save_corpus_emb()
604
- # Guardar el texto del corpus
605
- m.save_corpus_text()
606
- print(f"Las incrustaciones del corpus se han guardado en: {save_dir}")
607
-
608
- while True:
609
- query = input("\nEnter a query: ")
610
- if query == "exit":
611
- break
612
- if query.strip() == "":
613
- continue
614
- r, refs = m.predict(query)
615
- print(r, refs)
616
- print("\nRespuesta: ", r)
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:XuMing([email protected])
4
+ @description:
5
+ """
6
  import argparse
7
  import hashlib
8
  import os
 
18
  EnsembleSimilarity,
19
  BertSimilarity,
20
  BM25Similarity,
 
21
  )
22
  from similarities.similarity import SimilarityABC
23
  from transformers import (
 
54
  {query_str}
55
  """
56
 
57
+
58
  class SentenceSplitter:
59
  def __init__(self, chunk_size: int = 250, chunk_overlap: int = 50):
60
  self.chunk_size = chunk_size
 
67
  return self._split_english_text(text)
68
 
69
  def _split_chinese_text(self, text: str) -> List[str]:
70
+ sentence_endings = {'\n', '。', '!', '?', ';', '…'} # 句末标点符号
71
  chunks, current_chunk = [], ''
72
  for word in jieba.cut(text):
73
  if len(current_chunk) + len(word) > self.chunk_size:
 
85
  return chunks
86
 
87
  def _split_english_text(self, text: str) -> List[str]:
88
+ # 使用正则表达式按句子分割英文文本
89
  sentences = re.split(r'(?<=[.!?])\s+', text.replace('\n', ' '))
90
+ chunks = []
91
+ current_chunk = ''
92
  for sentence in sentences:
93
+ if len(current_chunk) + len(sentence) <= self.chunk_size:
94
  current_chunk += (' ' if current_chunk else '') + sentence
95
  else:
96
+ if len(sentence) > self.chunk_size:
97
+ for i in range(0, len(sentence), self.chunk_size):
98
+ chunks.append(sentence[i:i + self.chunk_size])
99
+ current_chunk = ''
100
+ else:
101
+ chunks.append(current_chunk)
102
+ current_chunk = sentence
103
+ if current_chunk: # Add the last chunk
104
  chunks.append(current_chunk)
105
 
106
  if self.chunk_overlap > 0 and len(chunks) > 1:
 
109
  return chunks
110
 
111
  def _is_has_chinese(self, text: str) -> bool:
112
+ # check if contains chinese characters
113
  if any("\u4e00" <= ch <= "\u9fff" for ch in text):
114
  return True
115
  else:
 
125
  return overlapped_chunks
126
 
127
 
128
+ class Rag:
129
  def __init__(
130
  self,
131
  similarity_model: SimilarityABC = None,
 
133
  generate_model_name_or_path: str = "Qwen/Qwen2-0.5B-Instruct",
134
  lora_model_name_or_path: str = None,
135
  corpus_files: Union[str, List[str]] = None,
136
+ save_corpus_emb_dir: str = "./corpus_embs/",
137
  device: str = None,
138
  int8: bool = False,
139
  int4: bool = False,
 
142
  rerank_model_name_or_path: str = None,
143
  enable_history: bool = False,
144
  num_expand_context_chunk: int = 2,
145
+ similarity_top_k: int = 10,
146
+ rerank_top_k: int = 3,
147
  ):
148
  """
149
  Init RAG model.
 
182
  m1 = BertSimilarity(model_name_or_path="sentence-transformers/all-mpnet-base-v2", device=self.device)
183
  m2 = BM25Similarity()
184
  m3 = TfidfSimilarity()
185
+ default_sim_model = EnsembleSimilarity(similarities=[m1, m2, m3], weights=[0.5, 0.5, 0.5], c=2) # Ajuste los pesos según los resultados
 
186
  self.sim_model = default_sim_model
187
  self.gen_model, self.tokenizer = self._init_gen_model(
188
  generate_model_type,
 
247
  try:
248
  model.generation_config = GenerationConfig.from_pretrained(gen_model_name_or_path, trust_remote_code=True)
249
  except Exception as e:
250
+ logger.warning(f"Failed to load generation config from {gen_model_name_or_path}, {e}")
251
  if peft_name:
252
  model = PeftModel.from_pretrained(
253
  model,
254
  peft_name,
255
  torch_dtype="auto",
256
  )
257
+ logger.info(f"Loaded peft model from {peft_name}")
258
  model.eval()
259
  return model, tokenizer
260
 
 
345
  raw_text = [text.strip() for text in page_text.splitlines() if text.strip()]
346
  new_text = ''
347
  for text in raw_text:
 
348
  if new_text:
349
  new_text += ' '
350
  new_text += text
 
417
  # Si se encuentra una coincidencia exacta, devolverla como contexto
418
  return [exact_match]
419
 
 
 
 
 
420
  reference_results = []
421
+ sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
422
+ # Get reference results from corpus
423
  hit_chunk_dict = dict()
424
  threshold_score = 0.5 # Establece un umbral para filtrar fragmentos irrelevantes
425
 
 
429
  hit_chunk = self.sim_model.corpus[corpus_id]
430
  reference_results.append(hit_chunk)
431
  hit_chunk_dict[corpus_id] = hit_chunk
432
+
433
  if reference_results:
434
  if self.rerank_model is not None:
435
  # Rerank reference results
 
454
  def predict_stream(
455
  self,
456
  query: str,
457
+ max_length: int = 512,
458
+ context_len: int = 2048,
459
+ temperature: float = 0.7,
460
  ):
461
  """Generate predictions stream."""
462
  stop_str = self.tokenizer.eos_token if self.tokenizer.eos_token else "</s>"
 
464
  self.history = []
465
  if self.sim_model.corpus:
466
  reference_results = self.get_reference_results(query)
467
+ if reference_results:
468
+ reference_results = self._add_source_numbers(reference_results)
469
+ context_str = '\n'.join(reference_results)[:]
470
+ else:
471
+ context_str = ''
472
  prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
 
473
  else:
474
  prompt = query
475
+ logger.debug(f"prompt: {prompt}")
476
  self.history.append([prompt, ''])
477
  response = ""
478
  for new_text in self.stream_generate_answer(
 
487
  def predict(
488
  self,
489
  query: str,
490
+ max_length: int = 512,
491
+ context_len: int = 2048,
492
+ temperature: float = 0.7,
493
  ):
494
  """Query from corpus."""
495
  reference_results = []
 
497
  self.history = []
498
  if self.sim_model.corpus:
499
  reference_results = self.get_reference_results(query)
500
+ if reference_results:
501
+ reference_results = self._add_source_numbers(reference_results)
502
+ context_str = '\n'.join(reference_results)[:]
503
+ else:
504
+ context_str = ''
 
 
 
 
 
505
  prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
 
506
  else:
507
  prompt = query
508
+ logger.debug(f"prompt: {prompt}")
509
  self.history.append([prompt, ''])
510
  response = ""
511
  for new_text in self.stream_generate_answer(
 
518
  self.history[-1][1] = response
519
  return response, reference_results
520
 
521
+ def query(self, query: str, **kwargs):
522
+ return self.predict(query, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
  def save_corpus_emb(self):
525
  dir_name = self.get_file_hash(self.corpus_files)
 
531
 
532
  def load_corpus_emb(self, emb_dir: str):
533
  if hasattr(self.sim_model, 'load_corpus_embeddings'):
534
+ logger.debug(f"Loading corpus embeddings from {emb_dir}")
535
  self.sim_model.load_corpus_embeddings(emb_dir)
 
 
536
 
537
 
538
  if __name__ == "__main__":
 
542
  parser.add_argument("--gen_model_name", type=str, default="Qwen/Qwen2-0.5B-Instruct")
543
  parser.add_argument("--lora_model", type=str, default=None)
544
  parser.add_argument("--rerank_model_name", type=str, default="")
545
+ parser.add_argument("--corpus_files", type=str, default="data/sample.pdf")
546
  parser.add_argument("--device", type=str, default=None)
547
  parser.add_argument("--int4", action='store_true', help="use int4 quantization")
548
  parser.add_argument("--int8", action='store_true', help="use int8 quantization")
 
552
  args = parser.parse_args()
553
  print(args)
554
  sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
555
+ m = Rag(
556
  similarity_model=sim_model,
557
  generate_model_type=args.gen_model_type,
558
  generate_model_name_or_path=args.gen_model_name,
 
566
  num_expand_context_chunk=args.num_expand_context_chunk,
567
  rerank_model_name_or_path=args.rerank_model_name,
568
  )
569
+ r, refs = m.predict('自然语言中的非平行迁移是指什么?')
570
+ print(r, refs)