amaye15 commited on
Commit
cccaa2c
·
1 Parent(s): 6f2dd4d

Feat - New Endpoint - Similarity Search

Browse files
requirements.txt CHANGED
@@ -5,4 +5,5 @@ uvicorn
5
  fastapi
6
  openai
7
  pandas
8
- datasets
 
 
5
  fastapi
6
  openai
7
  pandas
8
+ datasets
9
+ scikit-learn
src/api/database.py CHANGED
@@ -110,31 +110,6 @@ class Database:
110
  with self.lock:
111
  self.pool.append(conn)
112
 
113
- # async def fetch(self, query: str, *args) -> List[Dict]:
114
- # """
115
- # Execute a SELECT query and return the results as a list of dictionaries.
116
-
117
- # Args:
118
- # query (str): The SQL query to execute.
119
- # *args: Query parameters.
120
-
121
- # Returns:
122
- # List[Dict]: A list of dictionaries where keys are column names and values are column values.
123
-
124
- # Raises:
125
- # QueryExecutionError: If the query execution fails.
126
- # """
127
- # try:
128
- # async with self.get_connection() as conn:
129
- # cursor = conn.cursor()
130
- # cursor.execute(query, args)
131
- # rows = cursor.fetchall()
132
- # columns = [desc[0] for desc in cursor.description]
133
- # return [dict(zip(columns, row)) for row in rows]
134
- # except Pg8000DatabaseError as e:
135
- # logger.error(f"Query execution failed: {e}")
136
- # raise QueryExecutionError(f"Failed to execute query: {query}") from e
137
-
138
  async def fetch(self, query: str, *args) -> Dict[str, List]:
139
  """
140
  Execute a SELECT query and return the results as a dictionary of lists.
 
110
  with self.lock:
111
  self.pool.append(conn)
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  async def fetch(self, query: str, *args) -> Dict[str, List]:
114
  """
115
  Execute a SELECT query and return the results as a dictionary of lists.
src/api/models/embedding_models.py CHANGED
@@ -48,3 +48,11 @@ class EmbedRequest(BaseModel):
48
  output_column: str = (
49
  "embedding" # Column to store embeddings (default: "embeddings")
50
  )
 
 
 
 
 
 
 
 
 
48
  output_column: str = (
49
  "embedding" # Column to store embeddings (default: "embeddings")
50
  )
