File size: 3,100 Bytes
9918198
4cc10ce
 
456ec91
 
4cc10ce
9918198
1f97769
4cc10ce
883f7e7
75c1fd6
2d84b3b
4cc10ce
1b6e08f
 
2d84b3b
1b6e08f
 
4cc10ce
 
 
1b6e08f
 
 
4cc10ce
953debe
1b6e08f
 
953debe
1b6e08f
 
 
 
9918198
953debe
 
1b6e08f
 
 
 
 
953debe
 
1b6e08f
4cc10ce
 
1b6e08f
 
 
9918198
1b6e08f
 
 
 
 
 
 
9918198
1b6e08f
 
 
 
456ec91
 
9918198
 
 
 
 
4cc10ce
2829eb5
 
9918198
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from accelerate import Accelerator  # Accelerateλ₯Ό λ³„λ„λ‘œ μž„ν¬νŠΈ
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import faiss
import gradio as gr

hf_api_key = os.getenv('HF_API_KEY')  # ν™˜κ²½ λ³€μˆ˜μ—μ„œ API ν‚€ λ‘œλ“œ

model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
accelerator = Accelerator()  # Accelerator μΈμŠ€ν„΄μŠ€ 생성
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    token=hf_api_key,
    torch_dtype=torch.bfloat16,
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
)
model = accelerator.prepare(model)  # λͺ¨λΈμ„ Accelerator에 μ€€λΉ„μ‹œν‚΄

ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
dataset = load_dataset("not-lain/wikipedia", revision="embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings")

def search(query: str, k: int = 3):
    embedded_query = ST.encode(query)
    scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
    return scores, retrieved_examples

def format_prompt(prompt, retrieved_documents, k):
    PROMPT = f"Question:{prompt}\nContext:"
    for idx in range(k):
        PROMPT += f"{retrieved_documents['text'][idx]}\n"
    return PROMPT

def generate(formatted_prompt):
    formatted_prompt = formatted_prompt[:2000]  # GPU λ©”λͺ¨λ¦¬ μ œν•œμ„ κ³ λ €
    messages = [{"role": "system", "content": "You are an assistant..."}, {"role": "user", "content": formatted_prompt}]
    input_ids = tokenizer(messages, return_tensors="pt", padding=True).input_ids.to(accelerator.device)
    outputs = model.generate(
        input_ids,
        max_new_tokens=1024,
        eos_token_id=tokenizer.eos_token_id,
        do_sample=True,
        temperature=0.6,
        top_p=0.9
    )
    response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
    return response

def rag_chatbot_interface(prompt: str, k: int = 2):
    scores, retrieved_documents = search(prompt, k)
    formatted_prompt = format_prompt(prompt, retrieved_documents, k)
    return generate(formatted_prompt)

SYS_PROMPT = "You are an assistant for answering questions. You are given the extracted parts of a long document and a question. Provide a conversational answer. If you don't know the answer, just say 'I do not know.' Don't make up an answer."

iface = gr.Interface(
    fn=rag_chatbot_interface,
    inputs=gr.inputs.Textbox(label="Enter your question"),
    outputs=gr.outputs.Textbox(label="Answer"),
    title="Retrieval-Augmented Generation Chatbot",
    description="This chatbot uses a retrieval-augmented generation approach to provide more accurate answers. It first searches for relevant documents and then generates a response based on the prompt and the retrieved documents."
)

iface.launch()