SUBHRAJIT MOHANTY commited on
Commit
bcc14d5
·
1 Parent(s): 347dbd1

Openai sdk is replaced with Groq

Browse files
Files changed (2) hide show
  1. app.py +100 -58
  2. requirements.txt +1 -1
app.py CHANGED
@@ -10,7 +10,7 @@ import os
10
  from contextlib import asynccontextmanager
11
 
12
  # Third-party imports
13
- from groq import AsyncGroq
14
  from qdrant_client import AsyncQdrantClient
15
  from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
16
  from sentence_transformers import SentenceTransformer
@@ -49,6 +49,7 @@ class ChatCompletionChunk(BaseModel):
49
  # Configuration
50
  class Config:
51
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
 
52
  QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
53
  QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
54
  COLLECTION_NAME = os.getenv("COLLECTION_NAME", "documents")
@@ -60,7 +61,7 @@ class Config:
60
  class ApplicationState:
61
  """Application state container"""
62
  def __init__(self):
63
- self.groq_client = None
64
  self.qdrant_client = None
65
  self.embedding_service = None
66
 
@@ -205,6 +206,9 @@ class RAGService:
205
  print("Error: Embedding service is not initialized")
206
  return []
207
 
 
 
 
208
  # Get query embedding - all-MiniLM works well without special prefixes
209
  query_embedding = await app_state.embedding_service.get_query_embedding(query)
210
 
@@ -231,6 +235,31 @@ class RAGService:
231
  print(f"Error retrieving chunks: {e}")
232
  return []
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  @staticmethod
235
  def build_context_prompt(query: str, chunks: List[str]) -> str:
236
  """Build a context-aware prompt with retrieved chunks"""
@@ -278,19 +307,20 @@ async def health_check():
278
 
279
  return {
280
  "status": "healthy" if app_state.embedding_service is not None else "unhealthy",
281
- "groq": "connected" if app_state.groq_client else "not configured",
282
  "qdrant": qdrant_status,
283
  "embedding_service": embedding_health,
284
  "collection": Config.COLLECTION_NAME,
285
- "embedding_model": Config.EMBEDDING_MODEL
 
286
  }
287
 
288
  @app.post("/v1/chat/completions")
289
  async def chat_completions(request: ChatCompletionRequest):
290
  """OpenAI-compatible chat completions endpoint with RAG"""
291
 
292
- if not app_state.groq_client:
293
- raise HTTPException(status_code=500, detail="Groq client not initialized")
294
 
295
  try:
296
  # Get the last user message for retrieval
@@ -312,16 +342,16 @@ async def chat_completions(request: ChatCompletionRequest):
312
  else:
313
  enhanced_messages = request.messages
314
 
315
- # Convert to Groq format
316
- groq_messages = [{"role": msg.role, "content": msg.content} for msg in enhanced_messages]
317
 
318
  if request.stream:
319
  return StreamingResponse(
320
- stream_chat_completion(groq_messages, request),
321
  media_type="text/event-stream"
322
  )
323
  else:
324
- return await create_chat_completion(groq_messages, request)
325
 
326
  except Exception as e:
327
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@@ -329,7 +359,7 @@ async def chat_completions(request: ChatCompletionRequest):
329
  async def create_chat_completion(messages: List[Dict], request: ChatCompletionRequest) -> ChatCompletionResponse:
330
  """Create a non-streaming chat completion"""
331
  try:
