import os import torch from transformers import AutoTokenizer, AutoModelForCausalLM from sentence_transformers import SentenceTransformer from datasets import load_dataset import faiss import gradio as gr from accelerate import Accelerator hf_api_key = os.getenv('HF_API_KEY') model_id = "microsoft/phi-2" # model_id = "microsoft/Phi-3-mini-128k-instruct" # 토크나이저 및 모델 설정 tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key, trust_remote_code=True) # 패딩 토큰 설정 if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( model_id, token=hf_api_key, trust_remote_code=True, torch_dtype=torch.float32 ) accelerator = Accelerator() model = accelerator.prepare(model) ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") dataset = load_dataset("not-lain/wikipedia", revision="embedded") data = dataset["train"] data = data.add_faiss_index("embeddings") def generate(formatted_prompt): prompt_text = f"{SYS_PROMPT} {formatted_prompt}" encoding = tokenizer(prompt_text, return_tensors="pt", padding="max_length", max_length=512, truncation=True) input_ids = encoding['input_ids'].to(accelerator.device) attention_mask = encoding['attention_mask'].to(accelerator.device) outputs = model.generate( input_ids, attention_mask=attention_mask, max_new_tokens=1024, eos_token_id=tokenizer.eos_token_id, do_sample=True, temperature=0.6, top_p=0.9 ) return tokenizer.decode(outputs[0], skip_special_tokens=True) def search(query: str, k: int = 3): embedded_query = ST.encode(query) scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k) return scores, retrieved_examples def format_prompt(prompt, retrieved_documents, k): PROMPT = f"Question:{prompt}\nContext:" for idx in range(k): PROMPT += f"{retrieved_documents['text'][idx]}\n" return PROMPT def rag_chatbot_interface(prompt: str, k: int = 2): scores, retrieved_documents = search(prompt, k) formatted_prompt = format_prompt(prompt, retrieved_documents, k) return generate(formatted_prompt) SYS_PROMPT = "You are an assistant for answering questions. Provide a conversational answer." iface = gr.Interface( fn=rag_chatbot_interface, inputs="text", outputs="text", title="Retrieval-Augmented Generation Chatbot", description="This chatbot provides more accurate answers by searching relevant documents and generating responses." ) iface.launch(share=True)