File size: 1,656 Bytes
15216b5
68ba513
 
58b19a1
 
 
 
 
15216b5
58b19a1
 
 
 
 
 
 
 
 
 
b38da9b
58b19a1
 
7b7ab95
f489399
58b19a1
 
 
 
 
 
7b7ab95
58b19a1
 
 
 
65e5c64
cd8aa5a
0260030
58b19a1
15216b5
 
7b7ab95
58b19a1
 
f3dfeae
 
 
7b7ab95
f3dfeae
7b7ab95
f3dfeae
58b19a1
 
f3dfeae
 
f6524bf
7b7ab95
 
f3dfeae
58b19a1
f3dfeae
 
 
 
 
 
 
15216b5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
import monai
import torch
from monai.networks.nets import UNet
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np


model = UNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=1,
    channels=[16, 32, 64, 128, 256, 512],
    strides=(2, 2, 2, 2, 2),
    num_res_units=4,
    dropout=0.15,
)
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
model.eval()

def process_image(image):
    image = image / 255.0
    image = image.astype(np.float32)

    inference_transforms = A.Compose([
        A.Resize(height=512, width=512),
        ToTensorV2(),
    ])
    
    image = inference_transforms(image=image)["image"]
    image = image.unsqueeze(0)

    with torch.no_grad():
        mask_pred = torch.sigmoid(model(image))

    return mask_pred[0, 0, :, :].numpy()
    

demo = gr.Interface(
    fn=process_image,
    title="Histapathology segmentation",
    inputs=[
        gr.Image(
            label="Input image",
            image_mode="RGB",
            height=400,
            type="numpy",
            width=400,
        )
    ],
    outputs=[
        gr.Image(
            label="Model Prediction",
            image_mode="L",
            height=400,
            width=400,
        )
    ],
    # examples=[
    #     os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
    #     os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
    #     os.path.join(os.path.dirname(__file__), "images/logo.png"),
    #     os.path.join(os.path.dirname(__file__), "images/tower.jpg"),
    # ],

)

demo.launch()