332
- response = await app_state.groq_client.chat.completions.create(
333
  model=request.model,
334
  messages=messages,
335
  max_tokens=request.max_tokens,
@@ -338,36 +368,33 @@ async def create_chat_completion(messages: List[Dict], request: ChatCompletionRe
338
  stream=False
339
  )
340
 
341
- # Convert Groq response to OpenAI format
342
  return ChatCompletionResponse(
343
- id=f"chatcmpl-{uuid.uuid4().hex}",
344
- created=int(datetime.now().timestamp()),
345
- model=request.model,
346
  choices=[{
347
- "index": 0,
348
  "message": {
349
- "role": "assistant",
350
- "content": response.choices[0].message.content
351
  },
352
- "finish_reason": response.choices[0].finish_reason
353
- }],
354
  usage={
355
  "prompt_tokens": response.usage.prompt_tokens,
356
  "completion_tokens": response.usage.completion_tokens,
357
  "total_tokens": response.usage.total_tokens
358
- }
359
  )
360
 
361
  except Exception as e:
362
- raise HTTPException(status_code=500, detail=f"Error calling Groq API: {str(e)}")
363
 
364
  async def stream_chat_completion(messages: List[Dict], request: ChatCompletionRequest) -> AsyncGenerator[str, None]:
365
  """Stream chat completion responses"""
366
  try:
367
- completion_id = f"chatcmpl-{uuid.uuid4().hex}"
368
- created = int(datetime.now().timestamp())
369
-
370
- stream = await app_state.groq_client.chat.completions.create(
371
  model=request.model,
372
  messages=messages,
373
  max_tokens=request.max_tokens,
@@ -377,38 +404,26 @@ async def stream_chat_completion(messages: List[Dict], request: ChatCompletionRe
377
  )
378
 
379
  async for chunk in stream:
380
- if chunk.choices and chunk.choices[0].delta:
381
- delta = chunk.choices[0].delta
382
-
383
- chunk_response = ChatCompletionChunk(
384
- id=completion_id,
385
- created=created,
386
- model=request.model,
387
- choices=[{
388
- "index": 0,
389
- "delta": {
390
- "role": delta.role if hasattr(delta, 'role') and delta.role else None,
391
- "content": delta.content if hasattr(delta, 'content') else None
392
- },
393
- "finish_reason": chunk.choices[0].finish_reason
394
- }]
395
- )
396
-
397
- yield f"data: {chunk_response.model_dump_json()}\n\n"
398
 
399
  # Send final chunk
400
- final_chunk = ChatCompletionChunk(
401
- id=completion_id,
402
- created=created,
403
- model=request.model,
404
- choices=[{
405
- "index": 0,
406
- "delta": {},
407
- "finish_reason": "stop"
408
- }]
409
- )
410
-
411
- yield f"data: {final_chunk.model_dump_json()}\n\n"
412
  yield "data: [DONE]\n\n"
413
 
414
  except Exception as e:
@@ -429,6 +444,9 @@ async def add_document(content: str, metadata: Optional[Dict] = None):
429
  if app_state.embedding_service is None:
430
  raise HTTPException(status_code=500, detail="Embedding service is not initialized")
431
 
 
 
 
432
  # Generate embedding for document
433
  embedding = await app_state.embedding_service.get_document_embedding(content)
434
 
@@ -462,6 +480,9 @@ async def batch_add_documents(documents: List[Dict[str, Any]]):
462
  if app_state.embedding_service is None:
463
  raise HTTPException(status_code=500, detail="Embedding service is not initialized")
464
 
 
 
 
465
  # Extract texts and metadata
466
  texts = [doc.get("content", "") for doc in documents]
467
  metadatas = [doc.get("metadata", {}) for doc in documents]
@@ -507,6 +528,22 @@ async def create_collection():
507
 
508
  from qdrant_client.models import VectorParams, Distance
509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  await app_state.qdrant_client.create_collection(
511
  collection_name=Config.COLLECTION_NAME,
512
  vectors_config=VectorParams(
@@ -518,7 +555,8 @@ async def create_collection():
518
  return {
519
  "message": f"Collection '{Config.COLLECTION_NAME}' created successfully",
520
  "vector_size": app_state.embedding_service.dimension,
521
- "distance": "cosine"
 
522
  }
523
 
524
  except Exception as e:
@@ -531,11 +569,15 @@ async def get_collection_info():
531
  if app_state.qdrant_client is None:
532
  raise HTTPException(status_code=500, detail="Qdrant client is not initialized")
533
 
 
 
 
534
  collection_info = await app_state.qdrant_client.get_collection(Config.COLLECTION_NAME)
535
  return {
536
  "name": Config.COLLECTION_NAME,
537
  "vectors_count": collection_info.vectors_count,
538
- "status": collection_info.status
 
539
  }
540
  except Exception as e:
541
  raise HTTPException(status_code=500, detail=f"Error getting collection info: {str(e)}")
 
10
  from contextlib import asynccontextmanager
11
 
12
  # Third-party imports
13
+ from openai import AsyncOpenAI
14
  from qdrant_client import AsyncQdrantClient
15
  from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
16
  from sentence_transformers import SentenceTransformer
 
49
  # Configuration
50
  class Config:
51
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
52
+ GROQ_BASE_URL = os.getenv("GROQ_BASE_URL", "https://api.groq.com/openai/v1")
53
  QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
54
  QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
55
  COLLECTION_NAME = os.getenv("COLLECTION_NAME", "documents")
 
61
  class ApplicationState:
62
  """Application state container"""
63
  def __init__(self):
64
+ self.openai_client = None
65
  self.qdrant_client = None
66
  self.embedding_service = None
67
 
 
206
  print("Error: Embedding service is not initialized")
207
  return []
208
 
209
+ # Auto-create collection if it doesn't exist
210
+ await RAGService._ensure_collection_exists()
211
+
212
  # Get query embedding - all-MiniLM works well without special prefixes
213
  query_embedding = await app_state.embedding_service.get_query_embedding(query)
214
 
 
235
  print(f"Error retrieving chunks: {e}")
236
  return []
237
 
238
+ @staticmethod
239
+ async def _ensure_collection_exists():
240
+ """Ensure the collection exists, create if it doesn't"""
241
+ try:
242
+ # Check if collection exists
243
+ collections = await app_state.qdrant_client.get_collections()
244
+ collection_names = [c.name for c in collections.collections]
245
+
246
+ if Config.COLLECTION_NAME not in collection_names:
247
+ print(f"Creating collection '{Config.COLLECTION_NAME}' on-demand...")
248
+ from qdrant_client.models import VectorParams, Distance
249
+
250
+ await app_state.qdrant_client.create_collection(
251
+ collection_name=Config.COLLECTION_NAME,
252
+ vectors_config=VectorParams(
253
+ size=app_state.embedding_service.dimension,
254
+ distance=Distance.COSINE
255
+ )
256
+ )
257
+ print(f"✓ Collection '{Config.COLLECTION_NAME}' created successfully!")
258
+
259
+ except Exception as e:
260
+ print(f"Warning: Could not ensure collection exists: {e}")
261
+ # Continue anyway - the operation might still work
262
+
263
  @staticmethod
264
  def build_context_prompt(query: str, chunks: List[str]) -> str:
265
  """Build a context-aware prompt with retrieved chunks"""
 
307
 
308
  return {
309
  "status": "healthy" if app_state.embedding_service is not None else "unhealthy",
310
+ "openai_client": "connected" if app_state.openai_client else "not configured",
311
  "qdrant": qdrant_status,
312
  "embedding_service": embedding_health,
313
  "collection": Config.COLLECTION_NAME,
314
+ "embedding_model": Config.EMBEDDING_MODEL,
315
+ "groq_endpoint": Config.GROQ_BASE_URL
316
  }
317
 
318
  @app.post("/v1/chat/completions")
319
  async def chat_completions(request: ChatCompletionRequest):
320
  """OpenAI-compatible chat completions endpoint with RAG"""
321
 
322
+ if not app_state.openai_client:
323
+ raise HTTPException(status_code=500, detail="OpenAI client not initialized")
324
 
325
  try:
326
  # Get the last user message for retrieval
 
342
  else:
343
  enhanced_messages = request.messages
344
 
345
+ # Convert to OpenAI format
346
+ openai_messages = [{"role": msg.role, "content": msg.content} for msg in enhanced_messages]
347
 
348
  if request.stream:
349
  return StreamingResponse(
350
+ stream_chat_completion(openai_messages, request),
351
  media_type="text/event-stream"
352
  )
353
  else:
354
+ return await create_chat_completion(openai_messages, request)
355
 
356
  except Exception as e:
357
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 
359
  async def create_chat_completion(messages: List[Dict], request: ChatCompletionRequest) -> ChatCompletionResponse:
360
  """Create a non-streaming chat completion"""
361
  try:
362
+ response = await app_state.openai_client.chat.completions.create(
363
  model=request.model,
364
  messages=messages,
365
  max_tokens=request.max_tokens,
 
368
  stream=False
369
  )
370
 
371
+ # Convert response to OpenAI format (already compatible)
372
  return ChatCompletionResponse(
373
+ id=response.id,
374
+ created=response.created,
375
+ model=response.model,
376
  choices=[{
377
+ "index": choice.index,
378
  "message": {
379
+ "role": choice.message.role,
380
+ "content": choice.message.content
381
  },
382
+ "finish_reason": choice.finish_reason
383
+ } for choice in response.choices],
384
  usage={
385
  "prompt_tokens": response.usage.prompt_tokens,
386
  "completion_tokens": response.usage.completion_tokens,
387
  "total_tokens": response.usage.total_tokens
388
+ } if response.usage else None
389
  )
390
 
391
  except Exception as e:
392
+ raise HTTPException(status_code=500, detail=f"Error calling OpenAI API: {str(e)}")
393
 
394
  async def stream_chat_completion(messages: List[Dict], request: ChatCompletionRequest) -> AsyncGenerator[str, None]:
395
  """Stream chat completion responses"""
396
  try:
397
+ stream = await app_state.openai_client.chat.completions.create(
 
 
 
398
  model=request.model,
399
  messages=messages,
400
  max_tokens=request.max_tokens,
 
404
  )
405
 
406
  async for chunk in stream:
407
+ if chunk.choices and len(chunk.choices) > 0:
408
+ choice = chunk.choices[0]
409
+ if choice.delta:
410
+ chunk_response = ChatCompletionChunk(
411
+ id=chunk.id,
412
+ created=chunk.created,
413
+ model=chunk.model,
414
+ choices=[{
415
+ "index": choice.index,
416
+ "delta": {
417
+ "role": choice.delta.role if choice.delta.role else None,
418
+ "content": choice.delta.content if choice.delta.content else None
419
+ },
420
+ "finish_reason": choice.finish_reason
421
+ }]
422
+ )
423
+
424
+ yield f"data: {chunk_response.model_dump_json()}\n\n"
425
 
426
  # Send final chunk
 
 
 
 
 
 
 
 
 
 
 
 
427
  yield "data: [DONE]\n\n"
428
 
429
  except Exception as e:
 
444
  if app_state.embedding_service is None:
445
  raise HTTPException(status_code=500, detail="Embedding service is not initialized")
446
 
447
+ # Auto-create collection if it doesn't exist
448
+ await RAGService._ensure_collection_exists()
449
+
450
  # Generate embedding for document
451
  embedding = await app_state.embedding_service.get_document_embedding(content)
452
 
 
480
  if app_state.embedding_service is None:
481
  raise HTTPException(status_code=500, detail="Embedding service is not initialized")
482
 
483
+ # Auto-create collection if it doesn't exist
484
+ await RAGService._ensure_collection_exists()
485
+
486
  # Extract texts and metadata
487
  texts = [doc.get("content", "") for doc in documents]
488
  metadatas = [doc.get("metadata", {}) for doc in documents]
 
528
 
529
  from qdrant_client.models import VectorParams, Distance
530
 
531
+ # Check if collection already exists
532
+ try:
533
+ collections = await app_state.qdrant_client.get_collections()
534
+ collection_names = [c.name for c in collections.collections]
535
+
536
+ if Config.COLLECTION_NAME in collection_names:
537
+ return {
538
+ "message": f"Collection '{Config.COLLECTION_NAME}' already exists",
539
+ "vector_size": app_state.embedding_service.dimension,
540
+ "distance": "cosine",
541
+ "status": "exists"
542
+ }
543
+ except Exception as e:
544
+ print(f"Warning: Could not check existing collections: {e}")
545
+
546
+ # Create the collection
547
  await app_state.qdrant_client.create_collection(
548
  collection_name=Config.COLLECTION_NAME,
549
  vectors_config=VectorParams(
 
555
  return {
556
  "message": f"Collection '{Config.COLLECTION_NAME}' created successfully",
557
  "vector_size": app_state.embedding_service.dimension,
558
+ "distance": "cosine",
559
+ "status": "created"
560
  }
561
 
562
  except Exception as e:
 
569
  if app_state.qdrant_client is None:
570
  raise HTTPException(status_code=500, detail="Qdrant client is not initialized")
571
 
572
+ # Auto-create collection if it doesn't exist
573
+ await RAGService._ensure_collection_exists()
574
+
575
  collection_info = await app_state.qdrant_client.get_collection(Config.COLLECTION_NAME)
576
  return {
577
  "name": Config.COLLECTION_NAME,
578
  "vectors_count": collection_info.vectors_count,
579
+ "status": collection_info.status,
580
+ "vector_size": app_state.embedding_service.dimension if app_state.embedding_service else "unknown"
581
  }
582
  except Exception as e:
583
  raise HTTPException(status_code=500, detail=f"Error getting collection info: {str(e)}")
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  fastapi==0.104.1
2
  uvicorn[standard]==0.24.0
3
- groq==0.4.1
4
  qdrant-client==1.7.0
5
  sentence-transformers==2.2.2
6
  torch==2.1.1
 
1
  fastapi==0.104.1
2
  uvicorn[standard]==0.24.0
3
+ openai==1.3.7
4
  qdrant-client==1.7.0
5
  sentence-transformers==2.2.2
6
  torch==2.1.1