amaye15 commited on
Commit
2cb9dec
·
1 Parent(s): fd8f07a

Intial Deployment

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *pycache*
2
+ *.env*
Dockerfile ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stage 1: Build stage
2
+ FROM python:3.12-slim as builder
3
+
4
+ # Set environment variables
5
+ ENV PYTHONDONTWRITEBYTECODE=1
6
+ ENV PYTHONUNBUFFERED=1
7
+
8
+ # Create a non-root user
9
+ RUN useradd -m -u 1000 user
10
+
11
+ # Set the working directory
12
+ WORKDIR /app
13
+
14
+ # Copy only the requirements file first to leverage Docker cache
15
+ COPY --chown=user ./requirements.txt /app/requirements.txt
16
+
17
+ # Install dependencies in a virtual environment
18
+ RUN python -m venv /opt/venv
19
+ ENV PATH="/opt/venv/bin:$PATH"
20
+ RUN pip install --no-cache-dir --upgrade pip && \
21
+ pip install --no-cache-dir -r requirements.txt
22
+
23
+ # Copy the rest of the application code
24
+ COPY --chown=user . /app
25
+
26
+ # Stage 2: Runtime stage
27
+ FROM python:3.12-slim
28
+
29
+ # Create a non-root user
30
+ RUN useradd -m -u 1000 user
31
+ USER user
32
+
33
+ # Copy the virtual environment from the builder stage
34
+ COPY --from=builder /opt/venv /opt/venv
35
+ ENV PATH="/opt/venv/bin:$PATH"
36
+
37
+ # Set the working directory
38
+ WORKDIR /app
39
+
40
+ # Copy only the necessary files from the builder stage
41
+ COPY --from=builder --chown=user /app /app
42
+
43
+ # Expose the port the app runs on
44
+ EXPOSE 7860
45
+
46
+ # Health check to ensure the application is running
47
+ HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
48
+ CMD curl -f http://localhost:7860/health || exit 1
49
+
50
+ # Command to run the application with hot reloading
51
+ CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "7860", "--reload"]
README.md CHANGED
@@ -6,6 +6,9 @@ colorTo: gray
6
  sdk: docker
7
  pinned: false
8
  license: mit
 
 
 
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
6
  sdk: docker
7
  pinned: false
8
  license: mit
