Spaces:
Runtime error
Runtime error
File size: 3,308 Bytes
9918198 9ae4071 d2de08e 456ec91 9ae4071 9918198 d2de08e 1f97769 d2de08e 44a6b17 883f7e7 d2de08e e26fb8b 1ed0d57 e26fb8b 2d84b3b 9ae4071 44a6b17 d2de08e 1b6e08f 2d84b3b d2de08e 1b6e08f 9ae4071 953debe d2de08e 1b6e08f 953debe 1b6e08f d2de08e 9ae4071 1b6e08f 9918198 953debe 1b6e08f 953debe 9ae4071 4cc10ce 1b6e08f 9918198 1b6e08f 9ae4071 1b6e08f 9918198 1b6e08f 9ae4071 456ec91 000e3df 9918198 000e3df 9918198 000e3df 2829eb5 9918198 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import faiss
import gradio as gr
from accelerate import Accelerator
# νκ²½ λ³μμμ Hugging Face API ν€ λ‘λ
hf_api_key = os.getenv('HF_API_KEY')
# λͺ¨λΈ ID λ° ν ν¬λμ΄μ μ€μ
# λͺ¨λΈ ID
model_id = "microsoft/phi-2"
# μ¬μ©μ μ μ μ½λλ₯Ό μ λ’°νκ³ μ€ννλλ‘ μ€μ
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True # μ¬μ©μ μ μ μ½λ μ€ν νμ©
)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
accelerator = Accelerator()
# μμν μ€μ μμ΄ λͺ¨λΈ λ‘λ (λ¬Έμ ν΄κ²°μ μν μμ μ‘°μΉ)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_api_key,
torch_dtype=torch.float32 # κΈ°λ³Έ dtype μ¬μ©
)
model = accelerator.prepare(model)
# λ°μ΄ν° λ‘λ© λ° faiss μΈλ±μ€ μμ±
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
dataset = load_dataset("not-lain/wikipedia", revision="embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings")
# κΈ°ν ν¨μ λ° Gradio μΈν°νμ΄μ€ ꡬμ±μ μ΄μ κ³Ό λμΌ
# Define functions for search, prompt formatting, and generation
def search(query: str, k: int = 3):
embedded_query = ST.encode(query)
scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
return scores, retrieved_examples
def format_prompt(prompt, retrieved_documents, k):
PROMPT = f"Question:{prompt}\nContext:"
for idx in range(k):
PROMPT += f"{retrieved_documents['text'][idx]}\n"
return PROMPT
def generate(formatted_prompt):
formatted_prompt = formatted_prompt[:2000] # Limit due to GPU memory constraints
messages = [{"role": "system", "content": "You are an assistant..."}, {"role": "user", "content": formatted_prompt}]
input_ids = tokenizer(messages, return_tensors="pt", padding=True).input_ids.to(accelerator.device)
outputs = model.generate(
input_ids,
max_new_tokens=1024,
eos_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=0.6,
top_p=0.9
)
return tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
def rag_chatbot_interface(prompt: str, k: int = 2):
scores, retrieved_documents = search(prompt, k)
formatted_prompt = format_prompt(prompt, retrieved_documents, k)
return generate(formatted_prompt)
# Define system prompt for the chatbot
SYS_PROMPT = "You are an assistant for answering questions. You are given the extracted parts of a long document and a question. Provide a conversational answer. If you don't know the answer, just say 'I do not know.' Don't make up an answer."
iface = gr.Interface(
fn=rag_chatbot_interface,
inputs="text", # ν
μ€νΈ μ
λ ₯
outputs="text", # ν
μ€νΈ μΆλ ₯
title="Retrieval-Augmented Generation Chatbot",
description="This is a chatbot that uses a retrieval-augmented generation approach to provide more accurate answers. It first searches for relevant documents and then generates a response based on the prompt and the retrieved documents."
)
iface.launch()
|