File size: 1,736 Bytes
75cff75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms

import torchxrayvision as xrv


def classify_image(img, model_name):

    model = xrv.models.get_model(model_name, from_hf_hub=True)

    img = xrv.datasets.normalize(img, 255)

    # Check that images are 2D arrays
    if len(img.shape) > 2:
        img = img[:, :, 0]
    if len(img.shape) < 2:
        print("error, dimension lower than 2 for image")

    # Add color channel
    img = img[None, :, :]

    transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop()])

    img = transform(img)

    with torch.no_grad():
        img = torch.from_numpy(img).unsqueeze(0)
        preds = model(img).cpu()
        output = {
            k: float(v)
            for k, v in zip(xrv.datasets.default_pathologies, preds[0].detach().numpy())
        }
        return output


gr.Interface(
    fn=classify_image,
    inputs=[
        gr.Image(shape=(224, 224), image_mode="L"),
        gr.Dropdown(
            [
                "densenet121-res224-all",
                "densenet121-res224-nih",
                "densenet121-res224-pc",
                "densenet121-res224-chex",
                "densenet121-res224-rsna",
                "densenet121-res224-mimic_nb",
                "densenet121-res224-mimic_ch",
                "resnet50-res512-all",
            ],
            value="densenet121-res224-all",
            type="value",
            label="Pre-trained model",
        ),
    ],
    outputs=gr.outputs.Label(),
    title="Classify chest x-ray image",
    examples=[
        ["16747_3_1.jpg", "densenet121-res224-all"],
        ["00000001_000.png", "resnet50-res512-all"],
    ],
).launch()