51
+
52
+
53
+ class SearchEmbeddingRequest(BaseModel):
54
+ texts: List[str] # List of texts to search for
55
+ target_column: str # Column to return in the results
56
+ embedding_column: str # Column containing the embeddings to search against
57
+ num_results: int # Number of results to return
58
+ dataset_name: str # Name of the dataset to search in
src/api/services/embedding_service.py CHANGED
@@ -1,137 +1,10 @@
1
- # from openai import AsyncOpenAI
2
- # import logging
3
- # from typing import List, Dict, Union
4
- # import pandas as pd
5
- # import asyncio
6
- # from src.api.exceptions import OpenAIError
7
-
8
- # # Set up structured logging
9
- # logging.basicConfig(
10
- # level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
11
- # )
12
- # logger = logging.getLogger(__name__)
13
-
14
-
15
- # class EmbeddingService:
16
- # def __init__(
17
- # self,
18
- # openai_api_key: str,
19
- # model: str = "text-embedding-3-small",
20
- # batch_size: int = 10,
21
- # max_concurrent_requests: int = 10, # Limit to 10 concurrent requests
22
- # ):
23
- # self.client = AsyncOpenAI(api_key=openai_api_key)
24
- # self.model = model
25
- # self.batch_size = batch_size
26
- # self.semaphore = asyncio.Semaphore(max_concurrent_requests) # Rate limiter
27
- # self.total_requests = 0 # Total number of requests to process
28
- # self.completed_requests = 0 # Number of completed requests
29
-
30
- # async def get_embedding(self, text: str) -> List[float]:
31
- # """Generate embeddings for the given text using OpenAI."""
32
- # text = text.replace("\n", " ")
33
- # try:
34
- # async with self.semaphore: # Acquire a semaphore slot
35
- # response = await self.client.embeddings.create(
36
- # input=[text], model=self.model
37
- # )
38
- # self.completed_requests += 1 # Increment completed requests
39
- # self._log_progress() # Log progress
40
- # return response.data[0].embedding
41
- # except Exception as e:
42
- # logger.error(f"Failed to generate embedding: {e}")
43
- # raise OpenAIError(f"OpenAI API error: {e}")
44
-
45
- # async def create_embeddings(
46
- # self,
47
- # data: Union[pd.DataFrame, List[str]],
48
- # target_column: str = None,
49
- # output_column: str = "embeddings",
50
- # ) -> Union[pd.DataFrame, List[List[float]]]:
51
- # """
52
- # Create embeddings for either a DataFrame or a list of strings.
53
-
54
- # Args:
55
- # data: Either a DataFrame or a list of strings.
56
- # target_column: The column in the DataFrame to generate embeddings for (required if data is a DataFrame).
57
- # output_column: The column to store embeddings in the DataFrame (default: "embeddings").
58
-
59
- # Returns:
60
- # If data is a DataFrame, returns the DataFrame with the embeddings column.
61
- # If data is a list of strings, returns a list of embeddings.
62
- # """
63
- # if isinstance(data, pd.DataFrame):
64
- # if not target_column:
65
- # raise ValueError("target_column is required when data is a DataFrame.")
66
- # return await self._create_embeddings_for_dataframe(
67
- # data, target_column, output_column
68
- # )
69
- # elif isinstance(data, list):
70
- # return await self._create_embeddings_for_texts(data)
71
- # else:
72
- # raise TypeError(
73
- # "data must be either a pandas DataFrame or a list of strings."
74
- # )
75
-
76
- # async def _create_embeddings_for_dataframe(
77
- # self, df: pd.DataFrame, target_column: str, output_column: str
78
- # ) -> pd.DataFrame:
79
- # """Create embeddings for the target column in the DataFrame."""
80
- # logger.info("Generating embeddings for DataFrame...")
81
- # self.total_requests = len(df) # Set total number of requests
82
- # self.completed_requests = 0 # Reset completed requests counter
83
-
84
- # batches = [
85
- # df[i : i + self.batch_size] for i in range(0, len(df), self.batch_size)
86
- # ]
87
- # processed_batches = await asyncio.gather(
88
- # *[
89
- # self._process_batch(batch, target_column, output_column)
90
- # for batch in batches
91
- # ]
92
- # )
93
- # return pd.concat(processed_batches)
94
-
95
- # async def _create_embeddings_for_texts(self, texts: List[str]) -> List[List[float]]:
96
- # """Create embeddings for a list of strings."""
97
- # logger.info("Generating embeddings for list of texts...")
98
- # self.total_requests = len(texts) # Set total number of requests
99
- # self.completed_requests = 0 # Reset completed requests counter
100
-
101
- # batches = [
102
- # texts[i : i + self.batch_size]
103
- # for i in range(0, len(texts), self.batch_size)
104
- # ]
105
- # embeddings = []
106
- # for batch in batches:
107
- # batch_embeddings = await asyncio.gather(
108
- # *[self.get_embedding(text) for text in batch]
109
- # )
110
- # embeddings.extend(batch_embeddings)
111
- # return embeddings
112
-
113
- # async def _process_batch(
114
- # self, df_batch: pd.DataFrame, target_column: str, output_column: str
115
- # ) -> pd.DataFrame:
116
- # """Process a batch of rows to generate embeddings."""
117
- # embeddings = await asyncio.gather(
118
- # *[self.get_embedding(row[target_column]) for _, row in df_batch.iterrows()]
119
- # )
120
- # df_batch[output_column] = embeddings
121
- # return df_batch
122
-
123
- # def _log_progress(self):
124
- # """Log the progress of embedding generation."""
125
- # progress = (self.completed_requests / self.total_requests) * 100
126
- # logger.info(
127
- # f"Progress: {self.completed_requests}/{self.total_requests} ({progress:.2f}%)"
128
- # )
129
-
130
  from openai import AsyncOpenAI
131
  import logging
132
  from typing import List, Dict, Union
133
  from datasets import Dataset
134
  import asyncio
 
 
135
  from src.api.exceptions import OpenAIError
136
 
137
  # Set up structured logging
