mudaza's picture
update code
105b9cf
raw
history blame
1.45 kB
from sentence_transformers import SentenceTransformer
from fastapi import FastAPI
import pickle
import pandas as pd
from pydantic import BaseModel
corpus = pickle.load(open("./corpus/all_embeddings.pickle", "rb"))
label_encoder = pickle.load("./corpus/label_encoder.pickle", "rb")
model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
df = pd.DataFrame(data={"label": pickle.load(open("./corpus/y_all.pickle", "rb"))})
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Disease(BaseModel):
id: int
name: str
score: float
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/")
async def greet_post():
return {"Hello": "Post World!"}
# @app.post("/", response_model=list[Disease])
# async def predict(query: str):
# query_embedding = model.encode(query).astype('float')
# similarity_vectors = model.similarity(q, all_embeddings)
# scores, indicies = torch.topk(similarity_vectors, k=len(all_embeddings))
# id_ = df.iloc[indicies]
# id_ = df.drop_duplicates("label")
# scores = scores[id_.index]
# diseases = label_encoder.inverse_transform(id_.label.values)
# id_ = id_.label.values
# diseases = [dict("id": value[0], "name": value[1], "score" : value[2]) for value in zip(id_, diseases, scores)]
# return diseases