Spaces:
Sleeping
Sleeping
Update chatpdf.py
Browse files- chatpdf.py +55 -101
chatpdf.py
CHANGED
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import argparse
|
2 |
import hashlib
|
3 |
import os
|
@@ -13,7 +18,6 @@ from similarities import (
|
|
13 |
EnsembleSimilarity,
|
14 |
BertSimilarity,
|
15 |
BM25Similarity,
|
16 |
-
TfidfSimilarity
|
17 |
)
|
18 |
from similarities.similarity import SimilarityABC
|
19 |
from transformers import (
|
@@ -50,6 +54,7 @@ Pregunta:
|
|
50 |
{query_str}
|
51 |
"""
|
52 |
|
|
|
53 |
class SentenceSplitter:
|
54 |
def __init__(self, chunk_size: int = 250, chunk_overlap: int = 50):
|
55 |
self.chunk_size = chunk_size
|
@@ -62,7 +67,7 @@ class SentenceSplitter:
|
|
62 |
return self._split_english_text(text)
|
63 |
|
64 |
def _split_chinese_text(self, text: str) -> List[str]:
|
65 |
-
sentence_endings = {'\n', '。', '!', '?', ';', '…'} #
|
66 |
chunks, current_chunk = [], ''
|
67 |
for word in jieba.cut(text):
|
68 |
if len(current_chunk) + len(word) > self.chunk_size:
|
@@ -80,16 +85,22 @@ class SentenceSplitter:
|
|
80 |
return chunks
|
81 |
|
82 |
def _split_english_text(self, text: str) -> List[str]:
|
83 |
-
#
|
84 |
sentences = re.split(r'(?<=[.!?])\s+', text.replace('\n', ' '))
|
85 |
-
chunks
|
|
|
86 |
for sentence in sentences:
|
87 |
-
if len(current_chunk) + len(sentence) <= self.chunk_size
|
88 |
current_chunk += (' ' if current_chunk else '') + sentence
|
89 |
else:
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
93 |
chunks.append(current_chunk)
|
94 |
|
95 |
if self.chunk_overlap > 0 and len(chunks) > 1:
|
@@ -98,7 +109,7 @@ class SentenceSplitter:
|
|
98 |
return chunks
|
99 |
|
100 |
def _is_has_chinese(self, text: str) -> bool:
|
101 |
-
#
|
102 |
if any("\u4e00" <= ch <= "\u9fff" for ch in text):
|
103 |
return True
|
104 |
else:
|
@@ -114,7 +125,7 @@ class SentenceSplitter:
|
|
114 |
return overlapped_chunks
|
115 |
|
116 |
|
117 |
-
class
|
118 |
def __init__(
|
119 |
self,
|
120 |
similarity_model: SimilarityABC = None,
|
@@ -122,7 +133,7 @@ class ChatPDF:
|
|
122 |
generate_model_name_or_path: str = "Qwen/Qwen2-0.5B-Instruct",
|
123 |
lora_model_name_or_path: str = None,
|
124 |
corpus_files: Union[str, List[str]] = None,
|
125 |
-
save_corpus_emb_dir: str = "corpus_embs/",
|
126 |
device: str = None,
|
127 |
int8: bool = False,
|
128 |
int4: bool = False,
|
@@ -131,8 +142,8 @@ class ChatPDF:
|
|
131 |
rerank_model_name_or_path: str = None,
|
132 |
enable_history: bool = False,
|
133 |
num_expand_context_chunk: int = 2,
|
134 |
-
similarity_top_k: int =
|
135 |
-
rerank_top_k: int =
|
136 |
):
|
137 |
"""
|
138 |
Init RAG model.
|
@@ -171,8 +182,7 @@ class ChatPDF:
|
|
171 |
m1 = BertSimilarity(model_name_or_path="sentence-transformers/all-mpnet-base-v2", device=self.device)
|
172 |
m2 = BM25Similarity()
|
173 |
m3 = TfidfSimilarity()
|
174 |
-
default_sim_model = EnsembleSimilarity(similarities=[m1, m2, m3], weights=[0.5, 0.5, 0.5],
|
175 |
-
c=2) # Ajuste los pesos según los resultados
|
176 |
self.sim_model = default_sim_model
|
177 |
self.gen_model, self.tokenizer = self._init_gen_model(
|
178 |
generate_model_type,
|
@@ -237,14 +247,14 @@ class ChatPDF:
|
|
237 |
try:
|
238 |
model.generation_config = GenerationConfig.from_pretrained(gen_model_name_or_path, trust_remote_code=True)
|
239 |
except Exception as e:
|
240 |
-
logger.warning(f"
|
241 |
if peft_name:
|
242 |
model = PeftModel.from_pretrained(
|
243 |
model,
|
244 |
peft_name,
|
245 |
torch_dtype="auto",
|
246 |
)
|
247 |
-
logger.info(f"
|
248 |
model.eval()
|
249 |
return model, tokenizer
|
250 |
|
@@ -335,7 +345,6 @@ class ChatPDF:
|
|
335 |
raw_text = [text.strip() for text in page_text.splitlines() if text.strip()]
|
336 |
new_text = ''
|
337 |
for text in raw_text:
|
338 |
-
# Añadir un espacio antes de concatenar si new_text no está vacío
|
339 |
if new_text:
|
340 |
new_text += ' '
|
341 |
new_text += text
|
@@ -408,12 +417,9 @@ class ChatPDF:
|
|
408 |
# Si se encuentra una coincidencia exacta, devolverla como contexto
|
409 |
return [exact_match]
|
410 |
|
411 |
-
# Si no se encuentra una coincidencia exacta, continuar con la búsqueda general
|
412 |
-
sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
|
413 |
-
|
414 |
-
# Procesar los resultados de similitud
|
415 |
reference_results = []
|
416 |
-
|
|
|
417 |
hit_chunk_dict = dict()
|
418 |
threshold_score = 0.5 # Establece un umbral para filtrar fragmentos irrelevantes
|
419 |
|
@@ -423,6 +429,7 @@ class ChatPDF:
|
|
423 |
hit_chunk = self.sim_model.corpus[corpus_id]
|
424 |
reference_results.append(hit_chunk)
|
425 |
hit_chunk_dict[corpus_id] = hit_chunk
|
|
|
426 |
if reference_results:
|
427 |
if self.rerank_model is not None:
|
428 |
# Rerank reference results
|
@@ -447,9 +454,9 @@ class ChatPDF:
|
|
447 |
def predict_stream(
|
448 |
self,
|
449 |
query: str,
|
450 |
-
max_length: int =
|
451 |
-
context_len: int =
|
452 |
-
temperature: float = 0.
|
453 |
):
|
454 |
"""Generate predictions stream."""
|
455 |
stop_str = self.tokenizer.eos_token if self.tokenizer.eos_token else "</s>"
|
@@ -457,16 +464,15 @@ class ChatPDF:
|
|
457 |
self.history = []
|
458 |
if self.sim_model.corpus:
|
459 |
reference_results = self.get_reference_results(query)
|
460 |
-
if
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
|
466 |
-
logger.debug(f"prompt: {prompt}")
|
467 |
else:
|
468 |
prompt = query
|
469 |
-
|
470 |
self.history.append([prompt, ''])
|
471 |
response = ""
|
472 |
for new_text in self.stream_generate_answer(
|
@@ -481,9 +487,9 @@ class ChatPDF:
|
|
481 |
def predict(
|
482 |
self,
|
483 |
query: str,
|
484 |
-
max_length: int =
|
485 |
-
context_len: int =
|
486 |
-
temperature: float = 0.
|
487 |
):
|
488 |
"""Query from corpus."""
|
489 |
reference_results = []
|
@@ -491,20 +497,15 @@ class ChatPDF:
|
|
491 |
self.history = []
|
492 |
if self.sim_model.corpus:
|
493 |
reference_results = self.get_reference_results(query)
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
context_st = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
|
500 |
-
#print("Context: ", (context_len - len(PROMPT_TEMPLATE)))
|
501 |
-
print(".......................................................")
|
502 |
-
context_str = '\n'.join(reference_results)[:]
|
503 |
-
#print("context_str: ", context_str)
|
504 |
prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
|
505 |
-
logger.debug(f"prompt: {prompt}")
|
506 |
else:
|
507 |
prompt = query
|
|
|
508 |
self.history.append([prompt, ''])
|
509 |
response = ""
|
510 |
for new_text in self.stream_generate_answer(
|
@@ -517,29 +518,8 @@ class ChatPDF:
|
|
517 |
self.history[-1][1] = response
|
518 |
return response, reference_results
|
519 |
|
520 |
-
def
|
521 |
-
|
522 |
-
logger.warning("No hay archivos de corpus para guardar.")
|
523 |
-
return
|
524 |
-
|
525 |
-
corpus_text_file = os.path.join("corpus_embs/", "corpus_text.txt")
|
526 |
-
|
527 |
-
with open(corpus_text_file, 'w', encoding='utf-8') as f:
|
528 |
-
for chunk in self.sim_model.corpus.values():
|
529 |
-
f.write(chunk + "\n\n") # Añade dos saltos de línea entre chunks para mejor legibilidad
|
530 |
-
|
531 |
-
logger.info(f"Texto del corpus guardado en: {corpus_text_file}")
|
532 |
-
return corpus_text_file
|
533 |
-
|
534 |
-
def load_corpus_text(self, emb_dir: str):
|
535 |
-
corpus_text_file = os.path.join("corpus_embs/", "corpus_text.txt")
|
536 |
-
if os.path.exists(corpus_text_file):
|
537 |
-
with open(corpus_text_file, 'r', encoding='utf-8') as f:
|
538 |
-
corpus_text = f.read().split("\n\n") # Asumiendo que usamos dos saltos de línea como separador
|
539 |
-
self.sim_model.corpus = {i: chunk.strip() for i, chunk in enumerate(corpus_text) if chunk.strip()}
|
540 |
-
logger.info(f"Texto del corpus cargado desde: {corpus_text_file}")
|
541 |
-
else:
|
542 |
-
logger.warning(f"No se encontró el archivo de texto del corpus en: {corpus_text_file}")
|
543 |
|
544 |
def save_corpus_emb(self):
|
545 |
dir_name = self.get_file_hash(self.corpus_files)
|
@@ -551,10 +531,8 @@ class ChatPDF:
|
|
551 |
|
552 |
def load_corpus_emb(self, emb_dir: str):
|
553 |
if hasattr(self.sim_model, 'load_corpus_embeddings'):
|
554 |
-
logger.debug(f"
|
555 |
self.sim_model.load_corpus_embeddings(emb_dir)
|
556 |
-
# Cargar el texto del corpus
|
557 |
-
self.load_corpus_text(emb_dir)
|
558 |
|
559 |
|
560 |
if __name__ == "__main__":
|
@@ -564,7 +542,7 @@ if __name__ == "__main__":
|
|
564 |
parser.add_argument("--gen_model_name", type=str, default="Qwen/Qwen2-0.5B-Instruct")
|
565 |
parser.add_argument("--lora_model", type=str, default=None)
|
566 |
parser.add_argument("--rerank_model_name", type=str, default="")
|
567 |
-
parser.add_argument("--corpus_files", type=str, default="
|
568 |
parser.add_argument("--device", type=str, default=None)
|
569 |
parser.add_argument("--int4", action='store_true', help="use int4 quantization")
|
570 |
parser.add_argument("--int8", action='store_true', help="use int8 quantization")
|
@@ -574,7 +552,7 @@ if __name__ == "__main__":
|
|
574 |
args = parser.parse_args()
|
575 |
print(args)
|
576 |
sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
|
577 |
-
m =
|
578 |
similarity_model=sim_model,
|
579 |
generate_model_type=args.gen_model_type,
|
580 |
generate_model_name_or_path=args.gen_model_name,
|
@@ -588,29 +566,5 @@ if __name__ == "__main__":
|
|
588 |
num_expand_context_chunk=args.num_expand_context_chunk,
|
589 |
rerank_model_name_or_path=args.rerank_model_name,
|
590 |
)
|
591 |
-
|
592 |
-
|
593 |
-
dir_name = m.get_file_hash(args.corpus_files.split(','))
|
594 |
-
save_dir = os.path.join(m.save_corpus_emb_dir, dir_name)
|
595 |
-
|
596 |
-
if os.path.exists(save_dir):
|
597 |
-
# Cargar las incrustaciones guardadas
|
598 |
-
m.load_corpus_emb(save_dir)
|
599 |
-
print(f"Incrustaciones del corpus cargadas desde: {save_dir}")
|
600 |
-
else:
|
601 |
-
# Procesar el corpus y guardar las incrustaciones
|
602 |
-
m.add_corpus(args.corpus_files.split(','))
|
603 |
-
save_dir = m.save_corpus_emb()
|
604 |
-
# Guardar el texto del corpus
|
605 |
-
m.save_corpus_text()
|
606 |
-
print(f"Las incrustaciones del corpus se han guardado en: {save_dir}")
|
607 |
-
|
608 |
-
while True:
|
609 |
-
query = input("\nEnter a query: ")
|
610 |
-
if query == "exit":
|
611 |
-
break
|
612 |
-
if query.strip() == "":
|
613 |
-
continue
|
614 |
-
r, refs = m.predict(query)
|
615 |
-
print(r, refs)
|
616 |
-
print("\nRespuesta: ", r)
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author:XuMing([email protected])
|
4 |
+
@description:
|
5 |
+
"""
|
6 |
import argparse
|
7 |
import hashlib
|
8 |
import os
|
|
|
18 |
EnsembleSimilarity,
|
19 |
BertSimilarity,
|
20 |
BM25Similarity,
|
|
|
21 |
)
|
22 |
from similarities.similarity import SimilarityABC
|
23 |
from transformers import (
|
|
|
54 |
{query_str}
|
55 |
"""
|
56 |
|
57 |
+
|
58 |
class SentenceSplitter:
|
59 |
def __init__(self, chunk_size: int = 250, chunk_overlap: int = 50):
|
60 |
self.chunk_size = chunk_size
|
|
|
67 |
return self._split_english_text(text)
|
68 |
|
69 |
def _split_chinese_text(self, text: str) -> List[str]:
|
70 |
+
sentence_endings = {'\n', '。', '!', '?', ';', '…'} # 句末标点符号
|
71 |
chunks, current_chunk = [], ''
|
72 |
for word in jieba.cut(text):
|
73 |
if len(current_chunk) + len(word) > self.chunk_size:
|
|
|
85 |
return chunks
|
86 |
|
87 |
def _split_english_text(self, text: str) -> List[str]:
|
88 |
+
# 使用正则表达式按句子分割英文文本
|
89 |
sentences = re.split(r'(?<=[.!?])\s+', text.replace('\n', ' '))
|
90 |
+
chunks = []
|
91 |
+
current_chunk = ''
|
92 |
for sentence in sentences:
|
93 |
+
if len(current_chunk) + len(sentence) <= self.chunk_size:
|
94 |
current_chunk += (' ' if current_chunk else '') + sentence
|
95 |
else:
|
96 |
+
if len(sentence) > self.chunk_size:
|
97 |
+
for i in range(0, len(sentence), self.chunk_size):
|
98 |
+
chunks.append(sentence[i:i + self.chunk_size])
|
99 |
+
current_chunk = ''
|
100 |
+
else:
|
101 |
+
chunks.append(current_chunk)
|
102 |
+
current_chunk = sentence
|
103 |
+
if current_chunk: # Add the last chunk
|
104 |
chunks.append(current_chunk)
|
105 |
|
106 |
if self.chunk_overlap > 0 and len(chunks) > 1:
|
|
|
109 |
return chunks
|
110 |
|
111 |
def _is_has_chinese(self, text: str) -> bool:
|
112 |
+
# check if contains chinese characters
|
113 |
if any("\u4e00" <= ch <= "\u9fff" for ch in text):
|
114 |
return True
|
115 |
else:
|
|
|
125 |
return overlapped_chunks
|
126 |
|
127 |
|
128 |
+
class Rag:
|
129 |
def __init__(
|
130 |
self,
|
131 |
similarity_model: SimilarityABC = None,
|
|
|
133 |
generate_model_name_or_path: str = "Qwen/Qwen2-0.5B-Instruct",
|
134 |
lora_model_name_or_path: str = None,
|
135 |
corpus_files: Union[str, List[str]] = None,
|
136 |
+
save_corpus_emb_dir: str = "./corpus_embs/",
|
137 |
device: str = None,
|
138 |
int8: bool = False,
|
139 |
int4: bool = False,
|
|
|
142 |
rerank_model_name_or_path: str = None,
|
143 |
enable_history: bool = False,
|
144 |
num_expand_context_chunk: int = 2,
|
145 |
+
similarity_top_k: int = 10,
|
146 |
+
rerank_top_k: int = 3,
|
147 |
):
|
148 |
"""
|
149 |
Init RAG model.
|
|
|
182 |
m1 = BertSimilarity(model_name_or_path="sentence-transformers/all-mpnet-base-v2", device=self.device)
|
183 |
m2 = BM25Similarity()
|
184 |
m3 = TfidfSimilarity()
|
185 |
+
default_sim_model = EnsembleSimilarity(similarities=[m1, m2, m3], weights=[0.5, 0.5, 0.5], c=2) # Ajuste los pesos según los resultados
|
|
|
186 |
self.sim_model = default_sim_model
|
187 |
self.gen_model, self.tokenizer = self._init_gen_model(
|
188 |
generate_model_type,
|
|
|
247 |
try:
|
248 |
model.generation_config = GenerationConfig.from_pretrained(gen_model_name_or_path, trust_remote_code=True)
|
249 |
except Exception as e:
|
250 |
+
logger.warning(f"Failed to load generation config from {gen_model_name_or_path}, {e}")
|
251 |
if peft_name:
|
252 |
model = PeftModel.from_pretrained(
|
253 |
model,
|
254 |
peft_name,
|
255 |
torch_dtype="auto",
|
256 |
)
|
257 |
+
logger.info(f"Loaded peft model from {peft_name}")
|
258 |
model.eval()
|
259 |
return model, tokenizer
|
260 |
|
|
|
345 |
raw_text = [text.strip() for text in page_text.splitlines() if text.strip()]
|
346 |
new_text = ''
|
347 |
for text in raw_text:
|
|
|
348 |
if new_text:
|
349 |
new_text += ' '
|
350 |
new_text += text
|
|
|
417 |
# Si se encuentra una coincidencia exacta, devolverla como contexto
|
418 |
return [exact_match]
|
419 |
|
|
|
|
|
|
|
|
|
420 |
reference_results = []
|
421 |
+
sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
|
422 |
+
# Get reference results from corpus
|
423 |
hit_chunk_dict = dict()
|
424 |
threshold_score = 0.5 # Establece un umbral para filtrar fragmentos irrelevantes
|
425 |
|
|
|
429 |
hit_chunk = self.sim_model.corpus[corpus_id]
|
430 |
reference_results.append(hit_chunk)
|
431 |
hit_chunk_dict[corpus_id] = hit_chunk
|
432 |
+
|
433 |
if reference_results:
|
434 |
if self.rerank_model is not None:
|
435 |
# Rerank reference results
|
|
|
454 |
def predict_stream(
|
455 |
self,
|
456 |
query: str,
|
457 |
+
max_length: int = 512,
|
458 |
+
context_len: int = 2048,
|
459 |
+
temperature: float = 0.7,
|
460 |
):
|
461 |
"""Generate predictions stream."""
|
462 |
stop_str = self.tokenizer.eos_token if self.tokenizer.eos_token else "</s>"
|
|
|
464 |
self.history = []
|
465 |
if self.sim_model.corpus:
|
466 |
reference_results = self.get_reference_results(query)
|
467 |
+
if reference_results:
|
468 |
+
reference_results = self._add_source_numbers(reference_results)
|
469 |
+
context_str = '\n'.join(reference_results)[:]
|
470 |
+
else:
|
471 |
+
context_str = ''
|
472 |
prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
|
|
|
473 |
else:
|
474 |
prompt = query
|
475 |
+
logger.debug(f"prompt: {prompt}")
|
476 |
self.history.append([prompt, ''])
|
477 |
response = ""
|
478 |
for new_text in self.stream_generate_answer(
|
|
|
487 |
def predict(
|
488 |
self,
|
489 |
query: str,
|
490 |
+
max_length: int = 512,
|
491 |
+
context_len: int = 2048,
|
492 |
+
temperature: float = 0.7,
|
493 |
):
|
494 |
"""Query from corpus."""
|
495 |
reference_results = []
|
|
|
497 |
self.history = []
|
498 |
if self.sim_model.corpus:
|
499 |
reference_results = self.get_reference_results(query)
|
500 |
+
if reference_results:
|
501 |
+
reference_results = self._add_source_numbers(reference_results)
|
502 |
+
context_str = '\n'.join(reference_results)[:]
|
503 |
+
else:
|
504 |
+
context_str = ''
|
|
|
|
|
|
|
|
|
|
|
505 |
prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
|
|
|
506 |
else:
|
507 |
prompt = query
|
508 |
+
logger.debug(f"prompt: {prompt}")
|
509 |
self.history.append([prompt, ''])
|
510 |
response = ""
|
511 |
for new_text in self.stream_generate_answer(
|
|
|
518 |
self.history[-1][1] = response
|
519 |
return response, reference_results
|
520 |
|
521 |
+
def query(self, query: str, **kwargs):
|
522 |
+
return self.predict(query, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
523 |
|
524 |
def save_corpus_emb(self):
|
525 |
dir_name = self.get_file_hash(self.corpus_files)
|
|
|
531 |
|
532 |
def load_corpus_emb(self, emb_dir: str):
|
533 |
if hasattr(self.sim_model, 'load_corpus_embeddings'):
|
534 |
+
logger.debug(f"Loading corpus embeddings from {emb_dir}")
|
535 |
self.sim_model.load_corpus_embeddings(emb_dir)
|
|
|
|
|
536 |
|
537 |
|
538 |
if __name__ == "__main__":
|
|
|
542 |
parser.add_argument("--gen_model_name", type=str, default="Qwen/Qwen2-0.5B-Instruct")
|
543 |
parser.add_argument("--lora_model", type=str, default=None)
|
544 |
parser.add_argument("--rerank_model_name", type=str, default="")
|
545 |
+
parser.add_argument("--corpus_files", type=str, default="data/sample.pdf")
|
546 |
parser.add_argument("--device", type=str, default=None)
|
547 |
parser.add_argument("--int4", action='store_true', help="use int4 quantization")
|
548 |
parser.add_argument("--int8", action='store_true', help="use int8 quantization")
|
|
|
552 |
args = parser.parse_args()
|
553 |
print(args)
|
554 |
sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
|
555 |
+
m = Rag(
|
556 |
similarity_model=sim_model,
|
557 |
generate_model_type=args.gen_model_type,
|
558 |
generate_model_name_or_path=args.gen_model_name,
|
|
|
566 |
num_expand_context_chunk=args.num_expand_context_chunk,
|
567 |
rerank_model_name_or_path=args.rerank_model_name,
|
568 |
)
|
569 |
+
r, refs = m.predict('自然语言中的非平行迁移是指什么?')
|
570 |
+
print(r, refs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|