amaye15 commited on
Commit
6f4f307
·
1 Parent(s): fdc226e

Feat - Improved - Update Endpoint

Browse files
docker-compose.yml CHANGED
@@ -1,5 +1,3 @@
1
- version: "3.9"
2
-
3
  services:
4
  app:
5
  build:
 
 
 
1
  services:
2
  app:
3
  build:
src/api/dependency.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from src.api.services.embedding_service import EmbeddingService
3
+ from src.api.services.huggingface_service import HuggingFaceService
4
+
5
+
6
+ # Dependency to get EmbeddingService
7
+ def get_embedding_service() -> EmbeddingService:
8
+ return EmbeddingService(openai_api_key=os.getenv("OPENAI_API_KEY"))
9
+
10
+
11
+ # Dependency to get HuggingFaceService
12
+ def get_huggingface_service() -> HuggingFaceService:
13
+ return HuggingFaceService()
src/api/models/embedding_models.py CHANGED
@@ -17,10 +17,37 @@ class ReadEmbeddingRequest(BaseModel):
17
  dataset_name: str
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class UpdateEmbeddingRequest(BaseModel):
21
- dataset_name: str
22
- updates: Dict[str, List] # Column name -> List of values
 
 
 
 
 
 
23
 
24
 
25
  class DeleteEmbeddingRequest(BaseModel):
26
  dataset_name: str
 
 
 
 
 
 
 
 
 
17
  dataset_name: str
18
 
19
 
20
+ # class UpdateEmbeddingRequest(BaseModel):
21
+ # updates: Dict[str, List] # Column name -> List of values
22
+ # target_column: str = "product_type"
23
+ # output_column: str = "embedding"
24
+ # model: str = "text-embedding-3-small"
25
+ # batch_size: int = 10
26
+ # max_concurrent_requests: int = 10
27
+ # dataset_name: str = "re-mind/product_type_embedding"
28
+
29
+ from pydantic import BaseModel
30
+ from typing import Dict, List
31
+
32
+
33
  class UpdateEmbeddingRequest(BaseModel):
34
+ dataset_name: str = "re-mind/product_type_embedding"
35
+ updates: Dict[
36
+ str, List
37
+ ] # Dictionary of column names and their corresponding values
38
+ target_column: str = (
39
+ "product_type" # Column in the new data to generate embeddings for
40
+ )
41
+ output_column: str = "embedding" # Column to store the generated embeddings
42
 
43
 
44
  class DeleteEmbeddingRequest(BaseModel):
45
  dataset_name: str
