kedimestan commited on
Commit
217fb69
·
verified ·
1 Parent(s): da2762e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -37
app.py CHANGED
@@ -1,84 +1,99 @@
1
- from fastapi import FastAPI, UploadFile, HTTPException
2
  from fastapi.responses import JSONResponse
3
- import uvicorn
4
  import torch
5
  import torch.nn as nn
6
- from torchvision.models import vgg19
7
  from torchvision import transforms
 
8
  from PIL import Image
9
  import io
 
10
 
11
- # FastAPI uygulaması
12
  app = FastAPI()
13
 
14
- # Modeli yükle ve ayarla
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
 
 
 
 
 
17
  def load_model():
18
- model = vgg19(pretrained=False) # Pretrained ağırlıklar olmadan model
 
 
 
19
  model.classifier = nn.Sequential(
20
- nn.Linear(25088, 12544), # 25088 -> 12544
21
  nn.ReLU(),
22
  nn.Dropout(0.5),
23
- nn.Linear(12544, 6272), # 12544 -> 6272
24
  nn.ReLU(),
25
  nn.Dropout(0.5),
26
- nn.Linear(6272, 3136), # 6272 -> 3136
27
  nn.ReLU(),
28
  nn.Dropout(0.4),
29
- nn.Linear(3136, 1568), # 3136 -> 1568
30
  nn.ReLU(),
31
  nn.Dropout(0.4),
32
- nn.Linear(1568, 784), # 1568 -> 784
33
  nn.ReLU(),
34
  nn.Dropout(0.3),
35
- nn.Linear(784, 392), # 784 -> 392
36
  nn.ReLU(),
37
  nn.Dropout(0.3),
38
- nn.Linear(392, 196), # 392 -> 196
39
  nn.ReLU(),
40
  nn.Dropout(0.2),
41
- nn.Linear(196, 98), # 196 -> 98
42
  nn.ReLU(),
43
  nn.Dropout(0.2),
44
- nn.Linear(98, 49), # 98 -> 49
45
  nn.ReLU(),
46
  nn.Dropout(0.1),
47
- nn.Linear(49, 1), # 49 -> 1 (Binary classification)
48
  nn.Sigmoid()
49
  )
 
 
 
50
  model = model.to(device)
51
- model.load_state_dict(torch.load("best_model.pth", map_location=device))
52
  model.eval()
53
  return model
54
 
55
  model = load_model()
56
 
57
- # Görsel işleme transformasyonu
58
  transform = transforms.Compose([
59
  transforms.Resize((224, 224)),
60
  transforms.ToTensor(),
61
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
62
  ])
63
 
64
- def process_image(file: UploadFile):
65
- try:
66
- image = Image.open(io.BytesIO(file.file.read())).convert("RGB")
67
- return transform(image).unsqueeze(0).to(device)
68
- except Exception as e:
69
- raise HTTPException(status_code=400, detail=f"Invalid image file: {str(e)}")
 
70
 
71
- # Ana endpoint
72
  @app.post("/predict")
73
- async def predict(file: UploadFile):
74
- if not file.content_type.startswith("image/"):
75
- raise HTTPException(status_code=400, detail="File must be an image.")
76
- image_tensor = process_image(file)
77
- with torch.no_grad():
78
- output = model(image_tensor)
79
- prediction = float(output.item())
80
- return JSONResponse({"prediction": prediction})
 
 
 
81
 
82
- # Uygulama başlatıcı
83
- if __name__ == "__main__":
84
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
  from fastapi.responses import JSONResponse
 
3
  import torch
4
  import torch.nn as nn
 
5
  from torchvision import transforms
6
+ from torchvision.models import vgg19
7
  from PIL import Image
8
  import io
9
+ from huggingface_hub import hf_hub_download
10
 
11
+ # FastAPI uygulamasını başlat
12
  app = FastAPI()
13
 
14
+ # Cihaz ayarı
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+ # Hugging Face modelini yükleme bilgileri
18
+ REPO_ID = "kedimestan/retinoblastomaDetectionVGG19"
19
+ MODEL_FILE = "pytorch_model.bin"
20
+
21
+ # Modeli indir ve yükle
22
  def load_model():
23
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE)
24
+
25
+ # Model yapısını yeniden tanımla (VGG19 + özel sınıflandırıcı)
26
+ model = vgg19(pretrained=False)
27
  model.classifier = nn.Sequential(
28
+ nn.Linear(25088, 12544),
29
  nn.ReLU(),
30
  nn.Dropout(0.5),
31
+ nn.Linear(12544, 6272),
32
  nn.ReLU(),
33
  nn.Dropout(0.5),
34
+ nn.Linear(6272, 3136),
35
  nn.ReLU(),
36
  nn.Dropout(0.4),
37
+ nn.Linear(3136, 1568),
38
  nn.ReLU(),
39
  nn.Dropout(0.4),
40
+ nn.Linear(1568, 784),
41
  nn.ReLU(),
42
  nn.Dropout(0.3),
43
+ nn.Linear(784, 392),
44
  nn.ReLU(),
45
  nn.Dropout(0.3),
46
+ nn.Linear(392, 196),
47
  nn.ReLU(),
48
  nn.Dropout(0.2),
49
+ nn.Linear(196, 98),
50
  nn.ReLU(),
51
  nn.Dropout(0.2),
52
+ nn.Linear(98, 49),
53
  nn.ReLU(),
54
  nn.Dropout(0.1),
55
+ nn.Linear(49, 1),
56
  nn.Sigmoid()
57
  )
58
+
59
+ # Model ağırlıklarını yükle
60
+ model.load_state_dict(torch.load(model_path, map_location=device))
61
  model = model.to(device)
 
62
  model.eval()
63
  return model
64
 
65
  model = load_model()
66
 
67
+ # Görüntü dönüşüm pipeline'ı
68
  transform = transforms.Compose([
69
  transforms.Resize((224, 224)),
70
  transforms.ToTensor(),
71
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
72
  ])
73
 
74
+ # Tahmin fonksiyonu
75
+ def predict(image: Image.Image):
76
+ input_tensor = transform(image).unsqueeze(0).to(device)
77
+ with torch.no_grad():
78
+ output = model(input_tensor).squeeze(0).cpu().numpy()
79
+ prediction = "Positive" if output[0] > 0.5 else "Negative"
80
+ return {"Prediction": prediction, "Probability": round(float(output[0]), 2)}
81
 
82
+ # Ana API rotası
83
  @app.post("/predict")
84
+ async def predict_image(file: UploadFile = File(...)):
85
+ try:
86
+ # Görüntüyü oku
87
+ image_data = await file.read()
88
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
89
+
90
+ # Tahmin yap
91
+ result = predict(image)
92
+ return JSONResponse(content=result)
93
+ except Exception as e:
94
+ return JSONResponse(content={"error": str(e)}, status_code=400)
95
 
96
+ # Ana sayfa
97
+ @app.get("/")
98
+ def home():
99
+ return {"message": "Upload an image to /predict for classification."}