SUBHRAJIT MOHANTY commited on
Commit
5dbc569
·
1 Parent(s): 86e4192

app.py updated

Browse files
Files changed (1) hide show
  1. app.py +486 -299
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, HTTPException, Request
2
  from fastapi.responses import StreamingResponse
3
  from pydantic import BaseModel, Field
4
  from typing import List, Optional, Dict, Any, AsyncGenerator
@@ -8,6 +8,8 @@ import uuid
8
  from datetime import datetime
9
  import os
10
  from contextlib import asynccontextmanager
 
 
11
 
12
  # Third-party imports
13
  from openai import AsyncOpenAI
@@ -17,6 +19,7 @@ from sentence_transformers import SentenceTransformer
17
  import torch
18
  import asyncio
19
  from concurrent.futures import ThreadPoolExecutor
 
20
 
21
  # Models for OpenAI-compatible API
22
  class Message(BaseModel):
@@ -46,6 +49,14 @@ class ChatCompletionChunk(BaseModel):
46
  model: str
47
  choices: List[Dict[str, Any]]
48
 
 
 
 
 
 
 
 
 
49
  # Configuration
50
  class Config:
51
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
@@ -64,127 +75,11 @@ class ApplicationState:
64
  self.openai_client = None
65
  self.qdrant_client = None
66
  self.embedding_service = None
 
67
 
68
  # Global state instance
69
  app_state = ApplicationState()
70
 
71
- @asynccontextmanager
72
- async def lifespan(app: FastAPI):
73
- # Startup
74
- if not Config.GROQ_API_KEY:
75
- raise ValueError("GROQ_API_KEY environment variable is required")
76
-
77
- print("Initializing services...")
78
-
79
- # Initialize OpenAI client with Groq endpoint
80
- try:
81
- print(f"Configuring OpenAI client with:")
82
- print(f" Base URL: {Config.GROQ_BASE_URL}")
83
- print(f" API Key: {'*' * 10}...{Config.GROQ_API_KEY[-4:] if Config.GROQ_API_KEY else 'None'}")
84
-
85
- app_state.openai_client = AsyncOpenAI(
86
- api_key=Config.GROQ_API_KEY,
87
- base_url=Config.GROQ_BASE_URL,
88
- timeout=60.0 # Add timeout
89
- )
90
- print("✓ OpenAI client initialized with Groq endpoint")
91
-
92
- # Test the client with a simple request
93
- try:
94
- test_response = await app_state.openai_client.chat.completions.create(
95
- model="mixtral-8x7b-32768",
96
- messages=[{"role": "user", "content": "Hello"}],
97
- max_tokens=10
98
- )
99
- print(f"✓ OpenAI client test successful - Response ID: {test_response.id}")
100
- except Exception as test_error:
101
- print(f"⚠ OpenAI client test failed: {test_error}")
102
- print(" This might cause issues with chat completions")
103
-
104
- except Exception as e:
105
- print(f"✗ Error initializing OpenAI client: {e}")
106
- print(f" Error type: {type(e)}")
107
- raise e
108
-
109
- # Initialize Qdrant client
110
- try:
111
- app_state.qdrant_client = AsyncQdrantClient(
112
- url=Config.QDRANT_URL,
113
- api_key=Config.QDRANT_API_KEY
114
- )
115
- print("✓ Qdrant client initialized")
116
- except Exception as e:
117
- print(f"✗ Error initializing Qdrant client: {e}")
118
- raise e
119
-
120
- # Initialize embedding service
121
- try:
122
- print("Loading embedding model...")
123
- app_state.embedding_service = EmbeddingService()
124
- print(f"✓ Embedding model loaded: {Config.EMBEDDING_MODEL}")
125
- print(f"✓ Model device: {Config.DEVICE}")
126
- print(f"✓ Vector dimension: {app_state.embedding_service.dimension}")
127
- except Exception as e:
128
- print(f"✗ Error initializing embedding service: {e}")
129
- raise e # Fail fast if embedding service can't be initialized
130
-
131
- # Verify Qdrant connection and auto-create collection
132
- try:
133
- collections = await app_state.qdrant_client.get_collections()
134
- collection_names = [c.name for c in collections.collections]
135
- print(f"✓ Connected to Qdrant. Available collections: {collection_names}")
136
-
137
- # Check if our collection exists, if not create it
138
- if Config.COLLECTION_NAME not in collection_names:
139
- print(f"📁 Collection '{Config.COLLECTION_NAME}' not found. Creating automatically...")
140
- try:
141
- from qdrant_client.models import VectorParams, Distance
142
-
143
- await app_state.qdrant_client.create_collection(
144
- collection_name=Config.COLLECTION_NAME,
145
- vectors_config=VectorParams(
146
- size=app_state.embedding_service.dimension,
147
- distance=Distance.COSINE
148
- )
149
- )
150
- print(f"✓ Collection '{Config.COLLECTION_NAME}' created successfully!")
151
- print(f"✓ Vector dimension: {app_state.embedding_service.dimension}")
152
- print(f"✓ Distance metric: COSINE")
153
- except Exception as create_error:
154
- print(f"✗ Failed to create collection: {create_error}")
155
- print("⚠ You may need to create the collection manually")
156
- else:
157
- print(f"✓ Collection '{Config.COLLECTION_NAME}' already exists")
158
-
159
- except Exception as e:
160
- print(f"⚠ Warning: Could not connect to Qdrant: {e}")
161
- print("⚠ Collection auto-creation skipped")
162
-
163
- print("🚀 All services initialized successfully!")
164
-
165
- yield
166
-
167
- # Shutdown
168
- print("Shutting down services...")
169
- if app_state.qdrant_client:
170
- await app_state.qdrant_client.close()
171
- print("✓ Qdrant client closed")
172
- if app_state.openai_client:
173
- await app_state.openai_client.close()
174
- print("✓ OpenAI client closed")
175
- if app_state.embedding_service and hasattr(app_state.embedding_service, 'executor'):
176
- app_state.embedding_service.executor.shutdown(wait=True)
177
- print("✓ Embedding service executor shutdown")
178
- print("✓ Shutdown complete")
179
-
180
- # Initialize FastAPI app
181
- app = FastAPI(
182
- title="RAG API with Groq and Qdrant",
183
- description="OpenAI-compatible API for RAG using Groq and Qdrant",
184
- version="1.0.0",
185
- lifespan=lifespan
186
- )
187
-
188
  class EmbeddingService:
189
  """Service for generating embeddings using sentence-transformers"""
190
 
@@ -202,7 +97,6 @@ class EmbeddingService:
202
  async def get_embedding(self, text: str) -> List[float]:
203
  """Generate embedding for given text"""
204
  try:
205
- # Run the synchronous model.encode in a thread pool
206
  loop = asyncio.get_event_loop()
207
  embedding = await loop.run_in_executor(
208
  self.executor,
@@ -247,7 +141,6 @@ class EmbeddingService:
247
  def health_check(self) -> dict:
248
  """Check embedding service health"""
249
  try:
250
- # Test encoding
251
  test_embedding = self.model.encode(["test"])
252
  return {
253
  "status": "healthy",
@@ -263,94 +156,384 @@ class EmbeddingService:
263
  "error": str(e)
264
  }
265
 
266
- class RAGService:
267
- """Service for retrieval-augmented generation"""
268
 
269
- @staticmethod
270
- async def retrieve_relevant_chunks(query: str, top_k: int = Config.TOP_K) -> List[str]:
271
- """Retrieve relevant document chunks from Qdrant"""
 
 
 
 
 
 
272
  try:
273
- # Check if embedding service is initialized
274
- if app_state.embedding_service is None:
275
- print("Error: Embedding service is not initialized")
276
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
- # Auto-create collection if it doesn't exist
279
- await RAGService._ensure_collection_exists()
 
 
 
 
 
 
280
 
281
- # Get query embedding - all-MiniLM works well without special prefixes
282
- query_embedding = await app_state.embedding_service.get_query_embedding(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  # Search in Qdrant
285
- search_results = await app_state.qdrant_client.search(
286
- collection_name=Config.COLLECTION_NAME,
287
  query_vector=query_embedding,
288
- limit=top_k,
289
- score_threshold=Config.SIMILARITY_THRESHOLD
290
  )
291
 
292
- # Extract content from results
293
- chunks = []
294
  for result in search_results:
295
- if hasattr(result, 'payload') and 'content' in result.payload:
296
- chunks.append(result.payload['content'])
297
- elif hasattr(result, 'payload') and 'text' in result.payload:
298
- chunks.append(result.payload['text'])
 
 
 
299
 
300
- print(f"Retrieved {len(chunks)} relevant chunks for query")
301
- return chunks
302
 
303
  except Exception as e:
304
- print(f"Error retrieving chunks: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  return []
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  @staticmethod
308
- async def _ensure_collection_exists():
309
- """Ensure the collection exists, create if it doesn't"""
310
  try:
311
- # Check if collection exists
312
- collections = await app_state.qdrant_client.get_collections()
313
- collection_names = [c.name for c in collections.collections]
 
 
 
 
 
 
 
 
 
314
 
315
- if Config.COLLECTION_NAME not in collection_names:
316
- print(f"Creating collection '{Config.COLLECTION_NAME}' on-demand...")
317
- from qdrant_client.models import VectorParams, Distance
318
-
319
- await app_state.qdrant_client.create_collection(
320
- collection_name=Config.COLLECTION_NAME,
321
- vectors_config=VectorParams(
322
- size=app_state.embedding_service.dimension,
323
- distance=Distance.COSINE
324
- )
325
- )
326
- print(f"✓ Collection '{Config.COLLECTION_NAME}' created successfully!")
327
-
328
  except Exception as e:
329
- print(f"Warning: Could not ensure collection exists: {e}")
330
- # Continue anyway - the operation might still work
331
 
332
  @staticmethod
333
- def build_context_prompt(query: str, chunks: List[str]) -> str:
334
  """Build a context-aware prompt with retrieved chunks"""
335
- if not chunks:
336
  return query
337
 
338
- context = "\n\n".join([f"Document {i+1}: {chunk}" for i, chunk in enumerate(chunks)])
 
 
 
 
 
339
 
340
- prompt = f"""Based on the following documents, please answer the user's question. If the information is not available in the documents, please say so.
341
 
342
- Context Documents:
343
- {context}
344
 
345
- User Question: {query}
346
 
347
- Please provide a helpful and accurate response based on the context provided."""
348
 
349
  return prompt
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  @app.get("/")
352
  async def root():
353
- return {"message": "RAG API with Groq and Qdrant", "status": "running"}
354
 
355
  @app.get("/health")
356
  async def health_check():
@@ -379,7 +562,6 @@ async def health_check():
379
  openai_health = {"status": "not_initialized", "error": "OpenAI client is None"}
380
  else:
381
  try:
382
- # Quick test of OpenAI client
383
  test_response = await app_state.openai_client.chat.completions.create(
384
  model="mixtral-8x7b-32768",
385
  messages=[{"role": "user", "content": "test"}],
@@ -394,6 +576,7 @@ async def health_check():
394
  "openai_client": openai_health,
395
  "qdrant": qdrant_status,
396
  "embedding_service": embedding_health,
 
397
  "collection": Config.COLLECTION_NAME,
398
  "embedding_model": Config.EMBEDDING_MODEL,
399
  "groq_endpoint": Config.GROQ_BASE_URL
@@ -401,7 +584,7 @@ async def health_check():
401
 
402
  @app.post("/v1/chat/completions")
403
  async def chat_completions(request: ChatCompletionRequest):
404
- """OpenAI-compatible chat completions endpoint with RAG"""
405
 
406
  if not app_state.openai_client:
407
  raise HTTPException(status_code=500, detail="OpenAI client not initialized")
@@ -415,17 +598,17 @@ async def chat_completions(request: ChatCompletionRequest):
415
  last_user_message = user_messages[-1].content
416
  print(f"Processing query: {last_user_message[:100]}...")
417
 
418
- # Retrieve relevant chunks
419
  try:
420
- relevant_chunks = await RAGService.retrieve_relevant_chunks(last_user_message)
421
- print(f"Retrieved {len(relevant_chunks)} chunks")
422
  except Exception as e:
423
  print(f"Error in retrieval: {e}")
424
- relevant_chunks = []
425
 
426
  # Build context-aware prompt
427
- if relevant_chunks:
428
- context_prompt = RAGService.build_context_prompt(last_user_message, relevant_chunks)
429
  enhanced_messages = request.messages[:-1] + [Message(role="user", content=context_prompt)]
430
  print("Using context-enhanced prompt")
431
  else:
@@ -448,7 +631,6 @@ async def chat_completions(request: ChatCompletionRequest):
448
  raise
449
  except Exception as e:
450
  print(f"Unexpected error in chat_completions: {e}")
451
- print(f"Error type: {type(e)}")
452
  import traceback
453
  traceback.print_exc()
454
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@@ -456,10 +638,6 @@ async def chat_completions(request: ChatCompletionRequest):
456
  async def create_chat_completion(messages: List[Dict], request: ChatCompletionRequest) -> ChatCompletionResponse:
457
  """Create a non-streaming chat completion"""
458
  try:
459
- print(f"Calling OpenAI API with model: {request.model}")
460
- print(f"Messages count: {len(messages)}")
461
- print(f"Max tokens: {request.max_tokens}")
462
-
463
  response = await app_state.openai_client.chat.completions.create(
464
  model=request.model,
465
  messages=messages,
@@ -469,12 +647,6 @@ async def create_chat_completion(messages: List[Dict], request: ChatCompletionRe
469
  stream=False
470
  )
471
 
472
- print(f"Received response from OpenAI API")
473
- print(f"Response ID: {response.id}")
474
- print(f"Response model: {response.model}")
475
- print(f"Choices count: {len(response.choices)}")
476
-
477
- # Convert response to OpenAI format (already compatible)
478
  result = ChatCompletionResponse(
479
  id=response.id,
480
  created=response.created,
@@ -494,14 +666,10 @@ async def create_chat_completion(messages: List[Dict], request: ChatCompletionRe
494
  } if response.usage else None
495
  )
496
 
497
- print(f"Successfully created response")
498
  return result
499
 
500
  except Exception as e:
501
  print(f"Error in create_chat_completion: {e}")
502
- print(f"Error type: {type(e)}")
503
- import traceback
504
- traceback.print_exc()
505
  raise HTTPException(status_code=500, detail=f"Error calling OpenAI API: {str(e)}")
506
 
507
  async def stream_chat_completion(messages: List[Dict], request: ChatCompletionRequest) -> AsyncGenerator[str, None]:
@@ -536,7 +704,6 @@ async def stream_chat_completion(messages: List[Dict], request: ChatCompletionRe
536
 
537
  yield f"data: {chunk_response.model_dump_json()}\n\n"
538
 
539
- # Send final chunk
540
  yield "data: [DONE]\n\n"
541
 
542
  except Exception as e:
@@ -549,132 +716,153 @@ async def stream_chat_completion(messages: List[Dict], request: ChatCompletionRe
549
  }
550
  yield f"data: {json.dumps(error_chunk)}\n\n"
551
 
552
- # Additional endpoints for managing the vector database
553
- @app.post("/v1/embeddings/add")
554
- async def add_document(content: str, metadata: Optional[Dict] = None):
555
- """Add a document to the vector database"""
556
  try:
557
- # Check if embedding service is initialized
558
- if app_state.embedding_service is None:
559
- raise HTTPException(status_code=500, detail="Embedding service is not initialized")
560
 
561
- # Auto-create collection if it doesn't exist
562
- await RAGService._ensure_collection_exists()
 
563
 
564
- # Generate embedding for document
565
- embedding = await app_state.embedding_service.get_document_embedding(content)
 
 
 
 
 
566
 
567
- # Create point
568
- point = PointStruct(
569
- id=str(uuid.uuid4()),
570
- vector=embedding,
571
- payload={
572
- "content": content,
573
- "metadata": metadata or {},
574
- "timestamp": datetime.now().isoformat()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
  }
576
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
 
578
- # Insert into Qdrant
579
- await app_state.qdrant_client.upsert(
580
- collection_name=Config.COLLECTION_NAME,
581
- points=[point]
582
  )
583
 
584
- return {"message": "Document added successfully", "id": point.id}
 
 
 
 
585
 
586
  except Exception as e:
587
- raise HTTPException(status_code=500, detail=f"Error adding document: {str(e)}")
 
588
 
589
- @app.post("/v1/embeddings/batch_add")
590
- async def batch_add_documents(documents: List[Dict[str, Any]]):
591
- """Add multiple documents to the vector database"""
592
  try:
593
- # Check if embedding service is initialized
594
- if app_state.embedding_service is None:
595
- raise HTTPException(status_code=500, detail="Embedding service is not initialized")
596
-
597
- # Auto-create collection if it doesn't exist
598
- await RAGService._ensure_collection_exists()
599
-
600
- # Extract texts and metadata
601
- texts = [doc.get("content", "") for doc in documents]
602
- metadatas = [doc.get("metadata", {}) for doc in documents]
603
-
604
- # Generate embeddings for all documents
605
- embeddings = await app_state.embedding_service.batch_embed(texts)
606
-
607
- # Create points
608
- points = []
609
- for i, (text, embedding, metadata) in enumerate(zip(texts, embeddings, metadatas)):
610
- point = PointStruct(
611
- id=str(uuid.uuid4()),
612
- vector=embedding,
613
- payload={
614
- "content": text,
615
- "metadata": metadata,
616
- "timestamp": datetime.now().isoformat()
617
- }
618
- )
619
- points.append(point)
620
 
621
- # Insert all points into Qdrant
622
- await app_state.qdrant_client.upsert(
623
- collection_name=Config.COLLECTION_NAME,
624
- points=points
625
- )
626
 
627
  return {
628
- "message": f"Successfully added {len(points)} documents",
629
- "ids": [point.id for point in points]
630
  }
631
 
632
  except Exception as e:
633
- raise HTTPException(status_code=500, detail=f"Error adding documents: {str(e)}")
 
634
 
635
- @app.post("/v1/embeddings/create_collection")
636
- async def create_collection():
637
- """Create a new collection in Qdrant with the correct vector size"""
638
  try:
639
- # Check if embedding service is initialized
640
- if app_state.embedding_service is None:
641
- raise HTTPException(status_code=500, detail="Embedding service is not initialized")
642
 
643
- from qdrant_client.models import VectorParams, Distance
644
 
645
- # Check if collection already exists
646
- try:
647
- collections = await app_state.qdrant_client.get_collections()
648
- collection_names = [c.name for c in collections.collections]
649
-
650
- if Config.COLLECTION_NAME in collection_names:
651
- return {
652
- "message": f"Collection '{Config.COLLECTION_NAME}' already exists",
653
- "vector_size": app_state.embedding_service.dimension,
654
- "distance": "cosine",
655
- "status": "exists"
656
- }
657
- except Exception as e:
658
- print(f"Warning: Could not check existing collections: {e}")
 
 
 
 
659
 
660
- # Create the collection
661
- await app_state.qdrant_client.create_collection(
 
 
 
 
 
 
 
 
 
 
 
 
 
662
  collection_name=Config.COLLECTION_NAME,
663
- vectors_config=VectorParams(
664
- size=app_state.embedding_service.dimension, # 384 for all-MiniLM-L6-v2
665
- distance=Distance.COSINE
666
- )
667
  )
668
 
669
- return {
670
- "message": f"Collection '{Config.COLLECTION_NAME}' created successfully",
671
- "vector_size": app_state.embedding_service.dimension,
672
- "distance": "cosine",
673
- "status": "created"
674
- }
675
 
676
  except Exception as e:
677
- raise HTTPException(status_code=500, detail=f"Error creating collection: {str(e)}")
678
 
679
  @app.get("/v1/collections/info")
680
  async def get_collection_info():
@@ -683,8 +871,7 @@ async def get_collection_info():
683
  if app_state.qdrant_client is None:
684
  raise HTTPException(status_code=500, detail="Qdrant client is not initialized")
685
 
686
- # Auto-create collection if it doesn't exist
687
- await RAGService._ensure_collection_exists()
688
 
689
  collection_info = await app_state.qdrant_client.get_collection(Config.COLLECTION_NAME)
690
  return {
 
1
+ from fastapi import FastAPI, HTTPException, Request, UploadFile, File
2
  from fastapi.responses import StreamingResponse
3
  from pydantic import BaseModel, Field
4
  from typing import List, Optional, Dict, Any, AsyncGenerator
 
8
  from datetime import datetime
9
  import os
10
  from contextlib import asynccontextmanager
11
+ import tempfile
12
+ import shutil
13
 
14
  # Third-party imports
15
  from openai import AsyncOpenAI
 
19
  import torch
20
  import asyncio
21
  from concurrent.futures import ThreadPoolExecutor
22
+ import PyPDF2
23
 
24
  # Models for OpenAI-compatible API
25
  class Message(BaseModel):
 
49
  model: str
50
  choices: List[Dict[str, Any]]
51
 
52
+ class DocumentUploadRequest(BaseModel):
53
+ metadata: Optional[Dict[str, Any]] = None
54
+
55
+ class DocumentSearchRequest(BaseModel):
56
+ query: str = Field(..., description="Search query")
57
+ limit: int = Field(default=5, description="Maximum number of results")
58
+ min_score: float = Field(default=0.1, description="Minimum similarity score")
59
+
60
  # Configuration
61
  class Config:
62
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
 
75
  self.openai_client = None
76
  self.qdrant_client = None
77
  self.embedding_service = None
78
+ self.document_manager = None
79
 
80
  # Global state instance
81
  app_state = ApplicationState()
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  class EmbeddingService:
84
  """Service for generating embeddings using sentence-transformers"""
85
 
 
97
  async def get_embedding(self, text: str) -> List[float]:
98
  """Generate embedding for given text"""
99
  try:
 
100
  loop = asyncio.get_event_loop()
101
  embedding = await loop.run_in_executor(
102
  self.executor,
 
141
  def health_check(self) -> dict:
142
  """Check embedding service health"""
143
  try:
 
144
  test_embedding = self.model.encode(["test"])
145
  return {
146
  "status": "healthy",
 
156
  "error": str(e)
157
  }
158
 
159
+ class DocumentManager:
160
+ """Enhanced document management with async support"""
161
 
162
+ def __init__(self, qdrant_client: AsyncQdrantClient, embedding_service: EmbeddingService):
163
+ self.qdrant_client = qdrant_client
164
+ self.embedding_service = embedding_service
165
+ self.collection_name = Config.COLLECTION_NAME
166
+ self.vector_size = 384
167
+ self.executor = ThreadPoolExecutor(max_workers=2)
168
+
169
+ async def _read_pdf(self, file_path: str) -> str:
170
+ """Read text from PDF file asynchronously"""
171
  try:
172
+ loop = asyncio.get_event_loop()
173
+ return await loop.run_in_executor(self.executor, self._sync_read_pdf, file_path)
174
+ except Exception as e:
175
+ print(f"Error reading PDF {file_path}: {e}")
176
+ return ""
177
+
178
+ def _sync_read_pdf(self, file_path: str) -> str:
179
+ """Synchronous PDF reading"""
180
+ try:
181
+ with open(file_path, 'rb') as file:
182
+ pdf_reader = PyPDF2.PdfReader(file)
183
+ text = ""
184
+ for page in pdf_reader.pages:
185
+ text += page.extract_text() + "\n"
186
+ return text
187
+ except Exception as e:
188
+ print(f"Error reading PDF {file_path}: {e}")
189
+ return ""
190
+
191
+ def _chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]:
192
+ """Split text into chunks"""
193
+ if len(text) <= chunk_size:
194
+ return [text]
195
+
196
+ chunks = []
197
+ start = 0
198
+
199
+ while start < len(text):
200
+ end = start + chunk_size
201
 
202
+ if end < len(text):
203
+ sentence_end = text.rfind('.', start, end)
204
+ if sentence_end > start:
205
+ end = sentence_end + 1
206
+ else:
207
+ word_end = text.rfind(' ', start, end)
208
+ if word_end > start:
209
+ end = word_end
210
 
211
+ chunk = text[start:end].strip()
212
+ if chunk:
213
+ chunks.append(chunk)
214
+
215
+ start = end - overlap
216
+
217
+ return chunks
218
+
219
+ async def _ensure_collection_exists(self):
220
+ """Ensure the collection exists, create if it doesn't"""
221
+ try:
222
+ collections = await self.qdrant_client.get_collections()
223
+ collection_names = [c.name for c in collections.collections]
224
+
225
+ if self.collection_name not in collection_names:
226
+ print(f"Creating collection '{self.collection_name}' on-demand...")
227
+ await self.qdrant_client.create_collection(
228
+ collection_name=self.collection_name,
229
+ vectors_config=VectorParams(
230
+ size=self.vector_size,
231
+ distance=Distance.COSINE
232
+ )
233
+ )
234
+ print(f"✓ Collection '{self.collection_name}' created successfully!")
235
+ except Exception as e:
236
+ print(f"Warning: Could not ensure collection exists: {e}")
237
+
238
+ async def add_document(self, file_path: str, metadata: Dict[str, Any] = None) -> str:
239
+ """Add a PDF document to the collection"""
240
+ try:
241
+ await self._ensure_collection_exists()
242
+
243
+ # Read PDF
244
+ text = await self._read_pdf(file_path)
245
+ if not text:
246
+ print(f"Could not extract text from {file_path}")
247
+ return ""
248
+
249
+ # Create chunks
250
+ chunks = self._chunk_text(text)
251
+ if not chunks:
252
+ print(f"No chunks created from {file_path}")
253
+ return ""
254
+
255
+ # Generate document ID
256
+ document_id = str(uuid.uuid4())
257
+
258
+ # Create embeddings for all chunks
259
+ embeddings = await self.embedding_service.batch_embed(chunks)
260
+
261
+ # Create points for each chunk
262
+ points = []
263
+ for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
264
+ payload = {
265
+ "document_id": document_id,
266
+ "file_path": file_path,
267
+ "chunk_index": i,
268
+ "content": chunk, # Use 'content' as the main field
269
+ "chunk_text": chunk, # Keep for compatibility
270
+ "total_chunks": len(chunks),
271
+ "timestamp": datetime.now().isoformat()
272
+ }
273
+
274
+ if metadata:
275
+ payload["metadata"] = metadata
276
+
277
+ point = PointStruct(
278
+ id=str(uuid.uuid4()),
279
+ vector=embedding,
280
+ payload=payload
281
+ )
282
+ points.append(point)
283
+
284
+ # Insert into Qdrant
285
+ await self.qdrant_client.upsert(collection_name=self.collection_name, points=points)
286
+
287
+ print(f"✓ Added document: {file_path}")
288
+ print(f" Document ID: {document_id}")
289
+ print(f" Chunks: {len(chunks)}")
290
+
291
+ return document_id
292
+
293
+ except Exception as e:
294
+ print(f"Error adding document {file_path}: {e}")
295
+ return ""
296
+
297
+ async def search_documents(self, query: str, limit: int = 5, min_score: float = 0.1) -> List[Dict[str, Any]]:
298
+ """Search for relevant document chunks"""
299
+ try:
300
+ await self._ensure_collection_exists()
301
+
302
+ # Generate query embedding
303
+ query_embedding = await self.embedding_service.get_query_embedding(query)
304
 
305
  # Search in Qdrant
306
+ search_results = await self.qdrant_client.search(
307
+ collection_name=self.collection_name,
308
  query_vector=query_embedding,
309
+ limit=limit,
310
+ score_threshold=min_score
311
  )
312
 
313
+ # Format results
314
+ results = []
315
  for result in search_results:
316
+ results.append({
317
+ "score": result.score,
318
+ "text": result.payload.get("content", result.payload.get("chunk_text", "")),
319
+ "file_path": result.payload.get("file_path", ""),
320
+ "document_id": result.payload.get("document_id", ""),
321
+ "chunk_index": result.payload.get("chunk_index", 0)
322
+ })
323
 
324
+ print(f" Found {len(results)} results for query: '{query}'")
325
+ return results
326
 
327
  except Exception as e:
328
+ print(f"Error searching: {e}")
329
+ return []
330
+
331
+ async def list_documents(self) -> List[Dict[str, Any]]:
332
+ """List all documents in the collection"""
333
+ try:
334
+ await self._ensure_collection_exists()
335
+
336
+ # Get all points
337
+ points, _ = await self.qdrant_client.scroll(
338
+ collection_name=self.collection_name,
339
+ limit=10000,
340
+ with_payload=True,
341
+ with_vectors=False
342
+ )
343
+
344
+ # Group by document_id
345
+ documents = {}
346
+ for point in points:
347
+ doc_id = point.payload.get("document_id")
348
+ if doc_id and doc_id not in documents:
349
+ documents[doc_id] = {
350
+ "document_id": doc_id,
351
+ "file_path": point.payload.get("file_path", ""),
352
+ "total_chunks": point.payload.get("total_chunks", 0),
353
+ "timestamp": point.payload.get("timestamp", ""),
354
+ "metadata": point.payload.get("metadata", {})
355
+ }
356
+
357
+ doc_list = list(documents.values())
358
+ print(f"✓ Found {len(doc_list)} documents")
359
+ return doc_list
360
+
361
+ except Exception as e:
362
+ print(f"Error listing documents: {e}")
363
  return []
364
 
365
+ async def delete_document(self, document_id: str) -> bool:
366
+ """Delete a document and all its chunks"""
367
+ try:
368
+ await self._ensure_collection_exists()
369
+
370
+ # Find all points for this document
371
+ points, _ = await self.qdrant_client.scroll(
372
+ collection_name=self.collection_name,
373
+ limit=10000,
374
+ with_payload=True,
375
+ with_vectors=False
376
+ )
377
+
378
+ # Collect point IDs to delete
379
+ points_to_delete = []
380
+ for point in points:
381
+ if point.payload.get("document_id") == document_id:
382
+ points_to_delete.append(point.id)
383
+
384
+ if not points_to_delete:
385
+ print(f"No document found with ID: {document_id}")
386
+ return False
387
+
388
+ # Delete points
389
+ await self.qdrant_client.delete(
390
+ collection_name=self.collection_name,
391
+ points_selector=points_to_delete
392
+ )
393
+
394
+ print(f"✓ Deleted document: {document_id} ({len(points_to_delete)} chunks)")
395
+ return True
396
+
397
+ except Exception as e:
398
+ print(f"Error deleting document: {e}")
399
+ return False
400
+
401
+ class RAGService:
402
+ """Service for retrieval-augmented generation"""
403
+
404
  @staticmethod
405
+ async def retrieve_relevant_chunks(query: str, top_k: int = Config.TOP_K) -> List[Dict[str, Any]]:
406
+ """Retrieve relevant document chunks using the document manager"""
407
  try:
408
+ if app_state.document_manager is None:
409
+ print("Error: Document manager is not initialized")
410
+ return []
411
+
412
+ # Use the document manager's search functionality
413
+ results = await app_state.document_manager.search_documents(
414
+ query=query,
415
+ limit=top_k,
416
+ min_score=Config.SIMILARITY_THRESHOLD
417
+ )
418
+
419
+ return results
420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  except Exception as e:
422
+ print(f"Error retrieving chunks: {e}")
423
+ return []
424
 
425
  @staticmethod
426
+ def build_context_prompt(query: str, results: List[Dict[str, Any]]) -> str:
427
  """Build a context-aware prompt with retrieved chunks"""
428
+ if not results:
429
  return query
430
 
431
+ # Build context parts like in your example
432
+ context_parts = []
433
+ for result in results:
434
+ context_parts.append(f"Source: {result['file_path']}\n{result['text']}")
435
+
436
+ combined_context = "\n\n---\n\n".join(context_parts)
437
 
438
+ prompt = f"""Based on the following context, answer the user's question:
439
 
440
+ Context:
441
+ {combined_context}
442
 
443
+ Question: {query}
444
 
445
+ Please provide a comprehensive answer based on the context provided."""
446
 
447
  return prompt
448
 
449
+ @asynccontextmanager
450
+ async def lifespan(app: FastAPI):
451
+ # Startup
452
+ if not Config.GROQ_API_KEY:
453
+ raise ValueError("GROQ_API_KEY environment variable is required")
454
+
455
+ print("Initializing services...")
456
+
457
+ # Initialize OpenAI client with Groq endpoint
458
+ try:
459
+ print(f"Configuring OpenAI client with:")
460
+ print(f" Base URL: {Config.GROQ_BASE_URL}")
461
+ print(f" API Key: {'*' * 10}...{Config.GROQ_API_KEY[-4:] if Config.GROQ_API_KEY else 'None'}")
462
+
463
+ app_state.openai_client = AsyncOpenAI(
464
+ api_key=Config.GROQ_API_KEY,
465
+ base_url=Config.GROQ_BASE_URL,
466
+ timeout=60.0
467
+ )
468
+ print("✓ OpenAI client initialized with Groq endpoint")
469
+ except Exception as e:
470
+ print(f"✗ Error initializing OpenAI client: {e}")
471
+ raise e
472
+
473
+ # Initialize Qdrant client
474
+ try:
475
+ app_state.qdrant_client = AsyncQdrantClient(
476
+ url=Config.QDRANT_URL,
477
+ api_key=Config.QDRANT_API_KEY
478
+ )
479
+ print("✓ Qdrant client initialized")
480
+ except Exception as e:
481
+ print(f"✗ Error initializing Qdrant client: {e}")
482
+ raise e
483
+
484
+ # Initialize embedding service
485
+ try:
486
+ print("Loading embedding model...")
487
+ app_state.embedding_service = EmbeddingService()
488
+ print(f"✓ Embedding model loaded: {Config.EMBEDDING_MODEL}")
489
+ print(f"✓ Model device: {Config.DEVICE}")
490
+ print(f"✓ Vector dimension: {app_state.embedding_service.dimension}")
491
+ except Exception as e:
492
+ print(f"✗ Error initializing embedding service: {e}")
493
+ raise e
494
+
495
+ # Initialize document manager
496
+ try:
497
+ app_state.document_manager = DocumentManager(
498
+ qdrant_client=app_state.qdrant_client,
499
+ embedding_service=app_state.embedding_service
500
+ )
501
+ print("✓ Document manager initialized")
502
+ except Exception as e:
503
+ print(f"✗ Error initializing document manager: {e}")
504
+ raise e
505
+
506
+ print("🚀 All services initialized successfully!")
507
+
508
+ yield
509
+
510
+ # Shutdown
511
+ print("Shutting down services...")
512
+ if app_state.qdrant_client:
513
+ await app_state.qdrant_client.close()
514
+ print("✓ Qdrant client closed")
515
+ if app_state.openai_client:
516
+ await app_state.openai_client.close()
517
+ print("✓ OpenAI client closed")
518
+ if app_state.embedding_service and hasattr(app_state.embedding_service, 'executor'):
519
+ app_state.embedding_service.executor.shutdown(wait=True)
520
+ print("✓ Embedding service executor shutdown")
521
+ if app_state.document_manager and hasattr(app_state.document_manager, 'executor'):
522
+ app_state.document_manager.executor.shutdown(wait=True)
523
+ print("✓ Document manager executor shutdown")
524
+ print("✓ Shutdown complete")
525
+
526
+ # Initialize FastAPI app
527
+ app = FastAPI(
528
+ title="Enhanced RAG API with Document Management",
529
+ description="OpenAI-compatible API for RAG with document management using Groq and Qdrant",
530
+ version="1.0.0",
531
+ lifespan=lifespan
532
+ )
533
+
534
  @app.get("/")
535
  async def root():
536
+ return {"message": "Enhanced RAG API with Document Management", "status": "running"}
537
 
538
  @app.get("/health")
539
  async def health_check():
 
562
  openai_health = {"status": "not_initialized", "error": "OpenAI client is None"}
563
  else:
564
  try:
 
565
  test_response = await app_state.openai_client.chat.completions.create(
566
  model="mixtral-8x7b-32768",
567
  messages=[{"role": "user", "content": "test"}],
 
576
  "openai_client": openai_health,
577
  "qdrant": qdrant_status,
578
  "embedding_service": embedding_health,
579
+ "document_manager": "initialized" if app_state.document_manager else "not_initialized",
580
  "collection": Config.COLLECTION_NAME,
581
  "embedding_model": Config.EMBEDDING_MODEL,
582
  "groq_endpoint": Config.GROQ_BASE_URL
 
584
 
585
  @app.post("/v1/chat/completions")
586
  async def chat_completions(request: ChatCompletionRequest):
587
+ """OpenAI-compatible chat completions endpoint with enhanced RAG"""
588
 
589
  if not app_state.openai_client:
590
  raise HTTPException(status_code=500, detail="OpenAI client not initialized")
 
598
  last_user_message = user_messages[-1].content
599
  print(f"Processing query: {last_user_message[:100]}...")
600
 
601
+ # Retrieve relevant chunks using enhanced search
602
  try:
603
+ relevant_results = await RAGService.retrieve_relevant_chunks(last_user_message)
604
+ print(f"Retrieved {len(relevant_results)} chunks")
605
  except Exception as e:
606
  print(f"Error in retrieval: {e}")
607
+ relevant_results = []
608
 
609
  # Build context-aware prompt
610
+ if relevant_results:
611
+ context_prompt = RAGService.build_context_prompt(last_user_message, relevant_results)
612
  enhanced_messages = request.messages[:-1] + [Message(role="user", content=context_prompt)]
613
  print("Using context-enhanced prompt")
614
  else:
 
631
  raise
632
  except Exception as e:
633
  print(f"Unexpected error in chat_completions: {e}")
 
634
  import traceback
635
  traceback.print_exc()
636
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 
638
  async def create_chat_completion(messages: List[Dict], request: ChatCompletionRequest) -> ChatCompletionResponse:
639
  """Create a non-streaming chat completion"""
640
  try:
 
 
 
 
641
  response = await app_state.openai_client.chat.completions.create(
642
  model=request.model,
643
  messages=messages,
 
647
  stream=False
648
  )
649
 
 
 
 
 
 
 
650
  result = ChatCompletionResponse(
651
  id=response.id,
652
  created=response.created,
 
666
  } if response.usage else None
667
  )
668
 
 
669
  return result
670
 
671
  except Exception as e:
672
  print(f"Error in create_chat_completion: {e}")
 
 
 
673
  raise HTTPException(status_code=500, detail=f"Error calling OpenAI API: {str(e)}")
674
 
675
  async def stream_chat_completion(messages: List[Dict], request: ChatCompletionRequest) -> AsyncGenerator[str, None]:
 
704
 
705
  yield f"data: {chunk_response.model_dump_json()}\n\n"
706
 
 
707
  yield "data: [DONE]\n\n"
708
 
709
  except Exception as e:
 
716
  }
717
  yield f"data: {json.dumps(error_chunk)}\n\n"
718
 
719
+ # Document management endpoints
720
+ @app.post("/v1/documents/upload")
721
+ async def upload_document(file: UploadFile = File(...), metadata: str = None):
722
+ """Upload a PDF document"""
723
  try:
724
+ if not app_state.document_manager:
725
+ raise HTTPException(status_code=500, detail="Document manager not initialized")
 
726
 
727
+ # Validate file type
728
+ if not file.filename.lower().endswith('.pdf'):
729
+ raise HTTPException(status_code=400, detail="Only PDF files are supported")
730
 
731
+ # Parse metadata if provided
732
+ parsed_metadata = {}
733
+ if metadata:
734
+ try:
735
+ parsed_metadata = json.loads(metadata)
736
+ except json.JSONDecodeError:
737
+ raise HTTPException(status_code=400, detail="Invalid metadata JSON")
738
 
739
+ # Save uploaded file temporarily
740
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
741
+ shutil.copyfileobj(file.file, tmp_file)
742
+ tmp_path = tmp_file.name
743
+
744
+ try:
745
+ # Add document to the collection
746
+ document_id = await app_state.document_manager.add_document(
747
+ file_path=tmp_path,
748
+ metadata={
749
+ **parsed_metadata,
750
+ "original_filename": file.filename,
751
+ "upload_timestamp": datetime.now().isoformat()
752
+ }
753
+ )
754
+
755
+ if not document_id:
756
+ raise HTTPException(status_code=500, detail="Failed to add document")
757
+
758
+ return {
759
+ "message": "Document uploaded successfully",
760
+ "document_id": document_id,
761
+ "filename": file.filename
762
  }
763
+
764
+ finally:
765
+ # Clean up temporary file
766
+ os.unlink(tmp_path)
767
+
768
+ except HTTPException:
769
+ raise
770
+ except Exception as e:
771
+ print(f"Error uploading document: {e}")
772
+ raise HTTPException(status_code=500, detail=f"Error uploading document: {str(e)}")
773
+
774
+ @app.post("/v1/documents/search")
775
+ async def search_documents(request: DocumentSearchRequest):
776
+ """Search for documents"""
777
+ try:
778
+ if not app_state.document_manager:
779
+ raise HTTPException(status_code=500, detail="Document manager not initialized")
780
 
781
+ results = await app_state.document_manager.search_documents(
782
+ query=request.query,
783
+ limit=request.limit,
784
+ min_score=request.min_score
785
  )
786
 
787
+ return {
788
+ "query": request.query,
789
+ "results": results,
790
+ "count": len(results)
791
+ }
792
 
793
  except Exception as e:
794
+ print(f"Error searching documents: {e}")
795
+ raise HTTPException(status_code=500, detail=f"Error searching documents: {str(e)}")
796
 
797
+ @app.get("/v1/documents/list")
798
+ async def list_documents():
799
+ """List all documents"""
800
  try:
801
+ if not app_state.document_manager:
802
+ raise HTTPException(status_code=500, detail="Document manager not initialized")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
803
 
804
+ documents = await app_state.document_manager.list_documents()
 
 
 
 
805
 
806
  return {
807
+ "documents": documents,
808
+ "count": len(documents)
809
  }
810
 
811
  except Exception as e:
812
+ print(f"Error listing documents: {e}")
813
+ raise HTTPException(status_code=500, detail=f"Error listing documents: {str(e)}")
814
 
815
+ @app.delete("/v1/documents/{document_id}")
816
+ async def delete_document(document_id: str):
817
+ """Delete a document"""
818
  try:
819
+ if not app_state.document_manager:
820
+ raise HTTPException(status_code=500, detail="Document manager not initialized")
 
821
 
822
+ success = await app_state.document_manager.delete_document(document_id)
823
 
824
+ if not success:
825
+ raise HTTPException(status_code=404, detail="Document not found")
826
+
827
+ return {"message": "Document deleted successfully", "document_id": document_id}
828
+
829
+ except HTTPException:
830
+ raise
831
+ except Exception as e:
832
+ print(f"Error deleting document: {e}")
833
+ raise HTTPException(status_code=500, detail=f"Error deleting document: {str(e)}")
834
+
835
+ # Legacy compatibility endpoints
836
+ @app.post("/v1/embeddings/add")
837
+ async def add_document_legacy(content: str, metadata: Optional[Dict] = None):
838
+ """Legacy endpoint for adding documents (text content)"""
839
+ try:
840
+ if not app_state.embedding_service or not app_state.qdrant_client:
841
+ raise HTTPException(status_code=500, detail="Services not initialized")
842
 
843
+ await app_state.document_manager._ensure_collection_exists()
844
+
845
+ embedding = await app_state.embedding_service.get_document_embedding(content)
846
+
847
+ point = PointStruct(
848
+ id=str(uuid.uuid4()),
849
+ vector=embedding,
850
+ payload={
851
+ "content": content,
852
+ "metadata": metadata or {},
853
+ "timestamp": datetime.now().isoformat()
854
+ }
855
+ )
856
+
857
+ await app_state.qdrant_client.upsert(
858
  collection_name=Config.COLLECTION_NAME,
859
+ points=[point]
 
 
 
860
  )
861
 
862
+ return {"message": "Document added successfully", "id": point.id}
 
 
 
 
 
863
 
864
  except Exception as e:
865
+ raise HTTPException(status_code=500, detail=f"Error adding document: {str(e)}")
866
 
867
  @app.get("/v1/collections/info")
868
  async def get_collection_info():
 
871
  if app_state.qdrant_client is None:
872
  raise HTTPException(status_code=500, detail="Qdrant client is not initialized")
873
 
874
+ await app_state.document_manager._ensure_collection_exists()
 
875
 
876
  collection_info = await app_state.qdrant_client.get_collection(Config.COLLECTION_NAME)
877
  return {