@@ -245,3 +118,45 @@ class EmbeddingService:
245
  logger.info(
246
  f"Progress: {self.completed_requests}/{self.total_requests} ({progress:.2f}%)"
247
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from openai import AsyncOpenAI
2
  import logging
3
  from typing import List, Dict, Union
4
  from datasets import Dataset
5
  import asyncio
6
+ import numpy as np
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
  from src.api.exceptions import OpenAIError
9
 
10
  # Set up structured logging
 
118
  logger.info(
119
  f"Progress: {self.completed_requests}/{self.total_requests} ({progress:.2f}%)"
120
  )
121
+
122
+ async def search_embeddings(
123
+ self,
124
+ query_embeddings: List[List[float]],
125
+ dataset: Dataset,
126
+ embedding_column: str,
127
+ target_column: str,
128
+ num_results: int,
129
+ ) -> List[Dict]:
130
+ """
131
+ Perform a cosine similarity search between query embeddings and dataset embeddings.
132
+
133
+ Args:
134
+ query_embeddings: List of embeddings for the query texts.
135
+ dataset: The dataset to search in.
136
+ embedding_column: The column in the dataset containing embeddings.
137
+ target_column: The column to return in the results.
138
+ num_results: The number of results to return.
139
+
140
+ Returns:
141
+ A list of dictionaries containing the target column values and their similarity scores.
142
+ """
143
+ dataset_embeddings = np.array(dataset[embedding_column])
144
+ query_embeddings = np.array(query_embeddings)
145
+
146
+ # Compute cosine similarity
147
+ similarities = cosine_similarity(query_embeddings, dataset_embeddings)
148
+
149
+ # Get the top-k results for each query
150
+ results = []
151
+ for i, query_similarities in enumerate(similarities):
152
+ top_k_indices = np.argsort(query_similarities)[-num_results:][::-1]
153
+ top_k_results = [
154
+ {
155
+ target_column: dataset[target_column][idx],
156
+ "similarity": float(query_similarities[idx]),
157
+ }
158
+ for idx in top_k_indices
159
+ ]
160
+ results.append(top_k_results)
161
+
162
+ return results
src/api/services/huggingface_service.py CHANGED
@@ -1,106 +1,3 @@
1
- # from datasets import Dataset, load_dataset, concatenate_datasets
2
- # from huggingface_hub import HfApi, HfFolder
3
- # import logging
4
- # import os
5
- # from typing import Optional, Dict, List
6
- # import pandas as pd
7
- # from src.api.services.embedding_service import EmbeddingService
8
- # from src.api.exceptions import (
9
- # DatasetNotFoundError,
10
- # DatasetPushError,
11
- # DatasetDeleteError,
12
- # )
13
-
14
- # # Set up structured logging
15
- # logging.basicConfig(
16
- # level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
17
- # )
18
- # logger = logging.getLogger(__name__)
19
-
20
-
21
- # class HuggingFaceService:
22
- # def __init__(self, hf_token: Optional[str] = None):
23
- # """Initialize the HuggingFaceService with an optional token."""
24
- # self.hf_api = HfApi()
25
- # if hf_token:
26
- # HfFolder.save_token(hf_token) # Save the token for authentication
27
-
28
- # async def push_to_hub(self, df: pd.DataFrame, dataset_name: str) -> None:
29
- # """Push the dataset to Hugging Face Hub."""
30
- # try:
31
- # logger.info(f"Creating Hugging Face Dataset: {dataset_name}...")
32
- # ds = Dataset.from_pandas(df)
33
- # ds.push_to_hub(dataset_name)
34
- # logger.info(f"Dataset pushed to Hugging Face Hub: {dataset_name}")
35
- # except Exception as e:
36
- # logger.error(f"Failed to push dataset to Hugging Face Hub: {e}")
37
- # raise DatasetPushError(f"Failed to push dataset: {e}")
38
-
39
- # async def read_dataset(self, dataset_name: str) -> Optional[pd.DataFrame]:
40
- # """Read a dataset from Hugging Face Hub."""
41
- # try:
42
- # logger.info(f"Loading dataset from Hugging Face Hub: {dataset_name}...")
43
- # ds = load_dataset(dataset_name)
44
- # df = ds["train"].to_dict()
45
- # return df
46
- # except Exception as e:
47
- # logger.error(f"Failed to read dataset: {e}")
48
- # raise DatasetNotFoundError(f"Dataset not found: {e}")
49
-
50
- # async def update_dataset(
51
- # self,
52
- # dataset_name: str,
53
- # updates: Dict[str, List],
54
- # target_column: str,
55
- # output_column: str = "embeddings",
56
- # ) -> Optional[pd.DataFrame]:
57
- # """Update a dataset on Hugging Face Hub by generating embeddings for new data and concatenating it with the existing dataset."""
58
- # try:
59
- # # Step 1: Load the existing dataset from Hugging Face Hub
60
- # logger.info(
61
- # f"Loading existing dataset from Hugging Face Hub: {dataset_name}..."
62
- # )
63
- # existing_ds = await self.read_dataset(dataset_name)
64
- # existing_df = pd.DataFrame(existing_ds)
65
-
66
- # # Step 2: Convert the new updates into a DataFrame
67
- # logger.info("Converting updates to DataFrame...")
68
- # new_df = pd.DataFrame(updates)
69
-
70
- # # Step 3: Generate embeddings for the new data
71
- # logger.info("Generating embeddings for the new data...")
72
- # embedding_service = EmbeddingService(
73
- # openai_api_key=os.getenv("OPENAI_API_KEY")
74
- # ) # Get the embedding service
75
- # new_df = await embedding_service.create_embeddings(
76
- # new_df, target_column, output_column
77
- # )
78
-
79
- # # Step 4: Concatenate the existing DataFrame with the new DataFrame
80
- # logger.info("Concatenating existing dataset with new data...")
81
- # updated_df = pd.concat([existing_df, new_df], ignore_index=True)
82
-
83
- # # Step 5: Push the updated dataset back to Hugging Face Hub
84
- # logger.info(
85
- # f"Pushing updated dataset to Hugging Face Hub: {dataset_name}..."
86
- # )
87
- # await self.push_to_hub(updated_df, dataset_name)
88
-
89
- # return updated_df
90
- # except Exception as e:
91
- # logger.error(f"Failed to update dataset: {e}")
92
- # raise DatasetPushError(f"Failed to update dataset: {e}")
93
-
94
- # async def delete_dataset(self, dataset_name: str) -> None:
95
- # """Delete a dataset from Hugging Face Hub."""
96
- # try:
97
- # logger.info(f"Deleting dataset from Hugging Face Hub: {dataset_name}...")
98
- # self.hf_api.delete_repo(repo_id=dataset_name, repo_type="dataset")
99
- # logger.info(f"Dataset deleted from Hugging Face Hub: {dataset_name}")
100
- # except Exception as e:
101
- # logger.error(f"Failed to delete dataset: {e}")
102
- # raise DatasetDeleteError(f"Failed to delete dataset: {e}")
103
-
104
  from datasets import Dataset, load_dataset, concatenate_datasets
105
  from huggingface_hub import HfApi, HfFolder
106
  import logging
@@ -141,7 +38,11 @@ class HuggingFaceService:
141
  """Read a dataset from Hugging Face Hub."""
142
  try:
143
  logger.info(f"Loading dataset from Hugging Face Hub: {dataset_name}...")
144
- dataset = load_dataset(dataset_name)
 
 
 
 
145
  return dataset["train"]
146
  except Exception as e:
147
  logger.error(f"Failed to read dataset: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from datasets import Dataset, load_dataset, concatenate_datasets
2
  from huggingface_hub import HfApi, HfFolder
3
  import logging
 
38
  """Read a dataset from Hugging Face Hub."""
39
  try:
40
  logger.info(f"Loading dataset from Hugging Face Hub: {dataset_name}...")
41
+ dataset = load_dataset(
42
+ dataset_name,
43
+ keep_in_memory=True,
44
+ download_mode="force_redownload",
45
+ )
46
  return dataset["train"]
47
  except Exception as e:
48
  logger.error(f"Failed to read dataset: {e}")
src/main.py CHANGED
@@ -1,252 +1,3 @@
1
- # import os
2
- # from fastapi import FastAPI, Depends, HTTPException
3
- # from fastapi.responses import JSONResponse, RedirectResponse
4
- # from fastapi.middleware.gzip import GZipMiddleware
5
- # from pydantic import BaseModel
6
- # from typing import List, Dict
7
- # from src.api.models.embedding_models import (
8
- # CreateEmbeddingRequest,
9
- # ReadEmbeddingRequest,
10
- # UpdateEmbeddingRequest,
11
- # DeleteEmbeddingRequest,
12
- # EmbedRequest,
13
- # )
14
- # from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError
15
- # from src.api.services.embedding_service import EmbeddingService
16
- # from src.api.services.huggingface_service import HuggingFaceService
17
- # from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError
18
-
19
- # # from src.api.dependency import get_embedding_service, get_huggingface_service
20
- # import pandas as pd
21
- # import logging
22
- # from dotenv import load_dotenv
23
-
24
- # # Load environment variables
25
- # load_dotenv()
26
-
27
- # # Set up structured logging
28
- # logging.basicConfig(
29
- # level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
30
- # )
31
- # logger = logging.getLogger(__name__)
32
-
33
- # description = """A FastAPI application for similarity search with PostgreSQL and OpenAI embeddings.
34
-
35
- # Direct/API URL:
36
- # https://re-mind-similarity-search.hf.space
37
- # """
38
-
39
- # # Initialize FastAPI app
40
- # app = FastAPI(
41
- # title="Similarity Search API",
42
- # description=description,
43
- # version="1.0.0",
44
- # )
45
-
46
- # app.add_middleware(GZipMiddleware, minimum_size=1000)
47
-
48
-
49
- # # Dependency to get EmbeddingService
50
- # def get_embedding_service() -> EmbeddingService:
51
- # return EmbeddingService(openai_api_key=os.getenv("OPENAI_API_KEY"))
52
-
53
-
54
- # # Dependency to get HuggingFaceService
55
- # def get_huggingface_service() -> HuggingFaceService:
56
- # return HuggingFaceService()
57
-
58
-
59
- # # Root endpoint redirects to /docs
60
- # @app.get("/")
61
- # async def root():
62
- # return RedirectResponse(url="/docs")
63
-
64
-
65
- # # Health check endpoint
66
- # @app.get("/health")
67
- # async def health_check(db: Database = Depends(get_db)):
68
- # try:
69
- # is_healthy = await db.health_check()
70
- # if not is_healthy:
71
- # raise HTTPException(status_code=500, detail="Database is unhealthy")
72
- # return {"status": "healthy"}
73
- # except HealthCheckError as e:
74
- # raise HTTPException(status_code=500, detail=str(e))
75
-
76
-
77
- # # Endpoint to generate embeddings for a list of strings
78
- # @app.post("/embed")
79
- # async def embed(
80
- # request: EmbedRequest,
81
- # embedding_service: EmbeddingService = Depends(get_embedding_service),
82
- # ):
83
- # """
84
- # Generate embeddings for a list of strings and return them in the response.
85
- # """
86
- # try:
87
- # # Step 1: Generate embeddings
88
- # logger.info("Generating embeddings for list of texts...")
89
- # embeddings = await embedding_service.create_embeddings(request.texts)
90
-
91
- # return JSONResponse(
92
- # content={
93
- # "message": "Embeddings generated successfully.",
94
- # "embeddings": embeddings,
95
- # "num_texts": len(request.texts),
96
- # }
97
- # )
98
- # except OpenAIError as e:
99
- # logger.error(f"OpenAI API error: {e}")
100
- # raise HTTPException(status_code=500, detail=f"OpenAI API error: {e}")
101
- # except Exception as e:
102
- # logger.error(f"An error occurred: {e}")
103
- # raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
104
-
105
-
106
- # # Endpoint to create embeddings from a database query
107
- # @app.post("/create_embedding")
108
- # async def create_embedding(
109
- # request: CreateEmbeddingRequest,
110
- # db: Database = Depends(get_db),
111
- # embedding_service: EmbeddingService = Depends(get_embedding_service),
112
- # huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
113
- # ):
114
- # """
115
- # Create embeddings for the target column in the dataset.
116
- # """
117
- # try:
118
- # # Step 1: Query the database
119
- # logger.info("Fetching data from the database...")
120
- # result = await db.fetch(request.query)
121
- # df = pd.DataFrame(result)
122
-
123
- # # Step 2: Generate embeddings
124
- # df = await embedding_service.create_embeddings(
125
- # df, request.target_column, request.output_column
126
- # )
127
-
128
- # # Step 3: Push to Hugging Face Hub
129
- # await huggingface_service.push_to_hub(df, request.dataset_name)
130
-
131
- # return JSONResponse(
132
- # content={
133
- # "message": "Embeddings created and pushed to Hugging Face Hub.",
134
- # "dataset_name": request.dataset_name,
135
- # "num_rows": len(df),
136
- # }
137
- # )
138
- # except QueryExecutionError as e:
139
- # logger.error(f"Database query failed: {e}")
140
- # raise HTTPException(status_code=500, detail=f"Database query failed: {e}")
141
- # except OpenAIError as e:
142
- # logger.error(f"OpenAI API error: {e}")
143
- # raise HTTPException(status_code=500, detail=f"OpenAI API error: {e}")
144
- # except DatasetPushError as e:
145
- # logger.error(f"Failed to push dataset: {e}")
146
- # raise HTTPException(status_code=500, detail=f"Failed to push dataset: {e}")
147
- # except Exception as e:
148
- # logger.error(f"An error occurred: {e}")
149
- # raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
150
-
151
-
152
- # # Endpoint to read embeddings
153
- # @app.post("/read_embeddings")
154
- # async def read_embeddings(
155
- # request: ReadEmbeddingRequest,
156
- # huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
157
- # ):
158
- # """
159
- # Read embeddings from a Hugging Face dataset.
160
- # """
161
- # try:
162
- # df = await huggingface_service.read_dataset(request.dataset_name)
163
- # return df
164
- # except DatasetNotFoundError as e:
165
- # logger.error(f"Dataset not found: {e}")
166
- # raise HTTPException(status_code=404, detail=f"Dataset not found: {e}")
167
- # except Exception as e:
168
- # logger.error(f"An error occurred: {e}")
169
- # raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
170
-
171
-
172
- # # Endpoint to update embeddings
173
- # # @app.post("/update_embeddings")
174
- # # async def update_embeddings(
175
- # # request: UpdateEmbeddingRequest,
176
- # # huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
177
- # # ):
178
- # # """
179
- # # Update embeddings in a Hugging Face dataset.
180
- # # """
181
- # # try:
182
- # # df = await huggingface_service.update_dataset(
183
- # # request.dataset_name, request.updates
184
- # # )
185
- # # return {
186
- # # "message": "Embeddings updated successfully.",
187
- # # "dataset_name": request.dataset_name,
188
- # # }
189
- # # except DatasetPushError as e:
190
- # # logger.error(f"Failed to update dataset: {e}")
191
- # # raise HTTPException(status_code=500, detail=f"Failed to update dataset: {e}")
192
- # # except Exception as e:
193
- # # logger.error(f"An error occurred: {e}")
194
- # # raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
195
-
196
-
197
- # @app.post("/update_embeddings")
198
- # async def update_embeddings(
199
- # request: UpdateEmbeddingRequest,
200
- # huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
201
- # ):
202
- # """
203
- # Update embeddings in a Hugging Face dataset by generating embeddings for new data and concatenating it with the existing dataset.
204
- # """
205
- # try:
206
- # # Call the update_dataset method to generate embeddings, concatenate, and push the updated dataset
207
- # updated_df = await huggingface_service.update_dataset(
208
- # request.dataset_name,
209
- # request.updates,
210
- # request.target_column,
211
- # request.output_column,
212
- # )
213
-
214
- # return {
215
- # "message": "Embeddings updated successfully.",
216
- # "dataset_name": request.dataset_name,
217
- # "num_rows": len(updated_df),
218
- # }
219
- # except DatasetPushError as e:
220
- # logger.error(f"Failed to update dataset: {e}")
221
- # raise HTTPException(status_code=500, detail=f"Failed to update dataset: {e}")
222
- # except Exception as e:
223
- # logger.error(f"An error occurred: {e}")
224
- # raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
225
-
226
-
227
- # # Endpoint to delete embeddings
228
- # @app.post("/delete_embeddings")
229
- # async def delete_embeddings(
230
- # request: DeleteEmbeddingRequest,
231
- # huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
232
- # ):
233
- # """
234
- # Delete embeddings from a Hugging Face dataset.
235
- # """
236
- # try:
237
- # await huggingface_service.delete_dataset(request.dataset_name)
238
- # return {
239
- # "message": "Embeddings deleted successfully.",
240
- # "dataset_name": request.dataset_name,
241
- # }
242
- # except DatasetPushError as e:
243
- # logger.error(f"Failed to delete columns: {e}")
244
- # raise HTTPException(status_code=500, detail=f"Failed to delete columns: {e}")
245
- # except Exception as e:
246
- # logger.error(f"An error occurred: {e}")
247
- # raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
248
-
249
-
250
  import os
251
  from fastapi import FastAPI, Depends, HTTPException
252
  from fastapi.responses import JSONResponse, RedirectResponse
@@ -260,6 +11,7 @@ from src.api.models.embedding_models import (
260
  UpdateEmbeddingRequest,
261
  DeleteEmbeddingRequest,
262
  EmbedRequest,
 
263
  )
264
  from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError
265
  from src.api.services.embedding_service import EmbeddingService
@@ -363,6 +115,10 @@ async def create_embedding(
363
  Create embeddings for the target column in the dataset.
364
  """
365
  try:
 
 
 
 
366
  # Step 1: Query the database
367
  logger.info("Fetching data from the database...")
368
  result = await db.fetch(request.query)
@@ -371,8 +127,6 @@ async def create_embedding(
371
 
372
  dataset = Dataset.from_dict(result)
373
 
374
- embedding_service.batch_size = request.batch_size
375
-
376
  # Step 2: Generate embeddings
377
  dataset = await embedding_service.create_embeddings(
378
  dataset, request.target_column, request.output_column
@@ -474,3 +228,45 @@ async def delete_embeddings(
474
  except Exception as e:
475
  logger.error(f"An error occurred: {e}")
476
  raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from fastapi import FastAPI, Depends, HTTPException
3
  from fastapi.responses import JSONResponse, RedirectResponse
 
11
  UpdateEmbeddingRequest,
12
  DeleteEmbeddingRequest,
13
  EmbedRequest,
14
+ SearchEmbeddingRequest,
15
  )
16
  from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError
17
  from src.api.services.embedding_service import EmbeddingService
 
115
  Create embeddings for the target column in the dataset.
116
  """
117
  try:
118
+ embedding_service.model = request.model
119
+ embedding_service.batch_size = request.batch_size
120
+ # embedding_service.max_concurrent_requests = request.max_concurrent_requests
121
+
122
  # Step 1: Query the database
123
  logger.info("Fetching data from the database...")
124
  result = await db.fetch(request.query)
 
127
 
128
  dataset = Dataset.from_dict(result)
129
 
 
 
130
  # Step 2: Generate embeddings
131
  dataset = await embedding_service.create_embeddings(
132
  dataset, request.target_column, request.output_column
 
228
  except Exception as e:
229
  logger.error(f"An error occurred: {e}")
230
  raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
231
+
232
+
233
+ @app.post("/search_embedding")
234
+ async def search_embedding(
235
+ request: SearchEmbeddingRequest,
236
+ embedding_service: EmbeddingService = Depends(get_embedding_service),
237
+ huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
238
+ ):
239
+ """
240
+ Search for similar texts in a dataset using embeddings.
241
+ """
242
+ try:
243
+ # Step 1: Generate embeddings for the query texts
244
+ logger.info("Generating embeddings for query texts...")
245
+ query_embeddings = await embedding_service.create_embeddings(request.texts)
246
+
247
+ # Step 2: Load the dataset from Hugging Face Hub
248
+ logger.info(f"Loading dataset from Hugging Face Hub: {request.dataset_name}...")
249
+ dataset = await huggingface_service.read_dataset(request.dataset_name)
250
+
251
+ # Step 3: Perform cosine similarity search
252
+ logger.info("Performing cosine similarity search...")
253
+ results = await embedding_service.search_embeddings(
254
+ query_embeddings,
255
+ dataset,
256
+ request.embedding_column,
257
+ request.target_column,
258
+ request.num_results,
259
+ )
260
+
261
+ return JSONResponse(
262
+ content={
263
+ "message": "Search completed successfully.",
264
+ "results": results,
265
+ }
266
+ )
267
+ except DatasetNotFoundError as e:
268
+ logger.error(f"Dataset not found: {e}")
269
+ raise HTTPException(status_code=404, detail=f"Dataset not found: {e}")
270
+ except Exception as e:
271
+ logger.error(f"An error occurred: {e}")
272
+ raise HTTPException(status_code=500, detail=f"An error occurred: {e}")