kedimestan commited on
Commit
c22075a
·
verified ·
1 Parent(s): 85a55f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -6
app.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  from torchvision import transforms
5
  from PIL import Image
6
  import io
 
7
 
8
  # FastAPI uygulamasını başlat
9
  app = FastAPI()
@@ -11,8 +12,13 @@ app = FastAPI()
11
  # Cihaz ayarı
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
- # Eğitilmiş modeli yükleme
15
- model = torch.load("best_model.pth", map_location=device)
 
 
 
 
 
16
  model.eval()
17
 
18
  # Görüntü dönüşüm pipeline'ı
@@ -26,9 +32,9 @@ transform = transforms.Compose([
26
  def predict(image: Image.Image):
27
  input_tensor = transform(image).unsqueeze(0).to(device)
28
  with torch.no_grad():
29
- output = model(input_tensor).item() # Tahmini al
30
- prediction = "Positive" if output > 0.5 else "Negative"
31
- return {"Prediction": prediction, "Probability": round(output, 2)}
32
 
33
  # Ana API rotası
34
  @app.post("/predict")
@@ -47,4 +53,4 @@ async def predict_image(file: UploadFile = File(...)):
47
  # Ana sayfa
48
  @app.get("/")
49
  def home():
50
- return {"message": "Upload an image to /predict for classification."}
 
4
  from torchvision import transforms
5
  from PIL import Image
6
  import io
7
+ from huggingface_hub import hf_hub_download
8
 
9
  # FastAPI uygulamasını başlat
10
  app = FastAPI()
 
12
  # Cihaz ayarı
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
+ # Hugging Face modelini yükleme
16
+ REPO_ID = "kedimestan/retinoblastomaDetectionVGG19"
17
+ MODEL_FILE = "best_model.bin"
18
+
19
+ # Model dosyasını indirin
20
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE)
21
+ model = torch.load(model_path, map_location=device)
22
  model.eval()
23
 
24
  # Görüntü dönüşüm pipeline'ı
 
32
  def predict(image: Image.Image):
33
  input_tensor = transform(image).unsqueeze(0).to(device)
34
  with torch.no_grad():
35
+ output = model(input_tensor).squeeze(0).cpu().numpy()
36
+ prediction = "Positive" if output[0] > 0.5 else "Negative"
37
+ return {"Prediction": prediction, "Probability": round(float(output[0]), 2)}
38
 
39
  # Ana API rotası
40
  @app.post("/predict")
 
53
  # Ana sayfa
54
  @app.get("/")
55
  def home():
56
+ return {"message": "Upload an image to /predict for classification."}