Spaces:
Sleeping
Sleeping
Update chatpdf.py
Browse files- chatpdf.py +83 -59
chatpdf.py
CHANGED
@@ -16,8 +16,13 @@ from similarities import (
|
|
16 |
)
|
17 |
from similarities.similarity import SimilarityABC
|
18 |
from transformers import (
|
|
|
19 |
AutoModelForCausalLM,
|
20 |
AutoTokenizer,
|
|
|
|
|
|
|
|
|
21 |
TextIteratorStreamer,
|
22 |
GenerationConfig,
|
23 |
AutoModelForSequenceClassification,
|
@@ -26,9 +31,22 @@ from transformers import (
|
|
26 |
jieba.setLogLevel("ERROR")
|
27 |
|
28 |
MODEL_CLASSES = {
|
|
|
|
|
|
|
|
|
29 |
"auto": (AutoModelForCausalLM, AutoTokenizer),
|
30 |
}
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
PROMPT_TEMPLATE1 = """Utiliza la siguiente información para responder a la pregunta del usuario.
|
33 |
Si no sabes la respuesta, di simplemente que no la sabes, no intentes inventarte una respuesta.
|
34 |
|
@@ -41,7 +59,7 @@ Respuesta útil:
|
|
41 |
PROMPT_TEMPLATE = """Basándose en la siguiente información conocida, responda a la pregunta del usuario de forma
|
42 |
concisa y profesional. Si no puede obtener una respuesta, diga «No se puede responder a la pregunta basándose en la
|
43 |
información conocida» o «No se proporciona suficiente información relevante», no está permitido añadir elementos
|
44 |
-
inventados en la respuesta.
|
45 |
|
46 |
Contenido conocido:
|
47 |
{context_str}
|
@@ -81,7 +99,7 @@ class SentenceSplitter:
|
|
81 |
return chunks
|
82 |
|
83 |
def _split_english_text(self, text: str) -> List[str]:
|
84 |
-
#
|
85 |
sentences = re.split(r'(?<=[.!?])\s+', text.replace('\n', ' '))
|
86 |
chunks, current_chunk = [], ''
|
87 |
for sentence in sentences:
|
@@ -90,7 +108,7 @@ class SentenceSplitter:
|
|
90 |
else:
|
91 |
chunks.append(current_chunk)
|
92 |
current_chunk = sentence
|
93 |
-
if current_chunk: #
|
94 |
chunks.append(current_chunk)
|
95 |
|
96 |
if self.chunk_overlap > 0 and len(chunks) > 1:
|
@@ -99,14 +117,14 @@ class SentenceSplitter:
|
|
99 |
return chunks
|
100 |
|
101 |
def _is_has_chinese(self, text: str) -> bool:
|
102 |
-
#
|
103 |
if any("\u4e00" <= ch <= "\u9fff" for ch in text):
|
104 |
return True
|
105 |
else:
|
106 |
return False
|
107 |
|
108 |
def _handle_overlap(self, chunks: List[str]) -> List[str]:
|
109 |
-
#
|
110 |
overlapped_chunks = []
|
111 |
for i in range(len(chunks) - 1):
|
112 |
chunk = chunks[i] + ' ' + chunks[i + 1][:self.chunk_overlap]
|
@@ -115,12 +133,13 @@ class SentenceSplitter:
|
|
115 |
return overlapped_chunks
|
116 |
|
117 |
|
|
|
118 |
class ChatPDF:
|
119 |
def __init__(
|
120 |
self,
|
121 |
similarity_model: SimilarityABC = None,
|
122 |
generate_model_type: str = "auto",
|
123 |
-
generate_model_name_or_path: str = "
|
124 |
lora_model_name_or_path: str = None,
|
125 |
corpus_files: Union[str, List[str]] = None,
|
126 |
save_corpus_emb_dir: str = "corpus_embs/",
|
@@ -132,10 +151,28 @@ class ChatPDF:
|
|
132 |
rerank_model_name_or_path: str = None,
|
133 |
enable_history: bool = False,
|
134 |
num_expand_context_chunk: int = 2,
|
135 |
-
similarity_top_k: int =
|
136 |
-
rerank_top_k: int =
|
137 |
):
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
if torch.cuda.is_available():
|
140 |
default_device = torch.device(0)
|
141 |
elif torch.backends.mps.is_available():
|
@@ -151,7 +188,7 @@ class ChatPDF:
|
|
151 |
if similarity_model is not None:
|
152 |
self.sim_model = similarity_model
|
153 |
else:
|
154 |
-
m1 = BertSimilarity(model_name_or_path="
|
155 |
m2 = BM25Similarity()
|
156 |
default_sim_model = EnsembleSimilarity(similarities=[m1, m2], weights=[0.5, 0.5], c=2)
|
157 |
self.sim_model = default_sim_model
|
@@ -168,7 +205,7 @@ class ChatPDF:
|
|
168 |
self.add_corpus(corpus_files)
|
169 |
self.save_corpus_emb_dir = save_corpus_emb_dir
|
170 |
if rerank_model_name_or_path is None:
|
171 |
-
rerank_model_name_or_path = "
|
172 |
if rerank_model_name_or_path:
|
173 |
self.rerank_tokenizer = AutoTokenizer.from_pretrained(rerank_model_name_or_path)
|
174 |
self.rerank_model = AutoModelForSequenceClassification.from_pretrained(rerank_model_name_or_path)
|
@@ -252,7 +289,7 @@ class ChatPDF:
|
|
252 |
repetition_penalty=1.0,
|
253 |
context_len=2048
|
254 |
):
|
255 |
-
streamer = TextIteratorStreamer(self.tokenizer, timeout=
|
256 |
input_ids = self._get_chat_input()
|
257 |
max_src_len = context_len - max_new_tokens - 8
|
258 |
input_ids = input_ids[-max_src_len:]
|
@@ -383,29 +420,14 @@ class ChatPDF:
|
|
383 |
"""
|
384 |
reference_results = []
|
385 |
sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
for
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
# Extraer el valor necesario si corpus_id es un diccionario
|
394 |
-
corpus_id = next(iter(corpus_id.keys())) # Tomar la primera clave como ejemplo
|
395 |
-
if corpus_id in self.sim_model.corpus:
|
396 |
-
hit_chunk = self.sim_model.corpus[corpus_id]
|
397 |
-
reference_results.append(hit_chunk)
|
398 |
-
|
399 |
-
elif isinstance(sim_contents, dict):
|
400 |
-
for query_id, id_score_dict in sim_contents.items():
|
401 |
-
for corpus_id, s in id_score_dict.items():
|
402 |
-
if corpus_id in self.sim_model.corpus:
|
403 |
-
hit_chunk = self.sim_model.corpus[corpus_id]
|
404 |
-
reference_results.append(hit_chunk)
|
405 |
-
else:
|
406 |
-
logger.error(f"Unexpected type for sim_contents: {type(sim_contents)}")
|
407 |
|
408 |
-
|
409 |
if reference_results:
|
410 |
if self.rerank_model is not None:
|
411 |
# Rerank reference results
|
@@ -444,7 +466,7 @@ class ChatPDF:
|
|
444 |
yield 'No se ha proporcionado suficiente información relevante', reference_results
|
445 |
reference_results = self._add_source_numbers(reference_results)
|
446 |
context_str = '\n'.join(reference_results)[:]
|
447 |
-
|
448 |
prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
|
449 |
logger.debug(f"prompt: {prompt}")
|
450 |
else:
|
@@ -478,12 +500,12 @@ class ChatPDF:
|
|
478 |
if not reference_results:
|
479 |
return 'No se ha proporcionado suficiente información relevante', reference_results
|
480 |
reference_results = self._add_source_numbers(reference_results)
|
481 |
-
#context_str = '\n'.join(reference_results) # Usa todos los fragmentos
|
482 |
context_st = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
|
483 |
-
|
484 |
print(".......................................................")
|
485 |
context_str = '\n'.join(reference_results)[:]
|
486 |
-
|
487 |
prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
|
488 |
logger.debug(f"prompt: {prompt}")
|
489 |
else:
|
@@ -500,19 +522,6 @@ class ChatPDF:
|
|
500 |
self.history[-1][1] = response
|
501 |
return response, reference_results
|
502 |
|
503 |
-
def save_corpus_emb(self):
|
504 |
-
dir_name = self.get_file_hash(self.corpus_files)
|
505 |
-
save_dir = os.path.join(self.save_corpus_emb_dir, dir_name)
|
506 |
-
if hasattr(self.sim_model, 'save_corpus_embeddings'):
|
507 |
-
self.sim_model.save_corpus_embeddings(save_dir)
|
508 |
-
logger.debug(f"Saving corpus embeddings to {save_dir}")
|
509 |
-
return save_dir
|
510 |
-
|
511 |
-
def load_corpus_emb(self, emb_dir: str):
|
512 |
-
if hasattr(self.sim_model, 'load_corpus_embeddings'):
|
513 |
-
logger.debug(f"Loading corpus embeddings from {emb_dir}")
|
514 |
-
self.sim_model.load_corpus_embeddings(emb_dir)
|
515 |
-
|
516 |
def save_corpus_text(self):
|
517 |
if not self.corpus_files:
|
518 |
logger.warning("No hay archivos de corpus para guardar.")
|
@@ -537,20 +546,36 @@ class ChatPDF:
|
|
537 |
else:
|
538 |
logger.warning(f"No se encontró el archivo de texto del corpus en: {corpus_text_file}")
|
539 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
540 |
if __name__ == "__main__":
|
541 |
parser = argparse.ArgumentParser()
|
542 |
-
parser.add_argument("--sim_model_name", type=str, default="
|
543 |
parser.add_argument("--gen_model_type", type=str, default="auto")
|
544 |
-
parser.add_argument("--gen_model_name", type=str, default="
|
545 |
parser.add_argument("--lora_model", type=str, default=None)
|
546 |
-
parser.add_argument("--rerank_model_name", type=str, default="
|
547 |
-
parser.add_argument("--corpus_files", type=str, default="
|
548 |
parser.add_argument("--device", type=str, default=None)
|
549 |
parser.add_argument("--int4", action='store_true', help="use int4 quantization")
|
550 |
parser.add_argument("--int8", action='store_true', help="use int8 quantization")
|
551 |
parser.add_argument("--chunk_size", type=int, default=220)
|
552 |
-
parser.add_argument("--chunk_overlap", type=int, default=
|
553 |
-
parser.add_argument("--num_expand_context_chunk", type=int, default=
|
554 |
args = parser.parse_args()
|
555 |
print(args)
|
556 |
sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
|
@@ -568,7 +593,6 @@ if __name__ == "__main__":
|
|
568 |
num_expand_context_chunk=args.num_expand_context_chunk,
|
569 |
rerank_model_name_or_path=args.rerank_model_name,
|
570 |
)
|
571 |
-
logger.info(f"chatpdf model: {m}")
|
572 |
|
573 |
# Comprobar si existen incrustaciones guardadas
|
574 |
dir_name = m.get_file_hash(args.corpus_files.split(','))
|
@@ -594,4 +618,4 @@ if __name__ == "__main__":
|
|
594 |
continue
|
595 |
r, refs = m.predict(query)
|
596 |
print(r, refs)
|
597 |
-
print("\nRespuesta: ", r)
|
|
|
16 |
)
|
17 |
from similarities.similarity import SimilarityABC
|
18 |
from transformers import (
|
19 |
+
AutoModel,
|
20 |
AutoModelForCausalLM,
|
21 |
AutoTokenizer,
|
22 |
+
BloomForCausalLM,
|
23 |
+
BloomTokenizerFast,
|
24 |
+
LlamaTokenizer,
|
25 |
+
LlamaForCausalLM,
|
26 |
TextIteratorStreamer,
|
27 |
GenerationConfig,
|
28 |
AutoModelForSequenceClassification,
|
|
|
31 |
jieba.setLogLevel("ERROR")
|
32 |
|
33 |
MODEL_CLASSES = {
|
34 |
+
"bloom": (BloomForCausalLM, BloomTokenizerFast),
|
35 |
+
"chatglm": (AutoModel, AutoTokenizer),
|
36 |
+
"llama": (LlamaForCausalLM, LlamaTokenizer),
|
37 |
+
"baichuan": (AutoModelForCausalLM, AutoTokenizer),
|
38 |
"auto": (AutoModelForCausalLM, AutoTokenizer),
|
39 |
}
|
40 |
|
41 |
+
PROMPT_TEMPLATE1 = """基于以下已知信息,简洁和专业的来回答用户的问题。
|
42 |
+
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
|
43 |
+
|
44 |
+
已知内容:
|
45 |
+
{context_str}
|
46 |
+
|
47 |
+
问题:
|
48 |
+
{query_str}
|
49 |
+
"""
|
50 |
PROMPT_TEMPLATE1 = """Utiliza la siguiente información para responder a la pregunta del usuario.
|
51 |
Si no sabes la respuesta, di simplemente que no la sabes, no intentes inventarte una respuesta.
|
52 |
|
|
|
59 |
PROMPT_TEMPLATE = """Basándose en la siguiente información conocida, responda a la pregunta del usuario de forma
|
60 |
concisa y profesional. Si no puede obtener una respuesta, diga «No se puede responder a la pregunta basándose en la
|
61 |
información conocida» o «No se proporciona suficiente información relevante», no está permitido añadir elementos
|
62 |
+
inventados en la respuesta, y ésta debe estar en Español.
|
63 |
|
64 |
Contenido conocido:
|
65 |
{context_str}
|
|
|
99 |
return chunks
|
100 |
|
101 |
def _split_english_text(self, text: str) -> List[str]:
|
102 |
+
# Dividir el texto inglés por frases utilizando expresiones regulares
|
103 |
sentences = re.split(r'(?<=[.!?])\s+', text.replace('\n', ' '))
|
104 |
chunks, current_chunk = [], ''
|
105 |
for sentence in sentences:
|
|
|
108 |
else:
|
109 |
chunks.append(current_chunk)
|
110 |
current_chunk = sentence
|
111 |
+
if current_chunk: # Añade el último trozo
|
112 |
chunks.append(current_chunk)
|
113 |
|
114 |
if self.chunk_overlap > 0 and len(chunks) > 1:
|
|
|
117 |
return chunks
|
118 |
|
119 |
def _is_has_chinese(self, text: str) -> bool:
|
120 |
+
# comprobar si contiene caracteres chinos
|
121 |
if any("\u4e00" <= ch <= "\u9fff" for ch in text):
|
122 |
return True
|
123 |
else:
|
124 |
return False
|
125 |
|
126 |
def _handle_overlap(self, chunks: List[str]) -> List[str]:
|
127 |
+
# 处理块间重叠
|
128 |
overlapped_chunks = []
|
129 |
for i in range(len(chunks) - 1):
|
130 |
chunk = chunks[i] + ' ' + chunks[i + 1][:self.chunk_overlap]
|
|
|
133 |
return overlapped_chunks
|
134 |
|
135 |
|
136 |
+
|
137 |
class ChatPDF:
|
138 |
def __init__(
|
139 |
self,
|
140 |
similarity_model: SimilarityABC = None,
|
141 |
generate_model_type: str = "auto",
|
142 |
+
generate_model_name_or_path: str = "Qwen/Qwen2-0.5B-Instruct",
|
143 |
lora_model_name_or_path: str = None,
|
144 |
corpus_files: Union[str, List[str]] = None,
|
145 |
save_corpus_emb_dir: str = "corpus_embs/",
|
|
|
151 |
rerank_model_name_or_path: str = None,
|
152 |
enable_history: bool = False,
|
153 |
num_expand_context_chunk: int = 2,
|
154 |
+
similarity_top_k: int = 5,
|
155 |
+
rerank_top_k: int =3,
|
156 |
):
|
157 |
+
"""
|
158 |
+
Init RAG model.
|
159 |
+
:param similarity_model: similarity model, default None, if set, will use it instead of EnsembleSimilarity
|
160 |
+
:param generate_model_type: generate model type
|
161 |
+
:param generate_model_name_or_path: generate model name or path
|
162 |
+
:param lora_model_name_or_path: lora model name or path
|
163 |
+
:param corpus_files: corpus files
|
164 |
+
:param save_corpus_emb_dir: save corpus embeddings dir, default ./corpus_embs/
|
165 |
+
:param device: device, default None, auto select gpu or cpu
|
166 |
+
:param int8: use int8 quantization, default False
|
167 |
+
:param int4: use int4 quantization, default False
|
168 |
+
:param chunk_size: chunk size, default 250
|
169 |
+
:param chunk_overlap: chunk overlap, default 0, can not set to > 0 if num_expand_context_chunk > 0
|
170 |
+
:param rerank_model_name_or_path: rerank model name or path, default 'BAAI/bge-reranker-base'
|
171 |
+
:param enable_history: enable history, default False
|
172 |
+
:param num_expand_context_chunk: num expand context chunk, default 2, if set to 0, will not expand context chunk
|
173 |
+
:param similarity_top_k: similarity_top_k, default 5, similarity model search k corpus chunks
|
174 |
+
:param rerank_top_k: rerank_top_k, default 3, rerank model search k corpus chunks
|
175 |
+
"""
|
176 |
if torch.cuda.is_available():
|
177 |
default_device = torch.device(0)
|
178 |
elif torch.backends.mps.is_available():
|
|
|
188 |
if similarity_model is not None:
|
189 |
self.sim_model = similarity_model
|
190 |
else:
|
191 |
+
m1 = BertSimilarity(model_name_or_path="shibing624/text2vec-base-multilingual", device=self.device)
|
192 |
m2 = BM25Similarity()
|
193 |
default_sim_model = EnsembleSimilarity(similarities=[m1, m2], weights=[0.5, 0.5], c=2)
|
194 |
self.sim_model = default_sim_model
|
|
|
205 |
self.add_corpus(corpus_files)
|
206 |
self.save_corpus_emb_dir = save_corpus_emb_dir
|
207 |
if rerank_model_name_or_path is None:
|
208 |
+
rerank_model_name_or_path = "BAAI/bge-reranker-large"
|
209 |
if rerank_model_name_or_path:
|
210 |
self.rerank_tokenizer = AutoTokenizer.from_pretrained(rerank_model_name_or_path)
|
211 |
self.rerank_model = AutoModelForSequenceClassification.from_pretrained(rerank_model_name_or_path)
|
|
|
289 |
repetition_penalty=1.0,
|
290 |
context_len=2048
|
291 |
):
|
292 |
+
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
293 |
input_ids = self._get_chat_input()
|
294 |
max_src_len = context_len - max_new_tokens - 8
|
295 |
input_ids = input_ids[-max_src_len:]
|
|
|
420 |
"""
|
421 |
reference_results = []
|
422 |
sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
|
423 |
+
# Get reference results from corpus
|
424 |
+
hit_chunk_dict = dict()
|
425 |
+
for query_id, id_score_dict in sim_contents.items():
|
426 |
+
for corpus_id, s in id_score_dict.items():
|
427 |
+
hit_chunk = self.sim_model.corpus[corpus_id]
|
428 |
+
reference_results.append(hit_chunk)
|
429 |
+
hit_chunk_dict[corpus_id] = hit_chunk
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
|
|
|
431 |
if reference_results:
|
432 |
if self.rerank_model is not None:
|
433 |
# Rerank reference results
|
|
|
466 |
yield 'No se ha proporcionado suficiente información relevante', reference_results
|
467 |
reference_results = self._add_source_numbers(reference_results)
|
468 |
context_str = '\n'.join(reference_results)[:]
|
469 |
+
print("gggggg: ", (context_len - len(PROMPT_TEMPLATE)))
|
470 |
prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
|
471 |
logger.debug(f"prompt: {prompt}")
|
472 |
else:
|
|
|
500 |
if not reference_results:
|
501 |
return 'No se ha proporcionado suficiente información relevante', reference_results
|
502 |
reference_results = self._add_source_numbers(reference_results)
|
503 |
+
# context_str = '\n'.join(reference_results) # Usa todos los fragmentos
|
504 |
context_st = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
|
505 |
+
print("Context: ", (context_len - len(PROMPT_TEMPLATE)))
|
506 |
print(".......................................................")
|
507 |
context_str = '\n'.join(reference_results)[:]
|
508 |
+
print("context_str: ", context_str)
|
509 |
prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
|
510 |
logger.debug(f"prompt: {prompt}")
|
511 |
else:
|
|
|
522 |
self.history[-1][1] = response
|
523 |
return response, reference_results
|
524 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
525 |
def save_corpus_text(self):
|
526 |
if not self.corpus_files:
|
527 |
logger.warning("No hay archivos de corpus para guardar.")
|
|
|
546 |
else:
|
547 |
logger.warning(f"No se encontró el archivo de texto del corpus en: {corpus_text_file}")
|
548 |
|
549 |
+
def save_corpus_emb(self):
|
550 |
+
dir_name = self.get_file_hash(self.corpus_files)
|
551 |
+
save_dir = os.path.join(self.save_corpus_emb_dir, dir_name)
|
552 |
+
if hasattr(self.sim_model, 'save_corpus_embeddings'):
|
553 |
+
self.sim_model.save_corpus_embeddings(save_dir)
|
554 |
+
logger.debug(f"Saving corpus embeddings to {save_dir}")
|
555 |
+
return save_dir
|
556 |
+
|
557 |
+
def load_corpus_emb(self, emb_dir: str):
|
558 |
+
if hasattr(self.sim_model, 'load_corpus_embeddings'):
|
559 |
+
logger.debug(f"Cargando incrustaciones del corpus desde {emb_dir}")
|
560 |
+
self.sim_model.load_corpus_embeddings(emb_dir)
|
561 |
+
# Cargar el texto del corpus
|
562 |
+
self.load_corpus_text(emb_dir)
|
563 |
+
|
564 |
+
|
565 |
if __name__ == "__main__":
|
566 |
parser = argparse.ArgumentParser()
|
567 |
+
parser.add_argument("--sim_model_name", type=str, default="shibing624/text2vec-base-multilingual")
|
568 |
parser.add_argument("--gen_model_type", type=str, default="auto")
|
569 |
+
parser.add_argument("--gen_model_name", type=str, default="Qwen/Qwen2-0.5B-Instruct")
|
570 |
parser.add_argument("--lora_model", type=str, default=None)
|
571 |
+
parser.add_argument("--rerank_model_name", type=str, default="")
|
572 |
+
parser.add_argument("--corpus_files", type=str, default="Acuerdo009.pdf")
|
573 |
parser.add_argument("--device", type=str, default=None)
|
574 |
parser.add_argument("--int4", action='store_true', help="use int4 quantization")
|
575 |
parser.add_argument("--int8", action='store_true', help="use int8 quantization")
|
576 |
parser.add_argument("--chunk_size", type=int, default=220)
|
577 |
+
parser.add_argument("--chunk_overlap", type=int, default=0)
|
578 |
+
parser.add_argument("--num_expand_context_chunk", type=int, default=1)
|
579 |
args = parser.parse_args()
|
580 |
print(args)
|
581 |
sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
|
|
|
593 |
num_expand_context_chunk=args.num_expand_context_chunk,
|
594 |
rerank_model_name_or_path=args.rerank_model_name,
|
595 |
)
|
|
|
596 |
|
597 |
# Comprobar si existen incrustaciones guardadas
|
598 |
dir_name = m.get_file_hash(args.corpus_files.split(','))
|
|
|
618 |
continue
|
619 |
r, refs = m.predict(query)
|
620 |
print(r, refs)
|
621 |
+
print("\nRespuesta: ", r)
|