Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from sentence_transformers import SentenceTransformer | |
from datasets import load_dataset | |
import faiss | |
import gradio as gr | |
from accelerate import Accelerator | |
hf_api_key = os.getenv('HF_API_KEY') | |
model_id = "microsoft/phi-2" | |
# model_id = "microsoft/Phi-3-mini-128k-instruct" | |
# ν ν¬λμ΄μ λ° λͺ¨λΈ μ€μ | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key, trust_remote_code=True) | |
# ν¨λ© ν ν° μ€μ | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
token=hf_api_key, | |
trust_remote_code=True, | |
torch_dtype=torch.float32 | |
) | |
accelerator = Accelerator() | |
model = accelerator.prepare(model) | |
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") | |
dataset = load_dataset("not-lain/wikipedia", revision="embedded") | |
data = dataset["train"] | |
data = data.add_faiss_index("embeddings") | |
def generate(formatted_prompt): | |
prompt_text = f"{SYS_PROMPT} {formatted_prompt}" | |
encoding = tokenizer(prompt_text, return_tensors="pt", padding="max_length", max_length=512, truncation=True) | |
input_ids = encoding['input_ids'].to(accelerator.device) | |
attention_mask = encoding['attention_mask'].to(accelerator.device) | |
outputs = model.generate( | |
input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=1024, | |
eos_token_id=tokenizer.eos_token_id, | |
do_sample=True, | |
temperature=0.6, | |
top_p=0.9 | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
def search(query: str, k: int = 3): | |
embedded_query = ST.encode(query) | |
scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k) | |
return scores, retrieved_examples | |
def format_prompt(prompt, retrieved_documents, k): | |
PROMPT = f"Question:{prompt}\nContext:" | |
for idx in range(k): | |
PROMPT += f"{retrieved_documents['text'][idx]}\n" | |
return PROMPT | |
def rag_chatbot_interface(prompt: str, k: int = 2): | |
scores, retrieved_documents = search(prompt, k) | |
formatted_prompt = format_prompt(prompt, retrieved_documents, k) | |
return generate(formatted_prompt) | |
SYS_PROMPT = "You are an assistant for answering questions. Provide a conversational answer." | |
iface = gr.Interface( | |
fn=rag_chatbot_interface, | |
inputs="text", | |
outputs="text", | |
title="Retrieval-Augmented Generation Chatbot", | |
description="This chatbot provides more accurate answers by searching relevant documents and generating responses." | |
) | |
iface.launch(share=True) | |