Spaces:
Sleeping
Sleeping
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Note: you may need to restart the kernel to use updated packages.\n", | |
"Created 915 chunks from 2 PDF files\n", | |
"Query: What are the key principles of the AI Bill of Rights?\n", | |
"\n", | |
"Response:\n", | |
"The key principles of the AI Bill of Rights are civil rights, civil liberties, and privacy.\n", | |
"\n", | |
"Context used:\n", | |
"1. use, and deployment of automated systems to protect the rights of the American public in the age of ...\n", | |
"2. civil rights, civil liberties, and privacy. The Blueprint for an AI Bill of Rights includes this For...\n" | |
] | |
} | |
], | |
"source": [ | |
"# Cell 1: Install required packages\n", | |
"%pip install langchain openai chromadb PyPDF2 tiktoken -qU\n", | |
"\n", | |
"# Cell 2: Import necessary modules\n", | |
"import os\n", | |
"import tempfile\n", | |
"import aiohttp\n", | |
"import asyncio\n", | |
"import getpass\n", | |
"from io import BytesIO\n", | |
"from typing import List\n", | |
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n", | |
"from langchain.document_loaders import PyPDFLoader\n", | |
"from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate\n", | |
"from langchain.vectorstores import Chroma\n", | |
"from langchain.embeddings import OpenAIEmbeddings\n", | |
"from langchain.chat_models import ChatOpenAI\n", | |
"from PyPDF2 import PdfReader\n", | |
"\n", | |
"\n", | |
"# Cell 4: Set up prompts\n", | |
"system_template = \"Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer.\"\n", | |
"system_role_prompt = SystemMessagePromptTemplate.from_template(system_template)\n", | |
"\n", | |
"user_prompt_template = \"Context:\\n{context}\\n\\nQuestion:\\n{question}\"\n", | |
"user_role_prompt = HumanMessagePromptTemplate.from_template(user_prompt_template)\n", | |
"\n", | |
"# Cell 5: Define RetrievalAugmentedQAPipeline class\n", | |
"class RetrievalAugmentedQAPipeline:\n", | |
" def __init__(self, llm: ChatOpenAI, vector_db: Chroma) -> None:\n", | |
" self.llm = llm\n", | |
" self.vector_db = vector_db\n", | |
"\n", | |
" async def arun_pipeline(self, user_query: str):\n", | |
" context_docs = self.vector_db.similarity_search(user_query, k=2) # Reduced from 4 to 2\n", | |
" context_list = [doc.page_content for doc in context_docs]\n", | |
" context_prompt = \"\\n\".join(context_list)\n", | |
" \n", | |
" # Implement a simple truncation to ensure we don't exceed token limit\n", | |
" max_context_length = 12000 # Adjust this value as needed\n", | |
" if len(context_prompt) > max_context_length:\n", | |
" context_prompt = context_prompt[:max_context_length]\n", | |
" \n", | |
" formatted_system_prompt = system_role_prompt.format()\n", | |
" formatted_user_prompt = user_role_prompt.format(question=user_query, context=context_prompt)\n", | |
"\n", | |
" async def generate_response():\n", | |
" async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):\n", | |
" yield chunk.content\n", | |
"\n", | |
" return {\"response\": generate_response(), \"context\": context_list}\n", | |
"\n", | |
"# Cell 6: PDF processing functions\n", | |
"async def fetch_pdf(session, url):\n", | |
" async with session.get(url) as response:\n", | |
" if response.status == 200:\n", | |
" return await response.read()\n", | |
" else:\n", | |
" print(f\"Failed to fetch PDF from {url}\")\n", | |
" return None\n", | |
"\n", | |
"async def process_pdf(pdf_content):\n", | |
" pdf_reader = PdfReader(BytesIO(pdf_content))\n", | |
" text = \"\\n\".join([page.extract_text() for page in pdf_reader.pages])\n", | |
" text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40)\n", | |
" return text_splitter.split_text(text)\n", | |
"\n", | |
"# Cell 7: Main execution\n", | |
"async def main():\n", | |
" # Ensure API key is set\n", | |
" api_key = get_openai_api_key()\n", | |
"\n", | |
" # List of PDF URLs\n", | |
" pdf_urls = [\n", | |
" \"https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf\",\n", | |
" \"https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf\",\n", | |
" ]\n", | |
"\n", | |
" all_chunks = []\n", | |
" async with aiohttp.ClientSession() as session:\n", | |
" pdf_contents = await asyncio.gather(*[fetch_pdf(session, url) for url in pdf_urls])\n", | |
" \n", | |
" for pdf_content in pdf_contents:\n", | |
" if pdf_content:\n", | |
" chunks = await process_pdf(pdf_content)\n", | |
" all_chunks.extend(chunks)\n", | |
"\n", | |
" print(f\"Created {len(all_chunks)} chunks from {len(pdf_urls)} PDF files\")\n", | |
"\n", | |
" embeddings = OpenAIEmbeddings(openai_api_key=api_key)\n", | |
" vector_db = Chroma.from_texts(all_chunks, embeddings)\n", | |
" \n", | |
" chat_openai = ChatOpenAI(openai_api_key=api_key)\n", | |
" retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(vector_db=vector_db, llm=chat_openai)\n", | |
" \n", | |
" # Example query\n", | |
" query = \"What are the key principles of the AI Bill of Rights?\"\n", | |
" result = await retrieval_augmented_qa_pipeline.arun_pipeline(query)\n", | |
" \n", | |
" print(\"Query:\", query)\n", | |
" print(\"\\nResponse:\")\n", | |
" async for chunk in result[\"response\"]:\n", | |
" print(chunk, end=\"\")\n", | |
" print(\"\\n\\nContext used:\")\n", | |
" for i, context in enumerate(result[\"context\"], 1):\n", | |
" print(f\"{i}. {context[:100]}...\")\n", | |
"\n", | |
"# Cell 8: Run the main function\n", | |
"await main()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "base", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.14" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} | |