Spaces:
Sleeping
Sleeping
Update apis/chat_api.py
Browse files- apis/chat_api.py +26 -1
apis/chat_api.py
CHANGED
@@ -26,7 +26,7 @@ from networks.huggingface_streamer import HuggingfaceStreamer
|
|
26 |
from networks.huggingchat_streamer import HuggingchatStreamer
|
27 |
from networks.openai_streamer import OpenaiStreamer
|
28 |
|
29 |
-
from sentence_transformers import SentenceTransformer
|
30 |
import tiktoken
|
31 |
|
32 |
class EmbeddingsAPIInference:
|
@@ -72,6 +72,9 @@ class ChatAPIApp:
|
|
72 |
"intfloat/multilingual-e5-large-instruct":EmbeddingsAPIInference("intfloat/multilingual-e5-large-instruct"),
|
73 |
"mixedbread-ai/mxbai-embed-large-v1":EmbeddingsAPIInference("mixedbread-ai/mxbai-embed-large-v1")
|
74 |
}
|
|
|
|
|
|
|
75 |
|
76 |
def get_available_models(self):
|
77 |
return {"object": "list", "data": AVAILABLE_MODELS_DICTS}
|
@@ -376,6 +379,22 @@ class ChatAPIApp:
|
|
376 |
return {"embedding": embeddings}#.tolist()}
|
377 |
except ValueError as e:
|
378 |
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
|
380 |
|
381 |
def get_readme(self):
|
@@ -400,6 +419,12 @@ class ChatAPIApp:
|
|
400 |
include_in_schema=include_in_schema,
|
401 |
)(self.get_available_models)
|
402 |
|
|
|
|
|
|
|
|
|
|
|
|
|
403 |
self.app.post(
|
404 |
prefix + "/chat/completions",
|
405 |
summary="OpenAI Chat completions in conversation session",
|
|
|
26 |
from networks.huggingchat_streamer import HuggingchatStreamer
|
27 |
from networks.openai_streamer import OpenaiStreamer
|
28 |
|
29 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
30 |
import tiktoken
|
31 |
|
32 |
class EmbeddingsAPIInference:
|
|
|
72 |
"intfloat/multilingual-e5-large-instruct":EmbeddingsAPIInference("intfloat/multilingual-e5-large-instruct"),
|
73 |
"mixedbread-ai/mxbai-embed-large-v1":EmbeddingsAPIInference("mixedbread-ai/mxbai-embed-large-v1")
|
74 |
}
|
75 |
+
self.rerank = {
|
76 |
+
"bge-reranker-v2-m3":CrossEncoder("BAAI/bge-reranker-v2-m3")
|
77 |
+
}
|
78 |
|
79 |
def get_available_models(self):
|
80 |
return {"object": "list", "data": AVAILABLE_MODELS_DICTS}
|
|
|
379 |
return {"embedding": embeddings}#.tolist()}
|
380 |
except ValueError as e:
|
381 |
raise HTTPException(status_code=400, detail=str(e))
|
382 |
+
|
383 |
+
|
384 |
+
class RerankRequest(BaseModel):
|
385 |
+
model: str
|
386 |
+
input: str
|
387 |
+
document: list
|
388 |
+
top_k: int
|
389 |
+
return_documents: bool
|
390 |
+
|
391 |
+
def get_rerank(self, request: RerankRequest, api_key: str = Depends(extract_api_key)):
|
392 |
+
return self.rerank[request.model].rank(
|
393 |
+
request.input,
|
394 |
+
request.document,
|
395 |
+
top_k=request.top_k,
|
396 |
+
return_documents=request.return_documents
|
397 |
+
)
|
398 |
|
399 |
|
400 |
def get_readme(self):
|
|
|
419 |
include_in_schema=include_in_schema,
|
420 |
)(self.get_available_models)
|
421 |
|
422 |
+
self.app.post(
|
423 |
+
prefix+"/rerank",
|
424 |
+
summary="Rerank documents",
|
425 |
+
include_in_schema=include_in_schema,
|
426 |
+
)(self.get_rerank)
|
427 |
+
|
428 |
self.app.post(
|
429 |
prefix + "/chat/completions",
|
430 |
summary="OpenAI Chat completions in conversation session",
|