Codewithsalty commited on
Commit
37adc03
·
verified ·
1 Parent(s): 52e4323

Update newapi.py

Browse files
Files changed (1) hide show
  1. newapi.py +20 -14
newapi.py CHANGED
@@ -2,21 +2,20 @@
2
 
3
  from fastapi import FastAPI, File, UploadFile
4
  from fastapi.middleware.cors import CORSMiddleware
5
- from pydantic import BaseModel
6
  import torch
7
  import torchvision.transforms as transforms
8
  from PIL import Image
9
  import io
10
  import os
11
 
12
- # Use a writable directory in Hugging Face Spaces
13
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache"
14
  os.environ["HF_HOME"] = "/tmp/.cache"
15
 
16
- # Define FastAPI app
17
  app = FastAPI(title="🧠 Brain Tumor Detection API")
18
 
19
- # Enable CORS
20
  app.add_middleware(
21
  CORSMiddleware,
22
  allow_origins=["*"],
@@ -24,26 +23,27 @@ app.add_middleware(
24
  allow_headers=["*"],
25
  )
26
 
27
- # Image transform
28
  transform = transforms.Compose([
29
  transforms.Resize((224, 224)),
 
30
  transforms.ToTensor(),
31
  transforms.Normalize(mean=[0.5], std=[0.5]),
32
  ])
33
 
34
- # Define your model directly inside this file (to avoid import errors)
35
  import torch.nn as nn
36
 
37
  class BrainTumorModel(nn.Module):
38
  def __init__(self):
39
  super(BrainTumorModel, self).__init__()
40
- self.con1d = nn.Conv2d(3, 32, kernel_size=3)
41
  self.con2d = nn.Conv2d(32, 64, kernel_size=3)
42
  self.con3d = nn.Conv2d(64, 128, kernel_size=3)
43
  self.pool = nn.MaxPool2d(2)
44
- self.fc1 = nn.Linear(128 * 25 * 25, 256)
45
- self.fc2 = nn.Linear(256, 128)
46
- self.output = nn.Linear(128, 2)
47
 
48
  def forward(self, x):
49
  x = self.pool(torch.relu(self.con1d(x)))
@@ -65,25 +65,31 @@ btd_model = BrainTumorModel()
65
  btd_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
66
  btd_model.eval()
67
 
68
- # Define prediction endpoint
69
  @app.post("/predict/")
70
  async def predict(file: UploadFile = File(...)):
71
  try:
72
  contents = await file.read()
73
- image = Image.open(io.BytesIO(contents)).convert("RGB")
74
  image_tensor = transform(image).unsqueeze(0)
75
 
76
  with torch.no_grad():
77
  output = btd_model(image_tensor)
78
  prediction = torch.argmax(output, dim=1).item()
79
 
80
- result = {0: "No tumor", 1: "Tumor detected"}[prediction]
 
 
 
 
 
 
81
  return {"prediction": result}
82
 
83
  except Exception as e:
84
  return {"error": str(e)}
85
 
86
- # Health check endpoint
87
  @app.get("/")
88
  def root():
89
  return {"message": "🧠 Brain Tumor Detection API is running!"}
 
2
 
3
  from fastapi import FastAPI, File, UploadFile
4
  from fastapi.middleware.cors import CORSMiddleware
 
5
  import torch
6
  import torchvision.transforms as transforms
7
  from PIL import Image
8
  import io
9
  import os
10
 
11
+ # Set writable cache directories
12
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache"
13
  os.environ["HF_HOME"] = "/tmp/.cache"
14
 
15
+ # FastAPI setup
16
  app = FastAPI(title="🧠 Brain Tumor Detection API")
17
 
18
+ # Allow CORS
19
  app.add_middleware(
20
  CORSMiddleware,
21
  allow_origins=["*"],
 
23
  allow_headers=["*"],
24
  )
25
 
26
+ # Define image transform (grayscale)
27
  transform = transforms.Compose([
28
  transforms.Resize((224, 224)),
29
+ transforms.Grayscale(num_output_channels=1), # Ensure grayscale
30
  transforms.ToTensor(),
31
  transforms.Normalize(mean=[0.5], std=[0.5]),
32
  ])
33
 
34
+ # Define the exact same model used during training
35
  import torch.nn as nn
36
 
37
  class BrainTumorModel(nn.Module):
38
  def __init__(self):
39
  super(BrainTumorModel, self).__init__()
40
+ self.con1d = nn.Conv2d(1, 32, kernel_size=3) # Input is grayscale (1 channel)
41
  self.con2d = nn.Conv2d(32, 64, kernel_size=3)
42
  self.con3d = nn.Conv2d(64, 128, kernel_size=3)
43
  self.pool = nn.MaxPool2d(2)
44
+ self.fc1 = nn.Linear(128 * 25 * 25, 512) # Matches your saved model
45
+ self.fc2 = nn.Linear(512, 256)
46
+ self.output = nn.Linear(256, 4) # 4 classes expected
47
 
48
  def forward(self, x):
49
  x = self.pool(torch.relu(self.con1d(x)))
 
65
  btd_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
66
  btd_model.eval()
67
 
68
+ # Prediction endpoint
69
  @app.post("/predict/")
70
  async def predict(file: UploadFile = File(...)):
71
  try:
72
  contents = await file.read()
73
+ image = Image.open(io.BytesIO(contents)).convert("L") # Grayscale
74
  image_tensor = transform(image).unsqueeze(0)
75
 
76
  with torch.no_grad():
77
  output = btd_model(image_tensor)
78
  prediction = torch.argmax(output, dim=1).item()
79
 
80
+ result = {
81
+ 0: "No tumor",
82
+ 1: "Glioma",
83
+ 2: "Meningioma",
84
+ 3: "Pituitary tumor"
85
+ }[prediction]
86
+
87
  return {"prediction": result}
88
 
89
  except Exception as e:
90
  return {"error": str(e)}
91
 
92
+ # Health check
93
  @app.get("/")
94
  def root():
95
  return {"message": "🧠 Brain Tumor Detection API is running!"}