mudaza's picture
modified code and add files
c64e7a9
raw
history blame
1.87 kB
from sentence_transformers import SentenceTransformer
from fastapi import FastAPI
import pickle
import pandas as pd
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import torch
corpus = pickle.load(open("./corpus/all_embeddings_disease.pickle", "rb")).astype("float")
# label_encoder = pickle.load(open("./corpus/label_encoder.pickle", "rb"))
# model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')
df = pd.DataFrame(pickle.load(open("./corpus/y_all_disease.pickle", "rb")))
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Disease(BaseModel):
id: int
name: str
url: str
score: float
class Symptoms(BaseModel):
query: str
@app.get("/")
def home():
print(df.iloc[0])
return {"Hello": "World!"}
@app.post("/", response_model=list[Disease])
async def predict(symptoms: Symptoms):
query_embedding = model.encode(symptoms.query).astype('float')
similarity_vectors = model.similarity(query_embedding, corpus)[0]
scores, indicies = torch.topk(similarity_vectors, k=len(corpus))
# id_ = df.iloc[indicies].reset_index(drop=True)
ls = df.iloc[indicies].copy()
# print(ls.iloc[0])
# id_ = id_.drop_duplicates("label")
ls["scores"] = scores
# scores = scores[id_.index]
# diseases = label_encoder.inverse_transform(id_.label.values)
# id_ = id_.label.values
diseases = [dict({"id": value[0],
"name": value[1],
"url" : value[2],
"score" : value[3]})
for value in zip(ls.index, ls["name"], ls["url"], ls["scores"])]
return diseases