Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from sentence_transformers import SentenceTransformer | |
from datasets import load_dataset | |
import faiss | |
import gradio as gr | |
from accelerate import Accelerator | |
hf_api_key = os.getenv('HF_API_KEY') | |
model_id = "microsoft/phi-2" | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
token=hf_api_key, | |
trust_remote_code=True, | |
torch_dtype=torch.float32 | |
) | |
accelerator = Accelerator() | |
model = accelerator.prepare(model) | |
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") | |
dataset = load_dataset("not-lain/wikipedia", revision="embedded") | |
data = dataset["train"] | |
data = data.add_faiss_index("embeddings") | |
def search(query: str, k: int = 3): | |
embedded_query = ST.encode(query) | |
scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k) | |
return scores, retrieved_examples | |
def format_prompt(prompt, retrieved_documents, k): | |
PROMPT = f"Question:{prompt}\nContext:" | |
for idx in range(k): | |
PROMPT += f"{retrieved_documents['text'][idx]}\n" | |
return PROMPT | |
def rag_chatbot_interface(prompt: str, k: int = 2): | |
scores, retrieved_documents = search(prompt, k) | |
formatted_prompt = format_prompt(prompt, retrieved_documents, k) | |
return generate(formatted_prompt) | |
SYS_PROMPT = "You are an assistant for answering questions. Provide a conversational answer." | |
iface = gr.Interface( | |
fn=rag_chatbot_interface, | |
inputs="text", | |
outputs="text", | |
title="Retrieval-Augmented Generation Chatbot", | |
description="This chatbot provides more accurate answers by searching relevant documents and generating responses." | |
) | |
iface.launch(share=True) | |