fastelectronicvegetable
cors middleware
d29e2b3
raw
history blame
1.7 kB
from typing import Generic, List, Optional, TypeVar
from functools import partial
from pydantic import BaseModel, ValidationError, validator
from pydantic.generics import GenericModel
from sentence_transformers import SentenceTransformer
from fastapi import FastAPI
import os, asyncio, numpy, ujson
from fastapi.middleware.cors import CORSMiddleware
MODEL = SentenceTransformer("all-mpnet-base-v2")
def cache(func):
inner_cache = dict()
def inner(sentences: List[str]):
if len(sentences) == 0:
return []
not_in_cache = list(filter(lambda s: s not in inner_cache.keys(), sentences))
if len(not_in_cache) > 0:
processed_sentences = func(list(not_in_cache))
for sentence, embedding in zip(not_in_cache, processed_sentences):
inner_cache[sentence] = embedding
return [inner_cache[s] for s in sentences]
return inner
@cache
def _encode(sentences: List[str]):
array = [numpy.around(a.numpy(), 3) for a in MODEL.encode(sentences, normalize_embeddings=True, convert_to_tensor=True, batch_size=4, show_progress_bar=True)]
return array
async def encode(sentences: List[str]) -> List[numpy.ndarray]:
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, _encode, sentences)
return result
class EmbedReq(BaseModel):
sentences: List[str]
app = FastAPI()
@app.post("/embed")
async def embed(embed: EmbedReq):
result = await encode(embed.sentences)
# Convert it to an ordinary list of floats
return ujson.dumps([r.tolist() for r in result])
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)