File size: 3,346 Bytes
9918198
44a6b17
 
4cc10ce
456ec91
 
9918198
44a6b17
1f97769
44a6b17
 
883f7e7
44a6b17
75c1fd6
2d84b3b
4cc10ce
44a6b17
 
1b6e08f
 
2d84b3b
44a6b17
1b6e08f
4cc10ce
 
 
1b6e08f
 
 
4cc10ce
953debe
44a6b17
1b6e08f
 
953debe
1b6e08f
 
44a6b17
1b6e08f
 
9918198
953debe
 
44a6b17
 
 
1b6e08f
 
 
 
 
953debe
 
1b6e08f
4cc10ce
 
1b6e08f
 
 
9918198
1b6e08f
 
 
 
 
 
 
9918198
1b6e08f
 
 
 
456ec91
 
9918198
 
 
 
 
4cc10ce
2829eb5
 
9918198
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import torch  # torchλ₯Ό μž„ν¬νŠΈ
import faiss
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import gradio as gr
from accelerate import Accelerator

# ν™˜κ²½ λ³€μˆ˜μ—μ„œ Hugging Face API ν‚€ λ‘œλ“œ
hf_api_key = os.getenv('HF_API_KEY')

# λͺ¨λΈ ID 및 ν† ν¬λ‚˜μ΄μ € μ„€μ •
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
accelerator = Accelerator()  # Accelerator μΈμŠ€ν„΄μŠ€ 생성

# λͺ¨λΈ λ‘œλ”©
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    token=hf_api_key,
    torch_dtype=torch.bfloat16,  # torchλ₯Ό μ‚¬μš©ν•΄ 데이터 νƒ€μž… μ§€μ •
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
)
model = accelerator.prepare(model)  # λͺ¨λΈμ„ Accelerator에 μ€€λΉ„μ‹œν‚΄

# 데이터 λ‘œλ”© 및 faiss 인덱슀 생성
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
dataset = load_dataset("not-lain/wikipedia", revision="embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings")

# 검색 및 응닡 생성 ν•¨μˆ˜
def search(query: str, k: int = 3):
    embedded_query = ST.encode(query)
    scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
    return scores, retrieved_examples

# λ‚˜λ¨Έμ§€ μ½”λ“œλŠ” 이전과 λ™μΌν•˜κ²Œ μœ μ§€


def format_prompt(prompt, retrieved_documents, k):
    PROMPT = f"Question:{prompt}\nContext:"
    for idx in range(k):
        PROMPT += f"{retrieved_documents['text'][idx]}\n"
    return PROMPT

def generate(formatted_prompt):
    formatted_prompt = formatted_prompt[:2000]  # GPU λ©”λͺ¨λ¦¬ μ œν•œμ„ κ³ λ €
    messages = [{"role": "system", "content": "You are an assistant..."}, {"role": "user", "content": formatted_prompt}]
    input_ids = tokenizer(messages, return_tensors="pt", padding=True).input_ids.to(accelerator.device)
    outputs = model.generate(
        input_ids,
        max_new_tokens=1024,
        eos_token_id=tokenizer.eos_token_id,
        do_sample=True,
        temperature=0.6,
        top_p=0.9
    )
    response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
    return response

def rag_chatbot_interface(prompt: str, k: int = 2):
    scores, retrieved_documents = search(prompt, k)
    formatted_prompt = format_prompt(prompt, retrieved_documents, k)
    return generate(formatted_prompt)

SYS_PROMPT = "You are an assistant for answering questions. You are given the extracted parts of a long document and a question. Provide a conversational answer. If you don't know the answer, just say 'I do not know.' Don't make up an answer."

iface = gr.Interface(
    fn=rag_chatbot_interface,
    inputs=gr.inputs.Textbox(label="Enter your question"),
    outputs=gr.outputs.Textbox(label="Answer"),
    title="Retrieval-Augmented Generation Chatbot",
    description="This chatbot uses a retrieval-augmented generation approach to provide more accurate answers. It first searches for relevant documents and then generates a response based on the prompt and the retrieved documents."
)

iface.launch()