Spaces:
Sleeping
Sleeping
Update chatpdf.py
Browse files- chatpdf.py +36 -105
chatpdf.py
CHANGED
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import argparse
|
2 |
import hashlib
|
3 |
import os
|
@@ -38,24 +43,6 @@ MODEL_CLASSES = {
|
|
38 |
"auto": (AutoModelForCausalLM, AutoTokenizer),
|
39 |
}
|
40 |
|
41 |
-
PROMPT_TEMPLATE1 = """基于以下已知信息,简洁和专业的来回答用户的问题。
|
42 |
-
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
|
43 |
-
|
44 |
-
已知内容:
|
45 |
-
{context_str}
|
46 |
-
|
47 |
-
问题:
|
48 |
-
{query_str}
|
49 |
-
"""
|
50 |
-
PROMPT_TEMPLATE1 = """Utiliza la siguiente información para responder a la pregunta del usuario.
|
51 |
-
Si no sabes la respuesta, di simplemente que no la sabes, no intentes inventarte una respuesta.
|
52 |
-
|
53 |
-
Contexto: {context_str}
|
54 |
-
Pregunta: {query_str}
|
55 |
-
|
56 |
-
Devuelve sólo la respuesta útil que aparece a continuación y nada más, y ésta debe estar en Español.
|
57 |
-
Respuesta útil:
|
58 |
-
"""
|
59 |
PROMPT_TEMPLATE = """Basándose en la siguiente información conocida, responda a la pregunta del usuario de forma
|
60 |
concisa y profesional. Si no puede obtener una respuesta, diga «No se puede responder a la pregunta basándose en la
|
61 |
información conocida» o «No se proporciona suficiente información relevante», no está permitido añadir elementos
|
@@ -69,6 +56,7 @@ Pregunta:
|
|
69 |
"""
|
70 |
|
71 |
|
|
|
72 |
class SentenceSplitter:
|
73 |
def __init__(self, chunk_size: int = 250, chunk_overlap: int = 50):
|
74 |
self.chunk_size = chunk_size
|
@@ -134,7 +122,7 @@ class SentenceSplitter:
|
|
134 |
|
135 |
|
136 |
|
137 |
-
class
|
138 |
def __init__(
|
139 |
self,
|
140 |
similarity_model: SimilarityABC = None,
|
@@ -151,8 +139,8 @@ class ChatPDF:
|
|
151 |
rerank_model_name_or_path: str = None,
|
152 |
enable_history: bool = False,
|
153 |
num_expand_context_chunk: int = 2,
|
154 |
-
similarity_top_k: int =
|
155 |
-
rerank_top_k: int =3,
|
156 |
):
|
157 |
"""
|
158 |
Init RAG model.
|
@@ -188,7 +176,7 @@ class ChatPDF:
|
|
188 |
if similarity_model is not None:
|
189 |
self.sim_model = similarity_model
|
190 |
else:
|
191 |
-
m1 = BertSimilarity(model_name_or_path="
|
192 |
m2 = BM25Similarity()
|
193 |
default_sim_model = EnsembleSimilarity(similarities=[m1, m2], weights=[0.5, 0.5], c=2)
|
194 |
self.sim_model = default_sim_model
|
@@ -205,7 +193,7 @@ class ChatPDF:
|
|
205 |
self.add_corpus(corpus_files)
|
206 |
self.save_corpus_emb_dir = save_corpus_emb_dir
|
207 |
if rerank_model_name_or_path is None:
|
208 |
-
rerank_model_name_or_path = "BAAI/bge-reranker-
|
209 |
if rerank_model_name_or_path:
|
210 |
self.rerank_tokenizer = AutoTokenizer.from_pretrained(rerank_model_name_or_path)
|
211 |
self.rerank_model = AutoModelForSequenceClassification.from_pretrained(rerank_model_name_or_path)
|
@@ -255,14 +243,14 @@ class ChatPDF:
|
|
255 |
try:
|
256 |
model.generation_config = GenerationConfig.from_pretrained(gen_model_name_or_path, trust_remote_code=True)
|
257 |
except Exception as e:
|
258 |
-
logger.warning(f"
|
259 |
if peft_name:
|
260 |
model = PeftModel.from_pretrained(
|
261 |
model,
|
262 |
peft_name,
|
263 |
torch_dtype="auto",
|
264 |
)
|
265 |
-
logger.info(f"
|
266 |
model.eval()
|
267 |
return model, tokenizer
|
268 |
|
@@ -353,9 +341,6 @@ class ChatPDF:
|
|
353 |
raw_text = [text.strip() for text in page_text.splitlines() if text.strip()]
|
354 |
new_text = ''
|
355 |
for text in raw_text:
|
356 |
-
# Añadir un espacio antes de concatenar si new_text no está vacío
|
357 |
-
if new_text:
|
358 |
-
new_text += ' '
|
359 |
new_text += text
|
360 |
if text[-1] in ['.', '!', '?', '。', '!', '?', '…', ';', ';', ':', ':', '”', '’', ')', '】', '》', '」',
|
361 |
'』', '〕', '〉', '》', '〗', '〞', '〟', '»', '"', "'", ')', ']', '}']:
|
@@ -422,9 +407,10 @@ class ChatPDF:
|
|
422 |
sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
|
423 |
# Get reference results from corpus
|
424 |
hit_chunk_dict = dict()
|
425 |
-
for
|
426 |
-
for
|
427 |
-
|
|
|
428 |
reference_results.append(hit_chunk)
|
429 |
hit_chunk_dict[corpus_id] = hit_chunk
|
430 |
|
@@ -462,16 +448,15 @@ class ChatPDF:
|
|
462 |
self.history = []
|
463 |
if self.sim_model.corpus:
|
464 |
reference_results = self.get_reference_results(query)
|
465 |
-
if
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
|
471 |
-
logger.debug(f"prompt: {prompt}")
|
472 |
else:
|
473 |
prompt = query
|
474 |
-
|
475 |
self.history.append([prompt, ''])
|
476 |
response = ""
|
477 |
for new_text in self.stream_generate_answer(
|
@@ -496,20 +481,15 @@ class ChatPDF:
|
|
496 |
self.history = []
|
497 |
if self.sim_model.corpus:
|
498 |
reference_results = self.get_reference_results(query)
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
context_st = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
|
505 |
-
print("Context: ", (context_len - len(PROMPT_TEMPLATE)))
|
506 |
-
print(".......................................................")
|
507 |
-
context_str = '\n'.join(reference_results)[:]
|
508 |
-
print("context_str: ", context_str)
|
509 |
prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
|
510 |
-
logger.debug(f"prompt: {prompt}")
|
511 |
else:
|
512 |
prompt = query
|
|
|
513 |
self.history.append([prompt, ''])
|
514 |
response = ""
|
515 |
for new_text in self.stream_generate_answer(
|
@@ -522,29 +502,8 @@ class ChatPDF:
|
|
522 |
self.history[-1][1] = response
|
523 |
return response, reference_results
|
524 |
|
525 |
-
def
|
526 |
-
|
527 |
-
logger.warning("No hay archivos de corpus para guardar.")
|
528 |
-
return
|
529 |
-
|
530 |
-
corpus_text_file = os.path.join("corpus_embs/", "corpus_text.txt")
|
531 |
-
|
532 |
-
with open(corpus_text_file, 'w', encoding='utf-8') as f:
|
533 |
-
for chunk in self.sim_model.corpus.values():
|
534 |
-
f.write(chunk + "\n\n") # Añade dos saltos de línea entre chunks para mejor legibilidad
|
535 |
-
|
536 |
-
logger.info(f"Texto del corpus guardado en: {corpus_text_file}")
|
537 |
-
return corpus_text_file
|
538 |
-
|
539 |
-
def load_corpus_text(self, emb_dir: str):
|
540 |
-
corpus_text_file = os.path.join("corpus_embs/", "corpus_text.txt")
|
541 |
-
if os.path.exists(corpus_text_file):
|
542 |
-
with open(corpus_text_file, 'r', encoding='utf-8') as f:
|
543 |
-
corpus_text = f.read().split("\n\n") # Asumiendo que usamos dos saltos de línea como separador
|
544 |
-
self.sim_model.corpus = {i: chunk.strip() for i, chunk in enumerate(corpus_text) if chunk.strip()}
|
545 |
-
logger.info(f"Texto del corpus cargado desde: {corpus_text_file}")
|
546 |
-
else:
|
547 |
-
logger.warning(f"No se encontró el archivo de texto del corpus en: {corpus_text_file}")
|
548 |
|
549 |
def save_corpus_emb(self):
|
550 |
dir_name = self.get_file_hash(self.corpus_files)
|
@@ -556,20 +515,18 @@ class ChatPDF:
|
|
556 |
|
557 |
def load_corpus_emb(self, emb_dir: str):
|
558 |
if hasattr(self.sim_model, 'load_corpus_embeddings'):
|
559 |
-
logger.debug(f"
|
560 |
self.sim_model.load_corpus_embeddings(emb_dir)
|
561 |
-
# Cargar el texto del corpus
|
562 |
-
self.load_corpus_text(emb_dir)
|
563 |
|
564 |
|
565 |
if __name__ == "__main__":
|
566 |
parser = argparse.ArgumentParser()
|
567 |
-
parser.add_argument("--sim_model_name", type=str, default="
|
568 |
parser.add_argument("--gen_model_type", type=str, default="auto")
|
569 |
parser.add_argument("--gen_model_name", type=str, default="LenguajeNaturalAI/leniachat-qwen2-1.5B-v0")
|
570 |
parser.add_argument("--lora_model", type=str, default=None)
|
571 |
parser.add_argument("--rerank_model_name", type=str, default="")
|
572 |
-
parser.add_argument("--corpus_files", type=str, default="
|
573 |
parser.add_argument("--device", type=str, default=None)
|
574 |
parser.add_argument("--int4", action='store_true', help="use int4 quantization")
|
575 |
parser.add_argument("--int8", action='store_true', help="use int8 quantization")
|
@@ -579,7 +536,7 @@ if __name__ == "__main__":
|
|
579 |
args = parser.parse_args()
|
580 |
print(args)
|
581 |
sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
|
582 |
-
m =
|
583 |
similarity_model=sim_model,
|
584 |
generate_model_type=args.gen_model_type,
|
585 |
generate_model_name_or_path=args.gen_model_name,
|
@@ -592,30 +549,4 @@ if __name__ == "__main__":
|
|
592 |
corpus_files=args.corpus_files.split(','),
|
593 |
num_expand_context_chunk=args.num_expand_context_chunk,
|
594 |
rerank_model_name_or_path=args.rerank_model_name,
|
595 |
-
)
|
596 |
-
|
597 |
-
# Comprobar si existen incrustaciones guardadas
|
598 |
-
dir_name = m.get_file_hash(args.corpus_files.split(','))
|
599 |
-
save_dir = os.path.join(m.save_corpus_emb_dir, dir_name)
|
600 |
-
|
601 |
-
if os.path.exists(save_dir):
|
602 |
-
# Cargar las incrustaciones guardadas
|
603 |
-
m.load_corpus_emb(save_dir)
|
604 |
-
print(f"Incrustaciones del corpus cargadas desde: {save_dir}")
|
605 |
-
else:
|
606 |
-
# Procesar el corpus y guardar las incrustaciones
|
607 |
-
m.add_corpus(args.corpus_files.split(','))
|
608 |
-
save_dir = m.save_corpus_emb()
|
609 |
-
# Guardar el texto del corpus
|
610 |
-
m.save_corpus_text()
|
611 |
-
print(f"Las incrustaciones del corpus se han guardado en: {save_dir}")
|
612 |
-
|
613 |
-
while True:
|
614 |
-
query = input("\nEnter a query: ")
|
615 |
-
if query == "exit":
|
616 |
-
break
|
617 |
-
if query.strip() == "":
|
618 |
-
continue
|
619 |
-
r, refs = m.predict(query)
|
620 |
-
print(r, refs)
|
621 |
-
print("\nRespuesta: ", r)
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author:XuMing([email protected])
|
4 |
+
@description:
|
5 |
+
"""
|
6 |
import argparse
|
7 |
import hashlib
|
8 |
import os
|
|
|
43 |
"auto": (AutoModelForCausalLM, AutoTokenizer),
|
44 |
}
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
PROMPT_TEMPLATE = """Basándose en la siguiente información conocida, responda a la pregunta del usuario de forma
|
47 |
concisa y profesional. Si no puede obtener una respuesta, diga «No se puede responder a la pregunta basándose en la
|
48 |
información conocida» o «No se proporciona suficiente información relevante», no está permitido añadir elementos
|
|
|
56 |
"""
|
57 |
|
58 |
|
59 |
+
|
60 |
class SentenceSplitter:
|
61 |
def __init__(self, chunk_size: int = 250, chunk_overlap: int = 50):
|
62 |
self.chunk_size = chunk_size
|
|
|
122 |
|
123 |
|
124 |
|
125 |
+
class Rag:
|
126 |
def __init__(
|
127 |
self,
|
128 |
similarity_model: SimilarityABC = None,
|
|
|
139 |
rerank_model_name_or_path: str = None,
|
140 |
enable_history: bool = False,
|
141 |
num_expand_context_chunk: int = 2,
|
142 |
+
similarity_top_k: int = 10,
|
143 |
+
rerank_top_k: int = 3,
|
144 |
):
|
145 |
"""
|
146 |
Init RAG model.
|
|
|
176 |
if similarity_model is not None:
|
177 |
self.sim_model = similarity_model
|
178 |
else:
|
179 |
+
m1 = BertSimilarity(model_name_or_path="shibing624/text2vec-base-multilingual", device=self.device)
|
180 |
m2 = BM25Similarity()
|
181 |
default_sim_model = EnsembleSimilarity(similarities=[m1, m2], weights=[0.5, 0.5], c=2)
|
182 |
self.sim_model = default_sim_model
|
|
|
193 |
self.add_corpus(corpus_files)
|
194 |
self.save_corpus_emb_dir = save_corpus_emb_dir
|
195 |
if rerank_model_name_or_path is None:
|
196 |
+
rerank_model_name_or_path = "BAAI/bge-reranker-base"
|
197 |
if rerank_model_name_or_path:
|
198 |
self.rerank_tokenizer = AutoTokenizer.from_pretrained(rerank_model_name_or_path)
|
199 |
self.rerank_model = AutoModelForSequenceClassification.from_pretrained(rerank_model_name_or_path)
|
|
|
243 |
try:
|
244 |
model.generation_config = GenerationConfig.from_pretrained(gen_model_name_or_path, trust_remote_code=True)
|
245 |
except Exception as e:
|
246 |
+
logger.warning(f"Failed to load generation config from {gen_model_name_or_path}, {e}")
|
247 |
if peft_name:
|
248 |
model = PeftModel.from_pretrained(
|
249 |
model,
|
250 |
peft_name,
|
251 |
torch_dtype="auto",
|
252 |
)
|
253 |
+
logger.info(f"Loaded peft model from {peft_name}")
|
254 |
model.eval()
|
255 |
return model, tokenizer
|
256 |
|
|
|
341 |
raw_text = [text.strip() for text in page_text.splitlines() if text.strip()]
|
342 |
new_text = ''
|
343 |
for text in raw_text:
|
|
|
|
|
|
|
344 |
new_text += text
|
345 |
if text[-1] in ['.', '!', '?', '。', '!', '?', '…', ';', ';', ':', ':', '”', '’', ')', '】', '》', '」',
|
346 |
'』', '〕', '〉', '》', '〗', '〞', '〟', '»', '"', "'", ')', ']', '}']:
|
|
|
407 |
sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
|
408 |
# Get reference results from corpus
|
409 |
hit_chunk_dict = dict()
|
410 |
+
for c in sim_contents:
|
411 |
+
for id_score_dict in c:
|
412 |
+
corpus_id = id_score_dict['corpus_id']
|
413 |
+
hit_chunk = id_score_dict["corpus_doc"]
|
414 |
reference_results.append(hit_chunk)
|
415 |
hit_chunk_dict[corpus_id] = hit_chunk
|
416 |
|
|
|
448 |
self.history = []
|
449 |
if self.sim_model.corpus:
|
450 |
reference_results = self.get_reference_results(query)
|
451 |
+
if reference_results:
|
452 |
+
reference_results = self._add_source_numbers(reference_results)
|
453 |
+
context_str = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
|
454 |
+
else:
|
455 |
+
context_str = ''
|
456 |
prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
|
|
|
457 |
else:
|
458 |
prompt = query
|
459 |
+
logger.debug(f"prompt: {prompt}")
|
460 |
self.history.append([prompt, ''])
|
461 |
response = ""
|
462 |
for new_text in self.stream_generate_answer(
|
|
|
481 |
self.history = []
|
482 |
if self.sim_model.corpus:
|
483 |
reference_results = self.get_reference_results(query)
|
484 |
+
if reference_results:
|
485 |
+
reference_results = self._add_source_numbers(reference_results)
|
486 |
+
context_str = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
|
487 |
+
else:
|
488 |
+
context_str = ''
|
|
|
|
|
|
|
|
|
|
|
489 |
prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
|
|
|
490 |
else:
|
491 |
prompt = query
|
492 |
+
logger.debug(f"prompt: {prompt}")
|
493 |
self.history.append([prompt, ''])
|
494 |
response = ""
|
495 |
for new_text in self.stream_generate_answer(
|
|
|
502 |
self.history[-1][1] = response
|
503 |
return response, reference_results
|
504 |
|
505 |
+
def query(self, query: str, **kwargs):
|
506 |
+
return self.predict(query, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
507 |
|
508 |
def save_corpus_emb(self):
|
509 |
dir_name = self.get_file_hash(self.corpus_files)
|
|
|
515 |
|
516 |
def load_corpus_emb(self, emb_dir: str):
|
517 |
if hasattr(self.sim_model, 'load_corpus_embeddings'):
|
518 |
+
logger.debug(f"Loading corpus embeddings from {emb_dir}")
|
519 |
self.sim_model.load_corpus_embeddings(emb_dir)
|
|
|
|
|
520 |
|
521 |
|
522 |
if __name__ == "__main__":
|
523 |
parser = argparse.ArgumentParser()
|
524 |
+
parser.add_argument("--sim_model_name", type=str, default="shibing624/text2vec-base-multilingual")
|
525 |
parser.add_argument("--gen_model_type", type=str, default="auto")
|
526 |
parser.add_argument("--gen_model_name", type=str, default="LenguajeNaturalAI/leniachat-qwen2-1.5B-v0")
|
527 |
parser.add_argument("--lora_model", type=str, default=None)
|
528 |
parser.add_argument("--rerank_model_name", type=str, default="")
|
529 |
+
parser.add_argument("--corpus_files", type=str, default="data/sample.pdf")
|
530 |
parser.add_argument("--device", type=str, default=None)
|
531 |
parser.add_argument("--int4", action='store_true', help="use int4 quantization")
|
532 |
parser.add_argument("--int8", action='store_true', help="use int8 quantization")
|
|
|
536 |
args = parser.parse_args()
|
537 |
print(args)
|
538 |
sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
|
539 |
+
m = Rag(
|
540 |
similarity_model=sim_model,
|
541 |
generate_model_type=args.gen_model_type,
|
542 |
generate_model_name_or_path=args.gen_model_name,
|
|
|
549 |
corpus_files=args.corpus_files.split(','),
|
550 |
num_expand_context_chunk=args.num_expand_context_chunk,
|
551 |
rerank_model_name_or_path=args.rerank_model_name,
|
552 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|