File size: 2,990 Bytes
9918198
 
953debe
1b6e08f
9918198
1b6e08f
9918198
1f97769
883f7e7
 
 
 
 
 
 
 
953debe
1b6e08f
75c1fd6
1b6e08f
 
 
 
 
 
 
 
 
 
 
 
953debe
1b6e08f
 
 
953debe
1b6e08f
 
 
 
 
9918198
953debe
 
1b6e08f
 
 
 
 
953debe
 
1b6e08f
9918198
1b6e08f
 
 
 
9918198
1b6e08f
 
 
 
 
 
 
9918198
1b6e08f
 
 
 
9918198
 
 
 
 
 
 
2829eb5
 
9918198
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from datasets import load_dataset, Dataset
import faiss
import torch
import gradio as gr


# ν™˜κ²½ λ³€μˆ˜μ—μ„œ API ν‚€ λ‘œλ“œ
hf_api_key = os.getenv('HF_API_KEY')

# λͺ¨λΈ 및 ν† ν¬λ‚˜μ΄μ € λ‘œλ“œν•  λ•Œ API ν‚€ μ‚¬μš©
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_api_key)
model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=hf_api_key)


# λͺ¨λΈ 및 ν† ν¬λ‚˜μ΄μ € μ„€μ •
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True, 
        bnb_4bit_use_double_quant=True, 
        bnb_4bit_quant_type="nf4", 
        bnb_4bit_compute_dtype=torch.bfloat16
    )
)

# 데이터 λ‘œλ”© 및 faiss 인덱슀 생성
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
dataset = load_dataset("not-lain/wikipedia", revision="embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings")

# 검색 및 응닡 생성 ν•¨μˆ˜
def search(query: str, k: int = 3):
    embedded_query = ST.encode(query)
    scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
    return scores, retrieved_examples

def format_prompt(prompt, retrieved_documents, k):
    PROMPT = f"Question:{prompt}\nContext:"
    for idx in range(k):
        PROMPT += f"{retrieved_documents['text'][idx]}\n"
    return PROMPT

def generate(formatted_prompt):
    formatted_prompt = formatted_prompt[:2000]  # GPU λ©”λͺ¨λ¦¬ μ œν•œμ„ κ³ λ €
    messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
    input_ids = tokenizer(messages, return_tensors="pt", padding=True).input_ids.to(model.device)
    outputs = model.generate(
        input_ids,
        max_new_tokens=1024,
        eos_token_id=tokenizer.eos_token_id,
        do_sample=True,
        temperature=0.6,
        top_p=0.9
    )
    response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
    return response

def rag_chatbot_interface(prompt: str, k: int = 2):
    scores, retrieved_documents = search(prompt, k)
    formatted_prompt = format_prompt(prompt, retrieved_documents, k)
    return generate(formatted_prompt)

# Gradio μΈν„°νŽ˜μ΄μŠ€ μ„€μ •
iface = gr.Interface(
    fn=rag_chatbot_interface,
    inputs=gr.inputs.Textbox(label="Enter your question"),
    outputs=gr.outputs.Textbox(label="Answer"),
    title="Retrieval-Augmented Generation Chatbot",
    description="This is a chatbot that uses a retrieval-augmented generation approach to provide more accurate answers. It first searches for relevant documents and then generates a response based on the prompt and the retrieved documents."
)

iface.launch()