davanstrien HF staff commited on
Commit
3408aae
·
1 Parent(s): 5d6ca81
Files changed (1) hide show
  1. main.py +55 -3
main.py CHANGED
@@ -9,9 +9,13 @@ from httpx import AsyncClient
9
  from huggingface_hub import DatasetCard
10
  from pydantic import BaseModel
11
  from starlette.responses import RedirectResponse
12
- from starlette.status import HTTP_404_NOT_FOUND, HTTP_500_INTERNAL_SERVER_ERROR
 
 
 
 
13
 
14
- from load_data import get_embedding_function, get_save_path, refresh_data
15
 
16
  # Set up logging
17
  logging.basicConfig(
@@ -97,6 +101,14 @@ class DatasetCardNotFoundError(HTTPException):
97
  )
98
 
99
 
 
 
 
 
 
 
 
 
100
  @app.get("/similar", response_model=QueryResponse)
101
  @cache(ttl="1h")
102
  async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
@@ -115,7 +127,9 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
115
  collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
116
  logger.info(f"Dataset {dataset_id} added to collection")
117
  result = collection.get(ids=[dataset_id], include=["embeddings"])
118
- except DatasetCardNotFoundError:
 
 
119
  raise
120
  except Exception as e:
121
  logger.error(
@@ -157,6 +171,44 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
157
  ) from e
158
 
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  if __name__ == "__main__":
161
  import uvicorn
162
 
 
9
  from huggingface_hub import DatasetCard
10
  from pydantic import BaseModel
11
  from starlette.responses import RedirectResponse
12
+ from starlette.status import (
13
+ HTTP_404_NOT_FOUND,
14
+ HTTP_500_INTERNAL_SERVER_ERROR,
15
+ HTTP_403_FORBIDDEN,
16
+ )
17
 
18
+ from load_card_data import get_embedding_function, get_save_path, refresh_data
19
 
20
  # Set up logging
21
  logging.basicConfig(
 
101
  )
102
 
103
 
104
+ class DatasetNotForAllAudiencesError(HTTPException):
105
+ def __init__(self, dataset_id: str):
106
+ super().__init__(
107
+ status_code=HTTP_403_FORBIDDEN,
108
+ detail=f"Dataset {dataset_id} is not for all audiences and not supported in this service.",
109
+ )
110
+
111
+
112
  @app.get("/similar", response_model=QueryResponse)
113
  @cache(ttl="1h")
114
  async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
 
127
  collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
128
  logger.info(f"Dataset {dataset_id} added to collection")
129
  result = collection.get(ids=[dataset_id], include=["embeddings"])
130
+ if result.get("not-for-all-audiences"):
131
+ raise DatasetNotForAllAudiencesError(dataset_id)
132
+ except (DatasetCardNotFoundError, DatasetNotForAllAudiencesError):
133
  raise
134
  except Exception as e:
135
  logger.error(
 
171
  ) from e
172
 
173
 
174
+ @app.post("/similar_by_text", response_model=QueryResponse)
175
+ @cache(ttl="1h")
176
+ async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)):
177
+ try:
178
+ logger.info(f"Querying datasets by text: {query}")
179
+ collection = client.get_collection(
180
+ name="dataset_cards", embedding_function=get_embedding_function()
181
+ )
182
+ print(query)
183
+ query_result = collection.query(
184
+ query_texts=query, n_results=n, include=["distances"]
185
+ )
186
+ print(query_result)
187
+
188
+ if not query_result["ids"]:
189
+ logger.info(f"No similar datasets found for query: {query}")
190
+ raise HTTPException(
191
+ status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found."
192
+ )
193
+
194
+ # Prepare the response
195
+ results = [
196
+ QueryResult(dataset_id=str(id), similarity=float(1 - distance))
197
+ for id, distance in zip(
198
+ query_result["ids"][0], query_result["distances"][0]
199
+ )
200
+ ]
201
+ logger.info(f"Found {len(results)} similar datasets for query: {query}")
202
+ return QueryResponse(results=results)
203
+
204
+ except Exception as e:
205
+ logger.error(f"Error querying datasets by text {query}: {str(e)}")
206
+ raise HTTPException(
207
+ status_code=HTTP_500_INTERNAL_SERVER_ERROR,
208
+ detail="An unexpected error occurred.",
209
+ ) from e
210
+
211
+
212
  if __name__ == "__main__":
213
  import uvicorn
214