File size: 3,098 Bytes
9918198
1b6e08f
456ec91
 
 
 
9918198
1f97769
456ec91
883f7e7
 
456ec91
75c1fd6
456ec91
1b6e08f
 
456ec91
1b6e08f
 
 
 
 
 
 
 
 
953debe
1b6e08f
 
 
953debe
1b6e08f
 
 
 
 
9918198
953debe
 
1b6e08f
 
 
 
 
953debe
 
1b6e08f
9918198
1b6e08f
 
 
 
9918198
1b6e08f
 
 
 
 
 
 
9918198
1b6e08f
 
 
 
456ec91
 
9918198
 
 
 
 
 
 
2829eb5
 
9918198
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import torch
import faiss
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import gradio as gr

# ν™˜κ²½ λ³€μˆ˜μ—μ„œ Hugging Face API ν‚€ λ‘œλ“œ
hf_api_key = os.getenv('HF_API_KEY')

# λͺ¨λΈ ID 및 ν† ν¬λ‚˜μ΄μ € μ„€μ •
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_api_key)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    use_auth_token=hf_api_key,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True, 
        bnb_4bit_use_double_quant=True, 
        bnb_4bit_quant_type="nf4", 
        bnb_4bit_compute_dtype=torch.bfloat16
    )
)

# 데이터 λ‘œλ”© 및 faiss 인덱슀 생성
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
dataset = load_dataset("not-lain/wikipedia", revision="embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings")

# 검색 및 응닡 생성 ν•¨μˆ˜
def search(query: str, k: int = 3):
    embedded_query = ST.encode(query)
    scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
    return scores, retrieved_examples

def format_prompt(prompt, retrieved_documents, k):
    PROMPT = f"Question:{prompt}\nContext:"
    for idx in range(k):
        PROMPT += f"{retrieved_documents['text'][idx]}\n"
    return PROMPT

def generate(formatted_prompt):
    formatted_prompt = formatted_prompt[:2000]  # GPU λ©”λͺ¨λ¦¬ μ œν•œμ„ κ³ λ €
    messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
    input_ids = tokenizer(messages, return_tensors="pt", padding=True).input_ids.to(model.device)
    outputs = model.generate(
        input_ids,
        max_new_tokens=1024,
        eos_token_id=tokenizer.eos_token_id,
        do_sample=True,
        temperature=0.6,
        top_p=0.9
    )
    response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
    return response

def rag_chatbot_interface(prompt: str, k: int = 2):
    scores, retrieved_documents = search(prompt, k)
    formatted_prompt = format_prompt(prompt, retrieved_documents, k)
    return generate(formatted_prompt)

SYS_PROMPT = "You are an assistant for answering questions. You are given the extracted parts of a long document and a question. Provide a conversational answer. If you don't know the answer, just say 'I do not know.' Don't make up an answer."

# Gradio μΈν„°νŽ˜μ΄μŠ€ μ„€μ •
iface = gr.Interface(
    fn=rag_chatbot_interface,
    inputs=gr.inputs.Textbox(label="Enter your question"),
    outputs=gr.outputs.Textbox(label="Answer"),
    title="Retrieval-Augmented Generation Chatbot",
    description="This is a chatbot that uses a retrieval-augmented generation approach to provide more accurate answers. It first searches for relevant documents and then generates a response based on the prompt and the retrieved documents."
)

iface.launch()