upgraded openai version
Browse files- app.py +88 -52
- requirements.txt +1 -1
app.py
CHANGED
@@ -26,47 +26,61 @@ hf_home_dir = os.environ["HF_HOME"]
|
|
26 |
if not os.path.exists(hf_home_dir):
|
27 |
os.makedirs(hf_home_dir)
|
28 |
|
29 |
-
collection_name = os.getenv(
|
30 |
logging.info(f"Collection name: {collection_name}")
|
31 |
# Setup logging using Python's standard logging library
|
32 |
logging.basicConfig(level=logging.INFO)
|
33 |
|
34 |
# Load Hugging Face token from environment variable
|
35 |
-
huggingface_token = os.getenv(
|
36 |
if huggingface_token:
|
37 |
try:
|
38 |
login(token=huggingface_token, add_to_git_credential=True)
|
39 |
logging.info("Successfully logged into Hugging Face Hub.")
|
40 |
except Exception as e:
|
41 |
logging.error(f"Failed to log into Hugging Face Hub: {e}")
|
42 |
-
raise HTTPException(
|
|
|
|
|
43 |
else:
|
44 |
-
raise ValueError(
|
|
|
|
|
45 |
|
46 |
# Initialize the Qdrant searcher
|
47 |
-
qdrant_url = os.getenv(
|
48 |
-
access_token = os.getenv(
|
49 |
|
50 |
if not qdrant_url or not access_token:
|
51 |
-
raise ValueError(
|
|
|
|
|
52 |
|
53 |
# Load the model and tokenizer with trust_remote_code=True
|
54 |
try:
|
55 |
cache_folder = os.path.join(hf_home_dir, "transformers_cache")
|
56 |
-
|
57 |
# Load the tokenizer and model with trust_remote_code=True
|
58 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
59 |
-
|
|
|
|
|
|
|
|
|
60 |
|
61 |
logging.info("Successfully loaded the model and tokenizer with transformers.")
|
62 |
-
|
63 |
# Initialize the Qdrant searcher after the model is successfully loaded
|
64 |
global searcher # Ensure searcher is accessible globally if needed
|
65 |
searcher = QdrantSearcher(qdrant_url=qdrant_url, access_token=access_token)
|
66 |
|
67 |
except Exception as e:
|
68 |
logging.error(f"Failed to load the model or initialize searcher: {e}")
|
69 |
-
raise HTTPException(
|
|
|
|
|
|
|
|
|
70 |
|
71 |
# Function to embed text using the model
|
72 |
def embed_text(text):
|
@@ -75,43 +89,45 @@ def embed_text(text):
|
|
75 |
embeddings = outputs.last_hidden_state.mean(dim=1) # Example: mean pooling
|
76 |
return embeddings.detach().numpy()
|
77 |
|
|
|
78 |
# Define the request body models
|
79 |
class SearchDocumentsRequest(BaseModel):
|
80 |
query: str
|
81 |
limit: int = 3
|
82 |
file_id: str = None
|
83 |
|
|
|
84 |
class GenerateRAGRequest(BaseModel):
|
85 |
search_query: str
|
86 |
file_id: str = None
|
87 |
|
|
|
88 |
class XApiKeyRequest(BaseModel):
|
89 |
organization_id: str
|
90 |
user_id: str
|
91 |
-
search_query: str
|
92 |
file_id: str = None
|
93 |
|
94 |
-
import os
|
95 |
-
|
96 |
-
for name, value in os.environ.items():
|
97 |
-
print("{0}: {1}".format(name, value))
|
98 |
-
|
99 |
|
100 |
@app.get("/")
|
101 |
async def root():
|
102 |
-
return {
|
|
|
|
|
|
|
103 |
|
104 |
# Define the search documents endpoint
|
105 |
@app.post("/api/search-documents")
|
106 |
async def search_documents(
|
107 |
-
body: SearchDocumentsRequest,
|
108 |
-
credentials: tuple = Depends(token_required)
|
109 |
):
|
110 |
customer_id, user_id = credentials
|
111 |
start_time = time.time()
|
112 |
if not customer_id or not user_id:
|
113 |
logging.error("Failed to extract customer_id or user_id from the JWT token.")
|
114 |
-
raise HTTPException(
|
|
|
|
|
115 |
|
116 |
logging.info("Received request to search documents")
|
117 |
try:
|
@@ -120,14 +136,22 @@ async def search_documents(
|
|
120 |
# Encode the query using the custom embedding function
|
121 |
query_embedding = embed_text(body.query)
|
122 |
print(body.query)
|
123 |
-
#collection_name = "embed" # Use the collection name where the embeddings are stored
|
124 |
logging.info("Performing search using the precomputed embeddings")
|
125 |
if body.file_id:
|
126 |
-
hits, error = searcher.search_documents(
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
else:
|
128 |
# Perform search using the precomputed embeddings
|
129 |
-
hits, error = searcher.search_documents(
|
130 |
-
|
|
|
|
|
131 |
if error:
|
132 |
logging.error(f"Search documents error: {error}")
|
133 |
raise HTTPException(status_code=500, detail=error)
|
@@ -138,33 +162,39 @@ async def search_documents(
|
|
138 |
logging.error(f"Unexpected error: {e}")
|
139 |
raise HTTPException(status_code=500, detail=str(e))
|
140 |
|
|
|
141 |
# Define the generate RAG response endpoint
|
142 |
@app.post("/api/generate-rag-response")
|
143 |
async def generate_rag_response_api(
|
144 |
-
body: GenerateRAGRequest,
|
145 |
-
credentials: tuple = Depends(token_required)
|
146 |
):
|
147 |
customer_id, user_id = credentials
|
148 |
start_time = time.time()
|
149 |
if not customer_id or not user_id:
|
150 |
logging.error("Failed to extract customer_id or user_id from the JWT token.")
|
151 |
-
raise HTTPException(
|
|
|
|
|
152 |
|
153 |
logging.info("Received request to generate RAG response")
|
154 |
-
|
155 |
try:
|
156 |
search_time = time.time()
|
157 |
logging.info("Starting document search")
|
158 |
# Encode the query using the custom embedding function
|
159 |
query_embedding = embed_text(body.search_query)
|
160 |
print(body.search_query)
|
161 |
-
#collection_name = "embed" # Use the collection name where the embeddings are stored
|
162 |
# Perform search using the precomputed embeddings
|
163 |
if body.file_id:
|
164 |
-
hits, error = searcher.search_documents(
|
|
|
|
|
165 |
else:
|
166 |
-
hits, error = searcher.search_documents(
|
167 |
-
|
|
|
|
|
168 |
if error:
|
169 |
logging.error(f"Search documents error: {error}")
|
170 |
raise HTTPException(status_code=500, detail=error)
|
@@ -177,9 +207,11 @@ async def generate_rag_response_api(
|
|
177 |
response, error = generate_rag_response(hits, body.search_query)
|
178 |
rag_end_time = time.time()
|
179 |
rag_time_taken = rag_end_time - rag_start_time
|
180 |
-
end_time= time.time()
|
181 |
total_time = end_time - start_time
|
182 |
-
logging.info(
|
|
|
|
|
183 |
if error:
|
184 |
logging.error(f"Generate RAG response error: {error}")
|
185 |
raise HTTPException(status_code=500, detail=error)
|
@@ -189,10 +221,10 @@ async def generate_rag_response_api(
|
|
189 |
logging.error(f"Unexpected error: {e}")
|
190 |
raise HTTPException(status_code=500, detail=str(e))
|
191 |
|
|
|
192 |
@app.post("/api/search-documents/v1")
|
193 |
async def search_documents_x_api_key(
|
194 |
-
body: XApiKeyRequest,
|
195 |
-
authorized: bool = Depends(x_api_key_auth)
|
196 |
):
|
197 |
if not authorized:
|
198 |
raise HTTPException(status_code=401, detail="Unauthorized")
|
@@ -201,7 +233,7 @@ async def search_documents_x_api_key(
|
|
201 |
user_id = body.user_id
|
202 |
file_id = body.file_id
|
203 |
|
204 |
-
logging.info(f
|
205 |
logging.info(f"organization_id: {organization_id}, user_id: {user_id}")
|
206 |
logging.info("Received request to search documents with x-api-key auth")
|
207 |
try:
|
@@ -209,11 +241,13 @@ async def search_documents_x_api_key(
|
|
209 |
|
210 |
# Encode the query using the custom embedding function
|
211 |
query_embedding = embed_text(body.search_query)
|
212 |
-
#collection_name = "embed" # Use the collection name where the embeddings are stored
|
213 |
|
214 |
# Perform search using the precomputed embeddings
|
215 |
-
hits, error = searcher.search_documents(
|
216 |
-
|
|
|
|
|
217 |
if error:
|
218 |
logging.error(f"Search documents error: {error}")
|
219 |
raise HTTPException(status_code=500, detail=error)
|
@@ -226,10 +260,10 @@ async def search_documents_x_api_key(
|
|
226 |
logging.error(f"Unexpected error: {e}")
|
227 |
raise HTTPException(status_code=500, detail=str(e))
|
228 |
|
|
|
229 |
@app.post("/api/generate-rag-response/v1")
|
230 |
async def generate_rag_response_x_api_key(
|
231 |
-
body: XApiKeyRequest,
|
232 |
-
authorized: bool = Depends(x_api_key_auth)
|
233 |
):
|
234 |
# Assuming x_api_key_auth validates the key
|
235 |
if not authorized:
|
@@ -239,7 +273,7 @@ async def generate_rag_response_x_api_key(
|
|
239 |
user_id = body.user_id
|
240 |
file_id = body.file_id
|
241 |
|
242 |
-
logging.info(f
|
243 |
logging.info(f"organization_id: {organization_id}, user_id: {user_id}")
|
244 |
logging.info("Received request to generate RAG response with x-api-key auth")
|
245 |
try:
|
@@ -247,11 +281,13 @@ async def generate_rag_response_x_api_key(
|
|
247 |
|
248 |
# Encode the query using the custom embedding function
|
249 |
query_embedding = embed_text(body.search_query)
|
250 |
-
#collection_name = "embed" # Use the collection name where the embeddings are stored
|
251 |
|
252 |
# Perform search using the precomputed embeddings
|
253 |
-
hits, error = searcher.search_documents(
|
254 |
-
|
|
|
|
|
255 |
if error:
|
256 |
logging.error(f"Search documents error: {error}")
|
257 |
raise HTTPException(status_code=500, detail=error)
|
@@ -260,7 +296,7 @@ async def generate_rag_response_x_api_key(
|
|
260 |
|
261 |
# Generate the RAG response using the retrieved documents
|
262 |
response, error = generate_rag_response(hits, body.search_query)
|
263 |
-
|
264 |
if error:
|
265 |
logging.error(f"Generate RAG response error: {error}")
|
266 |
raise HTTPException(status_code=500, detail=error)
|
@@ -272,7 +308,7 @@ async def generate_rag_response_x_api_key(
|
|
272 |
raise HTTPException(status_code=500, detail=str(e))
|
273 |
|
274 |
|
275 |
-
|
276 |
-
if __name__ == '__main__':
|
277 |
import uvicorn
|
278 |
-
|
|
|
|
26 |
if not os.path.exists(hf_home_dir):
|
27 |
os.makedirs(hf_home_dir)
|
28 |
|
29 |
+
collection_name = os.getenv("QDRANT_COLLECTION_NAME")
|
30 |
logging.info(f"Collection name: {collection_name}")
|
31 |
# Setup logging using Python's standard logging library
|
32 |
logging.basicConfig(level=logging.INFO)
|
33 |
|
34 |
# Load Hugging Face token from environment variable
|
35 |
+
huggingface_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
|
36 |
if huggingface_token:
|
37 |
try:
|
38 |
login(token=huggingface_token, add_to_git_credential=True)
|
39 |
logging.info("Successfully logged into Hugging Face Hub.")
|
40 |
except Exception as e:
|
41 |
logging.error(f"Failed to log into Hugging Face Hub: {e}")
|
42 |
+
raise HTTPException(
|
43 |
+
status_code=500, detail="Failed to log into Hugging Face Hub."
|
44 |
+
)
|
45 |
else:
|
46 |
+
raise ValueError(
|
47 |
+
"Hugging Face token is not set. Please set the HUGGINGFACE_HUB_TOKEN environment variable."
|
48 |
+
)
|
49 |
|
50 |
# Initialize the Qdrant searcher
|
51 |
+
qdrant_url = os.getenv("QDRANT_URL")
|
52 |
+
access_token = os.getenv("QDRANT_ACCESS_TOKEN")
|
53 |
|
54 |
if not qdrant_url or not access_token:
|
55 |
+
raise ValueError(
|
56 |
+
"Qdrant URL or Access Token is not set. Please set the QDRANT_URL and QDRANT_ACCESS_TOKEN environment variables."
|
57 |
+
)
|
58 |
|
59 |
# Load the model and tokenizer with trust_remote_code=True
|
60 |
try:
|
61 |
cache_folder = os.path.join(hf_home_dir, "transformers_cache")
|
62 |
+
|
63 |
# Load the tokenizer and model with trust_remote_code=True
|
64 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
65 |
+
"nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True
|
66 |
+
)
|
67 |
+
model = AutoModel.from_pretrained(
|
68 |
+
"nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True
|
69 |
+
)
|
70 |
|
71 |
logging.info("Successfully loaded the model and tokenizer with transformers.")
|
72 |
+
|
73 |
# Initialize the Qdrant searcher after the model is successfully loaded
|
74 |
global searcher # Ensure searcher is accessible globally if needed
|
75 |
searcher = QdrantSearcher(qdrant_url=qdrant_url, access_token=access_token)
|
76 |
|
77 |
except Exception as e:
|
78 |
logging.error(f"Failed to load the model or initialize searcher: {e}")
|
79 |
+
raise HTTPException(
|
80 |
+
status_code=500,
|
81 |
+
detail="Failed to load the custom model or initialize searcher.",
|
82 |
+
)
|
83 |
+
|
84 |
|
85 |
# Function to embed text using the model
|
86 |
def embed_text(text):
|
|
|
89 |
embeddings = outputs.last_hidden_state.mean(dim=1) # Example: mean pooling
|
90 |
return embeddings.detach().numpy()
|
91 |
|
92 |
+
|
93 |
# Define the request body models
|
94 |
class SearchDocumentsRequest(BaseModel):
|
95 |
query: str
|
96 |
limit: int = 3
|
97 |
file_id: str = None
|
98 |
|
99 |
+
|
100 |
class GenerateRAGRequest(BaseModel):
|
101 |
search_query: str
|
102 |
file_id: str = None
|
103 |
|
104 |
+
|
105 |
class XApiKeyRequest(BaseModel):
|
106 |
organization_id: str
|
107 |
user_id: str
|
108 |
+
search_query: str
|
109 |
file_id: str = None
|
110 |
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
@app.get("/")
|
113 |
async def root():
|
114 |
+
return {
|
115 |
+
"message": "Welcome to the Search and RAG API!, go to relevant address for API request"
|
116 |
+
}
|
117 |
+
|
118 |
|
119 |
# Define the search documents endpoint
|
120 |
@app.post("/api/search-documents")
|
121 |
async def search_documents(
|
122 |
+
body: SearchDocumentsRequest, credentials: tuple = Depends(token_required)
|
|
|
123 |
):
|
124 |
customer_id, user_id = credentials
|
125 |
start_time = time.time()
|
126 |
if not customer_id or not user_id:
|
127 |
logging.error("Failed to extract customer_id or user_id from the JWT token.")
|
128 |
+
raise HTTPException(
|
129 |
+
status_code=401, detail="Invalid token: missing customer_id or user_id"
|
130 |
+
)
|
131 |
|
132 |
logging.info("Received request to search documents")
|
133 |
try:
|
|
|
136 |
# Encode the query using the custom embedding function
|
137 |
query_embedding = embed_text(body.query)
|
138 |
print(body.query)
|
139 |
+
# collection_name = "embed" # Use the collection name where the embeddings are stored
|
140 |
logging.info("Performing search using the precomputed embeddings")
|
141 |
if body.file_id:
|
142 |
+
hits, error = searcher.search_documents(
|
143 |
+
collection_name,
|
144 |
+
query_embedding,
|
145 |
+
user_id,
|
146 |
+
body.limit,
|
147 |
+
file_id=body.file_id,
|
148 |
+
)
|
149 |
else:
|
150 |
# Perform search using the precomputed embeddings
|
151 |
+
hits, error = searcher.search_documents(
|
152 |
+
collection_name, query_embedding, user_id, body.limit
|
153 |
+
)
|
154 |
+
|
155 |
if error:
|
156 |
logging.error(f"Search documents error: {error}")
|
157 |
raise HTTPException(status_code=500, detail=error)
|
|
|
162 |
logging.error(f"Unexpected error: {e}")
|
163 |
raise HTTPException(status_code=500, detail=str(e))
|
164 |
|
165 |
+
|
166 |
# Define the generate RAG response endpoint
|
167 |
@app.post("/api/generate-rag-response")
|
168 |
async def generate_rag_response_api(
|
169 |
+
body: GenerateRAGRequest, credentials: tuple = Depends(token_required)
|
|
|
170 |
):
|
171 |
customer_id, user_id = credentials
|
172 |
start_time = time.time()
|
173 |
if not customer_id or not user_id:
|
174 |
logging.error("Failed to extract customer_id or user_id from the JWT token.")
|
175 |
+
raise HTTPException(
|
176 |
+
status_code=401, detail="Invalid token: missing customer_id or user_id"
|
177 |
+
)
|
178 |
|
179 |
logging.info("Received request to generate RAG response")
|
180 |
+
|
181 |
try:
|
182 |
search_time = time.time()
|
183 |
logging.info("Starting document search")
|
184 |
# Encode the query using the custom embedding function
|
185 |
query_embedding = embed_text(body.search_query)
|
186 |
print(body.search_query)
|
187 |
+
# collection_name = "embed" # Use the collection name where the embeddings are stored
|
188 |
# Perform search using the precomputed embeddings
|
189 |
if body.file_id:
|
190 |
+
hits, error = searcher.search_documents(
|
191 |
+
collection_name, query_embedding, user_id, file_id=body.file_id
|
192 |
+
)
|
193 |
else:
|
194 |
+
hits, error = searcher.search_documents(
|
195 |
+
collection_name, query_embedding, user_id
|
196 |
+
)
|
197 |
+
|
198 |
if error:
|
199 |
logging.error(f"Search documents error: {error}")
|
200 |
raise HTTPException(status_code=500, detail=error)
|
|
|
207 |
response, error = generate_rag_response(hits, body.search_query)
|
208 |
rag_end_time = time.time()
|
209 |
rag_time_taken = rag_end_time - rag_start_time
|
210 |
+
end_time = time.time()
|
211 |
total_time = end_time - start_time
|
212 |
+
logging.info(
|
213 |
+
f"Search time: {search_time_taken}, RAG time: {rag_time_taken}, Total time: {total_time}"
|
214 |
+
)
|
215 |
if error:
|
216 |
logging.error(f"Generate RAG response error: {error}")
|
217 |
raise HTTPException(status_code=500, detail=error)
|
|
|
221 |
logging.error(f"Unexpected error: {e}")
|
222 |
raise HTTPException(status_code=500, detail=str(e))
|
223 |
|
224 |
+
|
225 |
@app.post("/api/search-documents/v1")
|
226 |
async def search_documents_x_api_key(
|
227 |
+
body: XApiKeyRequest, authorized: bool = Depends(x_api_key_auth)
|
|
|
228 |
):
|
229 |
if not authorized:
|
230 |
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
233 |
user_id = body.user_id
|
234 |
file_id = body.file_id
|
235 |
|
236 |
+
logging.info(f"search query {body.search_query}")
|
237 |
logging.info(f"organization_id: {organization_id}, user_id: {user_id}")
|
238 |
logging.info("Received request to search documents with x-api-key auth")
|
239 |
try:
|
|
|
241 |
|
242 |
# Encode the query using the custom embedding function
|
243 |
query_embedding = embed_text(body.search_query)
|
244 |
+
# collection_name = "embed" # Use the collection name where the embeddings are stored
|
245 |
|
246 |
# Perform search using the precomputed embeddings
|
247 |
+
hits, error = searcher.search_documents(
|
248 |
+
collection_name, query_embedding, user_id, limit=3, file_id=file_id
|
249 |
+
)
|
250 |
+
|
251 |
if error:
|
252 |
logging.error(f"Search documents error: {error}")
|
253 |
raise HTTPException(status_code=500, detail=error)
|
|
|
260 |
logging.error(f"Unexpected error: {e}")
|
261 |
raise HTTPException(status_code=500, detail=str(e))
|
262 |
|
263 |
+
|
264 |
@app.post("/api/generate-rag-response/v1")
|
265 |
async def generate_rag_response_x_api_key(
|
266 |
+
body: XApiKeyRequest, authorized: bool = Depends(x_api_key_auth)
|
|
|
267 |
):
|
268 |
# Assuming x_api_key_auth validates the key
|
269 |
if not authorized:
|
|
|
273 |
user_id = body.user_id
|
274 |
file_id = body.file_id
|
275 |
|
276 |
+
logging.info(f"search query {body.search_query}")
|
277 |
logging.info(f"organization_id: {organization_id}, user_id: {user_id}")
|
278 |
logging.info("Received request to generate RAG response with x-api-key auth")
|
279 |
try:
|
|
|
281 |
|
282 |
# Encode the query using the custom embedding function
|
283 |
query_embedding = embed_text(body.search_query)
|
284 |
+
# collection_name = "embed" # Use the collection name where the embeddings are stored
|
285 |
|
286 |
# Perform search using the precomputed embeddings
|
287 |
+
hits, error = searcher.search_documents(
|
288 |
+
collection_name, query_embedding, user_id, file_id=file_id
|
289 |
+
)
|
290 |
+
|
291 |
if error:
|
292 |
logging.error(f"Search documents error: {error}")
|
293 |
raise HTTPException(status_code=500, detail=error)
|
|
|
296 |
|
297 |
# Generate the RAG response using the retrieved documents
|
298 |
response, error = generate_rag_response(hits, body.search_query)
|
299 |
+
|
300 |
if error:
|
301 |
logging.error(f"Generate RAG response error: {error}")
|
302 |
raise HTTPException(status_code=500, detail=error)
|
|
|
308 |
raise HTTPException(status_code=500, detail=str(e))
|
309 |
|
310 |
|
311 |
+
if __name__ == "__main__":
|
|
|
312 |
import uvicorn
|
313 |
+
|
314 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
requirements.txt
CHANGED
@@ -2,7 +2,7 @@ fastapi==0.111.1
|
|
2 |
fastapi-cli==0.0.4
|
3 |
uvicorn==0.17.6
|
4 |
cryptography>=3.4.7
|
5 |
-
openai==1.
|
6 |
PyJWT==2.6.0
|
7 |
nltk==3.6.7
|
8 |
numpy==1.24.0
|
|
|
2 |
fastapi-cli==0.0.4
|
3 |
uvicorn==0.17.6
|
4 |
cryptography>=3.4.7
|
5 |
+
openai==1.75.0
|
6 |
PyJWT==2.6.0
|
7 |
nltk==3.6.7
|
8 |
numpy==1.24.0
|