SUBHRAJIT MOHANTY commited on
Commit
6cb77d4
Β·
1 Parent(s): bcc14d5

Fixing groq

Browse files
Files changed (1) hide show
  1. app.py +74 -23
app.py CHANGED
@@ -71,45 +71,97 @@ app_state = ApplicationState()
71
  @asynccontextmanager
72
  async def lifespan(app: FastAPI):
73
  # Startup
74
- global groq_client, qdrant_client, embedding_service
75
-
76
  if not Config.GROQ_API_KEY:
77
  raise ValueError("GROQ_API_KEY environment variable is required")
78
 
79
- groq_client = AsyncGroq(api_key=Config.GROQ_API_KEY)
80
- qdrant_client = AsyncQdrantClient(
81
- url=Config.QDRANT_URL,
82
- api_key=Config.QDRANT_API_KEY
83
- )
84
 
85
- # Initialize embedding service
86
- embedding_service = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- # Verify connections
89
  try:
90
- collections = await qdrant_client.get_collections()
91
- print(f"Connected to Qdrant. Available collections: {[c.name for c in collections.collections]}")
 
 
 
92
  except Exception as e:
93
- print(f"Warning: Could not connect to Qdrant: {e}")
 
94
 
95
- # Check embedding model
96
  try:
97
- print(f"Embedding model loaded: {Config.EMBEDDING_MODEL}")
98
- print(f"Model device: {Config.DEVICE}")
99
- print(f"Vector dimension: {embedding_service.dimension}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  except Exception as e:
101
- print(f"Warning: Could not load embedding model: {e}")
 
 
 
102
 
103
  yield
104
 
105
  # Shutdown
106
- if qdrant_client:
107
- await qdrant_client.close()
 
 
 
 
 
 
 
 
 
108
 
109
  # Initialize FastAPI app
110
  app = FastAPI(
111
  title="RAG API with Groq and Qdrant",
112
- description="OpenAI-compatible API for RAG using Groq LLM and Qdrant vector database",
113
  version="1.0.0",
114
  lifespan=lifespan
115
  )
@@ -192,8 +244,6 @@ class EmbeddingService:
192
  "error": str(e)
193
  }
194
 
195
- embedding_service = EmbeddingService()
196
-
197
  class RAGService:
198
  """Service for retrieval-augmented generation"""
199
 
@@ -436,6 +486,7 @@ async def stream_chat_completion(messages: List[Dict], request: ChatCompletionRe
436
  }
437
  yield f"data: {json.dumps(error_chunk)}\n\n"
438
 
 
439
  @app.post("/v1/embeddings/add")
440
  async def add_document(content: str, metadata: Optional[Dict] = None):
441
  """Add a document to the vector database"""
 
71
  @asynccontextmanager
72
  async def lifespan(app: FastAPI):
73
  # Startup
 
 
74
  if not Config.GROQ_API_KEY:
75
  raise ValueError("GROQ_API_KEY environment variable is required")
76
 
77
+ print("Initializing services...")
 
 
 
 
78
 
79
+ # Initialize OpenAI client with Groq endpoint
80
+ try:
81
+ app_state.openai_client = AsyncOpenAI(
82
+ api_key=Config.GROQ_API_KEY,
83
+ base_url=Config.GROQ_BASE_URL
84
+ )
85
+ print("βœ“ OpenAI client initialized with Groq endpoint")
86
+ except Exception as e:
87
+ print(f"βœ— Error initializing OpenAI client: {e}")
88
+ raise e
89
+
90
+ # Initialize Qdrant client
91
+ try:
92
+ app_state.qdrant_client = AsyncQdrantClient(
93
+ url=Config.QDRANT_URL,
94
+ api_key=Config.QDRANT_API_KEY
95
+ )
96
+ print("βœ“ Qdrant client initialized")
97
+ except Exception as e:
98
+ print(f"βœ— Error initializing Qdrant client: {e}")
99
+ raise e
100
 
101
+ # Initialize embedding service
102
  try:
103
+ print("Loading embedding model...")
104
+ app_state.embedding_service = EmbeddingService()
105
+ print(f"βœ“ Embedding model loaded: {Config.EMBEDDING_MODEL}")
106
+ print(f"βœ“ Model device: {Config.DEVICE}")
107
+ print(f"βœ“ Vector dimension: {app_state.embedding_service.dimension}")
108
  except Exception as e:
109
+ print(f"βœ— Error initializing embedding service: {e}")
110
+ raise e # Fail fast if embedding service can't be initialized
111
 
112
+ # Verify Qdrant connection and auto-create collection
113
  try:
114
+ collections = await app_state.qdrant_client.get_collections()
115
+ collection_names = [c.name for c in collections.collections]
116
+ print(f"βœ“ Connected to Qdrant. Available collections: {collection_names}")
117
+
118
+ # Check if our collection exists, if not create it
119
+ if Config.COLLECTION_NAME not in collection_names:
120
+ print(f"πŸ“ Collection '{Config.COLLECTION_NAME}' not found. Creating automatically...")
121
+ try:
122
+ from qdrant_client.models import VectorParams, Distance
123
+
124
+ await app_state.qdrant_client.create_collection(
125
+ collection_name=Config.COLLECTION_NAME,
126
+ vectors_config=VectorParams(
127
+ size=app_state.embedding_service.dimension,
128
+ distance=Distance.COSINE
129
+ )
130
+ )
131
+ print(f"βœ“ Collection '{Config.COLLECTION_NAME}' created successfully!")
132
+ print(f"βœ“ Vector dimension: {app_state.embedding_service.dimension}")
133
+ print(f"βœ“ Distance metric: COSINE")
134
+ except Exception as create_error:
135
+ print(f"βœ— Failed to create collection: {create_error}")
136
+ print("⚠ You may need to create the collection manually")
137
+ else:
138
+ print(f"βœ“ Collection '{Config.COLLECTION_NAME}' already exists")
139
+
140
  except Exception as e:
141
+ print(f"⚠ Warning: Could not connect to Qdrant: {e}")
142
+ print("⚠ Collection auto-creation skipped")
143
+
144
+ print("πŸš€ All services initialized successfully!")
145
 
146
  yield
147
 
148
  # Shutdown
149
+ print("Shutting down services...")
150
+ if app_state.qdrant_client:
151
+ await app_state.qdrant_client.close()
152
+ print("βœ“ Qdrant client closed")
153
+ if app_state.openai_client:
154
+ await app_state.openai_client.close()
155
+ print("βœ“ OpenAI client closed")
156
+ if app_state.embedding_service and hasattr(app_state.embedding_service, 'executor'):
157
+ app_state.embedding_service.executor.shutdown(wait=True)
158
+ print("βœ“ Embedding service executor shutdown")
159
+ print("βœ“ Shutdown complete")
160
 
161
  # Initialize FastAPI app
162
  app = FastAPI(
163
  title="RAG API with Groq and Qdrant",
164
+ description="OpenAI-compatible API for RAG using Groq and Qdrant",
165
  version="1.0.0",
166
  lifespan=lifespan
167
  )
 
244
  "error": str(e)
245
  }
246
 
 
 
247
  class RAGService:
248
  """Service for retrieval-augmented generation"""
249
 
 
486
  }
487
  yield f"data: {json.dumps(error_chunk)}\n\n"
488
 
489
+ # Additional endpoints for managing the vector database
490
  @app.post("/v1/embeddings/add")
491
  async def add_document(content: str, metadata: Optional[Dict] = None):
492
  """Add a document to the vector database"""