jonathanjordan21 commited on
Commit
b4970eb
·
verified ·
1 Parent(s): 9f702a0

Update apis/chat_api.py

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +26 -1
apis/chat_api.py CHANGED
@@ -26,7 +26,7 @@ from networks.huggingface_streamer import HuggingfaceStreamer
26
  from networks.huggingchat_streamer import HuggingchatStreamer
27
  from networks.openai_streamer import OpenaiStreamer
28
 
29
- from sentence_transformers import SentenceTransformer
30
  import tiktoken
31
 
32
  class EmbeddingsAPIInference:
@@ -72,6 +72,9 @@ class ChatAPIApp:
72
  "intfloat/multilingual-e5-large-instruct":EmbeddingsAPIInference("intfloat/multilingual-e5-large-instruct"),
73
  "mixedbread-ai/mxbai-embed-large-v1":EmbeddingsAPIInference("mixedbread-ai/mxbai-embed-large-v1")
74
  }
 
 
 
75
 
76
  def get_available_models(self):
77
  return {"object": "list", "data": AVAILABLE_MODELS_DICTS}
@@ -376,6 +379,22 @@ class ChatAPIApp:
376
  return {"embedding": embeddings}#.tolist()}
377
  except ValueError as e:
378
  raise HTTPException(status_code=400, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
 
381
  def get_readme(self):
@@ -400,6 +419,12 @@ class ChatAPIApp:
400
  include_in_schema=include_in_schema,
401
  )(self.get_available_models)
402
 
 
 
 
 
 
 
403
  self.app.post(
404
  prefix + "/chat/completions",
405
  summary="OpenAI Chat completions in conversation session",
 
26
  from networks.huggingchat_streamer import HuggingchatStreamer
27
  from networks.openai_streamer import OpenaiStreamer
28
 
29
+ from sentence_transformers import SentenceTransformer, CrossEncoder
30
  import tiktoken
31
 
32
  class EmbeddingsAPIInference:
 
72
  "intfloat/multilingual-e5-large-instruct":EmbeddingsAPIInference("intfloat/multilingual-e5-large-instruct"),
73
  "mixedbread-ai/mxbai-embed-large-v1":EmbeddingsAPIInference("mixedbread-ai/mxbai-embed-large-v1")
74
  }
75
+ self.rerank = {
76
+ "bge-reranker-v2-m3":CrossEncoder("BAAI/bge-reranker-v2-m3")
77
+ }
78
 
79
  def get_available_models(self):
80
  return {"object": "list", "data": AVAILABLE_MODELS_DICTS}
 
379
  return {"embedding": embeddings}#.tolist()}
380
  except ValueError as e:
381
  raise HTTPException(status_code=400, detail=str(e))
382
+
383
+
384
+ class RerankRequest(BaseModel):
385
+ model: str
386
+ input: str
387
+ document: list
388
+ top_k: int
389
+ return_documents: bool
390
+
391
+ def get_rerank(self, request: RerankRequest, api_key: str = Depends(extract_api_key)):
392
+ return self.rerank[request.model].rank(
393
+ request.input,
394
+ request.document,
395
+ top_k=request.top_k,
396
+ return_documents=request.return_documents
397
+ )
398
 
399
 
400
  def get_readme(self):
 
419
  include_in_schema=include_in_schema,
420
  )(self.get_available_models)
421
 
422
+ self.app.post(
423
+ prefix+"/rerank",
424
+ summary="Rerank documents",
425
+ include_in_schema=include_in_schema,
426
+ )(self.get_rerank)
427
+
428
  self.app.post(
429
  prefix + "/chat/completions",
430
  summary="OpenAI Chat completions in conversation session",