mudaza commited on
Commit
b2f9c3c
·
1 Parent(s): acad5f5

update code and requirements

Browse files
Files changed (2) hide show
  1. app.py +29 -0
  2. requirements.txt +3 -0
app.py CHANGED
@@ -1,7 +1,36 @@
 
1
  from fastapi import FastAPI
 
 
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
 
 
 
 
 
5
  @app.get("/")
6
  def greet_json():
7
  return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
  from fastapi import FastAPI
3
+ import pickle
4
+ import pandas as pd
5
+ from pydantic import BaseModel
6
+
7
+ corpus = pickle.load(open("./corpus/all_embeddings.pickle", "rb"))
8
+ label_encoder = pickle.load("./corpus/label_encoder.pickle")
9
+ model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
10
+ df = pd.DataFrame(data={"label": pickle.load(open("./corpus/y_all.pickle"))})
11
 
12
  app = FastAPI()
13
 
14
+ class Disease(BaseModel):
15
+ id: int
16
+ name: str
17
+ score: float
18
+
19
  @app.get("/")
20
  def greet_json():
21
  return {"Hello": "World!"}
22
+
23
+ @app.post("/", response_model=list[Disease])
24
+ async def predict(query: str):
25
+ query_embedding = model.encode(query).astype('float')
26
+ similarity_vectors = model.similarity(q, all_embeddings)
27
+ scores, indicies = torch.topk(similarity_vectors, k=len(all_embeddings))
28
+ id = df.iloc[indicies]
29
+ id = df.drop_duplicates("label")
30
+ scores = scores[id.index]
31
+ diseases = label_encoder.inverse_transform(id.label.values)
32
+ id = id.label.values
33
+ diseases = [dict("id": value[0], "name": value[1], "score" : value[2]) for value in zip(id, diseases, scores)]
34
+ return diseases
35
+
36
+
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
  fastapi
2
  uvicorn[standard]
 
 
 
 
1
  fastapi
2
  uvicorn[standard]
3
+ pandas
4
+ sentence-transformers
5
+ pydantic