SUBHRAJIT MOHANTY commited on
Commit
347dbd1
·
1 Parent(s): 09225f8

Chore: Bux fixing

Browse files
Files changed (1) hide show
  1. app.py +38 -27
app.py CHANGED
@@ -57,10 +57,15 @@ class Config:
57
  SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.7"))
58
  DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
59
 
60
- # Global clients
61
- groq_client = None
62
- qdrant_client = None
63
- embedding_service = None
 
 
 
 
 
64
 
65
  @asynccontextmanager
66
  async def lifespan(app: FastAPI):
@@ -196,15 +201,15 @@ class RAGService:
196
  """Retrieve relevant document chunks from Qdrant"""
197
  try:
198
  # Check if embedding service is initialized
199
- if embedding_service is None:
200
  print("Error: Embedding service is not initialized")
201
  return []
202
 
203
  # Get query embedding - all-MiniLM works well without special prefixes
204
- query_embedding = await embedding_service.get_query_embedding(query)
205
 
206
  # Search in Qdrant
207
- search_results = await qdrant_client.search(
208
  collection_name=Config.COLLECTION_NAME,
209
  query_vector=query_embedding,
210
  limit=top_k,
@@ -254,23 +259,26 @@ async def health_check():
254
  """Health check endpoint"""
255
  try:
256
  # Test Qdrant connection
257
- collections = await qdrant_client.get_collections()
258
- qdrant_status = "connected"
 
 
 
259
  except Exception as e:
260
  qdrant_status = f"error: {str(e)}"
261
 
262
  # Test embedding service
263
- if embedding_service is None:
264
  embedding_health = {"status": "not_initialized", "error": "EmbeddingService is None"}
265
  else:
266
  try:
267
- embedding_health = embedding_service.health_check()
268
  except Exception as e:
269
  embedding_health = {"status": "error", "error": str(e)}
270
 
271
  return {
272
- "status": "healthy" if embedding_service is not None else "unhealthy",
273
- "groq": "connected" if groq_client else "not configured",
274
  "qdrant": qdrant_status,
275
  "embedding_service": embedding_health,
276
  "collection": Config.COLLECTION_NAME,
@@ -281,7 +289,7 @@ async def health_check():
281
  async def chat_completions(request: ChatCompletionRequest):
282
  """OpenAI-compatible chat completions endpoint with RAG"""
283
 
284
- if not groq_client:
285
  raise HTTPException(status_code=500, detail="Groq client not initialized")
286
 
287
  try:
@@ -321,7 +329,7 @@ async def chat_completions(request: ChatCompletionRequest):
321
  async def create_chat_completion(messages: List[Dict], request: ChatCompletionRequest) -> ChatCompletionResponse:
322
  """Create a non-streaming chat completion"""
323
  try:
324
- response = await groq_client.chat.completions.create(
325
  model=request.model,
326
  messages=messages,
327
  max_tokens=request.max_tokens,
@@ -359,7 +367,7 @@ async def stream_chat_completion(messages: List[Dict], request: ChatCompletionRe
359
  completion_id = f"chatcmpl-{uuid.uuid4().hex}"
360
  created = int(datetime.now().timestamp())
361
 
362
- stream = await groq_client.chat.completions.create(
363
  model=request.model,
364
  messages=messages,
365
  max_tokens=request.max_tokens,
@@ -418,11 +426,11 @@ async def add_document(content: str, metadata: Optional[Dict] = None):
418
  """Add a document to the vector database"""
419
  try:
420
  # Check if embedding service is initialized
421
- if embedding_service is None:
422
  raise HTTPException(status_code=500, detail="Embedding service is not initialized")
423
 
424
  # Generate embedding for document
425
- embedding = await embedding_service.get_document_embedding(content)
426
 
427
  # Create point
428
  point = PointStruct(
@@ -436,7 +444,7 @@ async def add_document(content: str, metadata: Optional[Dict] = None):
436
  )
437
 
438
  # Insert into Qdrant
439
- await qdrant_client.upsert(
440
  collection_name=Config.COLLECTION_NAME,
441
  points=[point]
442
  )
@@ -451,7 +459,7 @@ async def batch_add_documents(documents: List[Dict[str, Any]]):
451
  """Add multiple documents to the vector database"""
452
  try:
453
  # Check if embedding service is initialized
454
- if embedding_service is None:
455
  raise HTTPException(status_code=500, detail="Embedding service is not initialized")
456
 
457
  # Extract texts and metadata
@@ -459,7 +467,7 @@ async def batch_add_documents(documents: List[Dict[str, Any]]):
459
  metadatas = [doc.get("metadata", {}) for doc in documents]
460
 
461
  # Generate embeddings for all documents
462
- embeddings = await embedding_service.batch_embed(texts)
463
 
464
  # Create points
465
  points = []
@@ -476,7 +484,7 @@ async def batch_add_documents(documents: List[Dict[str, Any]]):
476
  points.append(point)
477
 
478
  # Insert all points into Qdrant
479
- await qdrant_client.upsert(
480
  collection_name=Config.COLLECTION_NAME,
481
  points=points
482
  )
@@ -494,22 +502,22 @@ async def create_collection():
494
  """Create a new collection in Qdrant with the correct vector size"""
495
  try:
496
  # Check if embedding service is initialized
497
- if embedding_service is None:
498
  raise HTTPException(status_code=500, detail="Embedding service is not initialized")
499
 
500
  from qdrant_client.models import VectorParams, Distance
501
 
502
- await qdrant_client.create_collection(
503
  collection_name=Config.COLLECTION_NAME,
504
  vectors_config=VectorParams(
505
- size=embedding_service.dimension, # 384 for all-MiniLM-L6-v2
506
  distance=Distance.COSINE
507
  )
508
  )
509
 
510
  return {
511
  "message": f"Collection '{Config.COLLECTION_NAME}' created successfully",
512
- "vector_size": embedding_service.dimension,
513
  "distance": "cosine"
514
  }
515
 
@@ -520,7 +528,10 @@ async def create_collection():
520
  async def get_collection_info():
521
  """Get information about the collection"""
522
  try:
523
- collection_info = await qdrant_client.get_collection(Config.COLLECTION_NAME)
 
 
 
524
  return {
525
  "name": Config.COLLECTION_NAME,
526
  "vectors_count": collection_info.vectors_count,
 
57
  SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.7"))
58
  DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
59
 
60
+ class ApplicationState:
61
+ """Application state container"""
62
+ def __init__(self):
63
+ self.groq_client = None
64
+ self.qdrant_client = None
65
+ self.embedding_service = None
66
+
67
+ # Global state instance
68
+ app_state = ApplicationState()
69
 
70
  @asynccontextmanager
71
  async def lifespan(app: FastAPI):
 
201
  """Retrieve relevant document chunks from Qdrant"""
202
  try:
203
  # Check if embedding service is initialized
204
+ if app_state.embedding_service is None:
205
  print("Error: Embedding service is not initialized")
206
  return []
207
 
208
  # Get query embedding - all-MiniLM works well without special prefixes
209
+ query_embedding = await app_state.embedding_service.get_query_embedding(query)
210
 
211
  # Search in Qdrant
212
+ search_results = await app_state.qdrant_client.search(
213
  collection_name=Config.COLLECTION_NAME,
214
  query_vector=query_embedding,
215
  limit=top_k,
 
259
  """Health check endpoint"""
260
  try:
261
  # Test Qdrant connection
262
+ if app_state.qdrant_client:
263
+ collections = await app_state.qdrant_client.get_collections()
264
+ qdrant_status = "connected"
265
+ else:
266
+ qdrant_status = "not_initialized"
267
  except Exception as e:
268
  qdrant_status = f"error: {str(e)}"
269
 
270
  # Test embedding service
271
+ if app_state.embedding_service is None:
272
  embedding_health = {"status": "not_initialized", "error": "EmbeddingService is None"}
273
  else:
274
  try:
275
+ embedding_health = app_state.embedding_service.health_check()
276
  except Exception as e:
277
  embedding_health = {"status": "error", "error": str(e)}
278
 
279
  return {
280
+ "status": "healthy" if app_state.embedding_service is not None else "unhealthy",
281
+ "groq": "connected" if app_state.groq_client else "not configured",
282
  "qdrant": qdrant_status,
283
  "embedding_service": embedding_health,
284
  "collection": Config.COLLECTION_NAME,
 
289
  async def chat_completions(request: ChatCompletionRequest):
290
  """OpenAI-compatible chat completions endpoint with RAG"""
291
 
292
+ if not app_state.groq_client:
293
  raise HTTPException(status_code=500, detail="Groq client not initialized")
294
 
295
  try:
 
329
  async def create_chat_completion(messages: List[Dict], request: ChatCompletionRequest) -> ChatCompletionResponse:
330
  """Create a non-streaming chat completion"""
331
  try:
332
+ response = await app_state.groq_client.chat.completions.create(
333
  model=request.model,
334
  messages=messages,
335
  max_tokens=request.max_tokens,
 
367
  completion_id = f"chatcmpl-{uuid.uuid4().hex}"
368
  created = int(datetime.now().timestamp())
369
 
370
+ stream = await app_state.groq_client.chat.completions.create(
371
  model=request.model,
372
  messages=messages,
373
  max_tokens=request.max_tokens,
 
426
  """Add a document to the vector database"""
427
  try:
428
  # Check if embedding service is initialized
429
+ if app_state.embedding_service is None:
430
  raise HTTPException(status_code=500, detail="Embedding service is not initialized")
431
 
432
  # Generate embedding for document
433
+ embedding = await app_state.embedding_service.get_document_embedding(content)
434
 
435
  # Create point
436
  point = PointStruct(
 
444
  )
445
 
446
  # Insert into Qdrant
447
+ await app_state.qdrant_client.upsert(
448
  collection_name=Config.COLLECTION_NAME,
449
  points=[point]
450
  )
 
459
  """Add multiple documents to the vector database"""
460
  try:
461
  # Check if embedding service is initialized
462
+ if app_state.embedding_service is None:
463
  raise HTTPException(status_code=500, detail="Embedding service is not initialized")
464
 
465
  # Extract texts and metadata
 
467
  metadatas = [doc.get("metadata", {}) for doc in documents]
468
 
469
  # Generate embeddings for all documents
470
+ embeddings = await app_state.embedding_service.batch_embed(texts)
471
 
472
  # Create points
473
  points = []
 
484
  points.append(point)
485
 
486
  # Insert all points into Qdrant
487
+ await app_state.qdrant_client.upsert(
488
  collection_name=Config.COLLECTION_NAME,
489
  points=points
490
  )
 
502
  """Create a new collection in Qdrant with the correct vector size"""
503
  try:
504
  # Check if embedding service is initialized
505
+ if app_state.embedding_service is None:
506
  raise HTTPException(status_code=500, detail="Embedding service is not initialized")
507
 
508
  from qdrant_client.models import VectorParams, Distance
509
 
510
+ await app_state.qdrant_client.create_collection(
511
  collection_name=Config.COLLECTION_NAME,
512
  vectors_config=VectorParams(
513
+ size=app_state.embedding_service.dimension, # 384 for all-MiniLM-L6-v2
514
  distance=Distance.COSINE
515
  )
516
  )
517
 
518
  return {
519
  "message": f"Collection '{Config.COLLECTION_NAME}' created successfully",
520
+ "vector_size": app_state.embedding_service.dimension,
521
  "distance": "cosine"
522
  }
523
 
 
528
  async def get_collection_info():
529
  """Get information about the collection"""
530
  try:
531
+ if app_state.qdrant_client is None:
532
+ raise HTTPException(status_code=500, detail="Qdrant client is not initialized")
533
+
534
+ collection_info = await app_state.qdrant_client.get_collection(Config.COLLECTION_NAME)
535
  return {
536
  "name": Config.COLLECTION_NAME,
537
  "vectors_count": collection_info.vectors_count,