from sentence_transformers import SentenceTransformer from fastapi import FastAPI import pickle import pandas as pd from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware import torch corpus = pickle.load(open("./corpus/all_embeddings_disease.pickle", "rb")).astype("float") # label_encoder = pickle.load(open("./corpus/label_encoder.pickle", "rb")) # model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2') df = pd.DataFrame(pickle.load(open("./corpus/y_all_disease.pickle", "rb"))) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class Disease(BaseModel): id: int name: str url: str score: float class Symptoms(BaseModel): query: str @app.get("/") def home(): print(df.iloc[0]) return {"Hello": "World!"} @app.post("/", response_model=list[Disease]) async def predict(symptoms: Symptoms): query_embedding = model.encode(symptoms.query).astype('float') similarity_vectors = model.similarity(query_embedding, corpus)[0] scores, indicies = torch.topk(similarity_vectors, k=len(corpus)) # id_ = df.iloc[indicies].reset_index(drop=True) ls = df.iloc[indicies].copy() # print(ls.iloc[0]) # id_ = id_.drop_duplicates("label") ls["scores"] = scores # scores = scores[id_.index] # diseases = label_encoder.inverse_transform(id_.label.values) # id_ = id_.label.values diseases = [dict({"id": value[0], "name": value[1], "url" : value[2], "score" : value[3]}) for value in zip(ls.index, ls["name"], ls["url"], ls["scores"])] return diseases