File size: 3,351 Bytes
9918198
9ae4071
d2de08e
456ec91
 
9ae4071
9918198
d2de08e
1f97769
d2de08e
44a6b17
883f7e7
d2de08e
e26fb8b
1ed0d57
e26fb8b
 
 
 
 
 
2d84b3b
9ae4071
44a6b17
d2de08e
1b6e08f
 
2d84b3b
d2de08e
1b6e08f
9ae4071
953debe
d2de08e
1b6e08f
 
953debe
1b6e08f
 
d2de08e
 
 
9ae4071
1b6e08f
 
9918198
953debe
 
1b6e08f
 
 
 
 
953debe
 
9ae4071
4cc10ce
 
1b6e08f
 
 
9918198
1b6e08f
 
 
 
9ae4071
1b6e08f
9918198
1b6e08f
 
 
 
9ae4071
456ec91
 
9ae4071
9918198
 
 
 
 
4cc10ce
2829eb5
 
9918198
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import faiss
import gradio as gr
from accelerate import Accelerator

# ν™˜κ²½ λ³€μˆ˜μ—μ„œ Hugging Face API ν‚€ λ‘œλ“œ
hf_api_key = os.getenv('HF_API_KEY')

# λͺ¨λΈ ID 및 ν† ν¬λ‚˜μ΄μ € μ„€μ •
# λͺ¨λΈ ID
model_id = "microsoft/phi-2"

# μ‚¬μš©μž μ •μ˜ μ½”λ“œλ₯Ό μ‹ λ’°ν•˜κ³  μ‹€ν–‰ν•˜λ„λ‘ μ„€μ •
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True  # μ‚¬μš©μž μ •μ˜ μ½”λ“œ μ‹€ν–‰ ν—ˆμš©
)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
accelerator = Accelerator()

# μ–‘μžν™” μ„€μ • 없이 λͺ¨λΈ λ‘œλ“œ (문제 해결을 μœ„ν•œ μž„μ‹œ 쑰치)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    token=hf_api_key,
    torch_dtype=torch.float32  # κΈ°λ³Έ dtype μ‚¬μš©
)
model = accelerator.prepare(model)

# 데이터 λ‘œλ”© 및 faiss 인덱슀 생성
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
dataset = load_dataset("not-lain/wikipedia", revision="embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings")

# 기타 ν•¨μˆ˜ 및 Gradio μΈν„°νŽ˜μ΄μŠ€ ꡬ성은 이전과 동일


# Define functions for search, prompt formatting, and generation
def search(query: str, k: int = 3):
    embedded_query = ST.encode(query)
    scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
    return scores, retrieved_examples

def format_prompt(prompt, retrieved_documents, k):
    PROMPT = f"Question:{prompt}\nContext:"
    for idx in range(k):
        PROMPT += f"{retrieved_documents['text'][idx]}\n"
    return PROMPT

def generate(formatted_prompt):
    formatted_prompt = formatted_prompt[:2000]  # Limit due to GPU memory constraints
    messages = [{"role": "system", "content": "You are an assistant..."}, {"role": "user", "content": formatted_prompt}]
    input_ids = tokenizer(messages, return_tensors="pt", padding=True).input_ids.to(accelerator.device)
    outputs = model.generate(
        input_ids,
        max_new_tokens=1024,
        eos_token_id=tokenizer.eos_token_id,
        do_sample=True,
        temperature=0.6,
        top_p=0.9
    )
    return tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)

def rag_chatbot_interface(prompt: str, k: int = 2):
    scores, retrieved_documents = search(prompt, k)
    formatted_prompt = format_prompt(prompt, retrieved_documents, k)
    return generate(formatted_prompt)

# Define system prompt for the chatbot
SYS_PROMPT = "You are an assistant for answering questions. You are given the extracted parts of a long document and a question. Provide a conversational answer. If you don't know the answer, just say 'I do not know.' Don't make up an answer."

# Set up Gradio interface
iface = gr.Interface(
    fn=rag_chatbot_interface,
    inputs=gr.inputs.Textbox(label="Enter your question"),
    outputs=gr.outputs.Textbox(label="Answer"),
    title="Retrieval-Augmented Generation Chatbot",
    description="This chatbot uses a retrieval-augmented generation approach to provide more accurate answers. It first searches for relevant documents and then generates a response based on the prompt and the retrieved documents."
)

iface.launch()