from typing import Generic, List, Optional, TypeVar from functools import partial from pydantic import BaseModel, ValidationError, validator from pydantic.generics import GenericModel from sentence_transformers import SentenceTransformer from fastapi import FastAPI import os, asyncio, numpy, ujson MODEL = SentenceTransformer("all-mpnet-base-v2") def cache(func): inner_cache = dict() def inner(sentences: List[str]): if len(sentences) == 0: return [] not_in_cache = list(filter(lambda s: s not in inner_cache.keys(), sentences)) if len(not_in_cache) > 0: processed_sentences = func(list(not_in_cache)) for sentence, embedding in zip(not_in_cache, processed_sentences): inner_cache[sentence] = embedding return [inner_cache[s] for s in sentences] return inner @cache def _encode(sentences: List[str]): array = [numpy.around(a.numpy(), 3) for a in MODEL.encode(sentences, normalize_embeddings=True, convert_to_tensor=True, batch_size=4, show_progress_bar=True)] return array async def encode(sentences: List[str]) -> List[numpy.ndarray]: loop = asyncio.get_event_loop() result = await loop.run_in_executor(None, _encode, sentences) return result class SemanticSearchReq(BaseModel): query: str candidates: List[str] class EmbedReq(BaseModel): sentences: List[str] app = FastAPI() @app.post("/embed") async def embed(embed: EmbedReq): result = await encode(embed.sentences) # Convert it to an ordinary list of floats return ujson.dumps([r.tolist() for r in result])