Codewithsalty commited on
Commit
fc30ac5
·
verified ·
1 Parent(s): 946ddc0

Update newapi.py

Browse files
Files changed (1) hide show
  1. newapi.py +16 -7
newapi.py CHANGED
@@ -5,24 +5,33 @@ from PIL import Image
5
  import torch
6
  import torchvision.transforms as transforms
7
  from utils import BrainTumorModel, GliomaStageModel
 
8
 
9
  app = FastAPI()
10
 
11
- # Load models (updated to local .pth files)
12
- btd_model_path = "brain_tumor_model.pth"
13
- glioma_model_path = "glioma_stage_model.pth"
 
 
 
14
 
15
- # Initialize and load Brain Tumor Detection Model
 
 
 
 
 
 
16
  btd_model = BrainTumorModel()
17
  btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu')))
18
  btd_model.eval()
19
 
20
- # Initialize and load Glioma Stage Detection Model
21
  glioma_model = GliomaStageModel()
22
  glioma_model.load_state_dict(torch.load(glioma_model_path, map_location=torch.device('cpu')))
23
  glioma_model.eval()
24
 
25
- # Define preprocessing
26
  transform = transforms.Compose([
27
  transforms.Resize((224, 224)),
28
  transforms.ToTensor(),
@@ -62,4 +71,4 @@ async def glioma_stage(file: UploadFile = File(...)):
62
  return JSONResponse(content={"glioma_stage": result})
63
 
64
  except Exception as e:
65
- return JSONResponse(content={"error": str(e)})
 
5
  import torch
6
  import torchvision.transforms as transforms
7
  from utils import BrainTumorModel, GliomaStageModel
8
+ from huggingface_hub import hf_hub_download
9
 
10
  app = FastAPI()
11
 
12
+ # Download models from the Space's repo
13
+ btd_model_path = hf_hub_download(
14
+ repo_id="Codewithsalty/brain-tumor-api",
15
+ filename="brain_tumor_model.pth",
16
+ repo_type="space"
17
+ )
18
 
19
+ glioma_model_path = hf_hub_download(
20
+ repo_id="Codewithsalty/brain-tumor-api",
21
+ filename="glioma_stage_model.pth",
22
+ repo_type="space"
23
+ )
24
+
25
+ # Load and prepare models
26
  btd_model = BrainTumorModel()
27
  btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu')))
28
  btd_model.eval()
29
 
 
30
  glioma_model = GliomaStageModel()
31
  glioma_model.load_state_dict(torch.load(glioma_model_path, map_location=torch.device('cpu')))
32
  glioma_model.eval()
33
 
34
+ # Image preprocessing
35
  transform = transforms.Compose([
36
  transforms.Resize((224, 224)),
37
  transforms.ToTensor(),
 
71
  return JSONResponse(content={"glioma_stage": result})
72
 
73
  except Exception as e:
74
+ return JSONResponse(content={"error": str(e)})