Spaces:
Sleeping
Sleeping
import argparse | |
import os | |
import gradio as gr | |
from loguru import logger | |
from similarities import BertSimilarity, BM25Similarity | |
from chatpdf import Rag | |
pwd_path = os.path.abspath(os.path.dirname(__file__)) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--sim_model_name", type=str, default="shibing624/text2vec-base-multilingual") | |
parser.add_argument("--gen_model_type", type=str, default="auto") | |
parser.add_argument("--gen_model_name", type=str, default="Qwen/Qwen2-0.5B-Instruct") | |
parser.add_argument("--lora_model", type=str, default=None) | |
parser.add_argument("--rerank_model_name", type=str, default="") | |
parser.add_argument("--corpus_files", type=str, default="Acuerdo009.pdf") | |
parser.add_argument("--device", type=str, default=None) | |
#parser.add_argument("--int4", action='store_true', help="use int4 quantization") | |
#parser.add_argument("--int8", action='store_true', help="use int8 quantization") | |
parser.add_argument("--chunk_size", type=int, default=220) | |
parser.add_argument("--chunk_overlap", type=int, default=0) | |
parser.add_argument("--num_expand_context_chunk", type=int, default=1) | |
parser.add_argument("--server_name", type=str, default="0.0.0.0") | |
parser.add_argument("--server_port", type=int, default=8082) | |
parser.add_argument("--share", action='store_true', default=True, help="share model") | |
args = parser.parse_args() | |
logger.info(args) | |
# Inicializar el modelo | |
sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device) | |
model = Rag( | |
similarity_model=sim_model, | |
generate_model_type=args.gen_model_type, | |
generate_model_name_or_path=args.gen_model_name, | |
lora_model_name_or_path=args.lora_model, | |
corpus_files=args.corpus_files.split(','), | |
device=args.device, | |
chunk_size=args.chunk_size, | |
chunk_overlap=args.chunk_overlap, | |
num_expand_context_chunk=args.num_expand_context_chunk, | |
rerank_model_name_or_path=args.rerank_model_name, | |
) | |
logger.info(f"chatpdf model: {model}") | |
def predict_stream(message, history): | |
history_format = [] | |
for human, assistant in history: | |
history_format.append([human, assistant]) | |
model.history = history_format | |
for chunk in model.predict_stream(message): | |
yield chunk | |
def predict(message, history): | |
logger.debug(message) | |
response, reference_results = model.predict(message) | |
r = response + "\n\n" + '\n'.join(reference_results) | |
logger.debug(r) | |
return r | |
chatbot_stream = gr.Chatbot( | |
height=600, | |
avatar_images=( | |
os.path.join(pwd_path, "assets/user.png"), | |
os.path.join(pwd_path, "assets/Logo1.png"), | |
), bubble_full_width=False) | |
# Actualizar el t铆tulo y la descripci贸n | |
title = " 馃ChatPDF Zonia馃 " | |
# description = "Enlace en Github: [shibing624/ChatPDF](https://github.com/shibing624/ChatPDF)" | |
css = """.toast-wrap { display: none !importante } """ | |
examples = ['Puede hablarme del PNL?', 'Introducci贸n a la PNL'] | |
chat_interface_stream = gr.ChatInterface( | |
predict, | |
textbox=gr.Textbox(lines=4, placeholder="Ask me question", scale=7), # A帽adir submit=True | |
title=title, | |
# description=description, | |
chatbot=chatbot_stream, | |
css=css, | |
examples=examples, | |
theme='soft', | |
) | |
# Lanzar la aplicaci贸n sin `server_name` ni `server_port` | |
chat_interface_stream.launch() |