File size: 3,003 Bytes
077a255
b6d8eef
 
0e089a6
c472fe6
077a255
eb570bb
d63f7ea
0e089a6
 
c472fe6
de4a2f9
5a3ed26
d63f7ea
 
 
7f13cff
5a3ed26
7f13cff
5a3ed26
 
de4a2f9
0e089a6
c472fe6
0e089a6
c472fe6
0e089a6
de4a2f9
c472fe6
0e089a6
077a255
c472fe6
eb570bb
7c90ba3
eb570bb
9a82551
 
e49a0f8
eb570bb
5a3ed26
51daf88
 
de4a2f9
 
 
 
e49a0f8
 
7c90ba3
 
51daf88
 
eb570bb
 
c472fe6
263948f
0e089a6
 
 
d7af0b8
f964bf0
 
 
0e089a6
7f13cff
c472fe6
7c90ba3
288f88f
 
 
eb570bb
2f1dbb2
0e089a6
 
7f13cff
288f88f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
import deepinv as dinv
import torch
import numpy as np
import PIL.Image


def pil_to_torch(image, ref_size=256):
    image = np.array(image)
    image = image.transpose((2, 0, 1))
    image = torch.tensor(image).float() / 255
    image = image.unsqueeze(0)

    if ref_size == 128:
        size = (ref_size, ref_size)
    elif image.shape[2] > image.shape[3]:
        size = (ref_size, ref_size * image.shape[3]//image.shape[2])
    else:
        size = (ref_size * image.shape[2]//image.shape[3], ref_size)

    image = torch.nn.functional.interpolate(image, size=size, mode='bilinear')
    return image


def torch_to_pil(image):
    image = image.squeeze(0).cpu().detach().numpy()
    image = image.transpose((1, 2, 0))
    image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
    image = PIL.Image.fromarray(image)
    return image


def image_mod(image, noise_level, denoiser):
    image = pil_to_torch(image, ref_size=128 if denoiser == 'DiffUNet' else 256)
    if denoiser == 'DnCNN':
        den = dinv.models.DnCNN()
        sigma0 = 2/255
        denoiser = lambda x, sigma: den(x*sigma0/sigma)*sigma/sigma0
    elif denoiser == 'MedianFilter':
        denoiser = dinv.models.MedianFilter(kernel_size=5)
    elif denoiser == 'BM3D':
        denoiser = dinv.models.BM3D()
    elif denoiser == 'TV':
        denoiser = dinv.models.TVDenoiser()
    elif denoiser == 'TGV':
        denoiser = dinv.models.TGVDenoiser()
    elif denoiser == 'Wavelets':
        denoiser = dinv.models.WaveletPrior()
    elif denoiser == 'DiffUNet':
        denoiser = dinv.models.DiffUNet()
    elif denoiser == 'DRUNet':
        denoiser = dinv.models.DRUNet()
    else:
        raise ValueError("Invalid denoiser")
    noisy = image + torch.randn_like(image) * noise_level
    estimated = denoiser(noisy, noise_level)
    return torch_to_pil(noisy), torch_to_pil(estimated)


input_image = gr.Image(label='Input Image')
output_images = gr.Image(label='Denoised Image')
noise_image = gr.Image(label='Noisy Image')
input_image_output = gr.Image(label='Input Image')

noise_levels = gr.Dropdown(choices=[0.1, 0.2, 0.3, 0.5, 1], value=0.1, label='Noise Level')

denoiser = gr.Dropdown(choices=['DnCNN', 'DRUNet', 'DiffUNet', 'BM3D', 'MedianFilter', 'TV', 'TGV', 'Wavelets'], value='DnCNN', label='Denoiser')

demo = gr.Interface(
    image_mod,
    inputs=[input_image, noise_levels, denoiser],
    examples=[['https://deepinv.github.io/deepinv/_static/deepinv_logolarge.png', 0.1, 'DnCNN']],
    outputs=[noise_image, output_images],
    title="Image Denoising with DeepInverse",
    description="Denoise an image using a variety of denoisers and noise levels using the deepinverse library (https://deepinv.github.io/). We only include lightweight models like DnCNN and MedianFilter as this example is intended to be run on a CPU. We also automatically resize the input image to 256 pixels to reduce the computation time. For more advanced models, please run the code locally.",
)

demo.launch()