Spaces:
Runtime error
Runtime error
File size: 3,346 Bytes
9918198 44a6b17 4cc10ce 456ec91 9918198 44a6b17 1f97769 44a6b17 883f7e7 44a6b17 75c1fd6 2d84b3b 4cc10ce 44a6b17 1b6e08f 2d84b3b 44a6b17 1b6e08f 4cc10ce 1b6e08f 4cc10ce 953debe 44a6b17 1b6e08f 953debe 1b6e08f 44a6b17 1b6e08f 9918198 953debe 44a6b17 1b6e08f 953debe 1b6e08f 4cc10ce 1b6e08f 9918198 1b6e08f 9918198 1b6e08f 456ec91 9918198 4cc10ce 2829eb5 9918198 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import os
import torch # torchλ₯Ό μν¬νΈ
import faiss
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import gradio as gr
from accelerate import Accelerator
# νκ²½ λ³μμμ Hugging Face API ν€ λ‘λ
hf_api_key = os.getenv('HF_API_KEY')
# λͺ¨λΈ ID λ° ν ν¬λμ΄μ μ€μ
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
accelerator = Accelerator() # Accelerator μΈμ€ν΄μ€ μμ±
# λͺ¨λΈ λ‘λ©
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_api_key,
torch_dtype=torch.bfloat16, # torchλ₯Ό μ¬μ©ν΄ λ°μ΄ν° νμ
μ§μ
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
)
model = accelerator.prepare(model) # λͺ¨λΈμ Acceleratorμ μ€λΉμν΄
# λ°μ΄ν° λ‘λ© λ° faiss μΈλ±μ€ μμ±
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
dataset = load_dataset("not-lain/wikipedia", revision="embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings")
# κ²μ λ° μλ΅ μμ± ν¨μ
def search(query: str, k: int = 3):
embedded_query = ST.encode(query)
scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
return scores, retrieved_examples
# λλ¨Έμ§ μ½λλ μ΄μ κ³Ό λμΌνκ² μ μ§
def format_prompt(prompt, retrieved_documents, k):
PROMPT = f"Question:{prompt}\nContext:"
for idx in range(k):
PROMPT += f"{retrieved_documents['text'][idx]}\n"
return PROMPT
def generate(formatted_prompt):
formatted_prompt = formatted_prompt[:2000] # GPU λ©λͺ¨λ¦¬ μ νμ κ³ λ €
messages = [{"role": "system", "content": "You are an assistant..."}, {"role": "user", "content": formatted_prompt}]
input_ids = tokenizer(messages, return_tensors="pt", padding=True).input_ids.to(accelerator.device)
outputs = model.generate(
input_ids,
max_new_tokens=1024,
eos_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=0.6,
top_p=0.9
)
response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
return response
def rag_chatbot_interface(prompt: str, k: int = 2):
scores, retrieved_documents = search(prompt, k)
formatted_prompt = format_prompt(prompt, retrieved_documents, k)
return generate(formatted_prompt)
SYS_PROMPT = "You are an assistant for answering questions. You are given the extracted parts of a long document and a question. Provide a conversational answer. If you don't know the answer, just say 'I do not know.' Don't make up an answer."
iface = gr.Interface(
fn=rag_chatbot_interface,
inputs=gr.inputs.Textbox(label="Enter your question"),
outputs=gr.outputs.Textbox(label="Answer"),
title="Retrieval-Augmented Generation Chatbot",
description="This chatbot uses a retrieval-augmented generation approach to provide more accurate answers. It first searches for relevant documents and then generates a response based on the prompt and the retrieved documents."
)
iface.launch()
|