{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "from tqdm import tqdm\n",
    "from typing import List, Any\n",
    "from langchain.chains import RetrievalQA\n",
    "from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings\n",
    "from langchain.document_loaders import TextLoader\n",
    "from langchain.indexes import VectorstoreIndexCreator\n",
    "from langchain.text_splitter import CharacterTextSplitter\n",
    "from langchain.vectorstores import FAISS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "docs = []\n",
    "metadata = []\n",
    "for p in Path(\"./datasets/huggingface_docs/\").iterdir():\n",
    "    if not p.is_dir():\n",
    "        with open(p) as f:\n",
    "            # the first line is the source of the text\n",
    "            source = f.readline().strip().replace('source: ', '')\n",
    "            docs.append(f.read())\n",
    "            metadata.append({\"source\": source})\n",
    "\n",
    "print(f'number of documents: {len(docs)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_splitter = CharacterTextSplitter(\n",
    "    separator=\"\",\n",
    "    chunk_size=812,\n",
    "    chunk_overlap=100,\n",
    "    length_function=len,\n",
    ")\n",
    "docs = text_splitter.create_documents(docs, metadata)\n",
    "print(f'number of chunks: {len(docs)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"hkunlp/instructor-large\"\n",
    "embed_instruction = \"Represent the Hugging Face library documentation\"\n",
    "query_instruction = \"Query the most relevant piece of information from the Hugging Face documentation\"\n",
    "\n",
    "# embedding_model = HuggingFaceInstructEmbeddings(\n",
    "#     model_name=model_name,\n",
    "#     embed_instruction=embed_instruction,\n",
    "#     query_instruction=query_instruction,\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AverageInstructEmbeddings(HuggingFaceInstructEmbeddings):\n",
    "    max_length: int = None\n",
    "\n",
    "    def __init__(self, max_length: int = 512, **kwargs: Any):\n",
    "        super().__init__(**kwargs)\n",
    "        self.max_length = max_length\n",
    "        if self.max_length < 0:\n",
    "            print('max_length is not specified, using model default max_seq_length')\n",
    "\n",
    "    def embed_documents(self, texts: List[str]) -> List[List[float]]:\n",
    "        all_embeddings = []\n",
    "        for text in tqdm(texts, desc=\"Embedding documents\"):\n",
    "            if len(text) > self.max_length and self.max_length > -1:\n",
    "                n_chunks = math.ceil(len(text)/self.max_length)\n",
    "                chunks = [\n",
    "                    text[i*self.max_length:(i+1)*self.max_length]\n",
    "                    for i in range(n_chunks)\n",
    "                ]\n",
    "                instruction_pairs = [[self.embed_instruction, chunk] for chunk in chunks]\n",
    "                chunk_embeddings = self.client.encode(instruction_pairs)\n",
    "                avg_embedding = np.mean(chunk_embeddings, axis=0)\n",
    "                all_embeddings.append(avg_embedding.tolist())\n",
    "            else:\n",
    "                instruction_pairs = [[self.embed_instruction, text]]\n",
    "                embeddings = self.client.encode(instruction_pairs)\n",
    "                all_embeddings.append(embeddings[0].tolist())\n",
    "\n",
    "        return all_embeddings\n",
    "\n",
    "\n",
    "embedding_model = AverageInstructEmbeddings(  \n",
    "    model_name=model_name,\n",
    "    embed_instruction=embed_instruction,\n",
    "    query_instruction=query_instruction,\n",
    "    max_length=512,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = embedding_model.embed_documents(texts=[d.page_content for d in docs[:10]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index = FAISS.from_documents(docs, embedding_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index.save_local('../indexes/index-large-notebooks/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index = FAISS.load_local(f'../indexes/index-large-notebooks/', embedding_model)\n",
    "docs = index.similarity_search(query='how to create a pipeline object?', k=5)\n",
    "docs[0].page_content\n",
    "docs[0].metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for index, doc in enumerate(docs, start=1):\n",
    "    print(f\"\\n{'='*100}\\n\")\n",
    "    print(f\"Document {index} of {len(docs)}\")\n",
    "    print(\"Page Content:\")\n",
    "    print(f\"\\n{'-'*100}\\n\")\n",
    "    print(doc.page_content, '\\n')\n",
    "    print(doc.metadata)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import HfApi\n",
    "\n",
    "index_name = 'index-large-notebooks'\n",
    "\n",
    "api = HfApi()\n",
    "api.create_repo(\n",
    "    repo_id=f'KonradSzafer/{index_name}',\n",
    "    repo_type='dataset',\n",
    "    private=False,\n",
    "    exist_ok=True\n",
    ")\n",
    "api.upload_folder(\n",
    "    folder_path=f'../indexes/{index_name}',\n",
    "    repo_id=f'KonradSzafer/{index_name}',\n",
    "    repo_type='dataset',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hf_qa_bot",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}