Update apis/chat_api.py
Browse files- apis/chat_api.py +21 -21
apis/chat_api.py
CHANGED
@@ -175,7 +175,7 @@ class ChatAPIApp:
|
|
175 |
data_response = streamer.chat_return_dict(stream_response)
|
176 |
return data_response
|
177 |
|
178 |
-
def chat_embedding(texts, model_name, api_key):
|
179 |
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}"
|
180 |
headers = {"Authorization": f"Bearer {api_key}"}
|
181 |
response = requests.post(api_url, headers=headers, json={"inputs": texts})
|
@@ -189,26 +189,26 @@ class ChatAPIApp:
|
|
189 |
|
190 |
|
191 |
async def embedding(request: QueryRequest):
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
|
213 |
def setup_routes(self):
|
214 |
for prefix in ["", "/v1", "/api", "/api/v1"]:
|
|
|
175 |
data_response = streamer.chat_return_dict(stream_response)
|
176 |
return data_response
|
177 |
|
178 |
+
async def chat_embedding(texts, model_name, api_key):
|
179 |
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}"
|
180 |
headers = {"Authorization": f"Bearer {api_key}"}
|
181 |
response = requests.post(api_url, headers=headers, json={"inputs": texts})
|
|
|
189 |
|
190 |
|
191 |
async def embedding(request: QueryRequest):
|
192 |
+
try:
|
193 |
+
for attempt in range(3): # Retry logic
|
194 |
+
try:
|
195 |
+
embeddings = await chat_embedding(request.texts, request.model_name, request.api_key)
|
196 |
+
data = [
|
197 |
+
{"object": "embedding", "index": i, "embedding": embedding}
|
198 |
+
for i, embedding in enumerate(embeddings)
|
199 |
+
]
|
200 |
+
return {
|
201 |
+
"object": "list",
|
202 |
+
"data": data,
|
203 |
+
"model": request.model_name,
|
204 |
+
"usage": {"prompt_tokens": len(request.texts), "total_tokens": len(request.texts)}
|
205 |
+
}
|
206 |
+
except RuntimeError as e:
|
207 |
+
if attempt < 2: # Don't sleep on the last attempt
|
208 |
+
await asyncio.sleep(10) # Delay for the retry
|
209 |
+
raise HTTPException(status_code=503, detail="The model is currently loading, please try again later.")
|
210 |
+
except Exception as e:
|
211 |
+
raise HTTPException(status_code=500, detail=str(e))
|
212 |
|
213 |
def setup_routes(self):
|
214 |
for prefix in ["", "/v1", "/api", "/api/v1"]:
|