Spaces:
Runtime error
Runtime error
from sentence_transformers import SentenceTransformer | |
from datasets import load_dataset, Dataset | |
import faiss # νμν κ²½μ° faissλ₯Ό μν¬νΈν©λλ€. | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
import torch | |
import os | |
tokenkey=os.getenv('HF_API_KEY') | |
# λͺ¨λΈ λ° ν ν¬λμ΄μ μ€μ | |
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
quantization_config=BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
) | |
# λ°μ΄ν° λ‘λ© λ° faiss μΈλ±μ€ μμ± | |
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") | |
dataset = load_dataset("not-lain/wikipedia", revision="embedded") | |
data = dataset["train"] | |
data = data.add_faiss_index("embeddings") | |
# κ²μ λ° μλ΅ μμ± ν¨μ | |
def search(query: str, k: int = 3): | |
embedded_query = ST.encode(query) | |
scores, retrieved_examples = data.get_nearest_examples( | |
"embeddings", embedded_query, k=k | |
) | |
return scores, retrieved_examples | |
def format_prompt(prompt, retrieved_documents, k): | |
PROMPT = f"Question:{prompt}\nContext:" | |
for idx in range(k): | |
PROMPT += f"{retrieved_documents['text'][idx]}\n" | |
return PROMPT | |
def generate(formatted_prompt): | |
formatted_prompt = formatted_prompt[:2000] # GPU λ©λͺ¨λ¦¬ μ νμ κ³ λ € | |
messages = [{"role": "system", "content": "You are an assistant..."}, {"role": "user", "content": formatted_prompt}] | |
input_ids = tokenizer(messages, return_tensors="pt", padding=True).input_ids.to(model.device) | |
outputs = model.generate( | |
input_ids, | |
max_new_tokens=1024, | |
eos_token_id=[tokenizer.eos_token_id], | |
do_sample=True, | |
temperature=0.6, | |
top_p=0.9 | |
) | |
response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True) | |
return response | |
def rag_chatbot(prompt: str, k: int = 2): | |
scores, retrieved_documents = search(prompt, k) | |
formatted_prompt = format_prompt(prompt, retrieved_documents, k) | |
return generate(formatted_prompt) | |
rag_chatbot("What is anarchy?", k=2) | |
def rag_chatbot_interface(prompt:str,k:int=2): | |
scores , retrieved_documents = search(prompt, k) | |
formatted_prompt = format_prompt(prompt,retrieved_documents,k) | |
return generate(formatted_prompt) | |
SYS_PROMPT = """You are an assistant for answering questions. | |
You are given the extracted parts of a long document and a question. Provide a conversational answer. | |
If you don't know the answer, just say "I do not know." Don't make up an answer.""" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
quantization_config=bnb_config | |
) | |
terminators = [ | |
tokenizer.eos_token_id, | |
tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
] | |
iface = gr.Interface(fn=rag_chatbot_interface, | |
inputs="text", | |
outputs="text", | |
input_types=["text"], | |
output_types=["text"], | |
title="Retrieval-Augmented Generation Chatbot", | |
description="This is a chatbot that uses a retrieval-augmented generation approach to provide more accurate answers. It first searches for relevant documents and then generates a response based on the prompt and the retrieved documents." | |
) | |
iface.launch() |