kedimestan commited on
Commit
da2762e
·
verified ·
1 Parent(s): 1309856

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -38
app.py CHANGED
@@ -1,56 +1,84 @@
1
- from fastapi import FastAPI, UploadFile, File
2
  from fastapi.responses import JSONResponse
 
3
  import torch
 
 
4
  from torchvision import transforms
5
  from PIL import Image
6
  import io
7
- from huggingface_hub import hf_hub_download
8
 
9
- # FastAPI uygulamasını başlat
10
  app = FastAPI()
11
 
12
- # Cihaz ayarı
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
- # Hugging Face modelini yükleme
16
- REPO_ID = "kedimestan/retinoblastomaDetectionVGG19"
17
- MODEL_FILE = "best_model.pth"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # Model dosyasını indirin
20
- model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE)
21
- model = torch.load(model_path, map_location=device)
22
- model.eval()
23
 
24
- # Görüntü dönüşüm pipeline'ı
25
  transform = transforms.Compose([
26
- transforms.Resize((224, 224)), # Modelle uyumlu olacak şekilde yeniden boyutlandır
27
- transforms.ToTensor(), # Tensor'a çevir
28
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize et
29
  ])
30
 
31
- # Tahmin fonksiyonu
32
- def predict(image: Image.Image):
33
- input_tensor = transform(image).unsqueeze(0).to(device)
34
- with torch.no_grad():
35
- output = model(input_tensor).squeeze(0).cpu().numpy()
36
- prediction = "Positive" if output[0] > 0.5 else "Negative"
37
- return {"Prediction": prediction, "Probability": round(float(output[0]), 2)}
38
-
39
- # Ana API rotası
40
- @app.post("/predict")
41
- async def predict_image(file: UploadFile = File(...)):
42
  try:
43
- # Görüntüyü oku
44
- image_data = await file.read()
45
- image = Image.open(io.BytesIO(image_data)).convert("RGB")
46
-
47
- # Tahmin yap
48
- result = predict(image)
49
- return JSONResponse(content=result)
50
  except Exception as e:
51
- return JSONResponse(content={"error": str(e)}, status_code=400)
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- # Ana sayfa
54
- @app.get("/")
55
- def home():
56
- return {"message": "Upload an image to /predict for classification."}
 
1
+ from fastapi import FastAPI, UploadFile, HTTPException
2
  from fastapi.responses import JSONResponse
3
+ import uvicorn
4
  import torch
5
+ import torch.nn as nn
6
+ from torchvision.models import vgg19
7
  from torchvision import transforms
8
  from PIL import Image
9
  import io
 
10
 
11
+ # FastAPI uygulaması
12
  app = FastAPI()
13
 
14
+ # Modeli yükle ve ayarla
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+ def load_model():
18
+ model = vgg19(pretrained=False) # Pretrained ağırlıklar olmadan model
19
+ model.classifier = nn.Sequential(
20
+ nn.Linear(25088, 12544), # 25088 -> 12544
21
+ nn.ReLU(),
22
+ nn.Dropout(0.5),
23
+ nn.Linear(12544, 6272), # 12544 -> 6272
24
+ nn.ReLU(),
25
+ nn.Dropout(0.5),
26
+ nn.Linear(6272, 3136), # 6272 -> 3136
27
+ nn.ReLU(),
28
+ nn.Dropout(0.4),
29
+ nn.Linear(3136, 1568), # 3136 -> 1568
30
+ nn.ReLU(),
31
+ nn.Dropout(0.4),
32
+ nn.Linear(1568, 784), # 1568 -> 784
33
+ nn.ReLU(),
34
+ nn.Dropout(0.3),
35
+ nn.Linear(784, 392), # 784 -> 392
36
+ nn.ReLU(),
37
+ nn.Dropout(0.3),
38
+ nn.Linear(392, 196), # 392 -> 196
39
+ nn.ReLU(),
40
+ nn.Dropout(0.2),
41
+ nn.Linear(196, 98), # 196 -> 98
42
+ nn.ReLU(),
43
+ nn.Dropout(0.2),
44
+ nn.Linear(98, 49), # 98 -> 49
45
+ nn.ReLU(),
46
+ nn.Dropout(0.1),
47
+ nn.Linear(49, 1), # 49 -> 1 (Binary classification)
48
+ nn.Sigmoid()
49
+ )
50
+ model = model.to(device)
51
+ model.load_state_dict(torch.load("best_model.pth", map_location=device))
52
+ model.eval()
53
+ return model
54
 
55
+ model = load_model()
 
 
 
56
 
57
+ # Görsel işleme transformasyonu
58
  transform = transforms.Compose([
59
+ transforms.Resize((224, 224)),
60
+ transforms.ToTensor(),
61
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
62
  ])
63
 
64
+ def process_image(file: UploadFile):
 
 
 
 
 
 
 
 
 
 
65
  try:
66
+ image = Image.open(io.BytesIO(file.file.read())).convert("RGB")
67
+ return transform(image).unsqueeze(0).to(device)
 
 
 
 
 
68
  except Exception as e:
69
+ raise HTTPException(status_code=400, detail=f"Invalid image file: {str(e)}")
70
+
71
+ # Ana endpoint
72
+ @app.post("/predict")
73
+ async def predict(file: UploadFile):
74
+ if not file.content_type.startswith("image/"):
75
+ raise HTTPException(status_code=400, detail="File must be an image.")
76
+ image_tensor = process_image(file)
77
+ with torch.no_grad():
78
+ output = model(image_tensor)
79
+ prediction = float(output.item())
80
+ return JSONResponse({"prediction": prediction})
81
 
82
+ # Uygulama başlatıcı
83
+ if __name__ == "__main__":
84
+ uvicorn.run(app, host="0.0.0.0", port=8000)