File size: 2,560 Bytes
d2de08e
456ec91
 
9ae4071
9918198
d2de08e
85b887d
 
1f97769
d2de08e
44a6b17
883f7e7
85b887d
1ed0d57
3d962a1
9872f0b
 
 
 
 
1b6e08f
 
2d84b3b
12218a1
 
1b6e08f
3d962a1
 
9ae4071
953debe
1b6e08f
 
953debe
1b6e08f
 
 
 
9918198
12218a1
953debe
1b6e08f
 
 
 
 
953debe
 
363bbc4
9872f0b
85b887d
 
 
 
 
 
 
 
3d962a1
1b6e08f
9918198
1b6e08f
 
 
 
3d962a1
000e3df
9918198
 
3d962a1
 
9918198
85b887d
2829eb5
 
9872f0b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import faiss
import gradio as gr
from accelerate import Accelerator
import os
import torch

# ν™˜κ²½ λ³€μˆ˜μ—μ„œ Hugging Face API ν‚€ λ‘œλ“œ
hf_api_key = os.getenv('HF_API_KEY')

# λͺ¨λΈ 및 ν† ν¬λ‚˜μ΄μ € μ„€μ •
model_id = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key, trust_remote_code=True)

# ν† ν¬λ‚˜μ΄μ €μ— νŒ¨λ”© 토큰 μ„€μ •
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token  # EOS 토큰을 νŒ¨λ”© ν† ν°μœΌλ‘œ μ‚¬μš©

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    token=hf_api_key,
    trust_remote_code=True,
    torch_dtype=torch.float32
)

accelerator = Accelerator()
model = accelerator.prepare(model)

ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
dataset = load_dataset("not-lain/wikipedia", revision="embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings")

def search(query: str, k: int = 3):
    embedded_query = ST.encode(query)
    scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
    return scores, retrieved_examples

def format_prompt(prompt, retrieved_documents, k):
    PROMPT = f"Question:{prompt}\nContext:"
    for idx in range(k):
        PROMPT += f"{retrieved_documents['text'][idx]}\n"
    return PROMPT

def generate(formatted_prompt):
    prompt_text = f"{SYS_PROMPT} {formatted_prompt}"
    input_ids = tokenizer(prompt_text, return_tensors="pt", padding="max_length", max_length=512).input_ids.to(accelerator.device)
    outputs = model.generate(
        input_ids,
        max_new_tokens=1024,
        eos_token_id=tokenizer.eos_token_id,
        do_sample=True,
        temperature=0.6,
        top_p=0.9
    )
    return tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)

def rag_chatbot_interface(prompt: str, k: int = 2):
    scores, retrieved_documents = search(prompt, k)
    formatted_prompt = format_prompt(prompt, retrieved_documents, k)
    return generate(formatted_prompt)

SYS_PROMPT = "You are an assistant for answering questions. Provide a conversational answer."

iface = gr.Interface(
    fn=rag_chatbot_interface,
    inputs="text",
    outputs="text",
    title="Retrieval-Augmented Generation Chatbot",
    description="This chatbot provides more accurate answers by searching relevant documents and generating responses."
)

iface.launch(share=True)