kedimestan's picture
Update app.py
57e898f verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import vgg19
from PIL import Image
import io
from huggingface_hub import hf_hub_download
# Cihaz ayarı
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hugging Face modelini yükleme bilgileri
REPO_ID = "kedimestan/retinoblastomaDetectionVGG19"
MODEL_FILE = "best_model.pth"
# Modeli indir ve yükle
def load_model():
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE)
# Model yapısını yeniden tanımla (VGG19 + özel sınıflandırıcı)
model = vgg19(pretrained=False)
model.classifier = nn.Sequential(
nn.Linear(25088, 12544),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(12544, 6272),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(6272, 3136),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(3136, 1568),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(1568, 784),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(784, 392),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(392, 196),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(196, 98),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(98, 49),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(49, 1),
nn.Sigmoid()
)
# Ağırlıkları yükle (state_dict), modeli yüklemeden önce
model.load_state_dict(torch.load(model_path, map_location=device))
# Cihazda modele yükle
model = model.to(device)
model.eval()
return model
model = load_model()
# Görüntü dönüşüm pipeline'ı
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Tahmin fonksiyonu
def predict(image: Image.Image):
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor).squeeze(0).cpu().numpy()
prediction = "Normal" if output[0] > 0.5 else "Hasta"
return prediction
# Gradio arayüzü
inputs = gr.Image(type="pil", label="Görsel Yükle")
outputs = gr.Textbox(label="Tahmin Sonucu")
gr.Interface(
fn=predict,
inputs=inputs,
outputs=outputs,
title="Retinoblastoma Tespiti",
theme="default"
).launch(debug=True)