File size: 6,519 Bytes
99f7ccf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Note: you may need to restart the kernel to use updated packages.\n",
      "Created 915 chunks from 2 PDF files\n",
      "Query: What are the key principles of the AI Bill of Rights?\n",
      "\n",
      "Response:\n",
      "The key principles of the AI Bill of Rights are civil rights, civil liberties, and privacy.\n",
      "\n",
      "Context used:\n",
      "1. use, and deployment of automated systems to protect the rights of the American public in the age of ...\n",
      "2. civil rights, civil liberties, and privacy. The Blueprint for an AI Bill of Rights includes this For...\n"
     ]
    }
   ],
   "source": [
    "# Cell 1: Install required packages\n",
    "%pip install langchain openai chromadb PyPDF2 tiktoken -qU\n",
    "\n",
    "# Cell 2: Import necessary modules\n",
    "import os\n",
    "import tempfile\n",
    "import aiohttp\n",
    "import asyncio\n",
    "import getpass\n",
    "from io import BytesIO\n",
    "from typing import List\n",
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "from langchain.document_loaders import PyPDFLoader\n",
    "from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate\n",
    "from langchain.vectorstores import Chroma\n",
    "from langchain.embeddings import OpenAIEmbeddings\n",
    "from langchain.chat_models import ChatOpenAI\n",
    "from PyPDF2 import PdfReader\n",
    "\n",
    "\n",
    "# Cell 4: Set up prompts\n",
    "system_template = \"Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer.\"\n",
    "system_role_prompt = SystemMessagePromptTemplate.from_template(system_template)\n",
    "\n",
    "user_prompt_template = \"Context:\\n{context}\\n\\nQuestion:\\n{question}\"\n",
    "user_role_prompt = HumanMessagePromptTemplate.from_template(user_prompt_template)\n",
    "\n",
    "# Cell 5: Define RetrievalAugmentedQAPipeline class\n",
    "class RetrievalAugmentedQAPipeline:\n",
    "    def __init__(self, llm: ChatOpenAI, vector_db: Chroma) -> None:\n",
    "        self.llm = llm\n",
    "        self.vector_db = vector_db\n",
    "\n",
    "    async def arun_pipeline(self, user_query: str):\n",
    "        context_docs = self.vector_db.similarity_search(user_query, k=2)  # Reduced from 4 to 2\n",
    "        context_list = [doc.page_content for doc in context_docs]\n",
    "        context_prompt = \"\\n\".join(context_list)\n",
    "        \n",
    "        # Implement a simple truncation to ensure we don't exceed token limit\n",
    "        max_context_length = 12000  # Adjust this value as needed\n",
    "        if len(context_prompt) > max_context_length:\n",
    "            context_prompt = context_prompt[:max_context_length]\n",
    "        \n",
    "        formatted_system_prompt = system_role_prompt.format()\n",
    "        formatted_user_prompt = user_role_prompt.format(question=user_query, context=context_prompt)\n",
    "\n",
    "        async def generate_response():\n",
    "            async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):\n",
    "                yield chunk.content\n",
    "\n",
    "        return {\"response\": generate_response(), \"context\": context_list}\n",
    "\n",
    "# Cell 6: PDF processing functions\n",
    "async def fetch_pdf(session, url):\n",
    "    async with session.get(url) as response:\n",
    "        if response.status == 200:\n",
    "            return await response.read()\n",
    "        else:\n",
    "            print(f\"Failed to fetch PDF from {url}\")\n",
    "            return None\n",
    "\n",
    "async def process_pdf(pdf_content):\n",
    "    pdf_reader = PdfReader(BytesIO(pdf_content))\n",
    "    text = \"\\n\".join([page.extract_text() for page in pdf_reader.pages])\n",
    "    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40)\n",
    "    return text_splitter.split_text(text)\n",
    "\n",
    "# Cell 7: Main execution\n",
    "async def main():\n",
    "    # Ensure API key is set\n",
    "    api_key = get_openai_api_key()\n",
    "\n",
    "    # List of PDF URLs\n",
    "    pdf_urls = [\n",
    "        \"https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf\",\n",
    "        \"https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf\",\n",
    "    ]\n",
    "\n",
    "    all_chunks = []\n",
    "    async with aiohttp.ClientSession() as session:\n",
    "        pdf_contents = await asyncio.gather(*[fetch_pdf(session, url) for url in pdf_urls])\n",
    "        \n",
    "    for pdf_content in pdf_contents:\n",
    "        if pdf_content:\n",
    "            chunks = await process_pdf(pdf_content)\n",
    "            all_chunks.extend(chunks)\n",
    "\n",
    "    print(f\"Created {len(all_chunks)} chunks from {len(pdf_urls)} PDF files\")\n",
    "\n",
    "    embeddings = OpenAIEmbeddings(openai_api_key=api_key)\n",
    "    vector_db = Chroma.from_texts(all_chunks, embeddings)\n",
    "    \n",
    "    chat_openai = ChatOpenAI(openai_api_key=api_key)\n",
    "    retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(vector_db=vector_db, llm=chat_openai)\n",
    "    \n",
    "    # Example query\n",
    "    query = \"What are the key principles of the AI Bill of Rights?\"\n",
    "    result = await retrieval_augmented_qa_pipeline.arun_pipeline(query)\n",
    "    \n",
    "    print(\"Query:\", query)\n",
    "    print(\"\\nResponse:\")\n",
    "    async for chunk in result[\"response\"]:\n",
    "        print(chunk, end=\"\")\n",
    "    print(\"\\n\\nContext used:\")\n",
    "    for i, context in enumerate(result[\"context\"], 1):\n",
    "        print(f\"{i}. {context[:100]}...\")\n",
    "\n",
    "# Cell 8: Run the main function\n",
    "await main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}