Spaces:
Runtime error
Runtime error
File size: 2,990 Bytes
9918198 953debe 1b6e08f 9918198 1b6e08f 9918198 1f97769 883f7e7 953debe 1b6e08f 75c1fd6 1b6e08f 953debe 1b6e08f 953debe 1b6e08f 9918198 953debe 1b6e08f 953debe 1b6e08f 9918198 1b6e08f 9918198 1b6e08f 9918198 1b6e08f 9918198 2829eb5 9918198 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from datasets import load_dataset, Dataset
import faiss
import torch
import gradio as gr
# νκ²½ λ³μμμ API ν€ λ‘λ
hf_api_key = os.getenv('HF_API_KEY')
# λͺ¨λΈ λ° ν ν¬λμ΄μ λ‘λν λ API ν€ μ¬μ©
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_api_key)
model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=hf_api_key)
# λͺ¨λΈ λ° ν ν¬λμ΄μ μ€μ
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
)
# λ°μ΄ν° λ‘λ© λ° faiss μΈλ±μ€ μμ±
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
dataset = load_dataset("not-lain/wikipedia", revision="embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings")
# κ²μ λ° μλ΅ μμ± ν¨μ
def search(query: str, k: int = 3):
embedded_query = ST.encode(query)
scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
return scores, retrieved_examples
def format_prompt(prompt, retrieved_documents, k):
PROMPT = f"Question:{prompt}\nContext:"
for idx in range(k):
PROMPT += f"{retrieved_documents['text'][idx]}\n"
return PROMPT
def generate(formatted_prompt):
formatted_prompt = formatted_prompt[:2000] # GPU λ©λͺ¨λ¦¬ μ νμ κ³ λ €
messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
input_ids = tokenizer(messages, return_tensors="pt", padding=True).input_ids.to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=1024,
eos_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=0.6,
top_p=0.9
)
response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
return response
def rag_chatbot_interface(prompt: str, k: int = 2):
scores, retrieved_documents = search(prompt, k)
formatted_prompt = format_prompt(prompt, retrieved_documents, k)
return generate(formatted_prompt)
# Gradio μΈν°νμ΄μ€ μ€μ
iface = gr.Interface(
fn=rag_chatbot_interface,
inputs=gr.inputs.Textbox(label="Enter your question"),
outputs=gr.outputs.Textbox(label="Answer"),
title="Retrieval-Augmented Generation Chatbot",
description="This is a chatbot that uses a retrieval-augmented generation approach to provide more accurate answers. It first searches for relevant documents and then generates a response based on the prompt and the retrieved documents."
)
iface.launch()
|