amaye15 commited on
Commit
b96eea7
·
1 Parent(s): 494872d

Feat - Additional Columns Returned

Browse files
src/api/models/embedding_models.py CHANGED
@@ -1,5 +1,5 @@
1
  from pydantic import BaseModel
2
- from typing import List, Dict
3
 
4
 
5
  # Pydantic models for request validation
@@ -56,3 +56,6 @@ class SearchEmbeddingRequest(BaseModel):
56
  embedding_column: str # Column containing the embeddings to search against
57
  num_results: int # Number of results to return
58
  dataset_name: str # Name of the dataset to search in
 
 
 
 
1
  from pydantic import BaseModel
2
+ from typing import List, Dict, Optional
3
 
4
 
5
  # Pydantic models for request validation
 
56
  embedding_column: str # Column containing the embeddings to search against
57
  num_results: int # Number of results to return
58
  dataset_name: str # Name of the dataset to search in
59
+ additional_columns: Optional[List[str]] = (
60
+ None # Optional list of additional columns to include in the results
61
+ )
src/api/services/embedding_service.py CHANGED
@@ -1,6 +1,6 @@
1
  from openai import AsyncOpenAI
2
  import logging
3
- from typing import List, Dict, Union
4
  from datasets import Dataset
5
  import asyncio
6
  import numpy as np
@@ -119,6 +119,48 @@ class EmbeddingService:
119
  f"Progress: {self.completed_requests}/{self.total_requests} ({progress:.2f}%)"
120
  )
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  async def search_embeddings(
123
  self,
124
  query_embeddings: List[List[float]],
@@ -126,6 +168,7 @@ class EmbeddingService:
126
  embedding_column: str,
127
  target_column: str,
128
  num_results: int,
 
129
  ) -> Dict[str, List]:
130
  """
131
  Perform a cosine similarity search between query embeddings and dataset embeddings.
@@ -136,9 +179,11 @@ class EmbeddingService:
136
  embedding_column: The column in the dataset containing embeddings.
137
  target_column: The column to return in the results.
138
  num_results: The number of results to return.
 
139
 
140
  Returns:
141
- A dictionary of lists containing the target column values and their similarity scores.
 
142
  """
143
  dataset_embeddings = np.array(dataset[embedding_column])
144
  query_embeddings = np.array(query_embeddings)
@@ -152,11 +197,19 @@ class EmbeddingService:
152
  "similarity": [],
153
  }
154
 
 
 
 
 
 
155
  # Get the top-k results for each query
156
  for query_similarities in similarities:
157
  top_k_indices = np.argsort(query_similarities)[-num_results:][::-1]
158
  for idx in top_k_indices:
159
  results[target_column].append(dataset[target_column][idx])
160
  results["similarity"].append(float(query_similarities[idx]))
 
 
 
161
 
162
  return results
 
1
  from openai import AsyncOpenAI
2
  import logging
3
+ from typing import List, Dict, Union, Optional
4
  from datasets import Dataset
5
  import asyncio
6
  import numpy as np
 
119
  f"Progress: {self.completed_requests}/{self.total_requests} ({progress:.2f}%)"
120
  )
121
 
122
+ # async def search_embeddings(
123
+ # self,
124
+ # query_embeddings: List[List[float]],
125
+ # dataset: Dataset,
126
+ # embedding_column: str,
127
+ # target_column: str,
128
+ # num_results: int,
129
+ # ) -> Dict[str, List]:
130
+ # """
131
+ # Perform a cosine similarity search between query embeddings and dataset embeddings.
132
+
133
+ # Args:
134
+ # query_embeddings: List of embeddings for the query texts.
135
+ # dataset: The dataset to search in.
136
+ # embedding_column: The column in the dataset containing embeddings.
137
+ # target_column: The column to return in the results.
138
+ # num_results: The number of results to return.
139
+
140
+ # Returns:
141
+ # A dictionary of lists containing the target column values and their similarity scores.
142
+ # """
143
+ # dataset_embeddings = np.array(dataset[embedding_column])
144
+ # query_embeddings = np.array(query_embeddings)
145
+
146
+ # # Compute cosine similarity
147
+ # similarities = cosine_similarity(query_embeddings, dataset_embeddings)
148
+
149
+ # # Initialize the results dictionary
150
+ # results = {
151
+ # target_column: [],
152
+ # "similarity": [],
153
+ # }
154
+
155
+ # # Get the top-k results for each query
156
+ # for query_similarities in similarities:
157
+ # top_k_indices = np.argsort(query_similarities)[-num_results:][::-1]
158
+ # for idx in top_k_indices:
159
+ # results[target_column].append(dataset[target_column][idx])
160
+ # results["similarity"].append(float(query_similarities[idx]))
161
+
162
+ # return results
163
+
164
  async def search_embeddings(
165
  self,
166
  query_embeddings: List[List[float]],
 
168
  embedding_column: str,
169
  target_column: str,
170
  num_results: int,
171
+ additional_columns: Optional[List[str]] = None,
172
  ) -> Dict[str, List]:
173
  """
174
  Perform a cosine similarity search between query embeddings and dataset embeddings.
 
179
  embedding_column: The column in the dataset containing embeddings.
180
  target_column: The column to return in the results.
181
  num_results: The number of results to return.
182
+ additional_columns: List of additional columns to include in the results.
183
 
184
  Returns:
185
+ A dictionary of lists containing the target column values, their similarity scores,
186
+ and any additional columns specified.
187
  """
188
  dataset_embeddings = np.array(dataset[embedding_column])
189
  query_embeddings = np.array(query_embeddings)
 
197
  "similarity": [],
198
  }
199
 
200
+ # Add additional columns to the results dictionary
201
+ if additional_columns:
202
+ for column in additional_columns:
203
+ results[column] = []
204
+
205
  # Get the top-k results for each query
206
  for query_similarities in similarities:
207
  top_k_indices = np.argsort(query_similarities)[-num_results:][::-1]
208
  for idx in top_k_indices:
209
  results[target_column].append(dataset[target_column][idx])
210
  results["similarity"].append(float(query_similarities[idx]))
211
+ if additional_columns:
212
+ for column in additional_columns:
213
+ results[column].append(dataset[column][idx])
214
 
215
  return results
src/main.py CHANGED
@@ -256,6 +256,7 @@ async def search_embedding(
256
  request.embedding_column,
257
  request.target_column,
258
  request.num_results,
 
259
  )
260
 
261
  return JSONResponse(
 
256
  request.embedding_column,
257
  request.target_column,
258
  request.num_results,
259
+ request.additional_columns,
260
  )
261
 
262
  return JSONResponse(