Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from sentence_transformers import SentenceTransformer | |
from datasets import load_dataset | |
import faiss | |
import gradio as gr | |
from accelerate import Accelerator | |
# νκ²½ λ³μμμ Hugging Face API ν€ λ‘λ | |
hf_api_key = os.getenv('HF_API_KEY') | |
# λͺ¨λΈ ID λ° ν ν¬λμ΄μ μ€μ | |
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key) | |
accelerator = Accelerator() | |
# μμν μ€μ μμ΄ λͺ¨λΈ λ‘λ (λ¬Έμ ν΄κ²°μ μν μμ μ‘°μΉ) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
token=hf_api_key, | |
torch_dtype=torch.float32 # κΈ°λ³Έ dtype μ¬μ© | |
) | |
model = accelerator.prepare(model) | |
# λ°μ΄ν° λ‘λ© λ° faiss μΈλ±μ€ μμ± | |
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") | |
dataset = load_dataset("not-lain/wikipedia", revision="embedded") | |
data = dataset["train"] | |
data = data.add_faiss_index("embeddings") | |
# κΈ°ν ν¨μ λ° Gradio μΈν°νμ΄μ€ ꡬμ±μ μ΄μ κ³Ό λμΌ | |
# Define functions for search, prompt formatting, and generation | |
def search(query: str, k: int = 3): | |
embedded_query = ST.encode(query) | |
scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k) | |
return scores, retrieved_examples | |
def format_prompt(prompt, retrieved_documents, k): | |
PROMPT = f"Question:{prompt}\nContext:" | |
for idx in range(k): | |
PROMPT += f"{retrieved_documents['text'][idx]}\n" | |
return PROMPT | |
def generate(formatted_prompt): | |
formatted_prompt = formatted_prompt[:2000] # Limit due to GPU memory constraints | |
messages = [{"role": "system", "content": "You are an assistant..."}, {"role": "user", "content": formatted_prompt}] | |
input_ids = tokenizer(messages, return_tensors="pt", padding=True).input_ids.to(accelerator.device) | |
outputs = model.generate( | |
input_ids, | |
max_new_tokens=1024, | |
eos_token_id=tokenizer.eos_token_id, | |
do_sample=True, | |
temperature=0.6, | |
top_p=0.9 | |
) | |
return tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True) | |
def rag_chatbot_interface(prompt: str, k: int = 2): | |
scores, retrieved_documents = search(prompt, k) | |
formatted_prompt = format_prompt(prompt, retrieved_documents, k) | |
return generate(formatted_prompt) | |
# Define system prompt for the chatbot | |
SYS_PROMPT = "You are an assistant for answering questions. You are given the extracted parts of a long document and a question. Provide a conversational answer. If you don't know the answer, just say 'I do not know.' Don't make up an answer." | |
# Set up Gradio interface | |
iface = gr.Interface( | |
fn=rag_chatbot_interface, | |
inputs=gr.inputs.Textbox(label="Enter your question"), | |
outputs=gr.outputs.Textbox(label="Answer"), | |
title="Retrieval-Augmented Generation Chatbot", | |
description="This chatbot uses a retrieval-augmented generation approach to provide more accurate answers. It first searches for relevant documents and then generates a response based on the prompt and the retrieved documents." | |
) | |
iface.launch() | |