NameIsJACK commited on
Commit
8188392
·
1 Parent(s): 73e33bc

testing rag

Browse files
Files changed (2) hide show
  1. app.py +193 -6
  2. requirements.txt +17 -0
app.py CHANGED
@@ -1,6 +1,32 @@
1
- from fastapi import FastAPI
2
- from config import settings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  app = FastAPI()
 
 
 
 
 
 
 
 
 
 
4
 
5
  HUGGINGFACE_API_KEY = settings.huggingface_key
6
  ASTRA_DB_APPLICATION_TOKEN = settings.astra_db_application_token
@@ -10,7 +36,168 @@ GITHUB_TOKEN = settings.github_token
10
  AZURE_OPENAI_ENDPOINT = settings.azure_openai_endpoint
11
  AZURE_OPENAI_MODELNAME = settings.azure_openai_modelname
12
  AZURE_OPENAI_EMBEDMODELNAME = settings.azure_openai_embedmodelname
13
- @app.get("/")
14
- def greet_json():
15
- return {"HUGGINGFACE_API_KEY": HUGGINGFACE_API_KEY,
16
- "ASTRA_DB_APPLICATION_TOKEN": ASTRA_DB_APPLICATION_TOKEN}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from fastapi import FastAPI, UploadFile, File, HTTPException
4
+ from fastapi.responses import HTMLResponse, JSONResponse
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from fastapi.staticfiles import StaticFiles
7
+ from langchain.vectorstores import Chroma
8
+ from langchain.llms import OpenAI
9
+ from langchain.vectorstores.cassandra import Cassandra
10
+ from langchain.indexes.vectorstore import VectorStoreIndexWrapper
11
+ from langchain.chains import RetrievalQA
12
+ from langchain.document_loaders import PyPDFLoader
13
+ from langchain.vectorstores.base import VectorStoreRetriever
14
+ from langchain.text_splitter import CharacterTextSplitter
15
+ from azure.core.credentials import AzureKeyCredential
16
+ from azure.ai.inference import EmbeddingsClient
17
+ import cassio
18
+
19
  app = FastAPI()
20
+ app.add_middleware(
21
+ CORSMiddleware,
22
+ allow_origins=["*"],
23
+ allow_credentials=True,
24
+ allow_methods=["*"],
25
+ allow_headers=["*"],
26
+ )
27
+
28
+ app.logger.setLevel(logging.ERROR)
29
+
30
 
31
  HUGGINGFACE_API_KEY = settings.huggingface_key
32
  ASTRA_DB_APPLICATION_TOKEN = settings.astra_db_application_token
 
36
  AZURE_OPENAI_ENDPOINT = settings.azure_openai_endpoint
37
  AZURE_OPENAI_MODELNAME = settings.azure_openai_modelname
38
  AZURE_OPENAI_EMBEDMODELNAME = settings.azure_openai_embedmodelname
