File size: 1,706 Bytes
f493920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from common import DATA, MODEL, TERMINATORS, TOKENIZER, format_prompt, search_topk
from config import MAX_TOKENS_INPUT, SYS_PROMPT_HF
from preprocessing import FEATURE_EXTRACTOR


def generate(formatted_prompt):
    formatted_prompt = formatted_prompt[:MAX_TOKENS_INPUT]  # to avoid GPU OOM
    messages = [
        {"role": "system", "content": SYS_PROMPT_HF},
        {"role": "user", "content": formatted_prompt},
    ]

    input_ids = TOKENIZER.apply_chat_template(
        messages, add_generation_prompt=True, return_tensors="pt"
    ).to(MODEL.device)
    outputs = MODEL.generate(
        input_ids,
        max_new_tokens=512,
        eos_token_id=TERMINATORS,
        do_sample=True,
        temperature=0.1,
        top_p=0.9,
    )
    response = outputs[0]
    return TOKENIZER.decode(response[input_ids.shape[-1] :], skip_special_tokens=True)


def rag_chatbot(prompt: str, k: int = 2, return_user: bool = False):
    _, retrieved_documents = search_topk(
        DATA, FEATURE_EXTRACTOR, prompt, k, embedding_col="embedding"
    )
    formatted_prompt = format_prompt(prompt, retrieved_documents, k, text_col="chunk")
    bot_response = generate(formatted_prompt)
    return (
        f"[USER]: {prompt}\n\n[ASSISTANT]: {bot_response}"
        if return_user
        else bot_response
    )


if __name__ == "__main__":
    # example RAG Pipeline using HuggingFace
    DATA = DATA.add_faiss_index("embedding")
    prompt = """indicame qué va a pasar en la reforma pensional con los fondos en el pilar
    contributivo de prima media, podré pedir el dinero de vuelta cuando tenga la edad si no
    cumplo con las semanas cotizadas?"""
    print(rag_chatbot(prompt, k=3, return_user=True))