SUBHRAJIT MOHANTY commited on
Commit
247920d
·
1 Parent(s): d1f7294

initial commit

Browse files
Files changed (3) hide show
  1. Dockerfile +0 -0
  2. app.py +511 -0
  3. requirements.txt +15 -0
Dockerfile ADDED
File without changes
app.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request
2
+ from fastapi.responses import StreamingResponse
3
+ from pydantic import BaseModel, Field
4
+ from typing import List, Optional, Dict, Any, AsyncGenerator
5
+ import asyncio
6
+ import json
7
+ import uuid
8
+ from datetime import datetime
9
+ import os
10
+ from contextlib import asynccontextmanager
11
+
12
+ # Third-party imports
13
+ from groq import AsyncGroq
14
+ from qdrant_client import AsyncQdrantClient
15
+ from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
16
+ from sentence_transformers import SentenceTransformer
17
+ import torch
18
+ import asyncio
19
+ from concurrent.futures import ThreadPoolExecutor
20
+
21
+ # Models for OpenAI-compatible API
22
+ class Message(BaseModel):
23
+ role: str = Field(..., description="The role of the message author")
24
+ content: str = Field(..., description="The content of the message")
25
+
26
+ class ChatCompletionRequest(BaseModel):
27
+ model: str = Field(default="mixtral-8x7b-32768", description="Model to use")
28
+ messages: List[Message] = Field(..., description="List of messages")
29
+ max_tokens: Optional[int] = Field(default=1024, description="Maximum tokens to generate")
30
+ temperature: Optional[float] = Field(default=0.7, description="Temperature for sampling")
31
+ stream: Optional[bool] = Field(default=False, description="Whether to stream responses")
32
+ top_p: Optional[float] = Field(default=1.0, description="Top-p sampling parameter")
33
+
34
+ class ChatCompletionResponse(BaseModel):
35
+ id: str
36
+ object: str = "chat.completion"
37
+ created: int
38
+ model: str
39
+ choices: List[Dict[str, Any]]
40
+ usage: Optional[Dict[str, int]] = None
41
+
42
+ class ChatCompletionChunk(BaseModel):
43
+ id: str
44
+ object: str = "chat.completion.chunk"
45
+ created: int
46
+ model: str
47
+ choices: List[Dict[str, Any]]
48
+
49
+ # Configuration
50
+ class Config:
51
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
52
+ QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
53
+ QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
54
+ COLLECTION_NAME = os.getenv("COLLECTION_NAME", "documents")
55
+ EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
56
+ TOP_K = int(os.getenv("TOP_K", "5"))
57
+ SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.7"))
58
+ DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
59
+
60
+ # Global clients
61
+ groq_client = None
62
+ qdrant_client = None
63
+ embedding_service = None
64
+
65
+ @asynccontextmanager
66
+ async def lifespan(app: FastAPI):
67
+ # Startup
68
+ global groq_client, qdrant_client, embedding_service
69
+
70
+ if not Config.GROQ_API_KEY:
71
+ raise ValueError("GROQ_API_KEY environment variable is required")
72
+
73
+ groq_client = AsyncGroq(api_key=Config.GROQ_API_KEY)
74
+ qdrant_client = AsyncQdrantClient(
75
+ url=Config.QDRANT_URL,
76
+ api_key=Config.QDRANT_API_KEY
77
+ )
78
+
79
+ # Initialize embedding service
80
+ embedding_service = None
81
+
82
+ # Verify connections
83
+ try:
84
+ collections = await qdrant_client.get_collections()
85
+ print(f"Connected to Qdrant. Available collections: {[c.name for c in collections.collections]}")
86
+ except Exception as e:
87
+ print(f"Warning: Could not connect to Qdrant: {e}")
88
+
89
+ # Check embedding model
90
+ try:
91
+ print(f"Embedding model loaded: {Config.EMBEDDING_MODEL}")
92
+ print(f"Model device: {Config.DEVICE}")
93
+ print(f"Vector dimension: {embedding_service.dimension}")
94
+ except Exception as e:
95
+ print(f"Warning: Could not load embedding model: {e}")
96
+
97
+ yield
98
+
99
+ # Shutdown
100
+ if qdrant_client:
101
+ await qdrant_client.close()
102
+
103
+ # Initialize FastAPI app
104
+ app = FastAPI(
105
+ title="RAG API with Groq and Qdrant",
106
+ description="OpenAI-compatible API for RAG using Groq LLM and Qdrant vector database",
107
+ version="1.0.0",
108
+ lifespan=lifespan
109
+ )
110
+
111
+ class EmbeddingService:
112
+ """Service for generating embeddings using sentence-transformers"""
113
+
114
+ def __init__(self):
115
+ self.model_name = Config.EMBEDDING_MODEL
116
+ self.device = Config.DEVICE
117
+ self.dimension = 384 # all-MiniLM-L6-v2 dimension
118
+ self.executor = ThreadPoolExecutor(max_workers=4)
119
+
120
+ # Load the model
121
+ print(f"Loading embedding model: {self.model_name}")
122
+ self.model = SentenceTransformer(self.model_name, device=self.device)
123
+ print(f"Model loaded successfully on device: {self.device}")
124
+
125
+ async def get_embedding(self, text: str) -> List[float]:
126
+ """Generate embedding for given text"""
127
+ try:
128
+ # Run the synchronous model.encode in a thread pool
129
+ loop = asyncio.get_event_loop()
130
+ embedding = await loop.run_in_executor(
131
+ self.executor,
132
+ self._encode_text,
133
+ text
134
+ )
135
+ return embedding.tolist()
136
+ except Exception as e:
137
+ print(f"Error generating embedding: {e}")
138
+ return [0.1] * self.dimension
139
+
140
+ def _encode_text(self, text: str):
141
+ """Synchronous text encoding - runs in thread pool"""
142
+ return self.model.encode([text])[0]
143
+
144
+ async def get_document_embedding(self, text: str) -> List[float]:
145
+ """Generate embedding for document text"""
146
+ return await self.get_embedding(text)
147
+
148
+ async def get_query_embedding(self, text: str) -> List[float]:
149
+ """Generate embedding for query text"""
150
+ return await self.get_embedding(text)
151
+
152
+ async def batch_embed(self, texts: List[str]) -> List[List[float]]:
153
+ """Generate embeddings for multiple texts efficiently"""
154
+ try:
155
+ loop = asyncio.get_event_loop()
156
+ embeddings = await loop.run_in_executor(
157
+ self.executor,
158
+ self._batch_encode_texts,
159
+ texts
160
+ )
161
+ return embeddings.tolist()
162
+ except Exception as e:
163
+ print(f"Error in batch embedding: {e}")
164
+ return [[0.1] * self.dimension for _ in texts]
165
+
166
+ def _batch_encode_texts(self, texts: List[str]):
167
+ """Synchronous batch encoding - runs in thread pool"""
168
+ return self.model.encode(texts)
169
+
170
+ def health_check(self) -> dict:
171
+ """Check embedding service health"""
172
+ try:
173
+ # Test encoding
174
+ test_embedding = self.model.encode(["test"])
175
+ return {
176
+ "status": "healthy",
177
+ "model": self.model_name,
178
+ "device": self.device,
179
+ "dimension": self.dimension,
180
+ "test_embedding_shape": test_embedding.shape
181
+ }
182
+ except Exception as e:
183
+ return {
184
+ "status": "unhealthy",
185
+ "model": self.model_name,
186
+ "error": str(e)
187
+ }
188
+
189
+ embedding_service = EmbeddingService()
190
+
191
+ class RAGService:
192
+ """Service for retrieval-augmented generation"""
193
+
194
+ @staticmethod
195
+ async def retrieve_relevant_chunks(query: str, top_k: int = Config.TOP_K) -> List[str]:
196
+ """Retrieve relevant document chunks from Qdrant"""
197
+ try:
198
+ # Get query embedding - all-MiniLM works well without special prefixes
199
+ query_embedding = await embedding_service.get_query_embedding(query)
200
+
201
+ # Search in Qdrant
202
+ search_results = await qdrant_client.search(
203
+ collection_name=Config.COLLECTION_NAME,
204
+ query_vector=query_embedding,
205
+ limit=top_k,
206
+ score_threshold=Config.SIMILARITY_THRESHOLD
207
+ )
208
+
209
+ # Extract content from results
210
+ chunks = []
211
+ for result in search_results:
212
+ if hasattr(result, 'payload') and 'content' in result.payload:
213
+ chunks.append(result.payload['content'])
214
+ elif hasattr(result, 'payload') and 'text' in result.payload:
215
+ chunks.append(result.payload['text'])
216
+
217
+ print(f"Retrieved {len(chunks)} relevant chunks for query")
218
+ return chunks
219
+
220
+ except Exception as e:
221
+ print(f"Error retrieving chunks: {e}")
222
+ return []
223
+
224
+ @staticmethod
225
+ def build_context_prompt(query: str, chunks: List[str]) -> str:
226
+ """Build a context-aware prompt with retrieved chunks"""
227
+ if not chunks:
228
+ return query
229
+
230
+ context = "\n\n".join([f"Document {i+1}: {chunk}" for i, chunk in enumerate(chunks)])
231
+
232
+ prompt = f"""Based on the following documents, please answer the user's question. If the information is not available in the documents, please say so.
233
+
234
+ Context Documents:
235
+ {context}
236
+
237
+ User Question: {query}
238
+
239
+ Please provide a helpful and accurate response based on the context provided."""
240
+
241
+ return prompt
242
+
243
+ @app.get("/")
244
+ async def root():
245
+ return {"message": "RAG API with Groq and Qdrant", "status": "running"}
246
+
247
+ @app.get("/health")
248
+ async def health_check():
249
+ """Health check endpoint"""
250
+ try:
251
+ # Test Qdrant connection
252
+ collections = await qdrant_client.get_collections()
253
+ qdrant_status = "connected"
254
+ except Exception as e:
255
+ qdrant_status = f"error: {str(e)}"
256
+
257
+ # Test embedding service
258
+ embedding_health = embedding_service.health_check()
259
+
260
+ return {
261
+ "status": "healthy",
262
+ "groq": "connected" if groq_client else "not configured",
263
+ "qdrant": qdrant_status,
264
+ "embedding_service": embedding_health,
265
+ "collection": Config.COLLECTION_NAME,
266
+ "embedding_model": Config.EMBEDDING_MODEL
267
+ }
268
+
269
+ @app.post("/v1/chat/completions")
270
+ async def chat_completions(request: ChatCompletionRequest):
271
+ """OpenAI-compatible chat completions endpoint with RAG"""
272
+
273
+ if not groq_client:
274
+ raise HTTPException(status_code=500, detail="Groq client not initialized")
275
+
276
+ try:
277
+ # Get the last user message for retrieval
278
+ user_messages = [msg for msg in request.messages if msg.role == "user"]
279
+ if not user_messages:
280
+ raise HTTPException(status_code=400, detail="No user message found")
281
+
282
+ last_user_message = user_messages[-1].content
283
+
284
+ # Retrieve relevant chunks
285
+ relevant_chunks = await RAGService.retrieve_relevant_chunks(last_user_message)
286
+
287
+ # Build context-aware prompt
288
+ if relevant_chunks:
289
+ context_prompt = RAGService.build_context_prompt(last_user_message, relevant_chunks)
290
+
291
+ # Replace the last user message with context-enhanced version
292
+ enhanced_messages = request.messages[:-1] + [Message(role="user", content=context_prompt)]
293
+ else:
294
+ enhanced_messages = request.messages
295
+
296
+ # Convert to Groq format
297
+ groq_messages = [{"role": msg.role, "content": msg.content} for msg in enhanced_messages]
298
+
299
+ if request.stream:
300
+ return StreamingResponse(
301
+ stream_chat_completion(groq_messages, request),
302
+ media_type="text/plain"
303
+ )
304
+ else:
305
+ return await create_chat_completion(groq_messages, request)
306
+
307
+ except Exception as e:
308
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
309
+
310
+ async def create_chat_completion(messages: List[Dict], request: ChatCompletionRequest) -> ChatCompletionResponse:
311
+ """Create a non-streaming chat completion"""
312
+ try:
313
+ response = await groq_client.chat.completions.create(
314
+ model=request.model,
315
+ messages=messages,
316
+ max_tokens=request.max_tokens,
317
+ temperature=request.temperature,
318
+ top_p=request.top_p,
319
+ stream=False
320
+ )
321
+
322
+ # Convert Groq response to OpenAI format
323
+ return ChatCompletionResponse(
324
+ id=f"chatcmpl-{uuid.uuid4().hex}",
325
+ created=int(datetime.now().timestamp()),
326
+ model=request.model,
327
+ choices=[{
328
+ "index": 0,
329
+ "message": {
330
+ "role": "assistant",
331
+ "content": response.choices[0].message.content
332
+ },
333
+ "finish_reason": response.choices[0].finish_reason
334
+ }],
335
+ usage={
336
+ "prompt_tokens": response.usage.prompt_tokens,
337
+ "completion_tokens": response.usage.completion_tokens,
338
+ "total_tokens": response.usage.total_tokens
339
+ }
340
+ )
341
+
342
+ except Exception as e:
343
+ raise HTTPException(status_code=500, detail=f"Error calling Groq API: {str(e)}")
344
+
345
+ async def stream_chat_completion(messages: List[Dict], request: ChatCompletionRequest) -> AsyncGenerator[str, None]:
346
+ """Stream chat completion responses"""
347
+ try:
348
+ completion_id = f"chatcmpl-{uuid.uuid4().hex}"
349
+ created = int(datetime.now().timestamp())
350
+
351
+ stream = await groq_client.chat.completions.create(
352
+ model=request.model,
353
+ messages=messages,
354
+ max_tokens=request.max_tokens,
355
+ temperature=request.temperature,
356
+ top_p=request.top_p,
357
+ stream=True
358
+ )
359
+
360
+ async for chunk in stream:
361
+ if chunk.choices and chunk.choices[0].delta:
362
+ delta = chunk.choices[0].delta
363
+
364
+ chunk_response = ChatCompletionChunk(
365
+ id=completion_id,
366
+ created=created,
367
+ model=request.model,
368
+ choices=[{
369
+ "index": 0,
370
+ "delta": {
371
+ "role": delta.role if hasattr(delta, 'role') and delta.role else None,
372
+ "content": delta.content if hasattr(delta, 'content') else None
373
+ },
374
+ "finish_reason": chunk.choices[0].finish_reason
375
+ }]
376
+ )
377
+
378
+ yield f"data: {chunk_response.model_dump_json()}\n\n"
379
+
380
+ # Send final chunk
381
+ final_chunk = ChatCompletionChunk(
382
+ id=completion_id,
383
+ created=created,
384
+ model=request.model,
385
+ choices=[{
386
+ "index": 0,
387
+ "delta": {},
388
+ "finish_reason": "stop"
389
+ }]
390
+ )
391
+
392
+ yield f"data: {final_chunk.model_dump_json()}\n\n"
393
+ yield "data: [DONE]\n\n"
394
+
395
+ except Exception as e:
396
+ error_chunk = {
397
+ "error": {
398
+ "message": str(e),
399
+ "type": "internal_error"
400
+ }
401
+ }
402
+ yield f"data: {json.dumps(error_chunk)}\n\n"
403
+
404
+ # Additional endpoints for managing the vector database
405
+ @app.post("/v1/embeddings/add")
406
+ async def add_document(content: str, metadata: Optional[Dict] = None):
407
+ """Add a document to the vector database"""
408
+ try:
409
+ # Generate embedding for document
410
+ embedding = await embedding_service.get_document_embedding(content)
411
+
412
+ # Create point
413
+ point = PointStruct(
414
+ id=str(uuid.uuid4()),
415
+ vector=embedding,
416
+ payload={
417
+ "content": content,
418
+ "metadata": metadata or {},
419
+ "timestamp": datetime.now().isoformat()
420
+ }
421
+ )
422
+
423
+ # Insert into Qdrant
424
+ await qdrant_client.upsert(
425
+ collection_name=Config.COLLECTION_NAME,
426
+ points=[point]
427
+ )
428
+
429
+ return {"message": "Document added successfully", "id": point.id}
430
+
431
+ except Exception as e:
432
+ raise HTTPException(status_code=500, detail=f"Error adding document: {str(e)}")
433
+
434
+ @app.post("/v1/embeddings/batch_add")
435
+ async def batch_add_documents(documents: List[Dict[str, Any]]):
436
+ """Add multiple documents to the vector database"""
437
+ try:
438
+ # Extract texts and metadata
439
+ texts = [doc.get("content", "") for doc in documents]
440
+ metadatas = [doc.get("metadata", {}) for doc in documents]
441
+
442
+ # Generate embeddings for all documents
443
+ embeddings = await embedding_service.batch_embed(texts)
444
+
445
+ # Create points
446
+ points = []
447
+ for i, (text, embedding, metadata) in enumerate(zip(texts, embeddings, metadatas)):
448
+ point = PointStruct(
449
+ id=str(uuid.uuid4()),
450
+ vector=embedding,
451
+ payload={
452
+ "content": text,
453
+ "metadata": metadata,
454
+ "timestamp": datetime.now().isoformat()
455
+ }
456
+ )
457
+ points.append(point)
458
+
459
+ # Insert all points into Qdrant
460
+ await qdrant_client.upsert(
461
+ collection_name=Config.COLLECTION_NAME,
462
+ points=points
463
+ )
464
+
465
+ return {
466
+ "message": f"Successfully added {len(points)} documents",
467
+ "ids": [point.id for point in points]
468
+ }
469
+
470
+ except Exception as e:
471
+ raise HTTPException(status_code=500, detail=f"Error adding documents: {str(e)}")
472
+
473
+ @app.post("/v1/embeddings/create_collection")
474
+ async def create_collection():
475
+ """Create a new collection in Qdrant with the correct vector size"""
476
+ try:
477
+ from qdrant_client.models import VectorParams, Distance
478
+
479
+ await qdrant_client.create_collection(
480
+ collection_name=Config.COLLECTION_NAME,
481
+ vectors_config=VectorParams(
482
+ size=embedding_service.dimension, # 384 for all-MiniLM-L6-v2
483
+ distance=Distance.COSINE
484
+ )
485
+ )
486
+
487
+ return {
488
+ "message": f"Collection '{Config.COLLECTION_NAME}' created successfully",
489
+ "vector_size": embedding_service.dimension,
490
+ "distance": "cosine"
491
+ }
492
+
493
+ except Exception as e:
494
+ raise HTTPException(status_code=500, detail=f"Error creating collection: {str(e)}")
495
+
496
+ @app.get("/v1/collections/info")
497
+ async def get_collection_info():
498
+ """Get information about the collection"""
499
+ try:
500
+ collection_info = await qdrant_client.get_collection(Config.COLLECTION_NAME)
501
+ return {
502
+ "name": Config.COLLECTION_NAME,
503
+ "vectors_count": collection_info.vectors_count,
504
+ "status": collection_info.status
505
+ }
506
+ except Exception as e:
507
+ raise HTTPException(status_code=500, detail=f"Error getting collection info: {str(e)}")
508
+
509
+ if __name__ == "__main__":
510
+ import uvicorn
511
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ groq==0.4.1
4
+ qdrant-client==1.7.0
5
+ sentence-transformers==2.2.2
6
+ torch==2.1.1
7
+ pydantic==2.5.0
8
+ httpx==0.25.2
9
+ numpy==1.24.3
10
+ transformers==4.36.0
11
+ tokenizers==0.15.0
12
+ huggingface-hub==0.19.4
13
+ scipy==1.11.4
14
+ scikit-learn==1.3.2
15
+ python-multipart==0.0.6