Technocoloredgeek commited on
Commit
99f7ccf
·
verified ·
1 Parent(s): 758cdcc

Upload 2 files

Browse files
Files changed (2) hide show
  1. HF Deploy.ipynb +159 -0
  2. requirements.txt +7 -0
HF Deploy.ipynb ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 6,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Note: you may need to restart the kernel to use updated packages.\n",
13
+ "Created 915 chunks from 2 PDF files\n",
14
+ "Query: What are the key principles of the AI Bill of Rights?\n",
15
+ "\n",
16
+ "Response:\n",
17
+ "The key principles of the AI Bill of Rights are civil rights, civil liberties, and privacy.\n",
18
+ "\n",
19
+ "Context used:\n",
20
+ "1. use, and deployment of automated systems to protect the rights of the American public in the age of ...\n",
21
+ "2. civil rights, civil liberties, and privacy. The Blueprint for an AI Bill of Rights includes this For...\n"
22
+ ]
23
+ }
24
+ ],
25
+ "source": [
26
+ "# Cell 1: Install required packages\n",
27
+ "%pip install langchain openai chromadb PyPDF2 tiktoken -qU\n",
28
+ "\n",
29
+ "# Cell 2: Import necessary modules\n",
30
+ "import os\n",
31
+ "import tempfile\n",
32
+ "import aiohttp\n",
33
+ "import asyncio\n",
34
+ "import getpass\n",
35
+ "from io import BytesIO\n",
36
+ "from typing import List\n",
37
+ "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
38
+ "from langchain.document_loaders import PyPDFLoader\n",
39
+ "from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate\n",
40
+ "from langchain.vectorstores import Chroma\n",
41
+ "from langchain.embeddings import OpenAIEmbeddings\n",
42
+ "from langchain.chat_models import ChatOpenAI\n",
43
+ "from PyPDF2 import PdfReader\n",
44
+ "\n",
45
+ "\n",
46
+ "# Cell 4: Set up prompts\n",
47
+ "system_template = \"Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer.\"\n",
48
+ "system_role_prompt = SystemMessagePromptTemplate.from_template(system_template)\n",
49
+ "\n",
50
+ "user_prompt_template = \"Context:\\n{context}\\n\\nQuestion:\\n{question}\"\n",
51
+ "user_role_prompt = HumanMessagePromptTemplate.from_template(user_prompt_template)\n",
52
+ "\n",
53
+ "# Cell 5: Define RetrievalAugmentedQAPipeline class\n",
54
+ "class RetrievalAugmentedQAPipeline:\n",
55
+ " def __init__(self, llm: ChatOpenAI, vector_db: Chroma) -> None:\n",
56
+ " self.llm = llm\n",
57
+ " self.vector_db = vector_db\n",
58
+ "\n",
59
+ " async def arun_pipeline(self, user_query: str):\n",
60
+ " context_docs = self.vector_db.similarity_search(user_query, k=2) # Reduced from 4 to 2\n",
61
+ " context_list = [doc.page_content for doc in context_docs]\n",
62
+ " context_prompt = \"\\n\".join(context_list)\n",
63
+ " \n",
64
+ " # Implement a simple truncation to ensure we don't exceed token limit\n",
65
+ " max_context_length = 12000 # Adjust this value as needed\n",
66
+ " if len(context_prompt) > max_context_length:\n",
67
+ " context_prompt = context_prompt[:max_context_length]\n",
68
+ " \n",
69
+ " formatted_system_prompt = system_role_prompt.format()\n",
70
+ " formatted_user_prompt = user_role_prompt.format(question=user_query, context=context_prompt)\n",
71
+ "\n",
72
+ " async def generate_response():\n",
73
+ " async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):\n",
74
+ " yield chunk.content\n",
75
+ "\n",
76
+ " return {\"response\": generate_response(), \"context\": context_list}\n",
77
+ "\n",
78
+ "# Cell 6: PDF processing functions\n",
79
+ "async def fetch_pdf(session, url):\n",
80
+ " async with session.get(url) as response:\n",
81
+ " if response.status == 200:\n",
82
+ " return await response.read()\n",
83
+ " else:\n",
84
+ " print(f\"Failed to fetch PDF from {url}\")\n",
85
+ " return None\n",
86
+ "\n",
87
+ "async def process_pdf(pdf_content):\n",
88
+ " pdf_reader = PdfReader(BytesIO(pdf_content))\n",
89
+ " text = \"\\n\".join([page.extract_text() for page in pdf_reader.pages])\n",
90
+ " text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40)\n",
91
+ " return text_splitter.split_text(text)\n",
92
+ "\n",
93
+ "# Cell 7: Main execution\n",
94
+ "async def main():\n",
95
+ " # Ensure API key is set\n",
96
+ " api_key = get_openai_api_key()\n",
97
+ "\n",
98
+ " # List of PDF URLs\n",
99
+ " pdf_urls = [\n",
100
+ " \"https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf\",\n",
101
+ " \"https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf\",\n",
102
+ " ]\n",
103
+ "\n",
104
+ " all_chunks = []\n",
105
+ " async with aiohttp.ClientSession() as session:\n",
106
+ " pdf_contents = await asyncio.gather(*[fetch_pdf(session, url) for url in pdf_urls])\n",
107
+ " \n",
108
+ " for pdf_content in pdf_contents:\n",
109
+ " if pdf_content:\n",
110
+ " chunks = await process_pdf(pdf_content)\n",
111
+ " all_chunks.extend(chunks)\n",
112
+ "\n",
113
+ " print(f\"Created {len(all_chunks)} chunks from {len(pdf_urls)} PDF files\")\n",
114
+ "\n",
115
+ " embeddings = OpenAIEmbeddings(openai_api_key=api_key)\n",
116
+ " vector_db = Chroma.from_texts(all_chunks, embeddings)\n",
117
+ " \n",
118
+ " chat_openai = ChatOpenAI(openai_api_key=api_key)\n",
119
+ " retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(vector_db=vector_db, llm=chat_openai)\n",
120
+ " \n",
121
+ " # Example query\n",
122
+ " query = \"What are the key principles of the AI Bill of Rights?\"\n",
123
+ " result = await retrieval_augmented_qa_pipeline.arun_pipeline(query)\n",
124
+ " \n",
125
+ " print(\"Query:\", query)\n",
126
+ " print(\"\\nResponse:\")\n",
127
+ " async for chunk in result[\"response\"]:\n",
128
+ " print(chunk, end=\"\")\n",
129
+ " print(\"\\n\\nContext used:\")\n",
130
+ " for i, context in enumerate(result[\"context\"], 1):\n",
131
+ " print(f\"{i}. {context[:100]}...\")\n",
132
+ "\n",
133
+ "# Cell 8: Run the main function\n",
134
+ "await main()"
135
+ ]
136
+ }
137
+ ],
138
+ "metadata": {
139
+ "kernelspec": {
140
+ "display_name": "base",
141
+ "language": "python",
142
+ "name": "python3"
143
+ },
144
+ "language_info": {
145
+ "codemirror_mode": {
146
+ "name": "ipython",
147
+ "version": 3
148
+ },
149
+ "file_extension": ".py",
150
+ "mimetype": "text/x-python",
151
+ "name": "python",
152
+ "nbconvert_exporter": "python",
153
+ "pygments_lexer": "ipython3",
154
+ "version": "3.10.14"
155
+ }
156
+ },
157
+ "nbformat": 4,
158
+ "nbformat_minor": 2
159
+ }
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ langchain
3
+ openai
4
+ chromadb
5
+ PyPDF2
6
+ tiktoken
7
+ aiohttp