rag / app.py
seawolf2357's picture
Update app.py
d726220 verified
raw
history blame
1.81 kB
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import faiss
import gradio as gr
from accelerate import Accelerator
hf_api_key = os.getenv('HF_API_KEY')
model_id = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_api_key,
trust_remote_code=True,
torch_dtype=torch.float32
)
accelerator = Accelerator()
model = accelerator.prepare(model)
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
dataset = load_dataset("not-lain/wikipedia", revision="embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings")
def search(query: str, k: int = 3):
embedded_query = ST.encode(query)
scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
return scores, retrieved_examples
def format_prompt(prompt, retrieved_documents, k):
PROMPT = f"Question:{prompt}\nContext:"
for idx in range(k):
PROMPT += f"{retrieved_documents['text'][idx]}\n"
return PROMPT
def rag_chatbot_interface(prompt: str, k: int = 2):
scores, retrieved_documents = search(prompt, k)
formatted_prompt = format_prompt(prompt, retrieved_documents, k)
return generate(formatted_prompt)
SYS_PROMPT = "You are an assistant for answering questions. Provide a conversational answer."
iface = gr.Interface(
fn=rag_chatbot_interface,
inputs="text",
outputs="text",
title="Retrieval-Augmented Generation Chatbot",
description="This chatbot provides more accurate answers by searching relevant documents and generating responses."
)
iface.launch(share=True)