fastelectronicvegetable
.
69242cf
raw
history blame
1.61 kB
from typing import Generic, List, Optional, TypeVar
from functools import partial
from pydantic import BaseModel, ValidationError, validator
from pydantic.generics import GenericModel
from sentence_transformers import SentenceTransformer
from fastapi import FastAPI
import os, asyncio, numpy, ujson
MODEL = SentenceTransformer("all-mpnet-base-v2")
def cache(func):
inner_cache = dict()
def inner(sentences: List[str]):
if len(sentences) == 0:
return []
not_in_cache = list(filter(lambda s: s not in inner_cache.keys(), sentences))
if len(not_in_cache) > 0:
processed_sentences = func(list(not_in_cache))
for sentence, embedding in zip(not_in_cache, processed_sentences):
inner_cache[sentence] = embedding
return [inner_cache[s] for s in sentences]
return inner
@cache
def _encode(sentences: List[str]):
array = [numpy.around(a.numpy(), 3) for a in MODEL.encode(sentences, normalize_embeddings=True, convert_to_tensor=True, batch_size=4, show_progress_bar=True)]
return array
async def encode(sentences: List[str]) -> List[numpy.ndarray]:
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, _encode, sentences)
return result
class SemanticSearchReq(BaseModel):
query: str
candidates: List[str]
class EmbedReq(BaseModel):
sentences: List[str]
app = FastAPI()
@app.post("/embed")
async def embed(embed: EmbedReq):
result = await encode(embed.sentences)
# Convert it to an ordinary list of floats
return ujson.dumps([r.tolist() for r in result])