ZoniaChatbot commited on
Commit
483ce33
·
verified ·
1 Parent(s): 2350b40

Update chatpdf.py

Browse files
Files changed (1) hide show
  1. chatpdf.py +83 -59
chatpdf.py CHANGED
@@ -16,8 +16,13 @@ from similarities import (
16
  )
17
  from similarities.similarity import SimilarityABC
18
  from transformers import (
 
19
  AutoModelForCausalLM,
20
  AutoTokenizer,
 
 
 
 
21
  TextIteratorStreamer,
22
  GenerationConfig,
23
  AutoModelForSequenceClassification,
@@ -26,9 +31,22 @@ from transformers import (
26
  jieba.setLogLevel("ERROR")
27
 
28
  MODEL_CLASSES = {
 
 
 
 
29
  "auto": (AutoModelForCausalLM, AutoTokenizer),
30
  }
31
 
 
 
 
 
 
 
 
 
 
32
  PROMPT_TEMPLATE1 = """Utiliza la siguiente información para responder a la pregunta del usuario.
33
  Si no sabes la respuesta, di simplemente que no la sabes, no intentes inventarte una respuesta.
34
 
@@ -41,7 +59,7 @@ Respuesta útil:
41
  PROMPT_TEMPLATE = """Basándose en la siguiente información conocida, responda a la pregunta del usuario de forma
42
  concisa y profesional. Si no puede obtener una respuesta, diga «No se puede responder a la pregunta basándose en la
43
  información conocida» o «No se proporciona suficiente información relevante», no está permitido añadir elementos
44
- inventados en la respuesta.
45
 
46
  Contenido conocido:
47
  {context_str}
@@ -81,7 +99,7 @@ class SentenceSplitter:
81
  return chunks
82
 
83
  def _split_english_text(self, text: str) -> List[str]:
84
- # División de texto inglés por frases mediante expresiones regulares
85
  sentences = re.split(r'(?<=[.!?])\s+', text.replace('\n', ' '))
86
  chunks, current_chunk = [], ''
87
  for sentence in sentences:
@@ -90,7 +108,7 @@ class SentenceSplitter:
90
  else:
91
  chunks.append(current_chunk)
92
  current_chunk = sentence
93
- if current_chunk: # Add the last chunk
94
  chunks.append(current_chunk)
95
 
96
  if self.chunk_overlap > 0 and len(chunks) > 1:
@@ -99,14 +117,14 @@ class SentenceSplitter:
99
  return chunks
100
 
101
  def _is_has_chinese(self, text: str) -> bool:
102
- # check if contains chinese characters
103
  if any("\u4e00" <= ch <= "\u9fff" for ch in text):
104
  return True
105
  else:
106
  return False
107
 
108
  def _handle_overlap(self, chunks: List[str]) -> List[str]:
109
- # Tratamiento de los solapamientos entre bloques
110
  overlapped_chunks = []
111
  for i in range(len(chunks) - 1):
112
  chunk = chunks[i] + ' ' + chunks[i + 1][:self.chunk_overlap]
@@ -115,12 +133,13 @@ class SentenceSplitter:
115
  return overlapped_chunks
116
 
117
 
 
118
  class ChatPDF:
119
  def __init__(
120
  self,
121
  similarity_model: SimilarityABC = None,
122
  generate_model_type: str = "auto",
123
- generate_model_name_or_path: str = "LenguajeNaturalAI/leniachat-qwen2-1.5B-v0",
124
  lora_model_name_or_path: str = None,
125
  corpus_files: Union[str, List[str]] = None,
126
  save_corpus_emb_dir: str = "corpus_embs/",
@@ -132,10 +151,28 @@ class ChatPDF:
132
  rerank_model_name_or_path: str = None,
133
  enable_history: bool = False,
134
  num_expand_context_chunk: int = 2,
135
- similarity_top_k: int = 10,
136
- rerank_top_k: int = 3
137
  ):
138
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  if torch.cuda.is_available():
140
  default_device = torch.device(0)
141
  elif torch.backends.mps.is_available():
@@ -151,7 +188,7 @@ class ChatPDF:
151
  if similarity_model is not None:
152
  self.sim_model = similarity_model
153
  else:
154
- m1 = BertSimilarity(model_name_or_path="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", device=self.device)
155
  m2 = BM25Similarity()
156
  default_sim_model = EnsembleSimilarity(similarities=[m1, m2], weights=[0.5, 0.5], c=2)
157
  self.sim_model = default_sim_model
@@ -168,7 +205,7 @@ class ChatPDF:
168
  self.add_corpus(corpus_files)
169
  self.save_corpus_emb_dir = save_corpus_emb_dir
170
  if rerank_model_name_or_path is None:
171
- rerank_model_name_or_path = "maidalun1020/bce-reranker-base_v1"
172
  if rerank_model_name_or_path:
173
  self.rerank_tokenizer = AutoTokenizer.from_pretrained(rerank_model_name_or_path)
174
  self.rerank_model = AutoModelForSequenceClassification.from_pretrained(rerank_model_name_or_path)
@@ -252,7 +289,7 @@ class ChatPDF:
252
  repetition_penalty=1.0,
253
  context_len=2048
254
  ):
255
- streamer = TextIteratorStreamer(self.tokenizer, timeout=520.0, skip_prompt=True, skip_special_tokens=True)
256
  input_ids = self._get_chat_input()
257
  max_src_len = context_len - max_new_tokens - 8
258
  input_ids = input_ids[-max_src_len:]
@@ -383,29 +420,14 @@ class ChatPDF:
383
  """
384
  reference_results = []
385
  sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
386
-
387
- # Verificar si sim_contents es una lista o un diccionario
388
- if isinstance(sim_contents, list):
389
- for item in sim_contents:
390
- # Ajustar según la estructura real de item
391
- corpus_id = item[0] if isinstance(item, (list, tuple)) else item # Asegurarse de que corpus_id sea el valor correcto
392
- if isinstance(corpus_id, dict):
393
- # Extraer el valor necesario si corpus_id es un diccionario
394
- corpus_id = next(iter(corpus_id.keys())) # Tomar la primera clave como ejemplo
395
- if corpus_id in self.sim_model.corpus:
396
- hit_chunk = self.sim_model.corpus[corpus_id]
397
- reference_results.append(hit_chunk)
398
-
399
- elif isinstance(sim_contents, dict):
400
- for query_id, id_score_dict in sim_contents.items():
401
- for corpus_id, s in id_score_dict.items():
402
- if corpus_id in self.sim_model.corpus:
403
- hit_chunk = self.sim_model.corpus[corpus_id]
404
- reference_results.append(hit_chunk)
405
- else:
406
- logger.error(f"Unexpected type for sim_contents: {type(sim_contents)}")
407
 
408
-
409
  if reference_results:
410
  if self.rerank_model is not None:
411
  # Rerank reference results
@@ -444,7 +466,7 @@ class ChatPDF:
444
  yield 'No se ha proporcionado suficiente información relevante', reference_results
445
  reference_results = self._add_source_numbers(reference_results)
446
  context_str = '\n'.join(reference_results)[:]
447
- #print("context_str: " , (context_len - len(PROMPT_TEMPLATE)))
448
  prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
449
  logger.debug(f"prompt: {prompt}")
450
  else:
@@ -478,12 +500,12 @@ class ChatPDF:
478
  if not reference_results:
479
  return 'No se ha proporcionado suficiente información relevante', reference_results
480
  reference_results = self._add_source_numbers(reference_results)
481
- #context_str = '\n'.join(reference_results) # Usa todos los fragmentos
482
  context_st = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
483
- #print("Context: ", (context_len - len(PROMPT_TEMPLATE)))
484
  print(".......................................................")
485
  context_str = '\n'.join(reference_results)[:]
486
- #print("context_str: ", context_str)
487
  prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
488
  logger.debug(f"prompt: {prompt}")
489
  else:
@@ -500,19 +522,6 @@ class ChatPDF:
500
  self.history[-1][1] = response
501
  return response, reference_results
502
 
503
- def save_corpus_emb(self):
504
- dir_name = self.get_file_hash(self.corpus_files)
505
- save_dir = os.path.join(self.save_corpus_emb_dir, dir_name)
506
- if hasattr(self.sim_model, 'save_corpus_embeddings'):
507
- self.sim_model.save_corpus_embeddings(save_dir)
508
- logger.debug(f"Saving corpus embeddings to {save_dir}")
509
- return save_dir
510
-
511
- def load_corpus_emb(self, emb_dir: str):
512
- if hasattr(self.sim_model, 'load_corpus_embeddings'):
513
- logger.debug(f"Loading corpus embeddings from {emb_dir}")
514
- self.sim_model.load_corpus_embeddings(emb_dir)
515
-
516
  def save_corpus_text(self):
517
  if not self.corpus_files:
518
  logger.warning("No hay archivos de corpus para guardar.")
@@ -537,20 +546,36 @@ class ChatPDF:
537
  else:
538
  logger.warning(f"No se encontró el archivo de texto del corpus en: {corpus_text_file}")
539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
  if __name__ == "__main__":
541
  parser = argparse.ArgumentParser()
542
- parser.add_argument("--sim_model_name", type=str, default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
543
  parser.add_argument("--gen_model_type", type=str, default="auto")
544
- parser.add_argument("--gen_model_name", type=str, default="LenguajeNaturalAI/leniachat-qwen2-1.5B-v0")
545
  parser.add_argument("--lora_model", type=str, default=None)
546
- parser.add_argument("--rerank_model_name", type=str, default="maidalun1020/bce-reranker-base_v1")
547
- parser.add_argument("--corpus_files", type=str, default="docs/corpus.txt")
548
  parser.add_argument("--device", type=str, default=None)
549
  parser.add_argument("--int4", action='store_true', help="use int4 quantization")
550
  parser.add_argument("--int8", action='store_true', help="use int8 quantization")
551
  parser.add_argument("--chunk_size", type=int, default=220)
552
- parser.add_argument("--chunk_overlap", type=int, default=50)
553
- parser.add_argument("--num_expand_context_chunk", type=int, default=2)
554
  args = parser.parse_args()
555
  print(args)
556
  sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
@@ -568,7 +593,6 @@ if __name__ == "__main__":
568
  num_expand_context_chunk=args.num_expand_context_chunk,
569
  rerank_model_name_or_path=args.rerank_model_name,
570
  )
571
- logger.info(f"chatpdf model: {m}")
572
 
573
  # Comprobar si existen incrustaciones guardadas
574
  dir_name = m.get_file_hash(args.corpus_files.split(','))
@@ -594,4 +618,4 @@ if __name__ == "__main__":
594
  continue
595
  r, refs = m.predict(query)
596
  print(r, refs)
597
- print("\nRespuesta: ", r)
 
16
  )
17
  from similarities.similarity import SimilarityABC
18
  from transformers import (
19
+ AutoModel,
20
  AutoModelForCausalLM,
21
  AutoTokenizer,
22
+ BloomForCausalLM,
23
+ BloomTokenizerFast,
24
+ LlamaTokenizer,
25
+ LlamaForCausalLM,
26
  TextIteratorStreamer,
27
  GenerationConfig,
28
  AutoModelForSequenceClassification,
 
31
  jieba.setLogLevel("ERROR")
32
 
33
  MODEL_CLASSES = {
34
+ "bloom": (BloomForCausalLM, BloomTokenizerFast),
35
+ "chatglm": (AutoModel, AutoTokenizer),
36
+ "llama": (LlamaForCausalLM, LlamaTokenizer),
37
+ "baichuan": (AutoModelForCausalLM, AutoTokenizer),
38
  "auto": (AutoModelForCausalLM, AutoTokenizer),
39
  }
40
 
41
+ PROMPT_TEMPLATE1 = """基于以下已知信息,简洁和专业的来回答用户的问题。
42
+ 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
43
+
44
+ 已知内容:
45
+ {context_str}
46
+
47
+ 问题:
48
+ {query_str}
49
+ """
50
  PROMPT_TEMPLATE1 = """Utiliza la siguiente información para responder a la pregunta del usuario.
51
  Si no sabes la respuesta, di simplemente que no la sabes, no intentes inventarte una respuesta.
52
 
 
59
  PROMPT_TEMPLATE = """Basándose en la siguiente información conocida, responda a la pregunta del usuario de forma
60
  concisa y profesional. Si no puede obtener una respuesta, diga «No se puede responder a la pregunta basándose en la
61
  información conocida» o «No se proporciona suficiente información relevante», no está permitido añadir elementos
62
+ inventados en la respuesta, y ésta debe estar en Español.
63
 
64
  Contenido conocido:
65
  {context_str}
 
99
  return chunks
100
 
101
  def _split_english_text(self, text: str) -> List[str]:
102
+ # Dividir el texto inglés por frases utilizando expresiones regulares
103
  sentences = re.split(r'(?<=[.!?])\s+', text.replace('\n', ' '))
104
  chunks, current_chunk = [], ''
105
  for sentence in sentences:
 
108
  else:
109
  chunks.append(current_chunk)
110
  current_chunk = sentence
111
+ if current_chunk: # Añade el último trozo
112
  chunks.append(current_chunk)
113
 
114
  if self.chunk_overlap > 0 and len(chunks) > 1:
 
117
  return chunks
118
 
119
  def _is_has_chinese(self, text: str) -> bool:
120
+ # comprobar si contiene caracteres chinos
121
  if any("\u4e00" <= ch <= "\u9fff" for ch in text):
122
  return True
123
  else:
124
  return False
125
 
126
  def _handle_overlap(self, chunks: List[str]) -> List[str]:
127
+ # 处理块间重叠
128
  overlapped_chunks = []
129
  for i in range(len(chunks) - 1):
130
  chunk = chunks[i] + ' ' + chunks[i + 1][:self.chunk_overlap]
 
133
  return overlapped_chunks
134
 
135
 
136
+
137
  class ChatPDF:
138
  def __init__(
139
  self,
140
  similarity_model: SimilarityABC = None,
141
  generate_model_type: str = "auto",
142
+ generate_model_name_or_path: str = "Qwen/Qwen2-0.5B-Instruct",
143
  lora_model_name_or_path: str = None,
144
  corpus_files: Union[str, List[str]] = None,
145
  save_corpus_emb_dir: str = "corpus_embs/",
 
151
  rerank_model_name_or_path: str = None,
152
  enable_history: bool = False,
153
  num_expand_context_chunk: int = 2,
154
+ similarity_top_k: int = 5,
155
+ rerank_top_k: int =3,
156
  ):
157
+ """
158
+ Init RAG model.
159
+ :param similarity_model: similarity model, default None, if set, will use it instead of EnsembleSimilarity
160
+ :param generate_model_type: generate model type
161
+ :param generate_model_name_or_path: generate model name or path
162
+ :param lora_model_name_or_path: lora model name or path
163
+ :param corpus_files: corpus files
164
+ :param save_corpus_emb_dir: save corpus embeddings dir, default ./corpus_embs/
165
+ :param device: device, default None, auto select gpu or cpu
166
+ :param int8: use int8 quantization, default False
167
+ :param int4: use int4 quantization, default False
168
+ :param chunk_size: chunk size, default 250
169
+ :param chunk_overlap: chunk overlap, default 0, can not set to > 0 if num_expand_context_chunk > 0
170
+ :param rerank_model_name_or_path: rerank model name or path, default 'BAAI/bge-reranker-base'
171
+ :param enable_history: enable history, default False
172
+ :param num_expand_context_chunk: num expand context chunk, default 2, if set to 0, will not expand context chunk
173
+ :param similarity_top_k: similarity_top_k, default 5, similarity model search k corpus chunks
174
+ :param rerank_top_k: rerank_top_k, default 3, rerank model search k corpus chunks
175
+ """
176
  if torch.cuda.is_available():
177
  default_device = torch.device(0)
178
  elif torch.backends.mps.is_available():
 
188
  if similarity_model is not None:
189
  self.sim_model = similarity_model
190
  else:
191
+ m1 = BertSimilarity(model_name_or_path="shibing624/text2vec-base-multilingual", device=self.device)
192
  m2 = BM25Similarity()
193
  default_sim_model = EnsembleSimilarity(similarities=[m1, m2], weights=[0.5, 0.5], c=2)
194
  self.sim_model = default_sim_model
 
205
  self.add_corpus(corpus_files)
206
  self.save_corpus_emb_dir = save_corpus_emb_dir
207
  if rerank_model_name_or_path is None:
208
+ rerank_model_name_or_path = "BAAI/bge-reranker-large"
209
  if rerank_model_name_or_path:
210
  self.rerank_tokenizer = AutoTokenizer.from_pretrained(rerank_model_name_or_path)
211
  self.rerank_model = AutoModelForSequenceClassification.from_pretrained(rerank_model_name_or_path)
 
289
  repetition_penalty=1.0,
290
  context_len=2048
291
  ):
292
+ streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
293
  input_ids = self._get_chat_input()
294
  max_src_len = context_len - max_new_tokens - 8
295
  input_ids = input_ids[-max_src_len:]
 
420
  """
421
  reference_results = []
422
  sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
423
+ # Get reference results from corpus
424
+ hit_chunk_dict = dict()
425
+ for query_id, id_score_dict in sim_contents.items():
426
+ for corpus_id, s in id_score_dict.items():
427
+ hit_chunk = self.sim_model.corpus[corpus_id]
428
+ reference_results.append(hit_chunk)
429
+ hit_chunk_dict[corpus_id] = hit_chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
 
431
  if reference_results:
432
  if self.rerank_model is not None:
433
  # Rerank reference results
 
466
  yield 'No se ha proporcionado suficiente información relevante', reference_results
467
  reference_results = self._add_source_numbers(reference_results)
468
  context_str = '\n'.join(reference_results)[:]
469
+ print("gggggg: ", (context_len - len(PROMPT_TEMPLATE)))
470
  prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
471
  logger.debug(f"prompt: {prompt}")
472
  else:
 
500
  if not reference_results:
501
  return 'No se ha proporcionado suficiente información relevante', reference_results
502
  reference_results = self._add_source_numbers(reference_results)
503
+ # context_str = '\n'.join(reference_results) # Usa todos los fragmentos
504
  context_st = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
505
+ print("Context: ", (context_len - len(PROMPT_TEMPLATE)))
506
  print(".......................................................")
507
  context_str = '\n'.join(reference_results)[:]
508
+ print("context_str: ", context_str)
509
  prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
510
  logger.debug(f"prompt: {prompt}")
511
  else:
 
522
  self.history[-1][1] = response
523
  return response, reference_results
524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  def save_corpus_text(self):
526
  if not self.corpus_files:
527
  logger.warning("No hay archivos de corpus para guardar.")
 
546
  else:
547
  logger.warning(f"No se encontró el archivo de texto del corpus en: {corpus_text_file}")
548
 
549
+ def save_corpus_emb(self):
550
+ dir_name = self.get_file_hash(self.corpus_files)
551
+ save_dir = os.path.join(self.save_corpus_emb_dir, dir_name)
552
+ if hasattr(self.sim_model, 'save_corpus_embeddings'):
553
+ self.sim_model.save_corpus_embeddings(save_dir)
554
+ logger.debug(f"Saving corpus embeddings to {save_dir}")
555
+ return save_dir
556
+
557
+ def load_corpus_emb(self, emb_dir: str):
558
+ if hasattr(self.sim_model, 'load_corpus_embeddings'):
559
+ logger.debug(f"Cargando incrustaciones del corpus desde {emb_dir}")
560
+ self.sim_model.load_corpus_embeddings(emb_dir)
561
+ # Cargar el texto del corpus
562
+ self.load_corpus_text(emb_dir)
563
+
564
+
565
  if __name__ == "__main__":
566
  parser = argparse.ArgumentParser()
567
+ parser.add_argument("--sim_model_name", type=str, default="shibing624/text2vec-base-multilingual")
568
  parser.add_argument("--gen_model_type", type=str, default="auto")
569
+ parser.add_argument("--gen_model_name", type=str, default="Qwen/Qwen2-0.5B-Instruct")
570
  parser.add_argument("--lora_model", type=str, default=None)
571
+ parser.add_argument("--rerank_model_name", type=str, default="")
572
+ parser.add_argument("--corpus_files", type=str, default="Acuerdo009.pdf")
573
  parser.add_argument("--device", type=str, default=None)
574
  parser.add_argument("--int4", action='store_true', help="use int4 quantization")
575
  parser.add_argument("--int8", action='store_true', help="use int8 quantization")
576
  parser.add_argument("--chunk_size", type=int, default=220)
577
+ parser.add_argument("--chunk_overlap", type=int, default=0)
578
+ parser.add_argument("--num_expand_context_chunk", type=int, default=1)
579
  args = parser.parse_args()
580
  print(args)
581
  sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
 
593
  num_expand_context_chunk=args.num_expand_context_chunk,
594
  rerank_model_name_or_path=args.rerank_model_name,
595
  )
 
596
 
597
  # Comprobar si existen incrustaciones guardadas
598
  dir_name = m.get_file_hash(args.corpus_files.split(','))
 
618
  continue
619
  r, refs = m.predict(query)
620
  print(r, refs)
621
+ print("\nRespuesta: ", r)