Spaces:
Running
Running
File size: 4,279 Bytes
ca56e6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import asyncio
import base64
from typing import Union
import numpy as np
import tiktoken
from fastapi import APIRouter, Depends
from openai import AsyncOpenAI
from openai.types.create_embedding_response import Usage
from sentence_transformers import SentenceTransformer
from api.config import SETTINGS
from api.models import EMBEDDED_MODEL
from api.utils.protocol import EmbeddingCreateParams, Embedding, CreateEmbeddingResponse
from api.utils.request import check_api_key
embedding_router = APIRouter()
def get_embedding_engine():
yield EMBEDDED_MODEL
@embedding_router.post("/embeddings", dependencies=[Depends(check_api_key)])
@embedding_router.post("/engines/{model_name}/embeddings", dependencies=[Depends(check_api_key)])
async def create_embeddings(
request: EmbeddingCreateParams,
model_name: str = None,
client: Union[SentenceTransformer, AsyncOpenAI] = Depends(get_embedding_engine),
):
"""Creates embeddings for the text"""
if request.model is None:
request.model = model_name
request.input = request.input
if isinstance(request.input, str):
request.input = [request.input]
elif isinstance(request.input, list):
if isinstance(request.input[0], int):
decoding = tiktoken.model.encoding_for_model(request.model)
request.input = [decoding.decode(request.input)]
elif isinstance(request.input[0], list):
decoding = tiktoken.model.encoding_for_model(request.model)
request.input = [decoding.decode(text) for text in request.input]
data, total_tokens = [], 0
# support for tei: https://github.com/huggingface/text-embeddings-inference
if isinstance(client, AsyncOpenAI):
global_batch_size = SETTINGS.max_concurrent_requests * SETTINGS.max_client_batch_size
for i in range(0, len(request.input), global_batch_size):
tasks = []
texts = request.input[i: i + global_batch_size]
for j in range(0, len(texts), SETTINGS.max_client_batch_size):
tasks.append(
client.embeddings.create(
input=[text[:510] for text in texts[j: j + SETTINGS.max_client_batch_size]],
model=request.model,
)
)
res = await asyncio.gather(*tasks)
vecs = np.asarray([e.embedding for r in res for e in r.data])
bs, dim = vecs.shape
if SETTINGS.embedding_size > dim:
zeros = np.zeros((bs, SETTINGS.embedding_size - dim))
vecs = np.c_[vecs, zeros]
if request.encoding_format == "base64":
vecs = [base64.b64encode(v.tobytes()).decode("utf-8") for v in vecs]
else:
vecs = vecs.tolist()
data.extend(
Embedding(
index=i * global_batch_size + j,
object="embedding",
embedding=embed
)
for j, embed in enumerate(vecs)
)
total_tokens += sum(r.usage.total_tokens for r in res)
else:
batches = [request.input[i: i + 1024] for i in range(0, len(request.input), 1024)]
for num_batch, batch in enumerate(batches):
token_num = sum(len(i) for i in batch)
vecs = client.encode(batch, normalize_embeddings=True)
bs, dim = vecs.shape
if SETTINGS.embedding_size > dim:
zeros = np.zeros((bs, SETTINGS.embedding_size - dim))
vecs = np.c_[vecs, zeros]
if request.encoding_format == "base64":
vecs = [base64.b64encode(v.tobytes()).decode("utf-8") for v in vecs]
else:
vecs = vecs.tolist()
data.extend(
Embedding(
index=num_batch * 1024 + i,
object="embedding",
embedding=embedding,
)
for i, embedding in enumerate(vecs)
)
total_tokens += token_num
return CreateEmbeddingResponse(
data=data,
model=request.model,
object="list",
usage=Usage(prompt_tokens=total_tokens, total_tokens=total_tokens),
)
|