Spaces:
Running
Running
SUBHRAJIT MOHANTY
commited on
Commit
Β·
6cb77d4
1
Parent(s):
bcc14d5
Fixing groq
Browse files
app.py
CHANGED
@@ -71,45 +71,97 @@ app_state = ApplicationState()
|
|
71 |
@asynccontextmanager
|
72 |
async def lifespan(app: FastAPI):
|
73 |
# Startup
|
74 |
-
global groq_client, qdrant_client, embedding_service
|
75 |
-
|
76 |
if not Config.GROQ_API_KEY:
|
77 |
raise ValueError("GROQ_API_KEY environment variable is required")
|
78 |
|
79 |
-
|
80 |
-
qdrant_client = AsyncQdrantClient(
|
81 |
-
url=Config.QDRANT_URL,
|
82 |
-
api_key=Config.QDRANT_API_KEY
|
83 |
-
)
|
84 |
|
85 |
-
# Initialize
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
-
#
|
89 |
try:
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
92 |
except Exception as e:
|
93 |
-
print(f"
|
|
|
94 |
|
95 |
-
#
|
96 |
try:
|
97 |
-
|
98 |
-
|
99 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
except Exception as e:
|
101 |
-
print(f"Warning: Could not
|
|
|
|
|
|
|
102 |
|
103 |
yield
|
104 |
|
105 |
# Shutdown
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
# Initialize FastAPI app
|
110 |
app = FastAPI(
|
111 |
title="RAG API with Groq and Qdrant",
|
112 |
-
description="OpenAI-compatible API for RAG using Groq
|
113 |
version="1.0.0",
|
114 |
lifespan=lifespan
|
115 |
)
|
@@ -192,8 +244,6 @@ class EmbeddingService:
|
|
192 |
"error": str(e)
|
193 |
}
|
194 |
|
195 |
-
embedding_service = EmbeddingService()
|
196 |
-
|
197 |
class RAGService:
|
198 |
"""Service for retrieval-augmented generation"""
|
199 |
|
@@ -436,6 +486,7 @@ async def stream_chat_completion(messages: List[Dict], request: ChatCompletionRe
|
|
436 |
}
|
437 |
yield f"data: {json.dumps(error_chunk)}\n\n"
|
438 |
|
|
|
439 |
@app.post("/v1/embeddings/add")
|
440 |
async def add_document(content: str, metadata: Optional[Dict] = None):
|
441 |
"""Add a document to the vector database"""
|
|
|
71 |
@asynccontextmanager
|
72 |
async def lifespan(app: FastAPI):
|
73 |
# Startup
|
|
|
|
|
74 |
if not Config.GROQ_API_KEY:
|
75 |
raise ValueError("GROQ_API_KEY environment variable is required")
|
76 |
|
77 |
+
print("Initializing services...")
|
|
|
|
|
|
|
|
|
78 |
|
79 |
+
# Initialize OpenAI client with Groq endpoint
|
80 |
+
try:
|
81 |
+
app_state.openai_client = AsyncOpenAI(
|
82 |
+
api_key=Config.GROQ_API_KEY,
|
83 |
+
base_url=Config.GROQ_BASE_URL
|
84 |
+
)
|
85 |
+
print("β OpenAI client initialized with Groq endpoint")
|
86 |
+
except Exception as e:
|
87 |
+
print(f"β Error initializing OpenAI client: {e}")
|
88 |
+
raise e
|
89 |
+
|
90 |
+
# Initialize Qdrant client
|
91 |
+
try:
|
92 |
+
app_state.qdrant_client = AsyncQdrantClient(
|
93 |
+
url=Config.QDRANT_URL,
|
94 |
+
api_key=Config.QDRANT_API_KEY
|
95 |
+
)
|
96 |
+
print("β Qdrant client initialized")
|
97 |
+
except Exception as e:
|
98 |
+
print(f"β Error initializing Qdrant client: {e}")
|
99 |
+
raise e
|
100 |
|
101 |
+
# Initialize embedding service
|
102 |
try:
|
103 |
+
print("Loading embedding model...")
|
104 |
+
app_state.embedding_service = EmbeddingService()
|
105 |
+
print(f"β Embedding model loaded: {Config.EMBEDDING_MODEL}")
|
106 |
+
print(f"β Model device: {Config.DEVICE}")
|
107 |
+
print(f"β Vector dimension: {app_state.embedding_service.dimension}")
|
108 |
except Exception as e:
|
109 |
+
print(f"β Error initializing embedding service: {e}")
|
110 |
+
raise e # Fail fast if embedding service can't be initialized
|
111 |
|
112 |
+
# Verify Qdrant connection and auto-create collection
|
113 |
try:
|
114 |
+
collections = await app_state.qdrant_client.get_collections()
|
115 |
+
collection_names = [c.name for c in collections.collections]
|
116 |
+
print(f"β Connected to Qdrant. Available collections: {collection_names}")
|
117 |
+
|
118 |
+
# Check if our collection exists, if not create it
|
119 |
+
if Config.COLLECTION_NAME not in collection_names:
|
120 |
+
print(f"π Collection '{Config.COLLECTION_NAME}' not found. Creating automatically...")
|
121 |
+
try:
|
122 |
+
from qdrant_client.models import VectorParams, Distance
|
123 |
+
|
124 |
+
await app_state.qdrant_client.create_collection(
|
125 |
+
collection_name=Config.COLLECTION_NAME,
|
126 |
+
vectors_config=VectorParams(
|
127 |
+
size=app_state.embedding_service.dimension,
|
128 |
+
distance=Distance.COSINE
|
129 |
+
)
|
130 |
+
)
|
131 |
+
print(f"β Collection '{Config.COLLECTION_NAME}' created successfully!")
|
132 |
+
print(f"β Vector dimension: {app_state.embedding_service.dimension}")
|
133 |
+
print(f"β Distance metric: COSINE")
|
134 |
+
except Exception as create_error:
|
135 |
+
print(f"β Failed to create collection: {create_error}")
|
136 |
+
print("β You may need to create the collection manually")
|
137 |
+
else:
|
138 |
+
print(f"β Collection '{Config.COLLECTION_NAME}' already exists")
|
139 |
+
|
140 |
except Exception as e:
|
141 |
+
print(f"β Warning: Could not connect to Qdrant: {e}")
|
142 |
+
print("β Collection auto-creation skipped")
|
143 |
+
|
144 |
+
print("π All services initialized successfully!")
|
145 |
|
146 |
yield
|
147 |
|
148 |
# Shutdown
|
149 |
+
print("Shutting down services...")
|
150 |
+
if app_state.qdrant_client:
|
151 |
+
await app_state.qdrant_client.close()
|
152 |
+
print("β Qdrant client closed")
|
153 |
+
if app_state.openai_client:
|
154 |
+
await app_state.openai_client.close()
|
155 |
+
print("β OpenAI client closed")
|
156 |
+
if app_state.embedding_service and hasattr(app_state.embedding_service, 'executor'):
|
157 |
+
app_state.embedding_service.executor.shutdown(wait=True)
|
158 |
+
print("β Embedding service executor shutdown")
|
159 |
+
print("β Shutdown complete")
|
160 |
|
161 |
# Initialize FastAPI app
|
162 |
app = FastAPI(
|
163 |
title="RAG API with Groq and Qdrant",
|
164 |
+
description="OpenAI-compatible API for RAG using Groq and Qdrant",
|
165 |
version="1.0.0",
|
166 |
lifespan=lifespan
|
167 |
)
|
|
|
244 |
"error": str(e)
|
245 |
}
|
246 |
|
|
|
|
|
247 |
class RAGService:
|
248 |
"""Service for retrieval-augmented generation"""
|
249 |
|
|
|
486 |
}
|
487 |
yield f"data: {json.dumps(error_chunk)}\n\n"
|
488 |
|
489 |
+
# Additional endpoints for managing the vector database
|
490 |
@app.post("/v1/embeddings/add")
|
491 |
async def add_document(content: str, metadata: Optional[Dict] = None):
|
492 |
"""Add a document to the vector database"""
|