Spaces:
Sleeping
Sleeping
from typing import Generic, List, Optional, TypeVar | |
from functools import partial | |
from pydantic import BaseModel, ValidationError, validator | |
from pydantic.generics import GenericModel | |
from sentence_transformers import SentenceTransformer | |
from fastapi import FastAPI | |
import os, asyncio, numpy, ujson | |
from fastapi.middleware.cors import CORSMiddleware | |
MODEL = SentenceTransformer("all-mpnet-base-v2") | |
def cache(func): | |
inner_cache = dict() | |
def inner(sentences: List[str]): | |
if len(sentences) == 0: | |
return [] | |
not_in_cache = list(filter(lambda s: s not in inner_cache.keys(), sentences)) | |
if len(not_in_cache) > 0: | |
processed_sentences = func(list(not_in_cache)) | |
for sentence, embedding in zip(not_in_cache, processed_sentences): | |
inner_cache[sentence] = embedding | |
return [inner_cache[s] for s in sentences] | |
return inner | |
def _encode(sentences: List[str]): | |
array = [numpy.around(a.numpy(), 3) for a in MODEL.encode(sentences, normalize_embeddings=True, convert_to_tensor=True, batch_size=4, show_progress_bar=True)] | |
return array | |
async def encode(sentences: List[str]) -> List[numpy.ndarray]: | |
loop = asyncio.get_event_loop() | |
result = await loop.run_in_executor(None, _encode, sentences) | |
return result | |
class EmbedReq(BaseModel): | |
sentences: List[str] | |
app = FastAPI() | |
async def embed(embed: EmbedReq): | |
result = await encode(embed.sentences) | |
# Convert it to an ordinary list of floats | |
return ujson.dumps([r.tolist() for r in result]) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) |