rag / app.py
seawolf2357's picture
Update app.py
e9071d1 verified
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import faiss
import gradio as gr
from accelerate import Accelerator
hf_api_key = os.getenv('HF_API_KEY')
model_id = "microsoft/phi-2"
# model_id = "microsoft/Phi-3-mini-128k-instruct"
# ν† ν¬λ‚˜μ΄μ € 및 λͺ¨λΈ μ„€μ •
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key, trust_remote_code=True)
# νŒ¨λ”© 토큰 μ„€μ •
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_api_key,
trust_remote_code=True,
torch_dtype=torch.float32
)
accelerator = Accelerator()
model = accelerator.prepare(model)
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
dataset = load_dataset("not-lain/wikipedia", revision="embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings")
def generate(formatted_prompt):
prompt_text = f"{SYS_PROMPT} {formatted_prompt}"
encoding = tokenizer(prompt_text, return_tensors="pt", padding="max_length", max_length=512, truncation=True)
input_ids = encoding['input_ids'].to(accelerator.device)
attention_mask = encoding['attention_mask'].to(accelerator.device)
outputs = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=1024,
eos_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=0.6,
top_p=0.9
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def search(query: str, k: int = 3):
embedded_query = ST.encode(query)
scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
return scores, retrieved_examples
def format_prompt(prompt, retrieved_documents, k):
PROMPT = f"Question:{prompt}\nContext:"
for idx in range(k):
PROMPT += f"{retrieved_documents['text'][idx]}\n"
return PROMPT
def rag_chatbot_interface(prompt: str, k: int = 2):
scores, retrieved_documents = search(prompt, k)
formatted_prompt = format_prompt(prompt, retrieved_documents, k)
return generate(formatted_prompt)
SYS_PROMPT = "You are an assistant for answering questions. Provide a conversational answer."
iface = gr.Interface(
fn=rag_chatbot_interface,
inputs="text",
outputs="text",
title="Retrieval-Augmented Generation Chatbot",
description="This chatbot provides more accurate answers by searching relevant documents and generating responses."
)
iface.launch(share=True)