jonathanjordan21's picture
Update app.py
4cdd282 verified
from fastapi import FastAPI
from sentence_transformers import CrossEncoder, SentenceTransformer
from sentence_transformers.util import cos_sim
import torch
import numpy as np
from typing import List
from pydantic import BaseModel
app = FastAPI()
class InputListModel(BaseModel):
keywords: List[str]
contents: List[str]
class InputModel(BaseModel):
keyword: str
content: str
# model = CrossEncoder(
# # "jinaai/jina-reranker-v2-base-multilingual",
# "Alibaba-NLP/gte-multilingual-reranker-base",
# trust_remote_code=True,
# )
model = SentenceTransformer(
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
trust_remote_code=True
)
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/predict")
async def predict(inp: InputModel):
text_emb = model.encode(inp.content, convert_to_tensor=True)
summarize = model.encode(inp.keyword, convert_to_tensor=True)
out = (torch.nn.functional.cosine_similarity(text_emb, summarize, dim=-1) + 1)/2
# out = (cos_sim(text_emb, summarize) + 1)/2
return {"results":out.tolist()}
@app.post("/predict_list")
async def predict_list(inp: InputListModel):
text_emb = model.encode(inp.contents, convert_to_tensor=True)
summarize = model.encode(inp.keywords, convert_to_tensor=True)
out = (torch.nn.functional.cosine_similarity(text_emb, summarize, dim=-1) + 1)/2
# out = (cos_sim(text_emb, summarize) + 1)/2
return {"results":out.tolist()}
# @app.post("/predict_list")
# async def predict_list(inp : InputListModel):
# sentence_pairs = [[query, doc] for query,doc in zip(inp.keywords, inp.contents)]
# scores = model.predict(sentence_pairs, convert_to_tensor=False)#.tolist()
# # (-scores).argsort().tolist()
# return {"results":scores.tolist()}
# @app.post("/predict")
# async def predict(inp : InputModel):
# sentence_pairs = [[inp.keyword, inp.content]]
# scores = model.predict(sentence_pairs, convert_to_tensor=False)#.tolist()
# # (-scores).argsort().tolist()
# return {"results":scores.tolist()[0]}
# keywords = model.encode(inp.keywords)
# contents = model.encode(inp.contents)
# return {"results":np.linalg.norm(contents-keywords).tolist()}