46
+
47
+
48
+ # Request model for the /embed endpoint
49
+ class EmbedRequest(BaseModel):
50
+ texts: List[str] # List of strings to generate embeddings for
51
+ output_column: str = (
52
+ "embeddings" # Column to store embeddings (default: "embeddings")
53
+ )
src/api/services/embedding_service.py CHANGED
@@ -1,146 +1,3 @@
1
- # from openai import AsyncOpenAI
2
- # import logging
3
- # from typing import List, Dict
4
- # import pandas as pd
5
- # import asyncio
6
- # from src.api.exceptions import OpenAIError
7
-
8
- # # Set up structured logging
9
- # logging.basicConfig(
10
- # level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
11
- # )
12
- # logger = logging.getLogger(__name__)
13
-
14
-
15
- # class EmbeddingService:
16
- # def __init__(
17
- # self,
18
- # openai_api_key: str,
19
- # model: str = "text-embedding-3-small",
20
- # batch_size: int = 100,
21
- # ):
22
- # self.client = AsyncOpenAI(api_key=openai_api_key)
23
- # self.model = model
24
- # self.batch_size = batch_size
25
-
26
- # async def get_embedding(self, text: str) -> List[float]:
27
- # """Generate embeddings for the given text using OpenAI."""
28
- # text = text.replace("\n", " ")
29
- # try:
30
- # response = await self.client.embeddings.create(
31
- # input=[text], model=self.model
32
- # )
33
- # return response.data[0].embedding
34
- # except Exception as e:
35
- # logger.error(f"Failed to generate embedding: {e}")
36
- # raise OpenAIError(f"OpenAI API error: {e}")
37
-
38
- # async def create_embeddings(
39
- # self, df: pd.DataFrame, target_column: str, output_column: str
40
- # ) -> pd.DataFrame:
41
- # """Create embeddings for the target column in the dataset."""
42
- # logger.info("Generating embeddings...")
43
- # batches = [
44
- # df[i : i + self.batch_size] for i in range(0, len(df), self.batch_size)
45
- # ]
46
- # processed_batches = await asyncio.gather(
47
- # *[
48
- # self._process_batch(batch, target_column, output_column)
49
- # for batch in batches
50
- # ]
51
- # )
52
- # return pd.concat(processed_batches)
53
-
54
- # async def _process_batch(
55
- # self, df_batch: pd.DataFrame, target_column: str, output_column: str
56
- # ) -> pd.DataFrame:
57
- # """Process a batch of rows to generate embeddings."""
58
- # embeddings = await asyncio.gather(
59
- # *[self.get_embedding(row[target_column]) for _, row in df_batch.iterrows()]
60
- # )
61
- # df_batch[output_column] = embeddings
62
- # return df_batch
63
-
64
- # from openai import AsyncOpenAI
65
- # import logging
66
- # from typing import List, Dict
67
- # import pandas as pd
68
- # import asyncio
69
- # from src.api.exceptions import OpenAIError
70
-
71
- # # Set up structured logging
72
- # logging.basicConfig(
73
- # level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
74
- # )
75
- # logger = logging.getLogger(__name__)
76
-
77
-
78
- # class EmbeddingService:
79
- # def __init__(
80
- # self,
81
- # openai_api_key: str,
82
- # model: str = "text-embedding-3-small",
83
- # batch_size: int = 10,
84
- # max_concurrent_requests: int = 10, # Limit to 10 concurrent requests
85
- # ):
86
- # self.client = AsyncOpenAI(api_key=openai_api_key)
87
- # self.model = model
88
- # self.batch_size = batch_size
89
- # self.semaphore = asyncio.Semaphore(max_concurrent_requests) # Rate limiter
90
- # self.total_requests = 0 # Total number of requests to process
91
- # self.completed_requests = 0 # Number of completed requests
92
-
93
- # async def get_embedding(self, text: str) -> List[float]:
94
- # """Generate embeddings for the given text using OpenAI."""
95
- # text = text.replace("\n", " ")
96
- # try:
97
- # async with self.semaphore: # Acquire a semaphore slot
98
- # response = await self.client.embeddings.create(
99
- # input=[text], model=self.model
100
- # )
101
- # self.completed_requests += 1 # Increment completed requests
102
- # self._log_progress() # Log progress
103
- # return response.data[0].embedding
104
- # except Exception as e:
105
- # logger.error(f"Failed to generate embedding: {e}")
106
- # raise OpenAIError(f"OpenAI API error: {e}")
107
-
108
- # async def create_embeddings(
109
- # self, df: pd.DataFrame, target_column: str, output_column: str
110
- # ) -> pd.DataFrame:
111
- # """Create embeddings for the target column in the dataset."""
112
- # logger.info("Generating embeddings...")
113
- # self.total_requests = len(df) # Set total number of requests
114
- # self.completed_requests = 0 # Reset completed requests counter
115
-
116
- # batches = [
117
- # df[i : i + self.batch_size] for i in range(0, len(df), self.batch_size)
118
- # ]
119
- # processed_batches = await asyncio.gather(
120
- # *[
121
- # self._process_batch(batch, target_column, output_column)
122
- # for batch in batches
123
- # ]
124
- # )
125
- # return pd.concat(processed_batches)
126
-
127
- # async def _process_batch(
128
- # self, df_batch: pd.DataFrame, target_column: str, output_column: str
129
- # ) -> pd.DataFrame:
130
- # """Process a batch of rows to generate embeddings."""
131
- # embeddings = await asyncio.gather(
132
- # *[self.get_embedding(row[target_column]) for _, row in df_batch.iterrows()]
133
- # )
134
- # df_batch[output_column] = embeddings
135
- # return df_batch
136
-
137
- # def _log_progress(self):
138
- # """Log the progress of embedding generation."""
139
- # progress = (self.completed_requests / self.total_requests) * 100
140
- # logger.info(
141
- # f"Progress: {self.completed_requests}/{self.total_requests} ({progress:.2f}%)"
142
- # )
143
-
144
  from openai import AsyncOpenAI
145
  import logging
146
  from typing import List, Dict, Union
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from openai import AsyncOpenAI
2
  import logging
3
  from typing import List, Dict, Union
src/api/services/huggingface_service.py CHANGED
@@ -1,8 +1,9 @@
1
- from datasets import Dataset, load_dataset
2
  from huggingface_hub import HfApi, HfFolder
3
  import logging
4
  from typing import Optional, Dict, List
5
  import pandas as pd
 
