File size: 1,619 Bytes
546f352
0bda39c
7e9eb72
1587eea
e396781
 
 
 
 
 
 
7e9eb72
c10710a
 
546f352
7e9eb72
e2f6cfc
 
 
 
 
 
 
 
7e9eb72
546f352
c10710a
7e9eb72
e2f6cfc
546f352
7e9eb72
7b2689f
7e9eb72
567adb4
 
 
 
 
 
 
546f352
7e9eb72
 
e2f6cfc
546f352
67daaca
 
7e9eb72
0387452
7e9eb72
e2f6cfc
 
 
7e9eb72
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import gradio as gr
from fastai.vision.all import *
import os

# Define the functions to get the x and y values from the input dictionary - in this case, the x value is the image and the y value is the diagnosis
# needed to load the model since we defined them during training
def get_x(r): return ""
        
def get_y(r): return r['diagnosis']


# Load model
learn = load_learner('model.pkl')
labels = learn.dls.vocab

# Define label descriptions
label_descriptions = {
    0: "No DR",
    1: "Mild",
    2: "Moderate",
    3: "Severe",
    4: "Proliferative DR"
}

# Prediction function
def predict(img):
    img = PILImage.create(img)
    pred, pred_idx, probs = learn.predict(img)
    return {label_descriptions[labels[i]]: float(probs[i]) for i in range(len(labels))}

# Gradio Interface
title = "Diabetic Retinopathy Detection"
description = """Detects severity of diabetic retinopathy from a given retina image."""

article = """
<p style='text-align: center'>
    <a href='https://www.kaggle.com/code/josemauriciodelgado/proliferative-retinopathy' target='_blank'>Kaggle Training Notebook</a> | 
    <a href='https://huggingface.co/jdelgado2002/diabetic_retinopathy_detection' target='_blank'>Model Card</a>
</p>
"""

# Prepare examples if available
test_folder = "test"
image_paths = [os.path.join(test_folder, img) for img in os.listdir(test_folder) if img.endswith(('.png', '.jpg', '.jpeg'))]

gr.Interface(
    fn=predict,
    inputs=gr.Image(type="filepath"),
    outputs=gr.Label(num_top_classes=5),
    examples=image_paths,
    article=article,
    title=title,
    description=description,
).launch()