Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
3408aae
1
Parent(s):
5d6ca81
rename
Browse files
main.py
CHANGED
@@ -9,9 +9,13 @@ from httpx import AsyncClient
|
|
9 |
from huggingface_hub import DatasetCard
|
10 |
from pydantic import BaseModel
|
11 |
from starlette.responses import RedirectResponse
|
12 |
-
from starlette.status import
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
from
|
15 |
|
16 |
# Set up logging
|
17 |
logging.basicConfig(
|
@@ -97,6 +101,14 @@ class DatasetCardNotFoundError(HTTPException):
|
|
97 |
)
|
98 |
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
@app.get("/similar", response_model=QueryResponse)
|
101 |
@cache(ttl="1h")
|
102 |
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
|
@@ -115,7 +127,9 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
|
|
115 |
collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
|
116 |
logger.info(f"Dataset {dataset_id} added to collection")
|
117 |
result = collection.get(ids=[dataset_id], include=["embeddings"])
|
118 |
-
|
|
|
|
|
119 |
raise
|
120 |
except Exception as e:
|
121 |
logger.error(
|
@@ -157,6 +171,44 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
|
|
157 |
) from e
|
158 |
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
if __name__ == "__main__":
|
161 |
import uvicorn
|
162 |
|
|
|
9 |
from huggingface_hub import DatasetCard
|
10 |
from pydantic import BaseModel
|
11 |
from starlette.responses import RedirectResponse
|
12 |
+
from starlette.status import (
|
13 |
+
HTTP_404_NOT_FOUND,
|
14 |
+
HTTP_500_INTERNAL_SERVER_ERROR,
|
15 |
+
HTTP_403_FORBIDDEN,
|
16 |
+
)
|
17 |
|
18 |
+
from load_card_data import get_embedding_function, get_save_path, refresh_data
|
19 |
|
20 |
# Set up logging
|
21 |
logging.basicConfig(
|
|
|
101 |
)
|
102 |
|
103 |
|
104 |
+
class DatasetNotForAllAudiencesError(HTTPException):
|
105 |
+
def __init__(self, dataset_id: str):
|
106 |
+
super().__init__(
|
107 |
+
status_code=HTTP_403_FORBIDDEN,
|
108 |
+
detail=f"Dataset {dataset_id} is not for all audiences and not supported in this service.",
|
109 |
+
)
|
110 |
+
|
111 |
+
|
112 |
@app.get("/similar", response_model=QueryResponse)
|
113 |
@cache(ttl="1h")
|
114 |
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
|
|
|
127 |
collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
|
128 |
logger.info(f"Dataset {dataset_id} added to collection")
|
129 |
result = collection.get(ids=[dataset_id], include=["embeddings"])
|
130 |
+
if result.get("not-for-all-audiences"):
|
131 |
+
raise DatasetNotForAllAudiencesError(dataset_id)
|
132 |
+
except (DatasetCardNotFoundError, DatasetNotForAllAudiencesError):
|
133 |
raise
|
134 |
except Exception as e:
|
135 |
logger.error(
|
|
|
171 |
) from e
|
172 |
|
173 |
|
174 |
+
@app.post("/similar_by_text", response_model=QueryResponse)
|
175 |
+
@cache(ttl="1h")
|
176 |
+
async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)):
|
177 |
+
try:
|
178 |
+
logger.info(f"Querying datasets by text: {query}")
|
179 |
+
collection = client.get_collection(
|
180 |
+
name="dataset_cards", embedding_function=get_embedding_function()
|
181 |
+
)
|
182 |
+
print(query)
|
183 |
+
query_result = collection.query(
|
184 |
+
query_texts=query, n_results=n, include=["distances"]
|
185 |
+
)
|
186 |
+
print(query_result)
|
187 |
+
|
188 |
+
if not query_result["ids"]:
|
189 |
+
logger.info(f"No similar datasets found for query: {query}")
|
190 |
+
raise HTTPException(
|
191 |
+
status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found."
|
192 |
+
)
|
193 |
+
|
194 |
+
# Prepare the response
|
195 |
+
results = [
|
196 |
+
QueryResult(dataset_id=str(id), similarity=float(1 - distance))
|
197 |
+
for id, distance in zip(
|
198 |
+
query_result["ids"][0], query_result["distances"][0]
|
199 |
+
)
|
200 |
+
]
|
201 |
+
logger.info(f"Found {len(results)} similar datasets for query: {query}")
|
202 |
+
return QueryResponse(results=results)
|
203 |
+
|
204 |
+
except Exception as e:
|
205 |
+
logger.error(f"Error querying datasets by text {query}: {str(e)}")
|
206 |
+
raise HTTPException(
|
207 |
+
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
208 |
+
detail="An unexpected error occurred.",
|
209 |
+
) from e
|
210 |
+
|
211 |
+
|
212 |
if __name__ == "__main__":
|
213 |
import uvicorn
|
214 |
|