9
+ python_version: 3.12
10
+ app_port: 7860
11
+ app_file: src/main.py
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
docker-compose.yml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.9"
2
+
3
+ services:
4
+ app:
5
+ build:
6
+ context: .
7
+ dockerfile: Dockerfile
8
+ container_name: similarity-search-app
9
+ ports:
10
+ - "7860:7860"
11
+ volumes:
12
+ - ./src:/app/src # Mount the local src directory for hot reloading
13
+ environment:
14
+ - PYTHONUNBUFFERED=1
15
+ restart: unless-stopped
16
+ healthcheck:
17
+ test: ["CMD", "curl", "-f", "http://localhost:7860/health"]
18
+ interval: 30s
19
+ timeout: 10s
20
+ retries: 3
21
+ # depends_on:
22
+ # - db # If you have a database service, add it here
23
+
24
+ # # Example database service (optional)
25
+ # db:
26
+ # image: postgres:latest
27
+ # container_name: similarity-search-db
28
+ # environment:
29
+ # POSTGRES_USER: user
30
+ # POSTGRES_PASSWORD: password
31
+ # POSTGRES_DB: mydatabase
32
+ # ports:
33
+ # - "5432:5432"
34
+ # volumes:
35
+ # - postgres_data:/var/lib/postgresql/data
36
+ # healthcheck:
37
+ # test: ["CMD-SHELL", "pg_isready -U user -d mydatabase"]
38
+ # interval: 5s
39
+ # timeout: 5s
40
+ # retries: 5
41
+
42
+ # volumes:
43
+ # postgres_data:
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ pg8000
2
+ pydantic
3
+ pydantic-settings
4
+ uvicorn
5
+ fastapi
6
+ openai
7
+ pandas
8
+ datasets
src/api/database.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import logging
2
+ # from typing import Dict, List, Optional, AsyncGenerator
3
+ # from pydantic import BaseSettings, PostgresDsn
4
+ # import pg8000
5
+ # from pg8000 import Connection, Cursor
6
+ # from pg8000.exceptions import DatabaseError
7
+ # import asyncio
8
+ # from contextlib import asynccontextmanager
9
+ # from dataclasses import dataclass
10
+ # from threading import Lock
11
+
12
+ # # Set up structured logging
13
+ # logging.basicConfig(
14
+ # level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
15
+ # )
16
+ # logger = logging.getLogger(__name__)
17
+
18
+
19
+ # class DatabaseSettings(BaseSettings):
20
+ # db_url: PostgresDsn
21
+ # pool_size: int = 5
22
+
23
+ # class Config:
24
+ # env_file = ".env"
25
+
26
+
27
+ # @dataclass
28
+ # class DatabaseConfig:
29
+ # username: str
30
+ # password: str
31
+ # hostname: str
32
+ # port: int
33
+ # database: str
34
+
35
+
36
+ # class DatabaseError(Exception):
37
+ # """Custom exception for database errors."""
38
+
39
+ # pass
40
+
41
+
42
+ # class Database:
43
+ # def __init__(self, db_url: str, pool_size: int):
44
+ # self.db_url = db_url
45
+ # self.pool_size = pool_size
46
+ # self.pool: List[Connection] = []
47
+ # self.lock = Lock()
48
+ # self.config = self._parse_db_url()
49
+
50
+ # def _parse_db_url(self) -> DatabaseConfig:
51
+ # """Parse the database URL into components."""
52
+ # result = urlparse(self.db_url)
53
+ # return DatabaseConfig(
54
+ # username=result.username,
55
+ # password=result.password,
56
+ # hostname=result.hostname,
57
+ # port=result.port or 5432,
58
+ # database=result.path.lstrip("/"),
59
+ # )
60
+
61
+ # async def connect(self) -> None:
62
+ # """Create a connection pool."""
63
+ # try:
64
+ # for _ in range(self.pool_size):
65
+ # conn = await self._create_connection()
66
+ # self.pool.append(conn)
67
+ # logger.info(
68
+ # f"Database connection pool created with {self.pool_size} connections."
69
+ # )
70
+ # except DatabaseError as e:
71
+ # logger.error(f"Failed to create database connection pool: {e}")
72
+ # raise
73
+
74
+ # async def _create_connection(self) -> Connection:
75
+ # """Create a single database connection."""
76
+ # try:
77
+ # conn = pg8000.connect(
78
+ # user=self.config.username,
79
+ # password=self.config.password,
80
+ # host=self.config.hostname,
81
+ # port=self.config.port,
82
+ # database=self.config.database,
83
+ # )
84
+ # return conn
85
+ # except DatabaseError as e:
86
+ # logger.error(f"Failed to create database connection: {e}")
87
+ # raise DatabaseError("Failed to create database connection.")
88
+
89
+ # async def disconnect(self) -> None:
90
+ # """Close all connections in the pool."""
91
+ # with self.lock:
92
+ # for conn in self.pool:
93
+ # conn.close()
94
+ # self.pool.clear()
95
+ # logger.info("Database connection pool closed.")
96
+
97
+ # @asynccontextmanager
98
+ # async def get_connection(self) -> AsyncGenerator[Connection, None]:
99
+ # """Acquire a connection from the pool."""
100
+ # with self.lock:
101
+ # if not self.pool:
102
+ # raise DatabaseError("Database connection pool is empty.")
103
+ # conn = self.pool.pop()
104
+ # try:
105
+ # yield conn
106
+ # finally:
107
+ # with self.lock:
108
+ # self.pool.append(conn)
109
+
110
+ # async def fetch(self, query: str, *args) -> List[Dict]:
111
+ # """
112
+ # Execute a SELECT query and return the results as a list of dictionaries.
113
+
114
+ # Args:
115
+ # query (str): The SQL query to execute.
116
+ # *args: Query parameters.
117
+
118
+ # Returns:
119
+ # List[Dict]: A list of dictionaries where keys are column names and values are column values.
120
+ # """
121
+ # try:
122
+ # async with self.get_connection() as conn:
123
+ # cursor: Cursor = conn.cursor()
124
+ # cursor.execute(query, args)
125
+ # rows = cursor.fetchall()
126
+ # columns = [desc[0] for desc in cursor.description]
127
+ # return [dict(zip(columns, row)) for row in rows]
128
+ # except DatabaseError as e:
129
+ # logger.error(f"Error executing query: {query}. Error: {e}")
130
+ # raise DatabaseError(f"Failed to execute query: {query}")
131
+
132
+ # async def execute(self, query: str, *args) -> None:
133
+ # """
134
+ # Execute an INSERT, UPDATE, or DELETE query.
135
+
136
+ # Args:
137
+ # query (str): The SQL query to execute.
138
+ # *args: Query parameters.
139
+ # """
140
+ # try:
141
+ # async with self.get_connection() as conn:
142
+ # cursor: Cursor = conn.cursor()
143
+ # cursor.execute(query, args)
144
+ # conn.commit()
145
+ # except DatabaseError as e:
146
+ # logger.error(f"Error executing query: {query}. Error: {e}")
147
+ # raise DatabaseError(f"Failed to execute query: {query}")
148
+
149
+
150
+ # # Dependency to get the database instance
151
+ # async def get_db() -> AsyncGenerator[Database, None]:
152
+ # settings = DatabaseSettings()
153
+ # db = Database(db_url=settings.db_url, pool_size=settings.pool_size)
154
+ # await db.connect()
155
+ # try:
156
+ # yield db
157
+ # finally:
158
+ # await db.disconnect()
159
+
160
+
161
+ # # Example usage
162
+ # if __name__ == "__main__":
163
+
164
+ # async def main():
165
+ # settings = DatabaseSettings()
166
+ # db = Database(db_url=settings.db_url, pool_size=settings.pool_size)
167
+ # await db.connect()
168
+
169
+ # try:
170
+ # # Example query
171
+ # query = """
172
+ # SELECT
173
+ # ppt.type AS product_type,
174
+ # pc.name AS product_category
175
+ # FROM
176
+ # product_producttype ppt
177
+ # INNER JOIN
178
+ # product_category pc
179
+ # ON
180
+ # ppt.category_id = pc.id
181
+ # """
182
+ # result = await db.fetch(query)
183
+ # print(result)
184
+ # finally:
185
+ # await db.disconnect()
186
+
187
+ # asyncio.run(main())
188
+
189
+ # import logging
190
+ # from urllib.parse import urlparse
191
+ # from typing import Dict, List, Optional, AsyncGenerator
192
+ # from pydantic_settings import BaseSettings
193
+ # from pydantic import PostgresDsn
194
+ # import pg8000
195
+ # from pg8000 import Connection, Cursor
196
+ # from pg8000.exceptions import DatabaseError
197
+ # import asyncio
198
+ # from contextlib import asynccontextmanager
199
+ # from dataclasses import dataclass
200
+ # from threading import Lock
201
+
202
+ # # Set up structured logging
203
+ # logging.basicConfig(
204
+ # level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
205
+ # )
206
+ # logger = logging.getLogger(__name__)
207
+
208
+
209
+ # class DatabaseSettings(BaseSettings):
210
+ # db_url: PostgresDsn
211
+ # pool_size: int = 5
212
+
213
+ # class Config:
214
+ # env_file = ".env"
215
+
216
+
217
+ # @dataclass
218
+ # class DatabaseConfig:
219
+ # username: str
220
+ # password: str
221
+ # hostname: str
222
+ # port: int
223
+ # database: str
224
+
225
+
226
+ # class DatabaseError(Exception):
227
+ # """Custom exception for database errors."""
228
+
229
+ # pass
230
+
231
+
232
+ # class Database:
233
+ # def __init__(self, db_url: str, pool_size: int):
234
+ # self.db_url = db_url
235
+ # self.pool_size = pool_size
236
+ # self.pool: List[Connection] = []
237
+ # self.lock = Lock()
238
+ # self.config = self._parse_db_url()
239
+
240
+ # def _parse_db_url(self) -> DatabaseConfig:
241
+ # """Parse the database URL into components."""
242
+ # # Convert PostgresDsn to a string
243
+ # db_url_str = str(self.db_url)
244
+ # result = urlparse(db_url_str)
245
+ # return DatabaseConfig(
246
+ # username=result.username,
247
+ # password=result.password,
248
+ # hostname=result.hostname,
249
+ # port=result.port or 5432,
250
+ # database=result.path.lstrip("/"),
251
+ # )
252
+
253
+ # async def connect(self) -> None:
254
+ # """Create a connection pool."""
255
+ # try:
256
+ # for _ in range(self.pool_size):
257
+ # conn = await self._create_connection()
258
+ # self.pool.append(conn)
259
+ # logger.info(
260
+ # f"Database connection pool created with {self.pool_size} connections."
261
+ # )
262
+ # except DatabaseError as e:
263
+ # logger.error(f"Failed to create database connection pool: {e}")
264
+ # raise
265
+
266
+ # async def _create_connection(self) -> Connection:
267
+ # """Create a single database connection."""
268
+ # try:
269
+ # conn = pg8000.connect(
270
+ # user=self.config.username,
271
+ # password=self.config.password,
272
+ # host=self.config.hostname,
273
+ # port=self.config.port,
274
+ # database=self.config.database,
275
+ # )
276
+ # return conn
277
+ # except DatabaseError as e:
278
+ # logger.error(f"Failed to create database connection: {e}")
279
+ # raise DatabaseError("Failed to create database connection.")
280
+
281
+ # async def disconnect(self) -> None:
282
+ # """Close all connections in the pool."""
283
+ # with self.lock:
284
+ # for conn in self.pool:
285
+ # conn.close()
286
+ # self.pool.clear()
287
+ # logger.info("Database connection pool closed.")
288
+
289
+ # @asynccontextmanager
290
+ # async def get_connection(self) -> AsyncGenerator[Connection, None]:
291
+ # """Acquire a connection from the pool."""
292
+ # with self.lock:
293
+ # if not self.pool:
294
+ # raise DatabaseError("Database connection pool is empty.")
295
+ # conn = self.pool.pop()
296
+ # try:
297
+ # yield conn
298
+ # finally:
299
+ # with self.lock:
300
+ # self.pool.append(conn)
301
+
302
+ # async def fetch(self, query: str, *args) -> List[Dict]:
303
+ # """
304
+ # Execute a SELECT query and return the results as a list of dictionaries.
305
+
306
+ # Args:
307
+ # query (str): The SQL query to execute.
308
+ # *args: Query parameters.
309
+
310
+ # Returns:
311
+ # List[Dict]: A list of dictionaries where keys are column names and values are column values.
312
+ # """
313
+ # try:
314
+ # async with self.get_connection() as conn:
315
+ # cursor: Cursor = conn.cursor()
316
+ # cursor.execute(query, args)
317
+ # rows = cursor.fetchall()
318
+ # columns = [desc[0] for desc in cursor.description]
319
+ # return [dict(zip(columns, row)) for row in rows]
320
+ # except DatabaseError as e:
321
+ # logger.error(f"Error executing query: {query}. Error: {e}")
322
+ # raise DatabaseError(f"Failed to execute query: {query}")
323
+
324
+ # async def execute(self, query: str, *args) -> None:
325
+ # """
326
+ # Execute an INSERT, UPDATE, or DELETE query.
327
+
328
+ # Args:
329
+ # query (str): The SQL query to execute.
330
+ # *args: Query parameters.
331
+ # """
332
+ # try:
333
+ # async with self.get_connection() as conn:
334
+ # cursor: Cursor = conn.cursor()
335
+ # cursor.execute(query, args)
336
+ # conn.commit()
337
+ # except DatabaseError as e:
338
+ # logger.error(f"Error executing query: {query}. Error: {e}")
339
+ # raise DatabaseError(f"Failed to execute query: {query}")
340
+
341
+
342
+ # # Dependency to get the database instance
343
+ # async def get_db() -> AsyncGenerator[Database, None]:
344
+ # settings = DatabaseSettings()
345
+ # db = Database(db_url=settings.db_url, pool_size=settings.pool_size)
346
+ # await db.connect()
347
+ # try:
348
+ # yield db
349
+ # finally:
350
+ # await db.disconnect()
351
+
352
+
353
+ # # Example usage
354
+ # if __name__ == "__main__":
355
+
356
+ # async def main():
357
+ # settings = DatabaseSettings()
358
+ # db = Database(db_url=settings.db_url, pool_size=settings.pool_size)
359
+ # await db.connect()
360
+
361
+ # try:
362
+ # # Example query
363
+ # query = "SELECT * FROM your_table LIMIT 10"
364
+ # query = """
365
+ # SELECT
366
+ # ppt.type AS product_type,
367
+ # pc.name AS product_category
368
+ # FROM
369
+ # product_producttype ppt
370
+ # INNER JOIN
371
+ # product_category pc
372
+ # ON
373
+ # ppt.category_id = pc.id
374
+ # """
375
+ # result = await db.fetch(query)
376
+ # print(result)
377
+ # finally:
378
+ # await db.disconnect()
379
+
380
+ # asyncio.run(main())
381
+
382
+ import logging
383
+ from typing import AsyncGenerator, List, Optional, Dict
384
+ from pydantic_settings import BaseSettings
385
+ from pydantic import PostgresDsn
386
+ import pg8000
387
+ from pg8000 import Connection
388
+ from pg8000.exceptions import DatabaseError as Pg8000DatabaseError
389
+ import asyncio
390
+ from contextlib import asynccontextmanager
391
+ from threading import Lock
392
+ from urllib.parse import urlparse
393
+
394
+ # Set up structured logging
395
+ logging.basicConfig(
396
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
397
+ )
398
+ logger = logging.getLogger(__name__)
399
+
400
+
401
+ class DatabaseSettings(BaseSettings):
402
+ db_url: PostgresDsn
403
+ pool_size: int = 5 # Default pool size is 5
404
+
405
+ class Config:
406
+ env_file = ".env"
407
+
408
+
409
+ # Custom database errors
410
+ class DatabaseError(Exception):
411
+ """Base exception for database errors."""
412
+
413
+ pass
414
+
415
+
416
+ class ConnectionError(DatabaseError):
417
+ """Exception raised when a database connection fails."""
418
+
419
+ pass
420
+
421
+
422
+ class PoolExhaustedError(DatabaseError):
423
+ """Exception raised when the connection pool is exhausted."""
424
+
425
+ pass
426
+
427
+
428
+ class QueryExecutionError(DatabaseError):
429
+ """Exception raised when a query execution fails."""
430
+
431
+ pass
432
+
433
+
434
+ class HealthCheckError(DatabaseError):
435
+ """Exception raised when a health check fails."""
436
+
437
+ pass
438
+
439
+
440
+ class Database:
441
+ def __init__(self, db_url: PostgresDsn, pool_size: int):
442
+ self.db_url = db_url
443
+ self.pool_size = pool_size
444
+ self.pool: List[Connection] = []
445
+ self.lock = Lock()
446
+
447
+ async def connect(self) -> None:
448
+ """Create a connection pool."""
449
+ try:
450
+ # Convert PostgresDsn to a string
451
+ db_url_str = str(self.db_url)
452
+ result = urlparse(db_url_str)
453
+ for _ in range(self.pool_size):
454
+ conn = pg8000.connect(
455
+ user=result.username,
456
+ password=result.password,
457
+ host=result.hostname,
458
+ port=result.port or 5432,
459
+ database=result.path.lstrip("/"),
460
+ )
461
+ self.pool.append(conn)
462
+ logger.info(
463
+ f"Database connection pool created with {self.pool_size} connections."
464
+ )
465
+ except Pg8000DatabaseError as e:
466
+ logger.error(f"Failed to create database connection pool: {e}")
467
+ raise ConnectionError("Failed to create database connection pool.") from e
468
+
469
+ async def disconnect(self) -> None:
470
+ """Close all connections in the pool."""
471
+ with self.lock:
472
+ for conn in self.pool:
473
+ conn.close()
474
+ self.pool.clear()
475
+ logger.info("Database connection pool closed.")
476
+
477
+ @asynccontextmanager
478
+ async def get_connection(self) -> AsyncGenerator[Connection, None]:
479
+ """Acquire a connection from the pool."""
480
+ with self.lock:
481
+ if not self.pool:
482
+ logger.error("Connection pool is exhausted.")
483
+ raise PoolExhaustedError("No available connections in the pool.")
484
+ conn = self.pool.pop()
485
+ try:
486
+ yield conn
487
+ except Pg8000DatabaseError as e:
488
+ logger.error(f"Connection error: {e}")
489
+ raise ConnectionError("Failed to use database connection.") from e
490
+ finally:
491
+ with self.lock:
492
+ self.pool.append(conn)
493
+
494
+ async def fetch(self, query: str, *args) -> List[Dict]:
495
+ """
496
+ Execute a SELECT query and return the results as a list of dictionaries.
497
+
498
+ Args:
499
+ query (str): The SQL query to execute.
500
+ *args: Query parameters.
501
+
502
+ Returns:
503
+ List[Dict]: A list of dictionaries where keys are column names and values are column values.
504
+
505
+ Raises:
506
+ QueryExecutionError: If the query execution fails.
507
+ """
508
+ try:
509
+ async with self.get_connection() as conn:
510
+ cursor = conn.cursor()
511
+ cursor.execute(query, args)
512
+ rows = cursor.fetchall()
513
+ columns = [desc[0] for desc in cursor.description]
514
+ return [dict(zip(columns, row)) for row in rows]
515
+ except Pg8000DatabaseError as e:
516
+ logger.error(f"Query execution failed: {e}")
517
+ raise QueryExecutionError(f"Failed to execute query: {query}") from e
518
+
519
+ async def execute(self, query: str, *args) -> None:
520
+ """
521
+ Execute an INSERT, UPDATE, or DELETE query.
522
+
523
+ Args:
524
+ query (str): The SQL query to execute.
525
+ *args: Query parameters.
526
+
527
+ Raises:
528
+ QueryExecutionError: If the query execution fails.
529
+ """
530
+ try:
531
+ async with self.get_connection() as conn:
532
+ cursor = conn.cursor()
533
+ cursor.execute(query, args)
534
+ conn.commit()
535
+ except Pg8000DatabaseError as e:
536
+ logger.error(f"Query execution failed: {e}")
537
+ raise QueryExecutionError(f"Failed to execute query: {query}") from e
538
+
539
+ async def health_check(self) -> bool:
540
+ """
541
+ Perform a health check by executing a simple query (e.g., SELECT 1).
542
+
543
+ Returns:
544
+ bool: True if the database is healthy, False otherwise.
545
+
546
+ Raises:
547
+ HealthCheckError: If the health check fails.
548
+ """
549
+ try:
550
+ async with self.get_connection() as conn:
551
+ cursor = conn.cursor()
552
+ cursor.execute("SELECT 1")
553
+ result = cursor.fetchone()
554
+ cursor.close()
555
+
556
+ # Check if the result is as expected
557
+ if result and result[0] == 1:
558
+ logger.info("Database health check succeeded.")
559
+ return True
560
+ else:
561
+ logger.error("Database health check failed: Unexpected result.")
562
+ raise HealthCheckError("Unexpected result from health check query.")
563
+ except Pg8000DatabaseError as e:
564
+ logger.error(f"Health check failed: {e}")
565
+ raise HealthCheckError("Failed to perform health check.") from e
566
+
567
+
568
+ # Dependency to get the database instance
569
+ async def get_db() -> AsyncGenerator[Database, None]:
570
+ settings = DatabaseSettings()
571
+ db = Database(db_url=settings.db_url, pool_size=settings.pool_size)
572
+ await db.connect()
573
+ try:
574
+ yield db
575
+ finally:
576
+ await db.disconnect()
577
+
578
+
579
+ # Example usage
580
+ if __name__ == "__main__":
581
+
582
+ async def main():
583
+ settings = DatabaseSettings()
584
+ db = Database(db_url=settings.db_url, pool_size=settings.pool_size)
585
+ await db.connect()
586
+
587
+ try:
588
+ # Perform a health check
589
+ is_healthy = await db.health_check()
590
+ print(f"Database health check: {'Success' if is_healthy else 'Failure'}")
591
+ except HealthCheckError as e:
592
+ print(f"Health check failed: {e}")
593
+ finally:
594
+ await db.disconnect()
595
+
596
+ asyncio.run(main())
src/api/exceptions.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class DatabaseError(Exception):
2
+ """Base exception for database errors."""
3
+
4
+ pass
5
+
6
+
7
+ class QueryExecutionError(DatabaseError):
8
+ """Exception raised when a database query fails."""
9
+
10
+ pass
11
+
12
+
13
+ class EmbeddingError(Exception):
14
+ """Base exception for embedding-related errors."""
15
+
16
+ pass
17
+
18
+
19
+ class OpenAIError(EmbeddingError):
20
+ """Exception raised when OpenAI API fails."""
21
+
22
+ pass
23
+
24
+
25
+ class HuggingFaceError(Exception):
26
+ """Base exception for Hugging Face-related errors."""
27
+
28
+ pass
29
+
30
+
31
+ class DatasetNotFoundError(HuggingFaceError):
32
+ """Exception raised when a dataset is not found."""
33
+
34
+ pass
35
+
36
+
37
+ class DatasetPushError(HuggingFaceError):
38
+ """Exception raised when pushing a dataset to Hugging Face Hub fails."""
39
+
40
+ pass
src/api/models/embedding_models.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List, Dict
3
+
4
+
5
+ # Pydantic models for request validation
6
+ class CreateEmbeddingRequest(BaseModel):
7
+ query: str
8
+ target_column: str = "product_type"
9
+ output_column: str = "embedding"
10
+ model: str = "text-embedding-3-small"
11
+ batch_size: int = 100
12
+ dataset_name: str = "re-mind/product_type_embedding"
13
+
14
+
15
+ class UpdateEmbeddingRequest(BaseModel):
16
+ dataset_name: str
17
+ updates: Dict[str, List] # Column name -> List of values
18
+
19
+
20
+ class DeleteEmbeddingRequest(BaseModel):
21
+ dataset_name: str
22
+ columns: List[str] # List of columns to delete
src/api/services/embedding_service.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import AsyncOpenAI
2
+ import logging
3
+ from typing import List, Dict
4
+ import pandas as pd
5
+ import asyncio
6
+ from src.api.exceptions import OpenAIError
7
+
8
+ # Set up structured logging
9
+ logging.basicConfig(
10
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
11
+ )
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class EmbeddingService:
16
+ def __init__(
17
+ self,
18
+ openai_api_key: str,
19
+ model: str = "text-embedding-3-small",
20
+ batch_size: int = 100,
21
+ ):
22
+ self.client = AsyncOpenAI(api_key=openai_api_key)
23
+ self.model = model
24
+ self.batch_size = batch_size
25
+
26
+ async def get_embedding(self, text: str) -> List[float]:
27
+ """Generate embeddings for the given text using OpenAI."""
28
+ text = text.replace("\n", " ")
29
+ try:
30
+ response = await self.client.embeddings.create(
31
+ input=[text], model=self.model
32
+ )
33
+ return response.data[0].embedding
34
+ except Exception as e:
35
+ logger.error(f"Failed to generate embedding: {e}")
36
+ raise OpenAIError(f"OpenAI API error: {e}")
37
+
38
+ async def create_embeddings(
39
+ self, df: pd.DataFrame, target_column: str, output_column: str
40
+ ) -> pd.DataFrame:
41
+ """Create embeddings for the target column in the dataset."""
42
+ logger.info("Generating embeddings...")
43
+ batches = [
44
+ df[i : i + self.batch_size] for i in range(0, len(df), self.batch_size)
45
+ ]
46
+ processed_batches = await asyncio.gather(
47
+ *[
48
+ self._process_batch(batch, target_column, output_column)
49
+ for batch in batches
50
+ ]
51
+ )
52
+ return pd.concat(processed_batches)
53
+
54
+ async def _process_batch(
55
+ self, df_batch: pd.DataFrame, target_column: str, output_column: str
56
+ ) -> pd.DataFrame:
57
+ """Process a batch of rows to generate embeddings."""
58
+ embeddings = await asyncio.gather(
59
+ *[self.get_embedding(row[target_column]) for _, row in df_batch.iterrows()]
60
+ )
61
+ df_batch[output_column] = embeddings
62
+ return df_batch
src/api/services/huggingface_service.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import Dataset, load_dataset
2
+ import logging
3
+ from typing import Optional, Dict, List
4
+ import pandas as pd
5
+ from src.api.exceptions import DatasetNotFoundError, DatasetPushError
6
+
7
+ # Set up structured logging
8
+ logging.basicConfig(
9
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
10
+ )
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class HuggingFaceService:
15
+ async def push_to_hub(self, df: pd.DataFrame, dataset_name: str) -> None:
16
+ """Push the dataset to Hugging Face Hub."""
17
+ try:
18
+ logger.info(f"Creating Hugging Face Dataset: {dataset_name}...")
19
+ ds = Dataset.from_pandas(df).remove_columns("__index_level_0__")
20
+ ds.push_to_hub(dataset_name)
21
+ logger.info(f"Dataset pushed to Hugging Face Hub: {dataset_name}")
22
+ except Exception as e:
23
+ logger.error(f"Failed to push dataset to Hugging Face Hub: {e}")
24
+ raise DatasetPushError(f"Failed to push dataset: {e}")
25
+
26
+ async def read_dataset(self, dataset_name: str) -> Optional[pd.DataFrame]:
27
+ """Read a dataset from Hugging Face Hub."""
28
+ try:
29
+ logger.info(f"Loading dataset from Hugging Face Hub: {dataset_name}...")
30
+ ds = load_dataset(dataset_name)
31
+ df = ds["train"].to_pandas()
32
+ return df
33
+ except Exception as e:
34
+ logger.error(f"Failed to read dataset: {e}")
35
+ raise DatasetNotFoundError(f"Dataset not found: {e}")
36
+
37
+ async def update_dataset(
38
+ self, dataset_name: str, updates: Dict[str, List]
39
+ ) -> Optional[pd.DataFrame]:
40
+ """Update a dataset on Hugging Face Hub."""
41
+ try:
42
+ df = await self.read_dataset(dataset_name)
43
+ for column, values in updates.items():
44
+ if column in df.columns:
45
+ df[column] = values
46
+ else:
47
+ logger.warning(f"Column '{column}' not found in dataset.")
48
+ await self.push_to_hub(df, dataset_name)
49
+ return df
50
+ except Exception as e:
51
+ logger.error(f"Failed to update dataset: {e}")
52
+ raise DatasetPushError(f"Failed to update dataset: {e}")
53
+
54
+ async def delete_columns(
55
+ self, dataset_name: str, columns: List[str]
56
+ ) -> Optional[pd.DataFrame]:
57
+ """Delete columns from a dataset on Hugging Face Hub."""
58
+ try:
59
+ df = await self.read_dataset(dataset_name)
60
+ for column in columns:
61
+ if column in df.columns:
62
+ df.drop(column, axis=1, inplace=True)
63
+ else:
64
+ logger.warning(f"Column '{column}' not found in dataset.")
65
+ await self.push_to_hub(df, dataset_name)
66
+ return df
67
+ except Exception as e:
68
+ logger.error(f"Failed to delete columns: {e}")
69
+ raise DatasetPushError(f"Failed to delete columns: {e}")
src/main.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi import FastAPI, Depends, HTTPException
3
+ from fastapi.responses import JSONResponse, RedirectResponse
4
+ from pydantic import BaseModel
5
+ from typing import List, Dict
6
+ from src.api.models.embedding_models import (
7
+ CreateEmbeddingRequest,
8
+ UpdateEmbeddingRequest,
9
+ DeleteEmbeddingRequest,
10
+ )
11
+ from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError
12
+ from src.api.services.embedding_service import EmbeddingService
13
+ from src.api.services.huggingface_service import HuggingFaceService
14
+ from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError
15
+ import pandas as pd
16
+ import logging
17
+ from dotenv import load_dotenv
18
+
19
+ # Load environment variables
20
+ load_dotenv()
21
+
22
+ # Set up structured logging
23
+ logging.basicConfig(
24
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
25
+ )
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # Initialize FastAPI app
29
+ app = FastAPI(
30
+ title="Similarity Search API",
31
+ description="A FastAPI application for similarity search with PostgreSQL and OpenAI embeddings.",
32
+ version="1.0.0",
33
+ )
34
+
35
+
36
+ # Root endpoint redirects to /docs
37
+ @app.get("/")
38
+ async def root():
39
+ return RedirectResponse(url="/docs")
40
+
41
+
42
+ # Health check endpoint
43
+ @app.get("/health")
44
+ async def health_check(db: Database = Depends(get_db)):
45
+ try:
46
+ is_healthy = await db.health_check()
47
+ if not is_healthy:
48
+ raise HTTPException(status_code=500, detail="Database is unhealthy")
49
+ return {"status": "healthy"}
50
+ except HealthCheckError as e:
51
+ raise HTTPException(status_code=500, detail=str(e))
52
+
53
+
54
+ # Dependency to get EmbeddingService
55
+ def get_embedding_service() -> EmbeddingService:
56
+ return EmbeddingService(openai_api_key=os.getenv("OPENAI_API_KEY"))
57
+
58
+
59
+ # Dependency to get HuggingFaceService
60
+ def get_huggingface_service() -> HuggingFaceService:
61
+ return HuggingFaceService()
62
+
63
+
64
+ # Endpoint to create embeddings
65
+ @app.post("/create_embedding")
66
+ async def create_embedding(
67
+ request: CreateEmbeddingRequest,
68
+ db: Database = Depends(get_db),
69
+ embedding_service: EmbeddingService = Depends(get_embedding_service),
70
+ huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
71
+ ):
72
+ """
73
+ Create embeddings for the target column in the dataset.
74
+ """
75
+ try:
76
+ # Step 1: Query the database
77
+ logger.info("Fetching data from the database...")
78
+ result = await db.fetch(request.query)
79
+ df = pd.DataFrame(result)
80
+
81
+ # Step 2: Generate embeddings
82
+ df = await embedding_service.create_embeddings(
83
+ df, request.target_column, request.output_column
84
+ )
85
+
86
+ # Step 3: Push to Hugging Face Hub
87
+ await huggingface_service.push_to_hub(df, request.dataset_name)
88
+
89
+ return JSONResponse(
90
+ content={
91
+ "message": "Embeddings created and pushed to Hugging Face Hub.",
92
+ "dataset_name": request.dataset_name,
93
+ "num_rows": len(df),
94
+ }
95
+ )
96
+ except QueryExecutionError as e:
97
+ logger.error(f"Database query failed: {e}")
98
+ raise HTTPException(status_code=500, detail=f"Database query failed: {e}")
99
+ except OpenAIError as e:
100
+ logger.error(f"OpenAI API error: {e}")
101
+ raise HTTPException(status_code=500, detail=f"OpenAI API error: {e}")
102
+ except DatasetPushError as e:
103
+ logger.error(f"Failed to push dataset: {e}")
104
+ raise HTTPException(status_code=500, detail=f"Failed to push dataset: {e}")
105
+ except Exception as e:
106
+ logger.error(f"An error occurred: {e}")
107
+ raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
108
+
109
+
110
+ # Endpoint to read embeddings
111
+ @app.get("/read_embeddings/{dataset_name}")
112
+ async def read_embeddings(
113
+ dataset_name: str,
114
+ huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
115
+ ):
116
+ """
117
+ Read embeddings from a Hugging Face dataset.
118
+ """
119
+ try:
120
+ df = await huggingface_service.read_dataset(dataset_name)
121
+ return df.to_dict(orient="records")
122
+ except DatasetNotFoundError as e:
123
+ logger.error(f"Dataset not found: {e}")
124
+ raise HTTPException(status_code=404, detail=f"Dataset not found: {e}")
125
+ except Exception as e:
126
+ logger.error(f"An error occurred: {e}")
127
+ raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
128
+
129
+
130
+ # Endpoint to update embeddings
131
+ @app.post("/update_embeddings")
132
+ async def update_embeddings(
133
+ request: UpdateEmbeddingRequest,
134
+ huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
135
+ ):
136
+ """
137
+ Update embeddings in a Hugging Face dataset.
138
+ """
139
+ try:
140
+ df = await huggingface_service.update_dataset(
141
+ request.dataset_name, request.updates
142
+ )
143
+ return {
144
+ "message": "Embeddings updated successfully.",
145
+ "dataset_name": request.dataset_name,
146
+ }
147
+ except DatasetPushError as e:
148
+ logger.error(f"Failed to update dataset: {e}")
149
+ raise HTTPException(status_code=500, detail=f"Failed to update dataset: {e}")
150
+ except Exception as e:
151
+ logger.error(f"An error occurred: {e}")
152
+ raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
153
+
154
+
155
+ # Endpoint to delete embeddings
156
+ @app.post("/delete_embeddings")
157
+ async def delete_embeddings(
158
+ request: DeleteEmbeddingRequest,
159
+ huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
160
+ ):
161
+ """
162
+ Delete embeddings from a Hugging Face dataset.
163
+ """
164
+ try:
165
+ df = await huggingface_service.delete_columns(
166
+ request.dataset_name, request.columns
167
+ )
168
+ return {
169
+ "message": "Embeddings deleted successfully.",
170
+ "dataset_name": request.dataset_name,
171
+ }
172
+ except DatasetPushError as e:
173
+ logger.error(f"Failed to delete columns: {e}")
174
+ raise HTTPException(status_code=500, detail=f"Failed to delete columns: {e}")
175
+ except Exception as e:
176
+ logger.error(f"An error occurred: {e}")
177
+ raise HTTPException(status_code=500, detail=f"An error occurred: {e}")