File size: 1,611 Bytes
69242cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from typing import Generic, List, Optional, TypeVar
from functools import partial
from pydantic import BaseModel, ValidationError, validator
from pydantic.generics import GenericModel
from sentence_transformers import SentenceTransformer
from fastapi import FastAPI
import os, asyncio, numpy, ujson

MODEL = SentenceTransformer("all-mpnet-base-v2")

def cache(func):
    inner_cache = dict()
    def inner(sentences: List[str]):
        if len(sentences) == 0:
            return []
        not_in_cache = list(filter(lambda s: s not in inner_cache.keys(), sentences))
        if len(not_in_cache) > 0:
            processed_sentences = func(list(not_in_cache))
            for sentence, embedding in zip(not_in_cache, processed_sentences):
                inner_cache[sentence] = embedding
        return [inner_cache[s] for s in sentences]
    return inner

@cache
def _encode(sentences: List[str]):
    array = [numpy.around(a.numpy(), 3) for a in MODEL.encode(sentences, normalize_embeddings=True, convert_to_tensor=True, batch_size=4, show_progress_bar=True)]
    return array

async def encode(sentences: List[str]) -> List[numpy.ndarray]:
    loop = asyncio.get_event_loop()
    result = await loop.run_in_executor(None, _encode, sentences)
    return result

class SemanticSearchReq(BaseModel):
    query: str
    candidates: List[str]

class EmbedReq(BaseModel):
    sentences: List[str]

app = FastAPI()

@app.post("/embed")
async def embed(embed: EmbedReq):
    result = await encode(embed.sentences)
    # Convert it to an ordinary list of floats
    return ujson.dumps([r.tolist() for r in result])