39
+
40
+
41
+
42
+
43
+ UPLOAD_FOLDER = '/uploads'
44
+ conversation_retrieval_chain = None
45
+ chat_history = []
46
+ llm = None
47
+ embedding = None
48
+ cassio.init(token=ASTRA_DB_APPLICATION_TOKEN, database_id=ASTRA_DB_ID)
49
+
50
+ class AzureOpenAIEmbeddings:
51
+ def __init__(self, client):
52
+ self.client = client
53
+ self.model_name = AZURE_OPENAI_EMBEDMODELNAME # Store model name
54
+
55
+ def embed_query(self, text: str):
56
+ """Embed a query."""
57
+ response = self.client.embed(
58
+ input=[text],
59
+ model=self.model_name
60
+ )
61
+ return response.data[0].embedding
62
+
63
+ def embed_documents(self, texts: list):
64
+ """Embed a list of documents."""
65
+ response = self.client.embed(
66
+ input=texts,
67
+ model=self.model_name
68
+ )
69
+ return [item.embedding for item in response.data]
70
+
71
+ def init_llm():
72
+ global llm, embedding
73
+ llm = OpenAI(
74
+ base_url=AZURE_OPENAI_ENDPOINT,
75
+ api_key=GITHUB_TOKEN,
76
+ model=AZURE_OPENAI_MODELNAME
77
+ )
78
+ embedding = EmbeddingsClient(
79
+ endpoint=AZURE_OPENAI_ENDPOINT,
80
+ credential=AzureKeyCredential(GITHUB_TOKEN),
81
+ model=AZURE_OPENAI_EMBEDMODELNAME
82
+ )
83
+
84
+ def process_document(document_path):
85
+ init_llm()
86
+ global conversation_retrieval_chain
87
+ loader = PyPDFLoader(document_path)
88
+ documents = loader.load()
89
+ text_splitter = CharacterTextSplitter(
90
+ chunk_size=800,
91
+ chunk_overlap=200,
92
+ )
93
+ raw_text = "".join([doc.page_content for doc in documents])
94
+ texts = text_splitter.split_text(raw_text)
95
+ custom_embedding = AzureOpenAIEmbeddings(embedding)
96
+ astra_vector_store = Cassandra(
97
+ embedding=custom_embedding,
98
+ table_name="qa_mini_demo",
99
+ session=None,
100
+ keyspace=None,
101
+ )
102
+ astra_vector_store.add_texts(texts[:500])
103
+ retriever = VectorStoreRetriever(
104
+ vectorstore=astra_vector_store, search_type="mmr", search_kwargs={'k': 1, 'lambda_mult': 0.25}
105
+ )
106
+ conversation_retrieval_chain = RetrievalQA.from_chain_type(
107
+ llm=llm,
108
+ chain_type="stuff",
109
+ retriever=retriever,
110
+ return_source_documents=False,
111
+ input_key="question"
112
+ )
113
+
114
+ def process_prompt(prompt):
115
+ init_llm()
116
+ global chat_history
117
+ global conversation_retrieval_chain
118
+
119
+ output = conversation_retrieval_chain({"question": prompt, "chat_history": chat_history})
120
+ answer = output["result"]
121
+
122
+ chat_history.append((prompt, answer))
123
+ return answer
124
+
125
+ # Define the route for the index page
126
+ @app.get("/", response_class=HTMLResponse)
127
+ async def index():
128
+ return """
129
+ <!DOCTYPE html>
130
+ <html>
131
+ <head>
132
+ <title>File Upload</title>
133
+ </head>
134
+ <body>
135
+ <h2>Upload a PDF Document</h2>
136
+ <form action="/process-document" method="post" enctype="multipart/form-data">
137
+ <input type="file" name="file" required>
138
+ <button type="submit">Upload</button>
139
+ </form>
140
+ <h2>Chat with the Bot</h2>
141
+ <form id="chat-form">
142
+ <input type="text" id="userMessage" placeholder="Type your message here..." required>
143
+ <button type="submit">Send
144
+ </button>
145
+ </form>
146
+ <div id="chat-response"></div>
147
+ <script>
148
+ document.getElementById("chat-form").onsubmit = async (e) => {
149
+ e.preventDefault();
150
+ const userMessage = document.getElementById("userMessage").value;
151
+ const response = await fetch("/process-message", {
152
+ method: "POST",
153
+ headers: {
154
+ "Content-Type": "application/json",
155
+ },
156
+ body: JSON.stringify({ userMessage }),
157
+ });
158
+ const data = await response.json();
159
+ document.getElementById("chat-response").innerText = data.botResponse || data.error;
160
+ document.getElementById("userMessage").value = ""; // Clear input
161
+ };
162
+ </script>
163
+ </body>
164
+ </html>
165
+ """
166
+
167
+ # Define the route for processing messages
168
+ @app.post("/process-message")
169
+ async def process_message_route(user_message: str):
170
+ try:
171
+ if not user_message:
172
+ raise HTTPException(status_code=400, detail="User message is required.")
173
+
174
+ bot_response = process_prompt(user_message) # Process the user's message
175
+
176
+ # Return the bot's response as JSON
177
+ return JSONResponse(content={"botResponse": bot_response})
178
+ except Exception as e:
179
+ app.logger.error(f"Error processing message: {e}")
180
+ raise HTTPException(status_code=500, detail="An error occurred while processing the message.")
181
+
182
+ # Define the route for processing documents
183
+ @app.post("/process-document")
184
+ async def process_document_route(file: UploadFile = File(...)):
185
+ try:
186
+ # Check if a file was uploaded
187
+ if not file:
188
+ raise HTTPException(status_code=400, detail="File not uploaded.")
189
+
190
+ file_path = f"uploads/{file.filename}" # Define the path where the file will be saved
191
+ os.makedirs("uploads", exist_ok=True) # Create the uploads directory if it doesn't exist
192
+ with open(file_path, "wb") as buffer:
193
+ shutil.copyfileobj(file.file, buffer) # Save the file
194
+
195
+ process_document(file_path) # Process the document
196
+
197
+ # Return a success message as JSON
198
+ return JSONResponse(content={
199
+ "botResponse": "Thank you for providing your PDF document. I have analyzed it, so now you can ask me any questions regarding it!"
200
+ })
201
+ except Exception as e:
202
+ app.logger.error(f"Error processing document: {e}")
203
+ raise HTTPException(status_code=500, detail="An error occurred while processing the document.")
requirements.txt CHANGED
@@ -1,3 +1,20 @@
1
  fastapi
2
  uvicorn[standard]
3
  pydantic-settings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  fastapi
2
  uvicorn[standard]
3
  pydantic-settings
4
+ langchain
5
+ langchain-community
6
+ openai
7
+ python-dotenv
8
+ azure-core
9
+ azure-ai-inference
10
+ cassio
11
+ chromadb
12
+ datasets
13
+ pypdf
14
+ tiktoken
15
+ typing-extensions
16
+ numpy
17
+ pandas
18
+ tenacity
19
+ aiohttp
20
+ requests