Den-d3j2d's picture
Create app.py
8fed2d7 verified
raw
history blame
617 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import torch
from transformers import AutoTokenizer, AutoModel
from src.model.encoder import ProdFeatureEncoder
from src.config import ModelConfig
app = FastAPI()
class TextInput(BaseModel):
text: str
class EmbeddingOutput(BaseModel):
embedding: List[float]
@app.post("/encode_text", response_model=EmbeddingOutput)
async def encode_text(text_input: TextInput):
config = ModelConfig()
model = ProdFeatureEncoder(config=config)
embedding = model(text_input.text)
return {"embedding": embedding.tolist()}