Codewithsalty commited on
Commit
374a9d4
·
verified ·
1 Parent(s): 44b85c7

Update newapi.py

Browse files
Files changed (1) hide show
  1. newapi.py +7 -3
newapi.py CHANGED
@@ -8,25 +8,28 @@ from utils import BrainTumorModel, GliomaStageModel
8
 
9
  app = FastAPI()
10
 
11
- # Local model paths (files must be uploaded to the Space)
12
  btd_model_path = "brain_tumor_model.pth"
13
  glioma_model_path = "glioma_stage_model.pth"
14
 
15
- # Load models
16
  btd_model = BrainTumorModel()
17
  btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu')))
18
  btd_model.eval()
19
 
 
20
  glioma_model = GliomaStageModel()
21
  glioma_model.load_state_dict(torch.load(glioma_model_path, map_location=torch.device('cpu')))
22
  glioma_model.eval()
23
 
24
- # Image transformation
25
  transform = transforms.Compose([
26
  transforms.Resize((224, 224)),
27
  transforms.ToTensor(),
28
  ])
29
 
 
 
30
  @app.post("/predict/")
31
  async def predict(file: UploadFile = File(...)):
32
  try:
@@ -45,6 +48,7 @@ async def predict(file: UploadFile = File(...)):
45
  except Exception as e:
46
  return JSONResponse(content={"error": str(e)})
47
 
 
48
  @app.post("/glioma-stage/")
49
  async def glioma_stage(file: UploadFile = File(...)):
50
  try:
 
8
 
9
  app = FastAPI()
10
 
11
+ # === Use exact filenames from the Space directory ===
12
  btd_model_path = "brain_tumor_model.pth"
13
  glioma_model_path = "glioma_stage_model.pth"
14
 
15
+ # === Load Brain Tumor Model ===
16
  btd_model = BrainTumorModel()
17
  btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu')))
18
  btd_model.eval()
19
 
20
+ # === Load Glioma Stage Model ===
21
  glioma_model = GliomaStageModel()
22
  glioma_model.load_state_dict(torch.load(glioma_model_path, map_location=torch.device('cpu')))
23
  glioma_model.eval()
24
 
25
+ # === Image Transform ===
26
  transform = transforms.Compose([
27
  transforms.Resize((224, 224)),
28
  transforms.ToTensor(),
29
  ])
30
 
31
+ # === Routes ===
32
+
33
  @app.post("/predict/")
34
  async def predict(file: UploadFile = File(...)):
35
  try:
 
48
  except Exception as e:
49
  return JSONResponse(content={"error": str(e)})
50
 
51
+
52
  @app.post("/glioma-stage/")
53
  async def glioma_stage(file: UploadFile = File(...)):
54
  try: