{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "wDmLkBbZmJvB" }, "outputs": [], "source": [ "# ===============================\n", "# 1. 라이브러리 설치 (Google Colab)\n", "# ===============================\n", "!pip install unsloth xformers faiss-gpu-cu12 -U\n", "!pip install --no-deps --upgrade \"flash-attn>=2.6.3\"\n", "!pip install -U hf_transfer" ] }, { "cell_type": "code", "source": [ "# ===============================\n", "# 2. 환경 설정\n", "# ===============================\n", "import os\n", "import torch\n", "import numpy as np\n", "import faiss\n", "import json\n", "import ast\n", "from transformers import TextStreamer\n", "from sentence_transformers import SentenceTransformer\n", "from unsloth import FastLanguageModel\n", "from huggingface_hub import hf_hub_download\n", "\n", "os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"" ], "metadata": { "id": "OsEBB0aKmhBy" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# ===============================\n", "# 3. 모델 로드\n", "# ===============================\n", "model, tokenizer = FastLanguageModel.from_pretrained(\n", " model_name=\"Austin9/gemma-2-9b-it-Ko-RAG\",\n", " max_seq_length=8192,\n", " dtype=torch.float16,\n", " load_in_4bit=True\n", ")\n", "FastLanguageModel.for_inference(model)" ], "metadata": { "id": "ENT1FgZZmizd" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# ===============================\n", "# 4. FAISS 인덱스 로드 (Hugging Face Hub에서 직접 다운로드)\n", "# ===============================\n", "repo_id = \"Austin9/gemma-2-9b-it-Ko-RAG\" # 허깅페이스 저장소 ID\n", "filename = \"chunked_data_vectors.npz\" # 저장된 npz 파일 이름\n", "\n", "vector_db_path = hf_hub_download(repo_id=repo_id, filename=filename)\n", "data = np.load(vector_db_path)\n", "vectors, texts, titles = data[\"vectors\"], data[\"texts\"], data[\"titles\"]\n", "\n", "gpu_resources = faiss.StandardGpuResources()\n", "faiss_index = faiss.GpuIndexFlatL2(gpu_resources, vectors.shape[1])\n", "faiss_index.add(vectors)" ], "metadata": { "id": "9H7Xcc9GmkQ8" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# ===============================\n", "# 5. 임베딩 모델 로드\n", "# ===============================\n", "embedding_model = SentenceTransformer(\"nlpai-lab/KURE-v1\", device=\"cuda\").to(torch.float16)" ], "metadata": { "id": "EwpMV0kXmpSX" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# ===============================\n", "# 6. JSON 파싱 함수\n", "# ===============================\n", "def robust_parse_json(response_text):\n", " response_text = response_text.strip().strip(\"'\").strip('\"').replace(\"'\", '\"')\n", " try:\n", " return json.loads(response_text)\n", " except:\n", " try:\n", " return ast.literal_eval(response_text)\n", " except:\n", " return {\"search\": \"\"}\n", "\n", "# ===============================\n", "# 7. 검색 쿼리 생성 (QCR 단계)\n", "# ===============================\n", "def generate_search_query(conversation_history, user_input):\n", " instruction = (\n", " \"다음은 대화 기록(Context)와 사용자의 질문(Input)입니다. \"\n", " \"사용자의 질문에 답을 제공하기 위해 필요한 단일 문자열 검색 쿼리를 생성하세요. \"\n", " \"검색이 필요하지 않거나 검색이 불필요한 경우(인사나, 겉치레, 농담) 빈 문자열을 반환하세요.\\n\\n\"\n", " \"최종 출력 형식은 {'search': '<검색 쿼리>'}입니다.\"\n", " )\n", " prompt = f\"\"\"\n", " # Query Rewriter\n", " ### Instruction:\n", " {instruction}\n", " ### Conversation:\n", " {'\\n'.join([f'{role}: {msg}' for role, msg in conversation_history])}\n", " ### Input:\n", " {user_input}\n", " ### Response:\n", " \"\"\"\n", "\n", " inputs = tokenizer([prompt], return_tensors=\"pt\").to(\"cuda\")\n", " output_tokens = model.generate(**inputs, max_new_tokens=300)\n", " response_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True).split(\"### Response:\")[-1].strip()\n", " return robust_parse_json(response_text).get(\"search\", \"\")\n", "\n", "# ===============================\n", "# 8. FAISS 검색\n", "# ===============================\n", "def search_documents(query, k=3):\n", " if not query:\n", " return \"\"\n", " query_vector = embedding_model.encode([query])[0]\n", " _, indices = faiss_index.search(np.array([query_vector]), k)\n", " return \"\\n\\n\".join([f\"# Index [{i+1}]: {titles[idx]}\\n{texts[idx]}\" for i, idx in enumerate(indices[0])])\n", "\n", "# ===============================\n", "# 9. 답변 생성\n", "# ===============================\n", "def generate_response(conversation_history, context, user_input):\n", " instruction = (\n", " \"당신은 외부검색을 이용하여 사용자에게 도움을 주는 인공지능 조수입니다.\\n\"\n", " \"- Context는 외부검색을 통해 반환된 사용자 요청과 관련된 정보들입니다.\\n\"\n", " \"- Context를 활용할 때 문장 끝에 사용한 문서 조각의 [Index]를 붙이고 자연스러운 답변을 작성하세요. (e.g. [1])\\n\"\n", " \"- Context의 정보가 사용자 요청과 관련이 없거나 도움이 안될수도 있습니다. 관련있는 정보만 활용하고, 없는 정보를 절대 지어내지 마세요.\\n\"\n", " \"- 되도록이면 일반 지식으로 답변하지말고, 최대한 Context를 통해서 답변을 하려고 하세요. Context에 없을 경우에는 이 점을 언급하며 사죄하고 다른 주제나 질문을 추천해주세요.\\n\"\n", " \"- 사용자 요청에 알맞는 자연스러운 대화를 하세요.\\n\"\n", " \"- 항상 존댓말로 답변하세요.\"\n", " )\n", "\n", " prompt = f\"\"\"\n", " # Generator\n", " ### Instruction:\n", " {instruction}\n", " ### Conversation:\n", " {'\\n'.join([f'{role}: {msg}' for role, msg in conversation_history])}\n", " ### Context:\n", " {context}\n", " ### Input:\n", " {user_input}\n", " ### Response:\n", " \"\"\"\n", "\n", " inputs = tokenizer([prompt], return_tensors=\"pt\").to(\"cuda\")\n", " output_tokens = model.generate(**inputs, max_new_tokens=2500, do_sample=True)\n", " return tokenizer.decode(output_tokens[0], skip_special_tokens=True).split(\"### Response:\")[-1].strip()" ], "metadata": { "id": "Nsv2Xp2kmp1S" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# ===============================\n", "# 10. 대화 루프\n", "# ===============================\n", "def chat_loop():\n", " conversation_history = []\n", " print(\"대화를 시작합니다. 'exit' 입력 시 종료.\")\n", "\n", " while True:\n", " user_input = input(\"\\nUser> \").strip()\n", " if user_input.lower() in [\"exit\", \"quit\"]:\n", " print(\"대화를 종료합니다.\")\n", " break\n", "\n", " print(\"\\n[검색 쿼리 생성 중...]\")\n", " search_query = generate_search_query(conversation_history, user_input)\n", " context = search_documents(search_query, k=5) if search_query else \"\"\n", "\n", " print(\"\\n[답변 생성 중...]\")\n", " response = generate_response(conversation_history, context, user_input)\n", "\n", " conversation_history.append((\"User\", user_input))\n", " conversation_history.append((\"Assistant\", response))\n", " print(f\"\\nAssistant> {response}\")\n", "\n", "if __name__ == \"__main__\":\n", " chat_loop()" ], "metadata": { "id": "4XD0UDZImsuE" }, "execution_count": null, "outputs": [] } ] }