6
  from src.api.exceptions import (
7
  DatasetNotFoundError,
8
  DatasetPushError,
@@ -45,19 +46,69 @@ class HuggingFaceService:
45
  logger.error(f"Failed to read dataset: {e}")
46
  raise DatasetNotFoundError(f"Dataset not found: {e}")
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  async def update_dataset(
49
- self, dataset_name: str, updates: Dict[str, List]
 
 
 
 
50
  ) -> Optional[pd.DataFrame]:
51
- """Update a dataset on Hugging Face Hub."""
52
  try:
53
- df = await self.read_dataset(dataset_name)
54
- for column, values in updates.items():
55
- if column in df.columns:
56
- df[column] = values
57
- else:
58
- logger.warning(f"Column '{column}' not found in dataset.")
59
- await self.push_to_hub(df, dataset_name)
60
- return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  except Exception as e:
62
  logger.error(f"Failed to update dataset: {e}")
63
  raise DatasetPushError(f"Failed to update dataset: {e}")
 
1
+ from datasets import Dataset, load_dataset, concatenate_datasets
2
  from huggingface_hub import HfApi, HfFolder
3
  import logging
4
  from typing import Optional, Dict, List
5
  import pandas as pd
6
+ from src.api.dependency import get_embedding_service, get_huggingface_service
7
  from src.api.exceptions import (
8
  DatasetNotFoundError,
9
  DatasetPushError,
 
46
  logger.error(f"Failed to read dataset: {e}")
47
  raise DatasetNotFoundError(f"Dataset not found: {e}")
48
 
49
+ # async def update_dataset(
50
+ # self, dataset_name: str, updates: Dict[str, List]
51
+ # ) -> Optional[pd.DataFrame]:
52
+ # """Update a dataset on Hugging Face Hub."""
53
+
54
+ # embedding_service = get_embedding_service()
55
+
56
+ # try:
57
+ # df_src = await self.read_dataset(dataset_name)
58
+ # df_src = Dataset.from_dict(df_src)
59
+ # df_update = Dataset.from_dict(updates)
60
+
61
+ # df = concatenate_datasets(df_src, df_update)
62
+
63
+ # # for column, values in updates.items():
64
+ # # if column in df.columns:
65
+ # # df[column] = values
66
+ # # else:
67
+ # # logger.warning(f"Column '{column}' not found in dataset.")
68
+ # # await self.push_to_hub(df, dataset_name)
69
+ # # return df
70
+ # except Exception as e:
71
+ # logger.error(f"Failed to update dataset: {e}")
72
+ # raise DatasetPushError(f"Failed to update dataset: {e}")
73
+
74
  async def update_dataset(
75
+ self,
76
+ dataset_name: str,
77
+ updates: Dict[str, List],
78
+ target_column: str,
79
+ output_column: str = "embeddings",
80
  ) -> Optional[pd.DataFrame]:
81
+ """Update a dataset on Hugging Face Hub by generating embeddings for new data and concatenating it with the existing dataset."""
82
  try:
83
+ # Step 1: Load the existing dataset from Hugging Face Hub
84
+ logger.info(
85
+ f"Loading existing dataset from Hugging Face Hub: {dataset_name}..."
86
+ )
87
+ existing_ds = await self.read_dataset(dataset_name)
88
+ existing_df = pd.DataFrame(existing_ds)
89
+
90
+ # Step 2: Convert the new updates into a DataFrame
91
+ logger.info("Converting updates to DataFrame...")
92
+ new_df = pd.DataFrame(updates)
93
+
94
+ # Step 3: Generate embeddings for the new data
95
+ logger.info("Generating embeddings for the new data...")
96
+ embedding_service = get_embedding_service() # Get the embedding service
97
+ new_df = await embedding_service.create_embeddings(
98
+ new_df, target_column, output_column
99
+ )
100
+
101
+ # Step 4: Concatenate the existing DataFrame with the new DataFrame
102
+ logger.info("Concatenating existing dataset with new data...")
103
+ updated_df = pd.concat([existing_df, new_df], ignore_index=True)
104
+
105
+ # Step 5: Push the updated dataset back to Hugging Face Hub
106
+ logger.info(
107
+ f"Pushing updated dataset to Hugging Face Hub: {dataset_name}..."
108
+ )
109
+ await self.push_to_hub(updated_df, dataset_name)
110
+
111
+ # return updated_df
112
  except Exception as e:
113
  logger.error(f"Failed to update dataset: {e}")
114
  raise DatasetPushError(f"Failed to update dataset: {e}")
src/main.py CHANGED
@@ -197,11 +197,13 @@ from src.api.models.embedding_models import (
197
  ReadEmbeddingRequest,
198
  UpdateEmbeddingRequest,
199
  DeleteEmbeddingRequest,
 
200
  )
201
  from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError
202
  from src.api.services.embedding_service import EmbeddingService
203
  from src.api.services.huggingface_service import HuggingFaceService
204
  from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError
 
205
  import pandas as pd
206
  import logging
207
  from dotenv import load_dotenv
@@ -249,24 +251,6 @@ async def health_check(db: Database = Depends(get_db)):
249
  raise HTTPException(status_code=500, detail=str(e))
250
 
251
 
252
- # Dependency to get EmbeddingService
253
- def get_embedding_service() -> EmbeddingService:
254
- return EmbeddingService(openai_api_key=os.getenv("OPENAI_API_KEY"))
255
-
256
-
257
- # Dependency to get HuggingFaceService
258
- def get_huggingface_service() -> HuggingFaceService:
259
- return HuggingFaceService()
260
-
261
-
262
- # Request model for the /embed endpoint
263
- class EmbedRequest(BaseModel):
264
- texts: List[str] # List of strings to generate embeddings for
265
- output_column: str = (
266
- "embeddings" # Column to store embeddings (default: "embeddings")
267
- )
268
-
269
-
270
  # Endpoint to generate embeddings for a list of strings
271
  @app.post("/embed")
272
  async def embed(
@@ -363,21 +347,51 @@ async def read_embeddings(
363
 
364
 
365
  # Endpoint to update embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  @app.post("/update_embeddings")
367
  async def update_embeddings(
368
  request: UpdateEmbeddingRequest,
369
  huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
370
  ):
371
  """
372
- Update embeddings in a Hugging Face dataset.
373
  """
374
  try:
375
- df = await huggingface_service.update_dataset(
376
- request.dataset_name, request.updates
 
 
 
 
377
  )
 
378
  return {
379
  "message": "Embeddings updated successfully.",
380
  "dataset_name": request.dataset_name,
 
381
  }
382
  except DatasetPushError as e:
383
  logger.error(f"Failed to update dataset: {e}")
 
197
  ReadEmbeddingRequest,
198
  UpdateEmbeddingRequest,
199
  DeleteEmbeddingRequest,
200
+ EmbedRequest,
201
  )
202
  from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError
203
  from src.api.services.embedding_service import EmbeddingService
204
  from src.api.services.huggingface_service import HuggingFaceService
205
  from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError
206
+ from src.api.dependency import get_embedding_service, get_huggingface_service
207
  import pandas as pd
208
  import logging
209
  from dotenv import load_dotenv
 
251
  raise HTTPException(status_code=500, detail=str(e))
252
 
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  # Endpoint to generate embeddings for a list of strings
255
  @app.post("/embed")
256
  async def embed(
 
347
 
348
 
349
  # Endpoint to update embeddings
350
+ # @app.post("/update_embeddings")
351
+ # async def update_embeddings(
352
+ # request: UpdateEmbeddingRequest,
353
+ # huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
354
+ # ):
355
+ # """
356
+ # Update embeddings in a Hugging Face dataset.
357
+ # """
358
+ # try:
359
+ # df = await huggingface_service.update_dataset(
360
+ # request.dataset_name, request.updates
361
+ # )
362
+ # return {
363
+ # "message": "Embeddings updated successfully.",
364
+ # "dataset_name": request.dataset_name,
365
+ # }
366
+ # except DatasetPushError as e:
367
+ # logger.error(f"Failed to update dataset: {e}")
368
+ # raise HTTPException(status_code=500, detail=f"Failed to update dataset: {e}")
369
+ # except Exception as e:
370
+ # logger.error(f"An error occurred: {e}")
371
+ # raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
372
+
373
+
374
  @app.post("/update_embeddings")
375
  async def update_embeddings(
376
  request: UpdateEmbeddingRequest,
377
  huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
378
  ):
379
  """
380
+ Update embeddings in a Hugging Face dataset by generating embeddings for new data and concatenating it with the existing dataset.
381
  """
382
  try:
383
+ # Call the update_dataset method to generate embeddings, concatenate, and push the updated dataset
384
+ updated_df = await huggingface_service.update_dataset(
385
+ request.dataset_name,
386
+ request.updates,
387
+ request.target_column,
388
+ request.output_column,
389
  )
390
+
391
  return {
392
  "message": "Embeddings updated successfully.",
393
  "dataset_name": request.dataset_name,
394
+ "num_rows": len(updated_df),
395
  }
396
  except DatasetPushError as e:
397
  logger.error(f"Failed to update dataset: {e}")