Spaces:
Runtime error
Runtime error
Commit
·
8188392
1
Parent(s):
73e33bc
testing rag
Browse files- app.py +193 -6
- requirements.txt +17 -0
app.py
CHANGED
@@ -1,6 +1,32 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
HUGGINGFACE_API_KEY = settings.huggingface_key
|
6 |
ASTRA_DB_APPLICATION_TOKEN = settings.astra_db_application_token
|
@@ -10,7 +36,168 @@ GITHUB_TOKEN = settings.github_token
|
|
10 |
AZURE_OPENAI_ENDPOINT = settings.azure_openai_endpoint
|
11 |
AZURE_OPENAI_MODELNAME = settings.azure_openai_modelname
|
12 |
AZURE_OPENAI_EMBEDMODELNAME = settings.azure_openai_embedmodelname
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException
|
4 |
+
from fastapi.responses import HTMLResponse, JSONResponse
|
5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
6 |
+
from fastapi.staticfiles import StaticFiles
|
7 |
+
from langchain.vectorstores import Chroma
|
8 |
+
from langchain.llms import OpenAI
|
9 |
+
from langchain.vectorstores.cassandra import Cassandra
|
10 |
+
from langchain.indexes.vectorstore import VectorStoreIndexWrapper
|
11 |
+
from langchain.chains import RetrievalQA
|
12 |
+
from langchain.document_loaders import PyPDFLoader
|
13 |
+
from langchain.vectorstores.base import VectorStoreRetriever
|
14 |
+
from langchain.text_splitter import CharacterTextSplitter
|
15 |
+
from azure.core.credentials import AzureKeyCredential
|
16 |
+
from azure.ai.inference import EmbeddingsClient
|
17 |
+
import cassio
|
18 |
+
|
19 |
app = FastAPI()
|
20 |
+
app.add_middleware(
|
21 |
+
CORSMiddleware,
|
22 |
+
allow_origins=["*"],
|
23 |
+
allow_credentials=True,
|
24 |
+
allow_methods=["*"],
|
25 |
+
allow_headers=["*"],
|
26 |
+
)
|
27 |
+
|
28 |
+
app.logger.setLevel(logging.ERROR)
|
29 |
+
|
30 |
|
31 |
HUGGINGFACE_API_KEY = settings.huggingface_key
|
32 |
ASTRA_DB_APPLICATION_TOKEN = settings.astra_db_application_token
|
|
|
36 |
AZURE_OPENAI_ENDPOINT = settings.azure_openai_endpoint
|
37 |
AZURE_OPENAI_MODELNAME = settings.azure_openai_modelname
|
38 |
AZURE_OPENAI_EMBEDMODELNAME = settings.azure_openai_embedmodelname
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
UPLOAD_FOLDER = '/uploads'
|
44 |
+
conversation_retrieval_chain = None
|
45 |
+
chat_history = []
|
46 |
+
llm = None
|
47 |
+
embedding = None
|
48 |
+
cassio.init(token=ASTRA_DB_APPLICATION_TOKEN, database_id=ASTRA_DB_ID)
|
49 |
+
|
50 |
+
class AzureOpenAIEmbeddings:
|
51 |
+
def __init__(self, client):
|
52 |
+
self.client = client
|
53 |
+
self.model_name = AZURE_OPENAI_EMBEDMODELNAME # Store model name
|
54 |
+
|
55 |
+
def embed_query(self, text: str):
|
56 |
+
"""Embed a query."""
|
57 |
+
response = self.client.embed(
|
58 |
+
input=[text],
|
59 |
+
model=self.model_name
|
60 |
+
)
|
61 |
+
return response.data[0].embedding
|
62 |
+
|
63 |
+
def embed_documents(self, texts: list):
|
64 |
+
"""Embed a list of documents."""
|
65 |
+
response = self.client.embed(
|
66 |
+
input=texts,
|
67 |
+
model=self.model_name
|
68 |
+
)
|
69 |
+
return [item.embedding for item in response.data]
|
70 |
+
|
71 |
+
def init_llm():
|
72 |
+
global llm, embedding
|
73 |
+
llm = OpenAI(
|
74 |
+
base_url=AZURE_OPENAI_ENDPOINT,
|
75 |
+
api_key=GITHUB_TOKEN,
|
76 |
+
model=AZURE_OPENAI_MODELNAME
|
77 |
+
)
|
78 |
+
embedding = EmbeddingsClient(
|
79 |
+
endpoint=AZURE_OPENAI_ENDPOINT,
|
80 |
+
credential=AzureKeyCredential(GITHUB_TOKEN),
|
81 |
+
model=AZURE_OPENAI_EMBEDMODELNAME
|
82 |
+
)
|
83 |
+
|
84 |
+
def process_document(document_path):
|
85 |
+
init_llm()
|
86 |
+
global conversation_retrieval_chain
|
87 |
+
loader = PyPDFLoader(document_path)
|
88 |
+
documents = loader.load()
|
89 |
+
text_splitter = CharacterTextSplitter(
|
90 |
+
chunk_size=800,
|
91 |
+
chunk_overlap=200,
|
92 |
+
)
|
93 |
+
raw_text = "".join([doc.page_content for doc in documents])
|
94 |
+
texts = text_splitter.split_text(raw_text)
|
95 |
+
custom_embedding = AzureOpenAIEmbeddings(embedding)
|
96 |
+
astra_vector_store = Cassandra(
|
97 |
+
embedding=custom_embedding,
|
98 |
+
table_name="qa_mini_demo",
|
99 |
+
session=None,
|
100 |
+
keyspace=None,
|
101 |
+
)
|
102 |
+
astra_vector_store.add_texts(texts[:500])
|
103 |
+
retriever = VectorStoreRetriever(
|
104 |
+
vectorstore=astra_vector_store, search_type="mmr", search_kwargs={'k': 1, 'lambda_mult': 0.25}
|
105 |
+
)
|
106 |
+
conversation_retrieval_chain = RetrievalQA.from_chain_type(
|
107 |
+
llm=llm,
|
108 |
+
chain_type="stuff",
|
109 |
+
retriever=retriever,
|
110 |
+
return_source_documents=False,
|
111 |
+
input_key="question"
|
112 |
+
)
|
113 |
+
|
114 |
+
def process_prompt(prompt):
|
115 |
+
init_llm()
|
116 |
+
global chat_history
|
117 |
+
global conversation_retrieval_chain
|
118 |
+
|
119 |
+
output = conversation_retrieval_chain({"question": prompt, "chat_history": chat_history})
|
120 |
+
answer = output["result"]
|
121 |
+
|
122 |
+
chat_history.append((prompt, answer))
|
123 |
+
return answer
|
124 |
+
|
125 |
+
# Define the route for the index page
|
126 |
+
@app.get("/", response_class=HTMLResponse)
|
127 |
+
async def index():
|
128 |
+
return """
|
129 |
+
<!DOCTYPE html>
|
130 |
+
<html>
|
131 |
+
<head>
|
132 |
+
<title>File Upload</title>
|
133 |
+
</head>
|
134 |
+
<body>
|
135 |
+
<h2>Upload a PDF Document</h2>
|
136 |
+
<form action="/process-document" method="post" enctype="multipart/form-data">
|
137 |
+
<input type="file" name="file" required>
|
138 |
+
<button type="submit">Upload</button>
|
139 |
+
</form>
|
140 |
+
<h2>Chat with the Bot</h2>
|
141 |
+
<form id="chat-form">
|
142 |
+
<input type="text" id="userMessage" placeholder="Type your message here..." required>
|
143 |
+
<button type="submit">Send
|
144 |
+
</button>
|
145 |
+
</form>
|
146 |
+
<div id="chat-response"></div>
|
147 |
+
<script>
|
148 |
+
document.getElementById("chat-form").onsubmit = async (e) => {
|
149 |
+
e.preventDefault();
|
150 |
+
const userMessage = document.getElementById("userMessage").value;
|
151 |
+
const response = await fetch("/process-message", {
|
152 |
+
method: "POST",
|
153 |
+
headers: {
|
154 |
+
"Content-Type": "application/json",
|
155 |
+
},
|
156 |
+
body: JSON.stringify({ userMessage }),
|
157 |
+
});
|
158 |
+
const data = await response.json();
|
159 |
+
document.getElementById("chat-response").innerText = data.botResponse || data.error;
|
160 |
+
document.getElementById("userMessage").value = ""; // Clear input
|
161 |
+
};
|
162 |
+
</script>
|
163 |
+
</body>
|
164 |
+
</html>
|
165 |
+
"""
|
166 |
+
|
167 |
+
# Define the route for processing messages
|
168 |
+
@app.post("/process-message")
|
169 |
+
async def process_message_route(user_message: str):
|
170 |
+
try:
|
171 |
+
if not user_message:
|
172 |
+
raise HTTPException(status_code=400, detail="User message is required.")
|
173 |
+
|
174 |
+
bot_response = process_prompt(user_message) # Process the user's message
|
175 |
+
|
176 |
+
# Return the bot's response as JSON
|
177 |
+
return JSONResponse(content={"botResponse": bot_response})
|
178 |
+
except Exception as e:
|
179 |
+
app.logger.error(f"Error processing message: {e}")
|
180 |
+
raise HTTPException(status_code=500, detail="An error occurred while processing the message.")
|
181 |
+
|
182 |
+
# Define the route for processing documents
|
183 |
+
@app.post("/process-document")
|
184 |
+
async def process_document_route(file: UploadFile = File(...)):
|
185 |
+
try:
|
186 |
+
# Check if a file was uploaded
|
187 |
+
if not file:
|
188 |
+
raise HTTPException(status_code=400, detail="File not uploaded.")
|
189 |
+
|
190 |
+
file_path = f"uploads/{file.filename}" # Define the path where the file will be saved
|
191 |
+
os.makedirs("uploads", exist_ok=True) # Create the uploads directory if it doesn't exist
|
192 |
+
with open(file_path, "wb") as buffer:
|
193 |
+
shutil.copyfileobj(file.file, buffer) # Save the file
|
194 |
+
|
195 |
+
process_document(file_path) # Process the document
|
196 |
+
|
197 |
+
# Return a success message as JSON
|
198 |
+
return JSONResponse(content={
|
199 |
+
"botResponse": "Thank you for providing your PDF document. I have analyzed it, so now you can ask me any questions regarding it!"
|
200 |
+
})
|
201 |
+
except Exception as e:
|
202 |
+
app.logger.error(f"Error processing document: {e}")
|
203 |
+
raise HTTPException(status_code=500, detail="An error occurred while processing the document.")
|
requirements.txt
CHANGED
@@ -1,3 +1,20 @@
|
|
1 |
fastapi
|
2 |
uvicorn[standard]
|
3 |
pydantic-settings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
fastapi
|
2 |
uvicorn[standard]
|
3 |
pydantic-settings
|
4 |
+
langchain
|
5 |
+
langchain-community
|
6 |
+
openai
|
7 |
+
python-dotenv
|
8 |
+
azure-core
|
9 |
+
azure-ai-inference
|
10 |
+
cassio
|
11 |
+
chromadb
|
12 |
+
datasets
|
13 |
+
pypdf
|
14 |
+
tiktoken
|
15 |
+
typing-extensions
|
16 |
+
numpy
|
17 |
+
pandas
|
18 |
+
tenacity
|
19 |
+
aiohttp
|
20 |
+
requests
|