Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
9ed5b2c
1
Parent(s):
abbed11
chore: Refactor error handling in api_query_dataset
Browse files
main.py
CHANGED
@@ -9,6 +9,7 @@ from httpx import AsyncClient
|
|
9 |
from huggingface_hub import DatasetCard
|
10 |
from pydantic import BaseModel
|
11 |
from starlette.responses import RedirectResponse
|
|
|
12 |
|
13 |
from load_data import get_embedding_function, get_save_path, refresh_data
|
14 |
|
@@ -31,15 +32,6 @@ async_client = AsyncClient(
|
|
31 |
)
|
32 |
|
33 |
|
34 |
-
class QueryResult(BaseModel):
|
35 |
-
dataset_id: str
|
36 |
-
similarity: float
|
37 |
-
|
38 |
-
|
39 |
-
class QueryResponse(BaseModel):
|
40 |
-
results: List[QueryResult]
|
41 |
-
|
42 |
-
|
43 |
@asynccontextmanager
|
44 |
async def lifespan(app: FastAPI):
|
45 |
global collection
|
@@ -88,6 +80,23 @@ async def try_get_card(hub_id: str) -> Optional[str]:
|
|
88 |
return None
|
89 |
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
@app.get("/similar", response_model=QueryResponse)
|
92 |
@cache(ttl="1h")
|
93 |
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
|
@@ -101,16 +110,18 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
|
|
101 |
embedding_function = get_embedding_function()
|
102 |
card = await try_get_card(dataset_id)
|
103 |
if card is None:
|
104 |
-
|
105 |
embeddings = embedding_function(card)
|
106 |
collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
|
107 |
logger.info(f"Dataset {dataset_id} added to collection")
|
108 |
result = collection.get(ids=[dataset_id], include=["embeddings"])
|
|
|
|
|
109 |
except Exception as e:
|
110 |
logger.error(
|
111 |
f"Error adding dataset {dataset_id} to collection: {str(e)}"
|
112 |
)
|
113 |
-
|
114 |
|
115 |
embedding = result["embeddings"][0]
|
116 |
|
@@ -121,7 +132,9 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
|
|
121 |
|
122 |
if not query_result["ids"]:
|
123 |
logger.info(f"No similar datasets found for: {dataset_id}")
|
124 |
-
|
|
|
|
|
125 |
|
126 |
# Prepare the response
|
127 |
results = [
|
@@ -134,9 +147,15 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
|
|
134 |
logger.info(f"Found {len(results)} similar datasets for: {dataset_id}")
|
135 |
return QueryResponse(results=results)
|
136 |
|
|
|
|
|
137 |
except Exception as e:
|
138 |
logger.error(f"Error querying dataset {dataset_id}: {str(e)}")
|
139 |
-
raise HTTPException(
|
|
|
|
|
|
|
|
|
140 |
|
141 |
if __name__ == "__main__":
|
142 |
import uvicorn
|
|
|
9 |
from huggingface_hub import DatasetCard
|
10 |
from pydantic import BaseModel
|
11 |
from starlette.responses import RedirectResponse
|
12 |
+
from starlette.status import HTTP_404_NOT_FOUND, HTTP_500_INTERNAL_SERVER_ERROR
|
13 |
|
14 |
from load_data import get_embedding_function, get_save_path, refresh_data
|
15 |
|
|
|
32 |
)
|
33 |
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
@asynccontextmanager
|
36 |
async def lifespan(app: FastAPI):
|
37 |
global collection
|
|
|
80 |
return None
|
81 |
|
82 |
|
83 |
+
class QueryResult(BaseModel):
|
84 |
+
dataset_id: str
|
85 |
+
similarity: float
|
86 |
+
|
87 |
+
|
88 |
+
class QueryResponse(BaseModel):
|
89 |
+
results: List[QueryResult]
|
90 |
+
|
91 |
+
|
92 |
+
class DatasetCardNotFoundError(HTTPException):
|
93 |
+
def __init__(self, dataset_id: str):
|
94 |
+
super().__init__(
|
95 |
+
status_code=HTTP_404_NOT_FOUND,
|
96 |
+
detail=f"No dataset card available for dataset: {dataset_id}",
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
@app.get("/similar", response_model=QueryResponse)
|
101 |
@cache(ttl="1h")
|
102 |
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
|
|
|
110 |
embedding_function = get_embedding_function()
|
111 |
card = await try_get_card(dataset_id)
|
112 |
if card is None:
|
113 |
+
raise DatasetCardNotFoundError(dataset_id)
|
114 |
embeddings = embedding_function(card)
|
115 |
collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
|
116 |
logger.info(f"Dataset {dataset_id} added to collection")
|
117 |
result = collection.get(ids=[dataset_id], include=["embeddings"])
|
118 |
+
except DatasetCardNotFoundError:
|
119 |
+
raise
|
120 |
except Exception as e:
|
121 |
logger.error(
|
122 |
f"Error adding dataset {dataset_id} to collection: {str(e)}"
|
123 |
)
|
124 |
+
raise DatasetCardNotFoundError(dataset_id) from e
|
125 |
|
126 |
embedding = result["embeddings"][0]
|
127 |
|
|
|
132 |
|
133 |
if not query_result["ids"]:
|
134 |
logger.info(f"No similar datasets found for: {dataset_id}")
|
135 |
+
raise HTTPException(
|
136 |
+
status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found."
|
137 |
+
)
|
138 |
|
139 |
# Prepare the response
|
140 |
results = [
|
|
|
147 |
logger.info(f"Found {len(results)} similar datasets for: {dataset_id}")
|
148 |
return QueryResponse(results=results)
|
149 |
|
150 |
+
except (HTTPException, DatasetCardNotFoundError):
|
151 |
+
raise
|
152 |
except Exception as e:
|
153 |
logger.error(f"Error querying dataset {dataset_id}: {str(e)}")
|
154 |
+
raise HTTPException(
|
155 |
+
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
156 |
+
detail="An unexpected error occurred.",
|
157 |
+
) from e
|
158 |
+
|
159 |
|
160 |
if __name__ == "__main__":
|
161 |
import uvicorn
|