File size: 2,630 Bytes
d726220
 
d2de08e
456ec91
 
9ae4071
9918198
d2de08e
1f97769
44a6b17
e9071d1
8870678
e9071d1
dbd7c99
 
3d962a1
dbd7c99
 
 
 
1b6e08f
 
2d84b3b
12218a1
 
1b6e08f
3d962a1
 
9ae4071
953debe
1b6e08f
f4e7415
953debe
1b6e08f
 
4ad9b62
 
 
 
 
 
 
 
dbd7c99
4ad9b62
 
 
 
 
 
 
 
1b6e08f
 
9918198
12218a1
953debe
1b6e08f
 
 
 
 
953debe
9918198
1b6e08f
 
 
 
3d962a1
000e3df
9918198
 
3d962a1
 
9918198
85b887d
2829eb5
 
9872f0b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import faiss
import gradio as gr
from accelerate import Accelerator

hf_api_key = os.getenv('HF_API_KEY')
model_id = "microsoft/phi-2"

# model_id = "microsoft/Phi-3-mini-128k-instruct"

# ν† ν¬λ‚˜μ΄μ € 및 λͺ¨λΈ μ„€μ •
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key, trust_remote_code=True)
# νŒ¨λ”© 토큰 μ„€μ •
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    token=hf_api_key,
    trust_remote_code=True,
    torch_dtype=torch.float32
)

accelerator = Accelerator()
model = accelerator.prepare(model)

ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
dataset = load_dataset("not-lain/wikipedia", revision="embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings")

def generate(formatted_prompt):
    prompt_text = f"{SYS_PROMPT} {formatted_prompt}"
    encoding = tokenizer(prompt_text, return_tensors="pt", padding="max_length", max_length=512, truncation=True)
    input_ids = encoding['input_ids'].to(accelerator.device)
    attention_mask = encoding['attention_mask'].to(accelerator.device)

    outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=1024,
        eos_token_id=tokenizer.eos_token_id,
        do_sample=True,
        temperature=0.6,
        top_p=0.9
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def search(query: str, k: int = 3):
    embedded_query = ST.encode(query)
    scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
    return scores, retrieved_examples

def format_prompt(prompt, retrieved_documents, k):
    PROMPT = f"Question:{prompt}\nContext:"
    for idx in range(k):
        PROMPT += f"{retrieved_documents['text'][idx]}\n"
    return PROMPT

def rag_chatbot_interface(prompt: str, k: int = 2):
    scores, retrieved_documents = search(prompt, k)
    formatted_prompt = format_prompt(prompt, retrieved_documents, k)
    return generate(formatted_prompt)

SYS_PROMPT = "You are an assistant for answering questions. Provide a conversational answer."

iface = gr.Interface(
    fn=rag_chatbot_interface,
    inputs="text",
    outputs="text",
    title="Retrieval-Augmented Generation Chatbot",
    description="This chatbot provides more accurate answers by searching relevant documents and generating responses."
)

iface.launch(share=True)