File size: 4,279 Bytes
ca56e6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import asyncio
import base64
from typing import Union

import numpy as np
import tiktoken
from fastapi import APIRouter, Depends
from openai import AsyncOpenAI
from openai.types.create_embedding_response import Usage
from sentence_transformers import SentenceTransformer

from api.config import SETTINGS
from api.models import EMBEDDED_MODEL
from api.utils.protocol import EmbeddingCreateParams, Embedding, CreateEmbeddingResponse
from api.utils.request import check_api_key

embedding_router = APIRouter()


def get_embedding_engine():
    yield EMBEDDED_MODEL


@embedding_router.post("/embeddings", dependencies=[Depends(check_api_key)])
@embedding_router.post("/engines/{model_name}/embeddings", dependencies=[Depends(check_api_key)])
async def create_embeddings(
    request: EmbeddingCreateParams,
    model_name: str = None,
    client: Union[SentenceTransformer, AsyncOpenAI] = Depends(get_embedding_engine),
):
    """Creates embeddings for the text"""
    if request.model is None:
        request.model = model_name

    request.input = request.input
    if isinstance(request.input, str):
        request.input = [request.input]
    elif isinstance(request.input, list):
        if isinstance(request.input[0], int):
            decoding = tiktoken.model.encoding_for_model(request.model)
            request.input = [decoding.decode(request.input)]
        elif isinstance(request.input[0], list):
            decoding = tiktoken.model.encoding_for_model(request.model)
            request.input = [decoding.decode(text) for text in request.input]

    data, total_tokens = [], 0

    # support for tei: https://github.com/huggingface/text-embeddings-inference
    if isinstance(client, AsyncOpenAI):
        global_batch_size = SETTINGS.max_concurrent_requests * SETTINGS.max_client_batch_size
        for i in range(0, len(request.input), global_batch_size):
            tasks = []
            texts = request.input[i: i + global_batch_size]
            for j in range(0, len(texts), SETTINGS.max_client_batch_size):
                tasks.append(
                    client.embeddings.create(
                        input=[text[:510] for text in texts[j: j + SETTINGS.max_client_batch_size]],
                        model=request.model,
                    )
                )
            res = await asyncio.gather(*tasks)

            vecs = np.asarray([e.embedding for r in res for e in r.data])
            bs, dim = vecs.shape
            if SETTINGS.embedding_size > dim:
                zeros = np.zeros((bs, SETTINGS.embedding_size - dim))
                vecs = np.c_[vecs, zeros]

            if request.encoding_format == "base64":
                vecs = [base64.b64encode(v.tobytes()).decode("utf-8") for v in vecs]
            else:
                vecs = vecs.tolist()

            data.extend(
                Embedding(
                    index=i * global_batch_size + j,
                    object="embedding",
                    embedding=embed
                )
                for j, embed in enumerate(vecs)
            )
            total_tokens += sum(r.usage.total_tokens for r in res)
    else:
        batches = [request.input[i: i + 1024] for i in range(0, len(request.input), 1024)]
        for num_batch, batch in enumerate(batches):
            token_num = sum(len(i) for i in batch)
            vecs = client.encode(batch, normalize_embeddings=True)

            bs, dim = vecs.shape
            if SETTINGS.embedding_size > dim:
                zeros = np.zeros((bs, SETTINGS.embedding_size - dim))
                vecs = np.c_[vecs, zeros]

            if request.encoding_format == "base64":
                vecs = [base64.b64encode(v.tobytes()).decode("utf-8") for v in vecs]
            else:
                vecs = vecs.tolist()

            data.extend(
                Embedding(
                    index=num_batch * 1024 + i,
                    object="embedding",
                    embedding=embedding,
                )
                for i, embedding in enumerate(vecs)
            )
            total_tokens += token_num

    return CreateEmbeddingResponse(
        data=data,
        model=request.model,
        object="list",
        usage=Usage(prompt_tokens=total_tokens, total_tokens=total_tokens),
    )