ZoniaChatbot commited on
Commit
736c53f
·
verified ·
1 Parent(s): d6f258c

Update chatpdf.py

Browse files
Files changed (1) hide show
  1. chatpdf.py +590 -581
chatpdf.py CHANGED
@@ -1,582 +1,591 @@
1
- import argparse
2
- import hashlib
3
- import os
4
- import re
5
- from threading import Thread
6
- from typing import Union, List
7
-
8
- import jieba
9
- import torch
10
- from loguru import logger
11
- from peft import PeftModel
12
- from similarities import (
13
- EnsembleSimilarity,
14
- BertSimilarity,
15
- BM25Similarity,
16
- )
17
- from similarities.similarity import SimilarityABC
18
- from transformers import (
19
- AutoModelForCausalLM,
20
- AutoTokenizer,
21
- TextIteratorStreamer,
22
- GenerationConfig,
23
- AutoModelForSequenceClassification,
24
- )
25
-
26
- jieba.setLogLevel("ERROR")
27
-
28
- MODEL_CLASSES = {
29
- "auto": (AutoModelForCausalLM, AutoTokenizer),
30
- }
31
-
32
- PROMPT_TEMPLATE1 = """Utiliza la siguiente información para responder a la pregunta del usuario.
33
- Si no sabes la respuesta, di simplemente que no la sabes, no intentes inventarte una respuesta.
34
-
35
- Contexto: {context_str}
36
- Pregunta: {query_str}
37
-
38
- Devuelve sólo la respuesta útil que aparece a continuación y nada más, y ésta debe estar en Español.
39
- Respuesta útil:
40
- """
41
- PROMPT_TEMPLATE = """Basándose en la siguiente información conocida, responda a la pregunta del usuario de forma
42
- concisa y profesional. Si no puede obtener una respuesta, diga «No se puede responder a la pregunta basándose en la
43
- información conocida» o «No se proporciona suficiente información relevante», no está permitido añadir elementos
44
- inventados en la respuesta.
45
-
46
- Contenido conocido:
47
- {context_str}
48
-
49
- Pregunta:
50
- {query_str}
51
- """
52
-
53
-
54
- class SentenceSplitter:
55
- def __init__(self, chunk_size: int = 250, chunk_overlap: int = 50):
56
- self.chunk_size = chunk_size
57
- self.chunk_overlap = chunk_overlap
58
-
59
- def split_text(self, text: str) -> List[str]:
60
- if self._is_has_chinese(text):
61
- return self._split_chinese_text(text)
62
- else:
63
- return self._split_english_text(text)
64
-
65
- def _split_chinese_text(self, text: str) -> List[str]:
66
- sentence_endings = {'\n', '。', '!', '?', ';', '…'} # puntuación al final de una frase
67
- chunks, current_chunk = [], ''
68
- for word in jieba.cut(text):
69
- if len(current_chunk) + len(word) > self.chunk_size:
70
- chunks.append(current_chunk.strip())
71
- current_chunk = word
72
- else:
73
- current_chunk += word
74
- if word[-1] in sentence_endings and len(current_chunk) > self.chunk_size - self.chunk_overlap:
75
- chunks.append(current_chunk.strip())
76
- current_chunk = ''
77
- if current_chunk:
78
- chunks.append(current_chunk.strip())
79
- if self.chunk_overlap > 0 and len(chunks) > 1:
80
- chunks = self._handle_overlap(chunks)
81
- return chunks
82
-
83
- def _split_english_text(self, text: str) -> List[str]:
84
- # División de texto inglés por frases mediante expresiones regulares
85
- sentences = re.split(r'(?<=[.!?])\s+', text.replace('\n', ' '))
86
- chunks, current_chunk = [], ''
87
- for sentence in sentences:
88
- if len(current_chunk) + len(sentence) <= self.chunk_size or not current_chunk:
89
- current_chunk += (' ' if current_chunk else '') + sentence
90
- else:
91
- chunks.append(current_chunk)
92
- current_chunk = sentence
93
- if current_chunk: # Add the last chunk
94
- chunks.append(current_chunk)
95
-
96
- if self.chunk_overlap > 0 and len(chunks) > 1:
97
- chunks = self._handle_overlap(chunks)
98
-
99
- return chunks
100
-
101
- def _is_has_chinese(self, text: str) -> bool:
102
- # check if contains chinese characters
103
- if any("\u4e00" <= ch <= "\u9fff" for ch in text):
104
- return True
105
- else:
106
- return False
107
-
108
- def _handle_overlap(self, chunks: List[str]) -> List[str]:
109
- # Tratamiento de los solapamientos entre bloques
110
- overlapped_chunks = []
111
- for i in range(len(chunks) - 1):
112
- chunk = chunks[i] + ' ' + chunks[i + 1][:self.chunk_overlap]
113
- overlapped_chunks.append(chunk.strip())
114
- overlapped_chunks.append(chunks[-1])
115
- return overlapped_chunks
116
-
117
-
118
- class ChatPDF:
119
- def __init__(
120
- self,
121
- similarity_model: SimilarityABC = None,
122
- generate_model_type: str = "auto",
123
- generate_model_name_or_path: str = "LenguajeNaturalAI/leniachat-qwen2-1.5B-v0",
124
- lora_model_name_or_path: str = None,
125
- corpus_files: Union[str, List[str]] = None,
126
- save_corpus_emb_dir: str = "corpus_embs/",
127
- device: str = None,
128
- int8: bool = False,
129
- int4: bool = False,
130
- chunk_size: int = 250,
131
- chunk_overlap: int = 0,
132
- rerank_model_name_or_path: str = None,
133
- enable_history: bool = False,
134
- num_expand_context_chunk: int = 2,
135
- similarity_top_k: int = 10,
136
- rerank_top_k: int = 3
137
- ):
138
-
139
- if torch.cuda.is_available():
140
- default_device = torch.device(0)
141
- elif torch.backends.mps.is_available():
142
- default_device = torch.device('cpu')
143
- else:
144
- default_device = torch.device('cpu')
145
- self.device = device or default_device
146
- if num_expand_context_chunk > 0 and chunk_overlap > 0:
147
- logger.warning(f" 'num_expand_context_chunk' and 'chunk_overlap' cannot both be greater than zero. "
148
- f" 'chunk_overlap' has been set to zero by default.")
149
- chunk_overlap = 0
150
- self.text_splitter = SentenceSplitter(chunk_size, chunk_overlap)
151
- if similarity_model is not None:
152
- self.sim_model = similarity_model
153
- else:
154
- m1 = BertSimilarity(model_name_or_path="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", device=self.device)
155
- m2 = BM25Similarity()
156
- default_sim_model = EnsembleSimilarity(similarities=[m1, m2], weights=[0.5, 0.5], c=2)
157
- self.sim_model = default_sim_model
158
- self.gen_model, self.tokenizer = self._init_gen_model(
159
- generate_model_type,
160
- generate_model_name_or_path,
161
- peft_name=lora_model_name_or_path,
162
- int8=int8,
163
- int4=int4,
164
- )
165
- self.history = []
166
- self.corpus_files = corpus_files
167
- if corpus_files:
168
- self.add_corpus(corpus_files)
169
- self.save_corpus_emb_dir = save_corpus_emb_dir
170
- if rerank_model_name_or_path is None:
171
- rerank_model_name_or_path = "maidalun1020/bce-reranker-base_v1"
172
- if rerank_model_name_or_path:
173
- self.rerank_tokenizer = AutoTokenizer.from_pretrained(rerank_model_name_or_path)
174
- self.rerank_model = AutoModelForSequenceClassification.from_pretrained(rerank_model_name_or_path)
175
- self.rerank_model.to(self.device)
176
- self.rerank_model.eval()
177
- else:
178
- self.rerank_model = None
179
- self.rerank_tokenizer = None
180
- self.enable_history = enable_history
181
- self.similarity_top_k = similarity_top_k
182
- self.num_expand_context_chunk = num_expand_context_chunk
183
- self.rerank_top_k = rerank_top_k
184
-
185
- def __str__(self):
186
- return f"Similarity model: {self.sim_model}, Generate model: {self.gen_model}"
187
-
188
- def _init_gen_model(
189
- self,
190
- gen_model_type: str,
191
- gen_model_name_or_path: str,
192
- peft_name: str = None,
193
- int8: bool = False,
194
- int4: bool = False,
195
- ):
196
- """Init generate model."""
197
- if int8 or int4:
198
- device_map = None
199
- else:
200
- device_map = "auto"
201
- model_class, tokenizer_class = MODEL_CLASSES[gen_model_type]
202
- tokenizer = tokenizer_class.from_pretrained(gen_model_name_or_path, trust_remote_code=True)
203
- model = model_class.from_pretrained(
204
- gen_model_name_or_path,
205
- load_in_8bit=int8 if gen_model_type not in ['baichuan', 'chatglm'] else False,
206
- load_in_4bit=int4 if gen_model_type not in ['baichuan', 'chatglm'] else False,
207
- torch_dtype="auto",
208
- device_map=device_map,
209
- trust_remote_code=True,
210
- )
211
- if self.device == torch.device('cpu'):
212
- model.float()
213
- if gen_model_type in ['baichuan', 'chatglm']:
214
- if int4:
215
- model = model.quantize(4).cuda()
216
- elif int8:
217
- model = model.quantize(8).cuda()
218
- try:
219
- model.generation_config = GenerationConfig.from_pretrained(gen_model_name_or_path, trust_remote_code=True)
220
- except Exception as e:
221
- logger.warning(f"No se pudo cargar la configuración de generación desde {gen_model_name_or_path}, {e}")
222
- if peft_name:
223
- model = PeftModel.from_pretrained(
224
- model,
225
- peft_name,
226
- torch_dtype="auto",
227
- )
228
- logger.info(f"Modelo peft cargado desde {peft_name}")
229
- model.eval()
230
- return model, tokenizer
231
-
232
- def _get_chat_input(self):
233
- messages = []
234
- for conv in self.history:
235
- if conv and len(conv) > 0 and conv[0]:
236
- messages.append({'role': 'user', 'content': conv[0]})
237
- if conv and len(conv) > 1 and conv[1]:
238
- messages.append({'role': 'assistant', 'content': conv[1]})
239
- input_ids = self.tokenizer.apply_chat_template(
240
- conversation=messages,
241
- tokenize=True,
242
- add_generation_prompt=True,
243
- return_tensors='pt'
244
- )
245
- return input_ids.to(self.gen_model.device)
246
-
247
- @torch.inference_mode()
248
- def stream_generate_answer(
249
- self,
250
- max_new_tokens=512,
251
- temperature=0.7,
252
- repetition_penalty=1.0,
253
- context_len=2048
254
- ):
255
- streamer = TextIteratorStreamer(self.tokenizer, timeout=520.0, skip_prompt=True, skip_special_tokens=True)
256
- input_ids = self._get_chat_input()
257
- max_src_len = context_len - max_new_tokens - 8
258
- input_ids = input_ids[-max_src_len:]
259
- generation_kwargs = dict(
260
- input_ids=input_ids,
261
- max_new_tokens=max_new_tokens,
262
- temperature=temperature,
263
- do_sample=True,
264
- repetition_penalty=repetition_penalty,
265
- streamer=streamer,
266
- )
267
- thread = Thread(target=self.gen_model.generate, kwargs=generation_kwargs)
268
- thread.start()
269
-
270
- yield from streamer
271
-
272
- def add_corpus(self, files: Union[str, List[str]]):
273
- """Load document files."""
274
- if isinstance(files, str):
275
- files = [files]
276
- for doc_file in files:
277
- if doc_file.endswith('.pdf'):
278
- corpus = self.extract_text_from_pdf(doc_file)
279
- elif doc_file.endswith('.docx'):
280
- corpus = self.extract_text_from_docx(doc_file)
281
- elif doc_file.endswith('.md'):
282
- corpus = self.extract_text_from_markdown(doc_file)
283
- else:
284
- corpus = self.extract_text_from_txt(doc_file)
285
- full_text = '\n'.join(corpus)
286
- chunks = self.text_splitter.split_text(full_text)
287
- self.sim_model.add_corpus(chunks)
288
- self.corpus_files = files
289
- logger.debug(f"files: {files}, corpus size: {len(self.sim_model.corpus)}, top3: "
290
- f"{list(self.sim_model.corpus.values())[:3]}")
291
-
292
- @staticmethod
293
- def get_file_hash(fpaths):
294
- hasher = hashlib.md5()
295
- target_file_data = bytes()
296
- if isinstance(fpaths, str):
297
- fpaths = [fpaths]
298
- for fpath in fpaths:
299
- with open(fpath, 'rb') as file:
300
- chunk = file.read(1024 * 1024) # read only first 1MB
301
- hasher.update(chunk)
302
- target_file_data += chunk
303
-
304
- hash_name = hasher.hexdigest()[:32]
305
- return hash_name
306
-
307
- @staticmethod
308
- def extract_text_from_pdf(file_path: str):
309
- """Extract text content from a PDF file."""
310
- import PyPDF2
311
- contents = []
312
- with open(file_path, 'rb') as f:
313
- pdf_reader = PyPDF2.PdfReader(f)
314
- for page in pdf_reader.pages:
315
- page_text = page.extract_text().strip()
316
- raw_text = [text.strip() for text in page_text.splitlines() if text.strip()]
317
- new_text = ''
318
- for text in raw_text:
319
- # Añadir un espacio antes de concatenar si new_text no está vacío
320
- if new_text:
321
- new_text += ' '
322
- new_text += text
323
- if text[-1] in ['.', '!', '?', '。', '!', '?', '…', ';', ';', ':', ':', '”', '’', ')', '】', '》', '」',
324
- '』', '〕', '〉', '》', '〗', '〞', '〟', '»', '"', "'", ')', ']', '}']:
325
- contents.append(new_text)
326
- new_text = ''
327
- if new_text:
328
- contents.append(new_text)
329
- return contents
330
-
331
- @staticmethod
332
- def extract_text_from_txt(file_path: str):
333
- """Extract text content from a TXT file."""
334
- with open(file_path, 'r', encoding='utf-8') as f:
335
- contents = [text.strip() for text in f.readlines() if text.strip()]
336
- return contents
337
-
338
- @staticmethod
339
- def extract_text_from_docx(file_path: str):
340
- """Extract text content from a DOCX file."""
341
- import docx
342
- document = docx.Document(file_path)
343
- contents = [paragraph.text.strip() for paragraph in document.paragraphs if paragraph.text.strip()]
344
- return contents
345
-
346
- @staticmethod
347
- def extract_text_from_markdown(file_path: str):
348
- """Extract text content from a Markdown file."""
349
- import markdown
350
- from bs4 import BeautifulSoup
351
- with open(file_path, 'r', encoding='utf-8') as f:
352
- markdown_text = f.read()
353
- html = markdown.markdown(markdown_text)
354
- soup = BeautifulSoup(html, 'html.parser')
355
- contents = [text.strip() for text in soup.get_text().splitlines() if text.strip()]
356
- return contents
357
-
358
- @staticmethod
359
- def _add_source_numbers(lst):
360
- """Add source numbers to a list of strings."""
361
- return [f'[{idx + 1}]\t "{item}"' for idx, item in enumerate(lst)]
362
-
363
- def _get_reranker_score(self, query: str, reference_results: List[str]):
364
- """Get reranker score."""
365
- pairs = []
366
- for reference in reference_results:
367
- pairs.append([query, reference])
368
- with torch.no_grad():
369
- inputs = self.rerank_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
370
- inputs_on_device = {k: v.to(self.rerank_model.device) for k, v in inputs.items()}
371
- scores = self.rerank_model(**inputs_on_device, return_dict=True).logits.view(-1, ).float()
372
-
373
- return scores
374
-
375
- def get_reference_results(self, query: str):
376
- """
377
- Get reference results.
378
- 1. Similarity model get similar chunks
379
- 2. Rerank similar chunks
380
- 3. Expand reference context chunk
381
- :param query:
382
- :return:
383
- """
384
- reference_results = []
385
- sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
386
- # Get reference results from corpus
387
- hit_chunk_dict = dict()
388
- for query_id, id_score_dict in sim_contents.items():
389
- for corpus_id, s in id_score_dict.items():
390
- hit_chunk = self.sim_model.corpus[corpus_id]
391
- reference_results.append(hit_chunk)
392
- hit_chunk_dict[corpus_id] = hit_chunk
393
-
394
- if reference_results:
395
- if self.rerank_model is not None:
396
- # Rerank reference results
397
- rerank_scores = self._get_reranker_score(query, reference_results)
398
- logger.debug(f"rerank_scores: {rerank_scores}")
399
- # Get rerank top k chunks
400
- reference_results = [reference for reference, score in sorted(
401
- zip(reference_results, rerank_scores), key=lambda x: x[1], reverse=True)][:self.rerank_top_k]
402
- hit_chunk_dict = {corpus_id: hit_chunk for corpus_id, hit_chunk in hit_chunk_dict.items() if
403
- hit_chunk in reference_results}
404
- # Expand reference context chunk
405
- if self.num_expand_context_chunk > 0:
406
- new_reference_results = []
407
- for corpus_id, hit_chunk in hit_chunk_dict.items():
408
- expanded_reference = self.sim_model.corpus.get(corpus_id - 1, '') + hit_chunk
409
- for i in range(self.num_expand_context_chunk):
410
- expanded_reference += self.sim_model.corpus.get(corpus_id + i + 1, '')
411
- new_reference_results.append(expanded_reference)
412
- reference_results = new_reference_results
413
- return reference_results
414
-
415
- def predict_stream(
416
- self,
417
- query: str,
418
- max_length: int = 512,
419
- context_len: int = 2048,
420
- temperature: float = 0.7,
421
- ):
422
- """Generate predictions stream."""
423
- stop_str = self.tokenizer.eos_token if self.tokenizer.eos_token else "</s>"
424
- if not self.enable_history:
425
- self.history = []
426
- if self.sim_model.corpus:
427
- reference_results = self.get_reference_results(query)
428
- if not reference_results:
429
- yield 'No se ha proporcionado suficiente información relevante', reference_results
430
- reference_results = self._add_source_numbers(reference_results)
431
- context_str = '\n'.join(reference_results)[:]
432
- #print("context_str: " , (context_len - len(PROMPT_TEMPLATE)))
433
- prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
434
- logger.debug(f"prompt: {prompt}")
435
- else:
436
- prompt = query
437
- logger.debug(prompt)
438
- self.history.append([prompt, ''])
439
- response = ""
440
- for new_text in self.stream_generate_answer(
441
- max_new_tokens=max_length,
442
- temperature=temperature,
443
- context_len=context_len,
444
- ):
445
- if new_text != stop_str:
446
- response += new_text
447
- yield response
448
-
449
- def predict(
450
- self,
451
- query: str,
452
- max_length: int = 512,
453
- context_len: int = 2048,
454
- temperature: float = 0.7,
455
- ):
456
- """Query from corpus."""
457
- reference_results = []
458
- if not self.enable_history:
459
- self.history = []
460
- if self.sim_model.corpus:
461
- reference_results = self.get_reference_results(query)
462
-
463
- if not reference_results:
464
- return 'No se ha proporcionado suficiente información relevante', reference_results
465
- reference_results = self._add_source_numbers(reference_results)
466
- #context_str = '\n'.join(reference_results) # Usa todos los fragmentos
467
- context_st = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
468
- #print("Context: ", (context_len - len(PROMPT_TEMPLATE)))
469
- print(".......................................................")
470
- context_str = '\n'.join(reference_results)[:]
471
- #print("context_str: ", context_str)
472
- prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
473
- logger.debug(f"prompt: {prompt}")
474
- else:
475
- prompt = query
476
- self.history.append([prompt, ''])
477
- response = ""
478
- for new_text in self.stream_generate_answer(
479
- max_new_tokens=max_length,
480
- temperature=temperature,
481
- context_len=context_len,
482
- ):
483
- response += new_text
484
- response = response.strip()
485
- self.history[-1][1] = response
486
- return response, reference_results
487
-
488
- def save_corpus_emb(self):
489
- dir_name = self.get_file_hash(self.corpus_files)
490
- save_dir = os.path.join(self.save_corpus_emb_dir, dir_name)
491
- if hasattr(self.sim_model, 'save_corpus_embeddings'):
492
- self.sim_model.save_corpus_embeddings(save_dir)
493
- logger.debug(f"Saving corpus embeddings to {save_dir}")
494
- return save_dir
495
-
496
- def load_corpus_emb(self, emb_dir: str):
497
- if hasattr(self.sim_model, 'load_corpus_embeddings'):
498
- logger.debug(f"Loading corpus embeddings from {emb_dir}")
499
- self.sim_model.load_corpus_embeddings(emb_dir)
500
-
501
- def save_corpus_text(self):
502
- if not self.corpus_files:
503
- logger.warning("No hay archivos de corpus para guardar.")
504
- return
505
-
506
- corpus_text_file = os.path.join("corpus_embs/", "corpus_text.txt")
507
-
508
- with open(corpus_text_file, 'w', encoding='utf-8') as f:
509
- for chunk in self.sim_model.corpus.values():
510
- f.write(chunk + "\n\n") # Añade dos saltos de línea entre chunks para mejor legibilidad
511
-
512
- logger.info(f"Texto del corpus guardado en: {corpus_text_file}")
513
- return corpus_text_file
514
-
515
- def load_corpus_text(self, emb_dir: str):
516
- corpus_text_file = os.path.join("corpus_embs/", "corpus_text.txt")
517
- if os.path.exists(corpus_text_file):
518
- with open(corpus_text_file, 'r', encoding='utf-8') as f:
519
- corpus_text = f.read().split("\n\n") # Asumiendo que usamos dos saltos de línea como separador
520
- self.sim_model.corpus = {i: chunk.strip() for i, chunk in enumerate(corpus_text) if chunk.strip()}
521
- logger.info(f"Texto del corpus cargado desde: {corpus_text_file}")
522
- else:
523
- logger.warning(f"No se encontró el archivo de texto del corpus en: {corpus_text_file}")
524
-
525
- if __name__ == "__main__":
526
- parser = argparse.ArgumentParser()
527
- parser.add_argument("--sim_model_name", type=str, default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
528
- parser.add_argument("--gen_model_type", type=str, default="auto")
529
- parser.add_argument("--gen_model_name", type=str, default="LenguajeNaturalAI/leniachat-qwen2-1.5B-v0")
530
- parser.add_argument("--lora_model", type=str, default=None)
531
- parser.add_argument("--rerank_model_name", type=str, default="maidalun1020/bce-reranker-base_v1")
532
- parser.add_argument("--corpus_files", type=str, default="docs/corpus.txt")
533
- parser.add_argument("--device", type=str, default=None)
534
- parser.add_argument("--int4", action='store_true', help="use int4 quantization")
535
- parser.add_argument("--int8", action='store_true', help="use int8 quantization")
536
- parser.add_argument("--chunk_size", type=int, default=220)
537
- parser.add_argument("--chunk_overlap", type=int, default=50)
538
- parser.add_argument("--num_expand_context_chunk", type=int, default=2)
539
- args = parser.parse_args()
540
- print(args)
541
- sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
542
- m = ChatPDF(
543
- similarity_model=sim_model,
544
- generate_model_type=args.gen_model_type,
545
- generate_model_name_or_path=args.gen_model_name,
546
- lora_model_name_or_path=args.lora_model,
547
- device=args.device,
548
- int4=args.int4,
549
- int8=args.int8,
550
- chunk_size=args.chunk_size,
551
- chunk_overlap=args.chunk_overlap,
552
- corpus_files=args.corpus_files.split(','),
553
- num_expand_context_chunk=args.num_expand_context_chunk,
554
- rerank_model_name_or_path=args.rerank_model_name,
555
- )
556
- logger.info(f"chatpdf model: {m}")
557
-
558
- # Comprobar si existen incrustaciones guardadas
559
- dir_name = m.get_file_hash(args.corpus_files.split(','))
560
- save_dir = os.path.join(m.save_corpus_emb_dir, dir_name)
561
-
562
- if os.path.exists(save_dir):
563
- # Cargar las incrustaciones guardadas
564
- m.load_corpus_emb(save_dir)
565
- print(f"Incrustaciones del corpus cargadas desde: {save_dir}")
566
- else:
567
- # Procesar el corpus y guardar las incrustaciones
568
- m.add_corpus(args.corpus_files.split(','))
569
- save_dir = m.save_corpus_emb()
570
- # Guardar el texto del corpus
571
- m.save_corpus_text()
572
- print(f"Las incrustaciones del corpus se han guardado en: {save_dir}")
573
-
574
- while True:
575
- query = input("\nEnter a query: ")
576
- if query == "exit":
577
- break
578
- if query.strip() == "":
579
- continue
580
- r, refs = m.predict(query)
581
- print(r, refs)
 
 
 
 
 
 
 
 
 
582
  print("\nRespuesta: ", r)
 
1
+ import argparse
2
+ import hashlib
3
+ import os
4
+ import re
5
+ from threading import Thread
6
+ from typing import Union, List
7
+
8
+ import jieba
9
+ import torch
10
+ from loguru import logger
11
+ from peft import PeftModel
12
+ from similarities import (
13
+ EnsembleSimilarity,
14
+ BertSimilarity,
15
+ BM25Similarity,
16
+ )
17
+ from similarities.similarity import SimilarityABC
18
+ from transformers import (
19
+ AutoModelForCausalLM,
20
+ AutoTokenizer,
21
+ TextIteratorStreamer,
22
+ GenerationConfig,
23
+ AutoModelForSequenceClassification,
24
+ )
25
+
26
+ jieba.setLogLevel("ERROR")
27
+
28
+ MODEL_CLASSES = {
29
+ "auto": (AutoModelForCausalLM, AutoTokenizer),
30
+ }
31
+
32
+ PROMPT_TEMPLATE1 = """Utiliza la siguiente información para responder a la pregunta del usuario.
33
+ Si no sabes la respuesta, di simplemente que no la sabes, no intentes inventarte una respuesta.
34
+
35
+ Contexto: {context_str}
36
+ Pregunta: {query_str}
37
+
38
+ Devuelve sólo la respuesta útil que aparece a continuación y nada más, y ésta debe estar en Español.
39
+ Respuesta útil:
40
+ """
41
+ PROMPT_TEMPLATE = """Basándose en la siguiente información conocida, responda a la pregunta del usuario de forma
42
+ concisa y profesional. Si no puede obtener una respuesta, diga «No se puede responder a la pregunta basándose en la
43
+ información conocida» o «No se proporciona suficiente información relevante», no está permitido añadir elementos
44
+ inventados en la respuesta.
45
+
46
+ Contenido conocido:
47
+ {context_str}
48
+
49
+ Pregunta:
50
+ {query_str}
51
+ """
52
+
53
+
54
+ class SentenceSplitter:
55
+ def __init__(self, chunk_size: int = 250, chunk_overlap: int = 50):
56
+ self.chunk_size = chunk_size
57
+ self.chunk_overlap = chunk_overlap
58
+
59
+ def split_text(self, text: str) -> List[str]:
60
+ if self._is_has_chinese(text):
61
+ return self._split_chinese_text(text)
62
+ else:
63
+ return self._split_english_text(text)
64
+
65
+ def _split_chinese_text(self, text: str) -> List[str]:
66
+ sentence_endings = {'\n', '。', '!', '?', ';', '…'} # puntuación al final de una frase
67
+ chunks, current_chunk = [], ''
68
+ for word in jieba.cut(text):
69
+ if len(current_chunk) + len(word) > self.chunk_size:
70
+ chunks.append(current_chunk.strip())
71
+ current_chunk = word
72
+ else:
73
+ current_chunk += word
74
+ if word[-1] in sentence_endings and len(current_chunk) > self.chunk_size - self.chunk_overlap:
75
+ chunks.append(current_chunk.strip())
76
+ current_chunk = ''
77
+ if current_chunk:
78
+ chunks.append(current_chunk.strip())
79
+ if self.chunk_overlap > 0 and len(chunks) > 1:
80
+ chunks = self._handle_overlap(chunks)
81
+ return chunks
82
+
83
+ def _split_english_text(self, text: str) -> List[str]:
84
+ # División de texto inglés por frases mediante expresiones regulares
85
+ sentences = re.split(r'(?<=[.!?])\s+', text.replace('\n', ' '))
86
+ chunks, current_chunk = [], ''
87
+ for sentence in sentences:
88
+ if len(current_chunk) + len(sentence) <= self.chunk_size or not current_chunk:
89
+ current_chunk += (' ' if current_chunk else '') + sentence
90
+ else:
91
+ chunks.append(current_chunk)
92
+ current_chunk = sentence
93
+ if current_chunk: # Add the last chunk
94
+ chunks.append(current_chunk)
95
+
96
+ if self.chunk_overlap > 0 and len(chunks) > 1:
97
+ chunks = self._handle_overlap(chunks)
98
+
99
+ return chunks
100
+
101
+ def _is_has_chinese(self, text: str) -> bool:
102
+ # check if contains chinese characters
103
+ if any("\u4e00" <= ch <= "\u9fff" for ch in text):
104
+ return True
105
+ else:
106
+ return False
107
+
108
+ def _handle_overlap(self, chunks: List[str]) -> List[str]:
109
+ # Tratamiento de los solapamientos entre bloques
110
+ overlapped_chunks = []
111
+ for i in range(len(chunks) - 1):
112
+ chunk = chunks[i] + ' ' + chunks[i + 1][:self.chunk_overlap]
113
+ overlapped_chunks.append(chunk.strip())
114
+ overlapped_chunks.append(chunks[-1])
115
+ return overlapped_chunks
116
+
117
+
118
+ class ChatPDF:
119
+ def __init__(
120
+ self,
121
+ similarity_model: SimilarityABC = None,
122
+ generate_model_type: str = "auto",
123
+ generate_model_name_or_path: str = "LenguajeNaturalAI/leniachat-qwen2-1.5B-v0",
124
+ lora_model_name_or_path: str = None,
125
+ corpus_files: Union[str, List[str]] = None,
126
+ save_corpus_emb_dir: str = "corpus_embs/",
127
+ device: str = None,
128
+ int8: bool = False,
129
+ int4: bool = False,
130
+ chunk_size: int = 250,
131
+ chunk_overlap: int = 0,
132
+ rerank_model_name_or_path: str = None,
133
+ enable_history: bool = False,
134
+ num_expand_context_chunk: int = 2,
135
+ similarity_top_k: int = 10,
136
+ rerank_top_k: int = 3
137
+ ):
138
+
139
+ if torch.cuda.is_available():
140
+ default_device = torch.device(0)
141
+ elif torch.backends.mps.is_available():
142
+ default_device = torch.device('cpu')
143
+ else:
144
+ default_device = torch.device('cpu')
145
+ self.device = device or default_device
146
+ if num_expand_context_chunk > 0 and chunk_overlap > 0:
147
+ logger.warning(f" 'num_expand_context_chunk' and 'chunk_overlap' cannot both be greater than zero. "
148
+ f" 'chunk_overlap' has been set to zero by default.")
149
+ chunk_overlap = 0
150
+ self.text_splitter = SentenceSplitter(chunk_size, chunk_overlap)
151
+ if similarity_model is not None:
152
+ self.sim_model = similarity_model
153
+ else:
154
+ m1 = BertSimilarity(model_name_or_path="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", device=self.device)
155
+ m2 = BM25Similarity()
156
+ default_sim_model = EnsembleSimilarity(similarities=[m1, m2], weights=[0.5, 0.5], c=2)
157
+ self.sim_model = default_sim_model
158
+ self.gen_model, self.tokenizer = self._init_gen_model(
159
+ generate_model_type,
160
+ generate_model_name_or_path,
161
+ peft_name=lora_model_name_or_path,
162
+ int8=int8,
163
+ int4=int4,
164
+ )
165
+ self.history = []
166
+ self.corpus_files = corpus_files
167
+ if corpus_files:
168
+ self.add_corpus(corpus_files)
169
+ self.save_corpus_emb_dir = save_corpus_emb_dir
170
+ if rerank_model_name_or_path is None:
171
+ rerank_model_name_or_path = "maidalun1020/bce-reranker-base_v1"
172
+ if rerank_model_name_or_path:
173
+ self.rerank_tokenizer = AutoTokenizer.from_pretrained(rerank_model_name_or_path)
174
+ self.rerank_model = AutoModelForSequenceClassification.from_pretrained(rerank_model_name_or_path)
175
+ self.rerank_model.to(self.device)
176
+ self.rerank_model.eval()
177
+ else:
178
+ self.rerank_model = None
179
+ self.rerank_tokenizer = None
180
+ self.enable_history = enable_history
181
+ self.similarity_top_k = similarity_top_k
182
+ self.num_expand_context_chunk = num_expand_context_chunk
183
+ self.rerank_top_k = rerank_top_k
184
+
185
+ def __str__(self):
186
+ return f"Similarity model: {self.sim_model}, Generate model: {self.gen_model}"
187
+
188
+ def _init_gen_model(
189
+ self,
190
+ gen_model_type: str,
191
+ gen_model_name_or_path: str,
192
+ peft_name: str = None,
193
+ int8: bool = False,
194
+ int4: bool = False,
195
+ ):
196
+ """Init generate model."""
197
+ if int8 or int4:
198
+ device_map = None
199
+ else:
200
+ device_map = "auto"
201
+ model_class, tokenizer_class = MODEL_CLASSES[gen_model_type]
202
+ tokenizer = tokenizer_class.from_pretrained(gen_model_name_or_path, trust_remote_code=True)
203
+ model = model_class.from_pretrained(
204
+ gen_model_name_or_path,
205
+ load_in_8bit=int8 if gen_model_type not in ['baichuan', 'chatglm'] else False,
206
+ load_in_4bit=int4 if gen_model_type not in ['baichuan', 'chatglm'] else False,
207
+ torch_dtype="auto",
208
+ device_map=device_map,
209
+ trust_remote_code=True,
210
+ )
211
+ if self.device == torch.device('cpu'):
212
+ model.float()
213
+ if gen_model_type in ['baichuan', 'chatglm']:
214
+ if int4:
215
+ model = model.quantize(4).cuda()
216
+ elif int8:
217
+ model = model.quantize(8).cuda()
218
+ try:
219
+ model.generation_config = GenerationConfig.from_pretrained(gen_model_name_or_path, trust_remote_code=True)
220
+ except Exception as e:
221
+ logger.warning(f"No se pudo cargar la configuración de generación desde {gen_model_name_or_path}, {e}")
222
+ if peft_name:
223
+ model = PeftModel.from_pretrained(
224
+ model,
225
+ peft_name,
226
+ torch_dtype="auto",
227
+ )
228
+ logger.info(f"Modelo peft cargado desde {peft_name}")
229
+ model.eval()
230
+ return model, tokenizer
231
+
232
+ def _get_chat_input(self):
233
+ messages = []
234
+ for conv in self.history:
235
+ if conv and len(conv) > 0 and conv[0]:
236
+ messages.append({'role': 'user', 'content': conv[0]})
237
+ if conv and len(conv) > 1 and conv[1]:
238
+ messages.append({'role': 'assistant', 'content': conv[1]})
239
+ input_ids = self.tokenizer.apply_chat_template(
240
+ conversation=messages,
241
+ tokenize=True,
242
+ add_generation_prompt=True,
243
+ return_tensors='pt'
244
+ )
245
+ return input_ids.to(self.gen_model.device)
246
+
247
+ @torch.inference_mode()
248
+ def stream_generate_answer(
249
+ self,
250
+ max_new_tokens=512,
251
+ temperature=0.7,
252
+ repetition_penalty=1.0,
253
+ context_len=2048
254
+ ):
255
+ streamer = TextIteratorStreamer(self.tokenizer, timeout=520.0, skip_prompt=True, skip_special_tokens=True)
256
+ input_ids = self._get_chat_input()
257
+ max_src_len = context_len - max_new_tokens - 8
258
+ input_ids = input_ids[-max_src_len:]
259
+ generation_kwargs = dict(
260
+ input_ids=input_ids,
261
+ max_new_tokens=max_new_tokens,
262
+ temperature=temperature,
263
+ do_sample=True,
264
+ repetition_penalty=repetition_penalty,
265
+ streamer=streamer,
266
+ )
267
+ thread = Thread(target=self.gen_model.generate, kwargs=generation_kwargs)
268
+ thread.start()
269
+
270
+ yield from streamer
271
+
272
+ def add_corpus(self, files: Union[str, List[str]]):
273
+ """Load document files."""
274
+ if isinstance(files, str):
275
+ files = [files]
276
+ for doc_file in files:
277
+ if doc_file.endswith('.pdf'):
278
+ corpus = self.extract_text_from_pdf(doc_file)
279
+ elif doc_file.endswith('.docx'):
280
+ corpus = self.extract_text_from_docx(doc_file)
281
+ elif doc_file.endswith('.md'):
282
+ corpus = self.extract_text_from_markdown(doc_file)
283
+ else:
284
+ corpus = self.extract_text_from_txt(doc_file)
285
+ full_text = '\n'.join(corpus)
286
+ chunks = self.text_splitter.split_text(full_text)
287
+ self.sim_model.add_corpus(chunks)
288
+ self.corpus_files = files
289
+ logger.debug(f"files: {files}, corpus size: {len(self.sim_model.corpus)}, top3: "
290
+ f"{list(self.sim_model.corpus.values())[:3]}")
291
+
292
+ @staticmethod
293
+ def get_file_hash(fpaths):
294
+ hasher = hashlib.md5()
295
+ target_file_data = bytes()
296
+ if isinstance(fpaths, str):
297
+ fpaths = [fpaths]
298
+ for fpath in fpaths:
299
+ with open(fpath, 'rb') as file:
300
+ chunk = file.read(1024 * 1024) # read only first 1MB
301
+ hasher.update(chunk)
302
+ target_file_data += chunk
303
+
304
+ hash_name = hasher.hexdigest()[:32]
305
+ return hash_name
306
+
307
+ @staticmethod
308
+ def extract_text_from_pdf(file_path: str):
309
+ """Extract text content from a PDF file."""
310
+ import PyPDF2
311
+ contents = []
312
+ with open(file_path, 'rb') as f:
313
+ pdf_reader = PyPDF2.PdfReader(f)
314
+ for page in pdf_reader.pages:
315
+ page_text = page.extract_text().strip()
316
+ raw_text = [text.strip() for text in page_text.splitlines() if text.strip()]
317
+ new_text = ''
318
+ for text in raw_text:
319
+ # Añadir un espacio antes de concatenar si new_text no está vacío
320
+ if new_text:
321
+ new_text += ' '
322
+ new_text += text
323
+ if text[-1] in ['.', '!', '?', '。', '!', '?', '…', ';', ';', ':', ':', '”', '’', ')', '】', '》', '」',
324
+ '』', '〕', '〉', '》', '〗', '〞', '〟', '»', '"', "'", ')', ']', '}']:
325
+ contents.append(new_text)
326
+ new_text = ''
327
+ if new_text:
328
+ contents.append(new_text)
329
+ return contents
330
+
331
+ @staticmethod
332
+ def extract_text_from_txt(file_path: str):
333
+ """Extract text content from a TXT file."""
334
+ with open(file_path, 'r', encoding='utf-8') as f:
335
+ contents = [text.strip() for text in f.readlines() if text.strip()]
336
+ return contents
337
+
338
+ @staticmethod
339
+ def extract_text_from_docx(file_path: str):
340
+ """Extract text content from a DOCX file."""
341
+ import docx
342
+ document = docx.Document(file_path)
343
+ contents = [paragraph.text.strip() for paragraph in document.paragraphs if paragraph.text.strip()]
344
+ return contents
345
+
346
+ @staticmethod
347
+ def extract_text_from_markdown(file_path: str):
348
+ """Extract text content from a Markdown file."""
349
+ import markdown
350
+ from bs4 import BeautifulSoup
351
+ with open(file_path, 'r', encoding='utf-8') as f:
352
+ markdown_text = f.read()
353
+ html = markdown.markdown(markdown_text)
354
+ soup = BeautifulSoup(html, 'html.parser')
355
+ contents = [text.strip() for text in soup.get_text().splitlines() if text.strip()]
356
+ return contents
357
+
358
+ @staticmethod
359
+ def _add_source_numbers(lst):
360
+ """Add source numbers to a list of strings."""
361
+ return [f'[{idx + 1}]\t "{item}"' for idx, item in enumerate(lst)]
362
+
363
+ def _get_reranker_score(self, query: str, reference_results: List[str]):
364
+ """Get reranker score."""
365
+ pairs = []
366
+ for reference in reference_results:
367
+ pairs.append([query, reference])
368
+ with torch.no_grad():
369
+ inputs = self.rerank_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
370
+ inputs_on_device = {k: v.to(self.rerank_model.device) for k, v in inputs.items()}
371
+ scores = self.rerank_model(**inputs_on_device, return_dict=True).logits.view(-1, ).float()
372
+
373
+ return scores
374
+
375
+ def get_reference_results(self, query: str):
376
+ """
377
+ Get reference results.
378
+ 1. Similarity model get similar chunks
379
+ 2. Rerank similar chunks
380
+ 3. Expand reference context chunk
381
+ :param query:
382
+ :return:
383
+ """
384
+ reference_results = []
385
+ sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
386
+
387
+ # Ajustar según el tipo de retorno de sim_contents
388
+ if isinstance(sim_contents, dict): # Si es un diccionario
389
+ for query_id, id_score_dict in sim_contents.items():
390
+ for corpus_id, s in id_score_dict.items():
391
+ hit_chunk = self.sim_model.corpus[corpus_id]
392
+ reference_results.append(hit_chunk)
393
+ elif isinstance(sim_contents, list): # Si es una lista
394
+ for item in sim_contents:
395
+ # Ajusta esto dependiendo de la estructura de los elementos de la lista
396
+ # Ejemplo: si es una lista de (corpus_id, score) tuplas
397
+ corpus_id, _ = item
398
+ hit_chunk = self.sim_model.corpus[corpus_id]
399
+ reference_results.append(hit_chunk)
400
+
401
+ # Resto del código...
402
+ if reference_results:
403
+ if self.rerank_model is not None:
404
+ # Rerank reference results
405
+ rerank_scores = self._get_reranker_score(query, reference_results)
406
+ logger.debug(f"rerank_scores: {rerank_scores}")
407
+ # Get rerank top k chunks
408
+ reference_results = [reference for reference, score in sorted(
409
+ zip(reference_results, rerank_scores), key=lambda x: x[1], reverse=True)][:self.rerank_top_k]
410
+ hit_chunk_dict = {corpus_id: hit_chunk for corpus_id, hit_chunk in hit_chunk_dict.items() if
411
+ hit_chunk in reference_results}
412
+ # Expand reference context chunk
413
+ if self.num_expand_context_chunk > 0:
414
+ new_reference_results = []
415
+ for corpus_id, hit_chunk in hit_chunk_dict.items():
416
+ expanded_reference = self.sim_model.corpus.get(corpus_id - 1, '') + hit_chunk
417
+ for i in range(self.num_expand_context_chunk):
418
+ expanded_reference += self.sim_model.corpus.get(corpus_id + i + 1, '')
419
+ new_reference_results.append(expanded_reference)
420
+ reference_results = new_reference_results
421
+ return reference_results
422
+
423
+
424
+ def predict_stream(
425
+ self,
426
+ query: str,
427
+ max_length: int = 512,
428
+ context_len: int = 2048,
429
+ temperature: float = 0.7,
430
+ ):
431
+ """Generate predictions stream."""
432
+ stop_str = self.tokenizer.eos_token if self.tokenizer.eos_token else "</s>"
433
+ if not self.enable_history:
434
+ self.history = []
435
+ if self.sim_model.corpus:
436
+ reference_results = self.get_reference_results(query)
437
+ if not reference_results:
438
+ yield 'No se ha proporcionado suficiente información relevante', reference_results
439
+ reference_results = self._add_source_numbers(reference_results)
440
+ context_str = '\n'.join(reference_results)[:]
441
+ #print("context_str: " , (context_len - len(PROMPT_TEMPLATE)))
442
+ prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
443
+ logger.debug(f"prompt: {prompt}")
444
+ else:
445
+ prompt = query
446
+ logger.debug(prompt)
447
+ self.history.append([prompt, ''])
448
+ response = ""
449
+ for new_text in self.stream_generate_answer(
450
+ max_new_tokens=max_length,
451
+ temperature=temperature,
452
+ context_len=context_len,
453
+ ):
454
+ if new_text != stop_str:
455
+ response += new_text
456
+ yield response
457
+
458
+ def predict(
459
+ self,
460
+ query: str,
461
+ max_length: int = 512,
462
+ context_len: int = 2048,
463
+ temperature: float = 0.7,
464
+ ):
465
+ """Query from corpus."""
466
+ reference_results = []
467
+ if not self.enable_history:
468
+ self.history = []
469
+ if self.sim_model.corpus:
470
+ reference_results = self.get_reference_results(query)
471
+
472
+ if not reference_results:
473
+ return 'No se ha proporcionado suficiente información relevante', reference_results
474
+ reference_results = self._add_source_numbers(reference_results)
475
+ #context_str = '\n'.join(reference_results) # Usa todos los fragmentos
476
+ context_st = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
477
+ #print("Context: ", (context_len - len(PROMPT_TEMPLATE)))
478
+ print(".......................................................")
479
+ context_str = '\n'.join(reference_results)[:]
480
+ #print("context_str: ", context_str)
481
+ prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
482
+ logger.debug(f"prompt: {prompt}")
483
+ else:
484
+ prompt = query
485
+ self.history.append([prompt, ''])
486
+ response = ""
487
+ for new_text in self.stream_generate_answer(
488
+ max_new_tokens=max_length,
489
+ temperature=temperature,
490
+ context_len=context_len,
491
+ ):
492
+ response += new_text
493
+ response = response.strip()
494
+ self.history[-1][1] = response
495
+ return response, reference_results
496
+
497
+ def save_corpus_emb(self):
498
+ dir_name = self.get_file_hash(self.corpus_files)
499
+ save_dir = os.path.join(self.save_corpus_emb_dir, dir_name)
500
+ if hasattr(self.sim_model, 'save_corpus_embeddings'):
501
+ self.sim_model.save_corpus_embeddings(save_dir)
502
+ logger.debug(f"Saving corpus embeddings to {save_dir}")
503
+ return save_dir
504
+
505
+ def load_corpus_emb(self, emb_dir: str):
506
+ if hasattr(self.sim_model, 'load_corpus_embeddings'):
507
+ logger.debug(f"Loading corpus embeddings from {emb_dir}")
508
+ self.sim_model.load_corpus_embeddings(emb_dir)
509
+
510
+ def save_corpus_text(self):
511
+ if not self.corpus_files:
512
+ logger.warning("No hay archivos de corpus para guardar.")
513
+ return
514
+
515
+ corpus_text_file = os.path.join("corpus_embs/", "corpus_text.txt")
516
+
517
+ with open(corpus_text_file, 'w', encoding='utf-8') as f:
518
+ for chunk in self.sim_model.corpus.values():
519
+ f.write(chunk + "\n\n") # Añade dos saltos de línea entre chunks para mejor legibilidad
520
+
521
+ logger.info(f"Texto del corpus guardado en: {corpus_text_file}")
522
+ return corpus_text_file
523
+
524
+ def load_corpus_text(self, emb_dir: str):
525
+ corpus_text_file = os.path.join("corpus_embs/", "corpus_text.txt")
526
+ if os.path.exists(corpus_text_file):
527
+ with open(corpus_text_file, 'r', encoding='utf-8') as f:
528
+ corpus_text = f.read().split("\n\n") # Asumiendo que usamos dos saltos de línea como separador
529
+ self.sim_model.corpus = {i: chunk.strip() for i, chunk in enumerate(corpus_text) if chunk.strip()}
530
+ logger.info(f"Texto del corpus cargado desde: {corpus_text_file}")
531
+ else:
532
+ logger.warning(f"No se encontró el archivo de texto del corpus en: {corpus_text_file}")
533
+
534
+ if __name__ == "__main__":
535
+ parser = argparse.ArgumentParser()
536
+ parser.add_argument("--sim_model_name", type=str, default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
537
+ parser.add_argument("--gen_model_type", type=str, default="auto")
538
+ parser.add_argument("--gen_model_name", type=str, default="LenguajeNaturalAI/leniachat-qwen2-1.5B-v0")
539
+ parser.add_argument("--lora_model", type=str, default=None)
540
+ parser.add_argument("--rerank_model_name", type=str, default="maidalun1020/bce-reranker-base_v1")
541
+ parser.add_argument("--corpus_files", type=str, default="docs/corpus.txt")
542
+ parser.add_argument("--device", type=str, default=None)
543
+ parser.add_argument("--int4", action='store_true', help="use int4 quantization")
544
+ parser.add_argument("--int8", action='store_true', help="use int8 quantization")
545
+ parser.add_argument("--chunk_size", type=int, default=220)
546
+ parser.add_argument("--chunk_overlap", type=int, default=50)
547
+ parser.add_argument("--num_expand_context_chunk", type=int, default=2)
548
+ args = parser.parse_args()
549
+ print(args)
550
+ sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
551
+ m = ChatPDF(
552
+ similarity_model=sim_model,
553
+ generate_model_type=args.gen_model_type,
554
+ generate_model_name_or_path=args.gen_model_name,
555
+ lora_model_name_or_path=args.lora_model,
556
+ device=args.device,
557
+ int4=args.int4,
558
+ int8=args.int8,
559
+ chunk_size=args.chunk_size,
560
+ chunk_overlap=args.chunk_overlap,
561
+ corpus_files=args.corpus_files.split(','),
562
+ num_expand_context_chunk=args.num_expand_context_chunk,
563
+ rerank_model_name_or_path=args.rerank_model_name,
564
+ )
565
+ logger.info(f"chatpdf model: {m}")
566
+
567
+ # Comprobar si existen incrustaciones guardadas
568
+ dir_name = m.get_file_hash(args.corpus_files.split(','))
569
+ save_dir = os.path.join(m.save_corpus_emb_dir, dir_name)
570
+
571
+ if os.path.exists(save_dir):
572
+ # Cargar las incrustaciones guardadas
573
+ m.load_corpus_emb(save_dir)
574
+ print(f"Incrustaciones del corpus cargadas desde: {save_dir}")
575
+ else:
576
+ # Procesar el corpus y guardar las incrustaciones
577
+ m.add_corpus(args.corpus_files.split(','))
578
+ save_dir = m.save_corpus_emb()
579
+ # Guardar el texto del corpus
580
+ m.save_corpus_text()
581
+ print(f"Las incrustaciones del corpus se han guardado en: {save_dir}")
582
+
583
+ while True:
584
+ query = input("\nEnter a query: ")
585
+ if query == "exit":
586
+ break
587
+ if query.strip() == "":
588
+ continue
589
+ r, refs = m.predict(query)
590
+ print(r, refs)
591
  print("\nRespuesta: ", r)