File size: 2,903 Bytes
af9162a
52893ae
 
501b516
a983c72
af9162a
ba74db2
501b516
 
 
a983c72
da5fdaa
ece0ce5
 
 
fd548c6
ece0ce5
fd548c6
ece0ce5
 
da5fdaa
501b516
a983c72
501b516
a983c72
501b516
da5fdaa
a983c72
501b516
af9162a
a983c72
ba74db2
501b516
ba74db2
69590ad
 
866446d
ba74db2
ece0ce5
8a7fe4e
 
da5fdaa
 
 
 
866446d
 
69590ad
866446d
8a7fe4e
 
ba74db2
 
a983c72
a4dc15b
501b516
 
 
dbb94b0
7f74cd7
 
 
501b516
fd548c6
a983c72
866446d
744ad2f
 
 
ba74db2
26a9ba5
501b516
 
fd548c6
744ad2f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
import deepinv as dinv
import torch
import numpy as np
import PIL.Image


def pil_to_torch(image):
    image = np.array(image)
    image = image.transpose((2, 0, 1))
    image = torch.tensor(image).float() / 255
    image = image.unsqueeze(0)

    ref_size = 256
    if image.shape[2] > image.shape[3]:
        size = (ref_size, ref_size * image.shape[3]//image.shape[2])
    else:
        size = (ref_size * image.shape[2]//image.shape[3], ref_size)

    image = torch.nn.functional.interpolate(image, size=size, mode='bilinear')
    return image


def torch_to_pil(image):
    image = image.squeeze(0).cpu().detach().numpy()
    image = image.transpose((1, 2, 0))
    image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
    image = PIL.Image.fromarray(image)
    return image


def image_mod(image, noise_level, denoiser):
    image = pil_to_torch(image)
    if denoiser == 'DnCNN':
        den = dinv.models.DnCNN()
        sigma0 = 2/255
        denoiser = lambda x, sigma: den(x*sigma0/sigma)*sigma/sigma0
    elif denoiser == 'MedianFilter':
        denoiser = dinv.models.MedianFilter(kernel_size=5)
    elif denoiser == 'BM3D':
        denoiser = dinv.models.BM3D()
    elif denoiser == 'TV':
        denoiser = dinv.models.TVDenoiser()
    elif denoiser == 'TGV':
        denoiser = dinv.models.TGVDenoiser()
    elif denoiser == 'Wavelets':
        denoiser = dinv.models.WaveletPrior()
    elif denoiser == 'SwinIR':
        denoiser = dinv.models.SwinIR(img_size=256)
    elif denoiser == 'DRUNet':
        denoiser = dinv.models.DRUNet()
    else:
        raise ValueError("Invalid denoiser")
    noisy = image + torch.randn_like(image) * noise_level
    estimated = denoiser(noisy, noise_level)
    return torch_to_pil(noisy), torch_to_pil(estimated)


input_image = gr.Image(label='Input Image')
output_images = gr.Image(label='Denoised Image')
noise_image = gr.Image(label='Noisy Image')
input_image_output = gr.Image(label='Input Image')

noise_levels = gr.Dropdown(choices=[0.1, 0.2, 0.3, 0.5, 1], value=0.1, label='Noise Level')

denoiser = gr.Dropdown(choices=['DnCNN', 'DRUNet', 'SwinIR', 'BM3D', 'MedianFilter', 'TV', 'TGV', 'Wavelets'], value='DnCNN', label='Denoiser')

demo = gr.Interface(
    image_mod,
    inputs=[input_image, noise_levels, denoiser],
    examples=[['https://deepinv.github.io/deepinv/_static/deepinv_logolarge.png', 0.1, 'DnCNN']],
    outputs=[noise_image, output_images],
    title="Image Denoising with DeepInverse",
    description="Denoise an image using a variety of denoisers and noise levels using the deepinverse library (https://deepinv.github.io/). We only include lightweight models like DnCNN and MedianFilter as this example is intended to be run on a CPU. We also automatically resize the input image to 256 pixels to reduce the computation time. For more advanced models, please run the code locally.",
)

demo.launch()