Spaces:
Running
Running
SUBHRAJIT MOHANTY
commited on
Commit
·
347dbd1
1
Parent(s):
09225f8
Chore: Bux fixing
Browse files
app.py
CHANGED
@@ -57,10 +57,15 @@ class Config:
|
|
57 |
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.7"))
|
58 |
DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
@asynccontextmanager
|
66 |
async def lifespan(app: FastAPI):
|
@@ -196,15 +201,15 @@ class RAGService:
|
|
196 |
"""Retrieve relevant document chunks from Qdrant"""
|
197 |
try:
|
198 |
# Check if embedding service is initialized
|
199 |
-
if embedding_service is None:
|
200 |
print("Error: Embedding service is not initialized")
|
201 |
return []
|
202 |
|
203 |
# Get query embedding - all-MiniLM works well without special prefixes
|
204 |
-
query_embedding = await embedding_service.get_query_embedding(query)
|
205 |
|
206 |
# Search in Qdrant
|
207 |
-
search_results = await qdrant_client.search(
|
208 |
collection_name=Config.COLLECTION_NAME,
|
209 |
query_vector=query_embedding,
|
210 |
limit=top_k,
|
@@ -254,23 +259,26 @@ async def health_check():
|
|
254 |
"""Health check endpoint"""
|
255 |
try:
|
256 |
# Test Qdrant connection
|
257 |
-
|
258 |
-
|
|
|
|
|
|
|
259 |
except Exception as e:
|
260 |
qdrant_status = f"error: {str(e)}"
|
261 |
|
262 |
# Test embedding service
|
263 |
-
if embedding_service is None:
|
264 |
embedding_health = {"status": "not_initialized", "error": "EmbeddingService is None"}
|
265 |
else:
|
266 |
try:
|
267 |
-
embedding_health = embedding_service.health_check()
|
268 |
except Exception as e:
|
269 |
embedding_health = {"status": "error", "error": str(e)}
|
270 |
|
271 |
return {
|
272 |
-
"status": "healthy" if embedding_service is not None else "unhealthy",
|
273 |
-
"groq": "connected" if groq_client else "not configured",
|
274 |
"qdrant": qdrant_status,
|
275 |
"embedding_service": embedding_health,
|
276 |
"collection": Config.COLLECTION_NAME,
|
@@ -281,7 +289,7 @@ async def health_check():
|
|
281 |
async def chat_completions(request: ChatCompletionRequest):
|
282 |
"""OpenAI-compatible chat completions endpoint with RAG"""
|
283 |
|
284 |
-
if not groq_client:
|
285 |
raise HTTPException(status_code=500, detail="Groq client not initialized")
|
286 |
|
287 |
try:
|
@@ -321,7 +329,7 @@ async def chat_completions(request: ChatCompletionRequest):
|
|
321 |
async def create_chat_completion(messages: List[Dict], request: ChatCompletionRequest) -> ChatCompletionResponse:
|
322 |
"""Create a non-streaming chat completion"""
|
323 |
try:
|
324 |
-
response = await groq_client.chat.completions.create(
|
325 |
model=request.model,
|
326 |
messages=messages,
|
327 |
max_tokens=request.max_tokens,
|
@@ -359,7 +367,7 @@ async def stream_chat_completion(messages: List[Dict], request: ChatCompletionRe
|
|
359 |
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
360 |
created = int(datetime.now().timestamp())
|
361 |
|
362 |
-
stream = await groq_client.chat.completions.create(
|
363 |
model=request.model,
|
364 |
messages=messages,
|
365 |
max_tokens=request.max_tokens,
|
@@ -418,11 +426,11 @@ async def add_document(content: str, metadata: Optional[Dict] = None):
|
|
418 |
"""Add a document to the vector database"""
|
419 |
try:
|
420 |
# Check if embedding service is initialized
|
421 |
-
if embedding_service is None:
|
422 |
raise HTTPException(status_code=500, detail="Embedding service is not initialized")
|
423 |
|
424 |
# Generate embedding for document
|
425 |
-
embedding = await embedding_service.get_document_embedding(content)
|
426 |
|
427 |
# Create point
|
428 |
point = PointStruct(
|
@@ -436,7 +444,7 @@ async def add_document(content: str, metadata: Optional[Dict] = None):
|
|
436 |
)
|
437 |
|
438 |
# Insert into Qdrant
|
439 |
-
await qdrant_client.upsert(
|
440 |
collection_name=Config.COLLECTION_NAME,
|
441 |
points=[point]
|
442 |
)
|
@@ -451,7 +459,7 @@ async def batch_add_documents(documents: List[Dict[str, Any]]):
|
|
451 |
"""Add multiple documents to the vector database"""
|
452 |
try:
|
453 |
# Check if embedding service is initialized
|
454 |
-
if embedding_service is None:
|
455 |
raise HTTPException(status_code=500, detail="Embedding service is not initialized")
|
456 |
|
457 |
# Extract texts and metadata
|
@@ -459,7 +467,7 @@ async def batch_add_documents(documents: List[Dict[str, Any]]):
|
|
459 |
metadatas = [doc.get("metadata", {}) for doc in documents]
|
460 |
|
461 |
# Generate embeddings for all documents
|
462 |
-
embeddings = await embedding_service.batch_embed(texts)
|
463 |
|
464 |
# Create points
|
465 |
points = []
|
@@ -476,7 +484,7 @@ async def batch_add_documents(documents: List[Dict[str, Any]]):
|
|
476 |
points.append(point)
|
477 |
|
478 |
# Insert all points into Qdrant
|
479 |
-
await qdrant_client.upsert(
|
480 |
collection_name=Config.COLLECTION_NAME,
|
481 |
points=points
|
482 |
)
|
@@ -494,22 +502,22 @@ async def create_collection():
|
|
494 |
"""Create a new collection in Qdrant with the correct vector size"""
|
495 |
try:
|
496 |
# Check if embedding service is initialized
|
497 |
-
if embedding_service is None:
|
498 |
raise HTTPException(status_code=500, detail="Embedding service is not initialized")
|
499 |
|
500 |
from qdrant_client.models import VectorParams, Distance
|
501 |
|
502 |
-
await qdrant_client.create_collection(
|
503 |
collection_name=Config.COLLECTION_NAME,
|
504 |
vectors_config=VectorParams(
|
505 |
-
size=embedding_service.dimension, # 384 for all-MiniLM-L6-v2
|
506 |
distance=Distance.COSINE
|
507 |
)
|
508 |
)
|
509 |
|
510 |
return {
|
511 |
"message": f"Collection '{Config.COLLECTION_NAME}' created successfully",
|
512 |
-
"vector_size": embedding_service.dimension,
|
513 |
"distance": "cosine"
|
514 |
}
|
515 |
|
@@ -520,7 +528,10 @@ async def create_collection():
|
|
520 |
async def get_collection_info():
|
521 |
"""Get information about the collection"""
|
522 |
try:
|
523 |
-
|
|
|
|
|
|
|
524 |
return {
|
525 |
"name": Config.COLLECTION_NAME,
|
526 |
"vectors_count": collection_info.vectors_count,
|
|
|
57 |
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.7"))
|
58 |
DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
59 |
|
60 |
+
class ApplicationState:
|
61 |
+
"""Application state container"""
|
62 |
+
def __init__(self):
|
63 |
+
self.groq_client = None
|
64 |
+
self.qdrant_client = None
|
65 |
+
self.embedding_service = None
|
66 |
+
|
67 |
+
# Global state instance
|
68 |
+
app_state = ApplicationState()
|
69 |
|
70 |
@asynccontextmanager
|
71 |
async def lifespan(app: FastAPI):
|
|
|
201 |
"""Retrieve relevant document chunks from Qdrant"""
|
202 |
try:
|
203 |
# Check if embedding service is initialized
|
204 |
+
if app_state.embedding_service is None:
|
205 |
print("Error: Embedding service is not initialized")
|
206 |
return []
|
207 |
|
208 |
# Get query embedding - all-MiniLM works well without special prefixes
|
209 |
+
query_embedding = await app_state.embedding_service.get_query_embedding(query)
|
210 |
|
211 |
# Search in Qdrant
|
212 |
+
search_results = await app_state.qdrant_client.search(
|
213 |
collection_name=Config.COLLECTION_NAME,
|
214 |
query_vector=query_embedding,
|
215 |
limit=top_k,
|
|
|
259 |
"""Health check endpoint"""
|
260 |
try:
|
261 |
# Test Qdrant connection
|
262 |
+
if app_state.qdrant_client:
|
263 |
+
collections = await app_state.qdrant_client.get_collections()
|
264 |
+
qdrant_status = "connected"
|
265 |
+
else:
|
266 |
+
qdrant_status = "not_initialized"
|
267 |
except Exception as e:
|
268 |
qdrant_status = f"error: {str(e)}"
|
269 |
|
270 |
# Test embedding service
|
271 |
+
if app_state.embedding_service is None:
|
272 |
embedding_health = {"status": "not_initialized", "error": "EmbeddingService is None"}
|
273 |
else:
|
274 |
try:
|
275 |
+
embedding_health = app_state.embedding_service.health_check()
|
276 |
except Exception as e:
|
277 |
embedding_health = {"status": "error", "error": str(e)}
|
278 |
|
279 |
return {
|
280 |
+
"status": "healthy" if app_state.embedding_service is not None else "unhealthy",
|
281 |
+
"groq": "connected" if app_state.groq_client else "not configured",
|
282 |
"qdrant": qdrant_status,
|
283 |
"embedding_service": embedding_health,
|
284 |
"collection": Config.COLLECTION_NAME,
|
|
|
289 |
async def chat_completions(request: ChatCompletionRequest):
|
290 |
"""OpenAI-compatible chat completions endpoint with RAG"""
|
291 |
|
292 |
+
if not app_state.groq_client:
|
293 |
raise HTTPException(status_code=500, detail="Groq client not initialized")
|
294 |
|
295 |
try:
|
|
|
329 |
async def create_chat_completion(messages: List[Dict], request: ChatCompletionRequest) -> ChatCompletionResponse:
|
330 |
"""Create a non-streaming chat completion"""
|
331 |
try:
|
332 |
+
response = await app_state.groq_client.chat.completions.create(
|
333 |
model=request.model,
|
334 |
messages=messages,
|
335 |
max_tokens=request.max_tokens,
|
|
|
367 |
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
368 |
created = int(datetime.now().timestamp())
|
369 |
|
370 |
+
stream = await app_state.groq_client.chat.completions.create(
|
371 |
model=request.model,
|
372 |
messages=messages,
|
373 |
max_tokens=request.max_tokens,
|
|
|
426 |
"""Add a document to the vector database"""
|
427 |
try:
|
428 |
# Check if embedding service is initialized
|
429 |
+
if app_state.embedding_service is None:
|
430 |
raise HTTPException(status_code=500, detail="Embedding service is not initialized")
|
431 |
|
432 |
# Generate embedding for document
|
433 |
+
embedding = await app_state.embedding_service.get_document_embedding(content)
|
434 |
|
435 |
# Create point
|
436 |
point = PointStruct(
|
|
|
444 |
)
|
445 |
|
446 |
# Insert into Qdrant
|
447 |
+
await app_state.qdrant_client.upsert(
|
448 |
collection_name=Config.COLLECTION_NAME,
|
449 |
points=[point]
|
450 |
)
|
|
|
459 |
"""Add multiple documents to the vector database"""
|
460 |
try:
|
461 |
# Check if embedding service is initialized
|
462 |
+
if app_state.embedding_service is None:
|
463 |
raise HTTPException(status_code=500, detail="Embedding service is not initialized")
|
464 |
|
465 |
# Extract texts and metadata
|
|
|
467 |
metadatas = [doc.get("metadata", {}) for doc in documents]
|
468 |
|
469 |
# Generate embeddings for all documents
|
470 |
+
embeddings = await app_state.embedding_service.batch_embed(texts)
|
471 |
|
472 |
# Create points
|
473 |
points = []
|
|
|
484 |
points.append(point)
|
485 |
|
486 |
# Insert all points into Qdrant
|
487 |
+
await app_state.qdrant_client.upsert(
|
488 |
collection_name=Config.COLLECTION_NAME,
|
489 |
points=points
|
490 |
)
|
|
|
502 |
"""Create a new collection in Qdrant with the correct vector size"""
|
503 |
try:
|
504 |
# Check if embedding service is initialized
|
505 |
+
if app_state.embedding_service is None:
|
506 |
raise HTTPException(status_code=500, detail="Embedding service is not initialized")
|
507 |
|
508 |
from qdrant_client.models import VectorParams, Distance
|
509 |
|
510 |
+
await app_state.qdrant_client.create_collection(
|
511 |
collection_name=Config.COLLECTION_NAME,
|
512 |
vectors_config=VectorParams(
|
513 |
+
size=app_state.embedding_service.dimension, # 384 for all-MiniLM-L6-v2
|
514 |
distance=Distance.COSINE
|
515 |
)
|
516 |
)
|
517 |
|
518 |
return {
|
519 |
"message": f"Collection '{Config.COLLECTION_NAME}' created successfully",
|
520 |
+
"vector_size": app_state.embedding_service.dimension,
|
521 |
"distance": "cosine"
|
522 |
}
|
523 |
|
|
|
528 |
async def get_collection_info():
|
529 |
"""Get information about the collection"""
|
530 |
try:
|
531 |
+
if app_state.qdrant_client is None:
|
532 |
+
raise HTTPException(status_code=500, detail="Qdrant client is not initialized")
|
533 |
+
|
534 |
+
collection_info = await app_state.qdrant_client.get_collection(Config.COLLECTION_NAME)
|
535 |
return {
|
536 |
"name": Config.COLLECTION_NAME,
|
537 |
"vectors_count": collection_info.vectors_count,
|