Spaces:
Sleeping
Sleeping
Update chatpdf.py
Browse files- chatpdf.py +122 -60
chatpdf.py
CHANGED
@@ -1,8 +1,3 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
"""
|
3 |
-
@author:XuMing([email protected])
|
4 |
-
@description:
|
5 |
-
"""
|
6 |
import argparse
|
7 |
import hashlib
|
8 |
import os
|
@@ -18,6 +13,7 @@ from similarities import (
|
|
18 |
EnsembleSimilarity,
|
19 |
BertSimilarity,
|
20 |
BM25Similarity,
|
|
|
21 |
)
|
22 |
from similarities.similarity import SimilarityABC
|
23 |
from transformers import (
|
@@ -43,10 +39,9 @@ MODEL_CLASSES = {
|
|
43 |
"auto": (AutoModelForCausalLM, AutoTokenizer),
|
44 |
}
|
45 |
|
46 |
-
PROMPT_TEMPLATE = """Basándose en la
|
47 |
-
|
48 |
-
|
49 |
-
inventados en la respuesta, y ésta debe estar en Español.
|
50 |
|
51 |
Contenido conocido:
|
52 |
{context_str}
|
@@ -55,8 +50,6 @@ Pregunta:
|
|
55 |
{query_str}
|
56 |
"""
|
57 |
|
58 |
-
|
59 |
-
|
60 |
class SentenceSplitter:
|
61 |
def __init__(self, chunk_size: int = 250, chunk_overlap: int = 50):
|
62 |
self.chunk_size = chunk_size
|
@@ -121,8 +114,7 @@ class SentenceSplitter:
|
|
121 |
return overlapped_chunks
|
122 |
|
123 |
|
124 |
-
|
125 |
-
class Rag:
|
126 |
def __init__(
|
127 |
self,
|
128 |
similarity_model: SimilarityABC = None,
|
@@ -139,8 +131,8 @@ class Rag:
|
|
139 |
rerank_model_name_or_path: str = None,
|
140 |
enable_history: bool = False,
|
141 |
num_expand_context_chunk: int = 2,
|
142 |
-
similarity_top_k: int =
|
143 |
-
rerank_top_k: int =
|
144 |
):
|
145 |
"""
|
146 |
Init RAG model.
|
@@ -176,9 +168,11 @@ class Rag:
|
|
176 |
if similarity_model is not None:
|
177 |
self.sim_model = similarity_model
|
178 |
else:
|
179 |
-
m1 = BertSimilarity(model_name_or_path="
|
180 |
m2 = BM25Similarity()
|
181 |
-
|
|
|
|
|
182 |
self.sim_model = default_sim_model
|
183 |
self.gen_model, self.tokenizer = self._init_gen_model(
|
184 |
generate_model_type,
|
@@ -243,14 +237,14 @@ class Rag:
|
|
243 |
try:
|
244 |
model.generation_config = GenerationConfig.from_pretrained(gen_model_name_or_path, trust_remote_code=True)
|
245 |
except Exception as e:
|
246 |
-
logger.warning(f"
|
247 |
if peft_name:
|
248 |
model = PeftModel.from_pretrained(
|
249 |
model,
|
250 |
peft_name,
|
251 |
torch_dtype="auto",
|
252 |
)
|
253 |
-
logger.info(f"
|
254 |
model.eval()
|
255 |
return model, tokenizer
|
256 |
|
@@ -341,6 +335,7 @@ class Rag:
|
|
341 |
raw_text = [text.strip() for text in page_text.splitlines() if text.strip()]
|
342 |
new_text = ''
|
343 |
for text in raw_text:
|
|
|
344 |
if new_text:
|
345 |
new_text += ' '
|
346 |
new_text += text
|
@@ -397,25 +392,37 @@ class Rag:
|
|
397 |
return scores
|
398 |
|
399 |
def get_reference_results(self, query: str):
|
400 |
-
""
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
|
410 |
-
# Get reference results from corpus
|
411 |
-
hit_chunk_dict = dict()
|
412 |
-
for c in sim_contents:
|
413 |
-
for id_score_dict in c:
|
414 |
-
corpus_id = id_score_dict['corpus_id']
|
415 |
-
hit_chunk = id_score_dict["corpus_doc"]
|
416 |
-
reference_results.append(hit_chunk)
|
417 |
-
hit_chunk_dict[corpus_id] = hit_chunk
|
418 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
if reference_results:
|
420 |
if self.rerank_model is not None:
|
421 |
# Rerank reference results
|
@@ -440,9 +447,9 @@ class Rag:
|
|
440 |
def predict_stream(
|
441 |
self,
|
442 |
query: str,
|
443 |
-
max_length: int =
|
444 |
-
context_len: int =
|
445 |
-
temperature: float = 0.
|
446 |
):
|
447 |
"""Generate predictions stream."""
|
448 |
stop_str = self.tokenizer.eos_token if self.tokenizer.eos_token else "</s>"
|
@@ -450,15 +457,16 @@ class Rag:
|
|
450 |
self.history = []
|
451 |
if self.sim_model.corpus:
|
452 |
reference_results = self.get_reference_results(query)
|
453 |
-
if reference_results:
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
|
|
|
459 |
else:
|
460 |
prompt = query
|
461 |
-
|
462 |
self.history.append([prompt, ''])
|
463 |
response = ""
|
464 |
for new_text in self.stream_generate_answer(
|
@@ -473,9 +481,9 @@ class Rag:
|
|
473 |
def predict(
|
474 |
self,
|
475 |
query: str,
|
476 |
-
max_length: int =
|
477 |
-
context_len: int =
|
478 |
-
temperature: float = 0.
|
479 |
):
|
480 |
"""Query from corpus."""
|
481 |
reference_results = []
|
@@ -483,15 +491,20 @@ class Rag:
|
|
483 |
self.history = []
|
484 |
if self.sim_model.corpus:
|
485 |
reference_results = self.get_reference_results(query)
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
|
|
|
|
|
|
|
|
|
|
491 |
prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
|
|
|
492 |
else:
|
493 |
prompt = query
|
494 |
-
logger.debug(f"prompt: {prompt}")
|
495 |
self.history.append([prompt, ''])
|
496 |
response = ""
|
497 |
for new_text in self.stream_generate_answer(
|
@@ -504,8 +517,29 @@ class Rag:
|
|
504 |
self.history[-1][1] = response
|
505 |
return response, reference_results
|
506 |
|
507 |
-
def
|
508 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
509 |
|
510 |
def save_corpus_emb(self):
|
511 |
dir_name = self.get_file_hash(self.corpus_files)
|
@@ -517,13 +551,15 @@ class Rag:
|
|
517 |
|
518 |
def load_corpus_emb(self, emb_dir: str):
|
519 |
if hasattr(self.sim_model, 'load_corpus_embeddings'):
|
520 |
-
logger.debug(f"
|
521 |
self.sim_model.load_corpus_embeddings(emb_dir)
|
|
|
|
|
522 |
|
523 |
|
524 |
if __name__ == "__main__":
|
525 |
parser = argparse.ArgumentParser()
|
526 |
-
parser.add_argument("--sim_model_name", type=str, default="
|
527 |
parser.add_argument("--gen_model_type", type=str, default="auto")
|
528 |
parser.add_argument("--gen_model_name", type=str, default="Qwen/Qwen2-0.5B-Instruct")
|
529 |
parser.add_argument("--lora_model", type=str, default=None)
|
@@ -538,7 +574,7 @@ if __name__ == "__main__":
|
|
538 |
args = parser.parse_args()
|
539 |
print(args)
|
540 |
sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
|
541 |
-
m =
|
542 |
similarity_model=sim_model,
|
543 |
generate_model_type=args.gen_model_type,
|
544 |
generate_model_name_or_path=args.gen_model_name,
|
@@ -551,4 +587,30 @@ if __name__ == "__main__":
|
|
551 |
corpus_files=args.corpus_files.split(','),
|
552 |
num_expand_context_chunk=args.num_expand_context_chunk,
|
553 |
rerank_model_name_or_path=args.rerank_model_name,
|
554 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import argparse
|
2 |
import hashlib
|
3 |
import os
|
|
|
13 |
EnsembleSimilarity,
|
14 |
BertSimilarity,
|
15 |
BM25Similarity,
|
16 |
+
TfidfSimilarity
|
17 |
)
|
18 |
from similarities.similarity import SimilarityABC
|
19 |
from transformers import (
|
|
|
39 |
"auto": (AutoModelForCausalLM, AutoTokenizer),
|
40 |
}
|
41 |
|
42 |
+
PROMPT_TEMPLATE = """Basándose únicamente en la información proporcionada a continuación, responda a las preguntas del usuario de manera concisa y profesional.
|
43 |
+
No se debe responder a preguntas relacionadas con sentimientos, emociones, temas personales o cualquier información que no esté explícitamente presente en el contenido proporcionado.
|
44 |
+
Si la pregunta se refiere a un artículo específico y no se encuentra en el contenido proporcionado, diga: "No se puede encontrar el artículo solicitado en la información conocida".
|
|
|
45 |
|
46 |
Contenido conocido:
|
47 |
{context_str}
|
|
|
50 |
{query_str}
|
51 |
"""
|
52 |
|
|
|
|
|
53 |
class SentenceSplitter:
|
54 |
def __init__(self, chunk_size: int = 250, chunk_overlap: int = 50):
|
55 |
self.chunk_size = chunk_size
|
|
|
114 |
return overlapped_chunks
|
115 |
|
116 |
|
117 |
+
class ChatPDF:
|
|
|
118 |
def __init__(
|
119 |
self,
|
120 |
similarity_model: SimilarityABC = None,
|
|
|
131 |
rerank_model_name_or_path: str = None,
|
132 |
enable_history: bool = False,
|
133 |
num_expand_context_chunk: int = 2,
|
134 |
+
similarity_top_k: int = 15,
|
135 |
+
rerank_top_k: int = 5,
|
136 |
):
|
137 |
"""
|
138 |
Init RAG model.
|
|
|
168 |
if similarity_model is not None:
|
169 |
self.sim_model = similarity_model
|
170 |
else:
|
171 |
+
m1 = BertSimilarity(model_name_or_path="sentence-transformers/all-mpnet-base-v2", device=self.device)
|
172 |
m2 = BM25Similarity()
|
173 |
+
m3 = TfidfSimilarity()
|
174 |
+
default_sim_model = EnsembleSimilarity(similarities=[m1, m2, m3], weights=[0.5, 0.5, 0.5],
|
175 |
+
c=2) # Ajuste los pesos según los resultados
|
176 |
self.sim_model = default_sim_model
|
177 |
self.gen_model, self.tokenizer = self._init_gen_model(
|
178 |
generate_model_type,
|
|
|
237 |
try:
|
238 |
model.generation_config = GenerationConfig.from_pretrained(gen_model_name_or_path, trust_remote_code=True)
|
239 |
except Exception as e:
|
240 |
+
logger.warning(f"No se pudo cargar la configuración de generación desde {gen_model_name_or_path}, {e}")
|
241 |
if peft_name:
|
242 |
model = PeftModel.from_pretrained(
|
243 |
model,
|
244 |
peft_name,
|
245 |
torch_dtype="auto",
|
246 |
)
|
247 |
+
logger.info(f"Modelo peft cargado desde {peft_name}")
|
248 |
model.eval()
|
249 |
return model, tokenizer
|
250 |
|
|
|
335 |
raw_text = [text.strip() for text in page_text.splitlines() if text.strip()]
|
336 |
new_text = ''
|
337 |
for text in raw_text:
|
338 |
+
# Añadir un espacio antes de concatenar si new_text no está vacío
|
339 |
if new_text:
|
340 |
new_text += ' '
|
341 |
new_text += text
|
|
|
392 |
return scores
|
393 |
|
394 |
def get_reference_results(self, query: str):
|
395 |
+
# Verificar si la consulta incluye un "Artículo X"
|
396 |
+
exact_match = None
|
397 |
+
if re.search(r'Artículo\s*\d+', query, re.IGNORECASE):
|
398 |
+
# Buscar el término específico "Artículo X" en el corpus de manera más precisa
|
399 |
+
term = re.search(r'Artículo\s*\d+', query, re.IGNORECASE).group()
|
400 |
+
# Buscar coincidencias exactas en el corpus
|
401 |
+
for corpus_id, content in self.sim_model.corpus.items():
|
402 |
+
# Agregar espacio o signo de puntuación alrededor de "term" para evitar coincidencias parciales
|
403 |
+
if re.search(r'\b' + re.escape(term) + r'\b', content, re.IGNORECASE):
|
404 |
+
exact_match = content
|
405 |
+
break
|
406 |
+
|
407 |
+
if exact_match:
|
408 |
+
# Si se encuentra una coincidencia exacta, devolverla como contexto
|
409 |
+
return [exact_match]
|
410 |
+
|
411 |
+
# Si no se encuentra una coincidencia exacta, continuar con la búsqueda general
|
412 |
sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
413 |
|
414 |
+
# Procesar los resultados de similitud
|
415 |
+
reference_results = []
|
416 |
+
|
417 |
+
hit_chunk_dict = dict()
|
418 |
+
threshold_score = 0.5 # Establece un umbral para filtrar fragmentos irrelevantes
|
419 |
+
|
420 |
+
for query_id, id_score_dict in sim_contents.items():
|
421 |
+
for corpus_id, s in id_score_dict.items():
|
422 |
+
if s > threshold_score: # Filtrar por puntuación de similitud
|
423 |
+
hit_chunk = self.sim_model.corpus[corpus_id]
|
424 |
+
reference_results.append(hit_chunk)
|
425 |
+
hit_chunk_dict[corpus_id] = hit_chunk
|
426 |
if reference_results:
|
427 |
if self.rerank_model is not None:
|
428 |
# Rerank reference results
|
|
|
447 |
def predict_stream(
|
448 |
self,
|
449 |
query: str,
|
450 |
+
max_length: int = 256,
|
451 |
+
context_len: int = 1024,
|
452 |
+
temperature: float = 0.5,
|
453 |
):
|
454 |
"""Generate predictions stream."""
|
455 |
stop_str = self.tokenizer.eos_token if self.tokenizer.eos_token else "</s>"
|
|
|
457 |
self.history = []
|
458 |
if self.sim_model.corpus:
|
459 |
reference_results = self.get_reference_results(query)
|
460 |
+
if not reference_results:
|
461 |
+
yield 'No se ha proporcionado suficiente información relevante', reference_results
|
462 |
+
reference_results = self._add_source_numbers(reference_results)
|
463 |
+
context_str = '\n'.join(reference_results)[:]
|
464 |
+
print("gggggg: ", (context_len - len(PROMPT_TEMPLATE)))
|
465 |
prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
|
466 |
+
logger.debug(f"prompt: {prompt}")
|
467 |
else:
|
468 |
prompt = query
|
469 |
+
logger.debug(prompt)
|
470 |
self.history.append([prompt, ''])
|
471 |
response = ""
|
472 |
for new_text in self.stream_generate_answer(
|
|
|
481 |
def predict(
|
482 |
self,
|
483 |
query: str,
|
484 |
+
max_length: int = 256,
|
485 |
+
context_len: int = 1024,
|
486 |
+
temperature: float = 0.5,
|
487 |
):
|
488 |
"""Query from corpus."""
|
489 |
reference_results = []
|
|
|
491 |
self.history = []
|
492 |
if self.sim_model.corpus:
|
493 |
reference_results = self.get_reference_results(query)
|
494 |
+
|
495 |
+
if not reference_results:
|
496 |
+
return 'No se ha proporcionado suficiente información relevante', reference_results
|
497 |
+
reference_results = self._add_source_numbers(reference_results)
|
498 |
+
# context_str = '\n'.join(reference_results) # Usa todos los fragmentos
|
499 |
+
context_st = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
|
500 |
+
#print("Context: ", (context_len - len(PROMPT_TEMPLATE)))
|
501 |
+
print(".......................................................")
|
502 |
+
context_str = '\n'.join(reference_results)[:]
|
503 |
+
#print("context_str: ", context_str)
|
504 |
prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
|
505 |
+
logger.debug(f"prompt: {prompt}")
|
506 |
else:
|
507 |
prompt = query
|
|
|
508 |
self.history.append([prompt, ''])
|
509 |
response = ""
|
510 |
for new_text in self.stream_generate_answer(
|
|
|
517 |
self.history[-1][1] = response
|
518 |
return response, reference_results
|
519 |
|
520 |
+
def save_corpus_text(self):
|
521 |
+
if not self.corpus_files:
|
522 |
+
logger.warning("No hay archivos de corpus para guardar.")
|
523 |
+
return
|
524 |
+
|
525 |
+
corpus_text_file = os.path.join("corpus_embs/", "corpus_text.txt")
|
526 |
+
|
527 |
+
with open(corpus_text_file, 'w', encoding='utf-8') as f:
|
528 |
+
for chunk in self.sim_model.corpus.values():
|
529 |
+
f.write(chunk + "\n\n") # Añade dos saltos de línea entre chunks para mejor legibilidad
|
530 |
+
|
531 |
+
logger.info(f"Texto del corpus guardado en: {corpus_text_file}")
|
532 |
+
return corpus_text_file
|
533 |
+
|
534 |
+
def load_corpus_text(self, emb_dir: str):
|
535 |
+
corpus_text_file = os.path.join("corpus_embs/", "corpus_text.txt")
|
536 |
+
if os.path.exists(corpus_text_file):
|
537 |
+
with open(corpus_text_file, 'r', encoding='utf-8') as f:
|
538 |
+
corpus_text = f.read().split("\n\n") # Asumiendo que usamos dos saltos de línea como separador
|
539 |
+
self.sim_model.corpus = {i: chunk.strip() for i, chunk in enumerate(corpus_text) if chunk.strip()}
|
540 |
+
logger.info(f"Texto del corpus cargado desde: {corpus_text_file}")
|
541 |
+
else:
|
542 |
+
logger.warning(f"No se encontró el archivo de texto del corpus en: {corpus_text_file}")
|
543 |
|
544 |
def save_corpus_emb(self):
|
545 |
dir_name = self.get_file_hash(self.corpus_files)
|
|
|
551 |
|
552 |
def load_corpus_emb(self, emb_dir: str):
|
553 |
if hasattr(self.sim_model, 'load_corpus_embeddings'):
|
554 |
+
logger.debug(f"Cargando incrustaciones del corpus desde {emb_dir}")
|
555 |
self.sim_model.load_corpus_embeddings(emb_dir)
|
556 |
+
# Cargar el texto del corpus
|
557 |
+
self.load_corpus_text(emb_dir)
|
558 |
|
559 |
|
560 |
if __name__ == "__main__":
|
561 |
parser = argparse.ArgumentParser()
|
562 |
+
parser.add_argument("--sim_model_name", type=str, default="sentence-transformers/all-mpnet-base-v2")
|
563 |
parser.add_argument("--gen_model_type", type=str, default="auto")
|
564 |
parser.add_argument("--gen_model_name", type=str, default="Qwen/Qwen2-0.5B-Instruct")
|
565 |
parser.add_argument("--lora_model", type=str, default=None)
|
|
|
574 |
args = parser.parse_args()
|
575 |
print(args)
|
576 |
sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
|
577 |
+
m = ChatPDF(
|
578 |
similarity_model=sim_model,
|
579 |
generate_model_type=args.gen_model_type,
|
580 |
generate_model_name_or_path=args.gen_model_name,
|
|
|
587 |
corpus_files=args.corpus_files.split(','),
|
588 |
num_expand_context_chunk=args.num_expand_context_chunk,
|
589 |
rerank_model_name_or_path=args.rerank_model_name,
|
590 |
+
)
|
591 |
+
|
592 |
+
# Comprobar si existen incrustaciones guardadas
|
593 |
+
dir_name = m.get_file_hash(args.corpus_files.split(','))
|
594 |
+
save_dir = os.path.join(m.save_corpus_emb_dir, dir_name)
|
595 |
+
|
596 |
+
if os.path.exists(save_dir):
|
597 |
+
# Cargar las incrustaciones guardadas
|
598 |
+
m.load_corpus_emb(save_dir)
|
599 |
+
print(f"Incrustaciones del corpus cargadas desde: {save_dir}")
|
600 |
+
else:
|
601 |
+
# Procesar el corpus y guardar las incrustaciones
|
602 |
+
m.add_corpus(args.corpus_files.split(','))
|
603 |
+
save_dir = m.save_corpus_emb()
|
604 |
+
# Guardar el texto del corpus
|
605 |
+
m.save_corpus_text()
|
606 |
+
print(f"Las incrustaciones del corpus se han guardado en: {save_dir}")
|
607 |
+
|
608 |
+
while True:
|
609 |
+
query = input("\nEnter a query: ")
|
610 |
+
if query == "exit":
|
611 |
+
break
|
612 |
+
if query.strip() == "":
|
613 |
+
continue
|
614 |
+
r, refs = m.predict(query)
|
615 |
+
print(r, refs)
|
616 |
+
print("\nRespuesta: ", r)
|