|
from sentence_transformers import SentenceTransformer |
|
from fastapi import FastAPI |
|
import pickle |
|
import pandas as pd |
|
from pydantic import BaseModel |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import torch |
|
|
|
corpus = pickle.load(open("./corpus/all_embeddings_disease.pickle", "rb")).astype("float") |
|
|
|
|
|
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2') |
|
df = pd.DataFrame(pickle.load(open("./corpus/y_all_disease.pickle", "rb"))) |
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
class Disease(BaseModel): |
|
id: int |
|
name: str |
|
url: str |
|
score: float |
|
|
|
class Symptoms(BaseModel): |
|
query: str |
|
|
|
@app.get("/") |
|
def home(): |
|
print(df.iloc[0]) |
|
return {"Hello": "World!"} |
|
|
|
@app.post("/", response_model=list[Disease]) |
|
async def predict(symptoms: Symptoms): |
|
query_embedding = model.encode(symptoms.query).astype('float') |
|
similarity_vectors = model.similarity(query_embedding, corpus)[0] |
|
scores, indicies = torch.topk(similarity_vectors, k=len(corpus)) |
|
|
|
ls = df.iloc[indicies].copy() |
|
|
|
|
|
ls["scores"] = scores |
|
|
|
|
|
|
|
diseases = [dict({"id": value[0], |
|
"name": value[1], |
|
"url" : value[2], |
|
"score" : value[3]}) |
|
for value in zip(ls.index, ls["name"], ls["url"], ls["scores"])] |
|
return diseases |