Spaces:
Runtime error
Runtime error
from common import DATA, MODEL, TERMINATORS, TOKENIZER, format_prompt, search_topk | |
from config import MAX_TOKENS_INPUT, SYS_PROMPT_HF | |
from preprocessing import FEATURE_EXTRACTOR | |
def generate(formatted_prompt): | |
formatted_prompt = formatted_prompt[:MAX_TOKENS_INPUT] # to avoid GPU OOM | |
messages = [ | |
{"role": "system", "content": SYS_PROMPT_HF}, | |
{"role": "user", "content": formatted_prompt}, | |
] | |
input_ids = TOKENIZER.apply_chat_template( | |
messages, add_generation_prompt=True, return_tensors="pt" | |
).to(MODEL.device) | |
outputs = MODEL.generate( | |
input_ids, | |
max_new_tokens=512, | |
eos_token_id=TERMINATORS, | |
do_sample=True, | |
temperature=0.1, | |
top_p=0.9, | |
) | |
response = outputs[0] | |
return TOKENIZER.decode(response[input_ids.shape[-1] :], skip_special_tokens=True) | |
def rag_chatbot(prompt: str, k: int = 2, return_user: bool = False): | |
_, retrieved_documents = search_topk( | |
DATA, FEATURE_EXTRACTOR, prompt, k, embedding_col="embedding" | |
) | |
formatted_prompt = format_prompt(prompt, retrieved_documents, k, text_col="chunk") | |
bot_response = generate(formatted_prompt) | |
return ( | |
f"[USER]: {prompt}\n\n[ASSISTANT]: {bot_response}" | |
if return_user | |
else bot_response | |
) | |
if __name__ == "__main__": | |
# example RAG Pipeline using HuggingFace | |
DATA = DATA.add_faiss_index("embedding") | |
prompt = """indicame qué va a pasar en la reforma pensional con los fondos en el pilar | |
contributivo de prima media, podré pedir el dinero de vuelta cuando tenga la edad si no | |
cumplo con las semanas cotizadas?""" | |
print(rag_chatbot(prompt, k=3, return_user=True)) | |