File size: 2,438 Bytes
d501ffc
69047b7
da2762e
69047b7
217fb69
69047b7
 
217fb69
69047b7
217fb69
69047b7
 
217fb69
 
81f7942
217fb69
 
da2762e
217fb69
 
 
 
da2762e
217fb69
da2762e
 
217fb69
da2762e
 
217fb69
da2762e
 
217fb69
da2762e
 
217fb69
da2762e
 
217fb69
da2762e
 
217fb69
da2762e
 
217fb69
da2762e
 
217fb69
da2762e
 
217fb69
da2762e
 
217fb69
069eaa5
 
 
 
da2762e
 
 
c22075a
da2762e
69047b7
217fb69
69047b7
da2762e
 
217fb69
69047b7
 
217fb69
 
 
 
 
57e898f
 
da2762e
d501ffc
d164cde
57e898f
d501ffc
d164cde
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import vgg19
from PIL import Image
import io
from huggingface_hub import hf_hub_download

# Cihaz ayarı
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hugging Face modelini yükleme bilgileri
REPO_ID = "kedimestan/retinoblastomaDetectionVGG19"
MODEL_FILE = "best_model.pth"

# Modeli indir ve yükle
def load_model():
    model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE)
    
    # Model yapısını yeniden tanımla (VGG19 + özel sınıflandırıcı)
    model = vgg19(pretrained=False)
    model.classifier = nn.Sequential(
        nn.Linear(25088, 12544),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(12544, 6272),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(6272, 3136),
        nn.ReLU(),
        nn.Dropout(0.4),
        nn.Linear(3136, 1568),
        nn.ReLU(),
        nn.Dropout(0.4),
        nn.Linear(1568, 784),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(784, 392),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(392, 196),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(196, 98),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(98, 49),
        nn.ReLU(),
        nn.Dropout(0.1),
        nn.Linear(49, 1),
        nn.Sigmoid()
    )
    
    # Ağırlıkları yükle (state_dict), modeli yüklemeden önce
    model.load_state_dict(torch.load(model_path, map_location=device))
    
    # Cihazda modele yükle
    model = model.to(device)
    model.eval()
    return model

model = load_model()

# Görüntü dönüşüm pipeline'ı
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Tahmin fonksiyonu
def predict(image: Image.Image):
    input_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(input_tensor).squeeze(0).cpu().numpy()
        prediction = "Normal" if output[0] > 0.5 else "Hasta"
        return prediction

# Gradio arayüzü
inputs = gr.Image(type="pil", label="Görsel Yükle")
outputs = gr.Textbox(label="Tahmin Sonucu")

gr.Interface(
    fn=predict,
    inputs=inputs,
    outputs=outputs,
    title="Retinoblastoma Tespiti",
    theme="default"
).launch(debug=True)