File size: 3,060 Bytes
077a255
b6d8eef
 
0e089a6
c472fe6
077a255
eb570bb
94086a1
0e089a6
 
c472fe6
de4a2f9
5a3ed26
94086a1
d63f7ea
 
7f13cff
5a3ed26
7f13cff
5a3ed26
 
de4a2f9
0e089a6
c472fe6
0e089a6
c472fe6
0e089a6
de4a2f9
c472fe6
0e089a6
077a255
c472fe6
eb570bb
94086a1
eb570bb
9a82551
 
e49a0f8
eb570bb
5a3ed26
51daf88
 
de4a2f9
 
 
 
e49a0f8
 
7c90ba3
 
51daf88
 
eb570bb
 
c472fe6
263948f
0e089a6
 
 
d7af0b8
f964bf0
 
 
0e089a6
94086a1
c472fe6
94086a1
288f88f
 
 
eb570bb
94086a1
0e089a6
 
94086a1
288f88f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
import deepinv as dinv
import torch
import numpy as np
import PIL.Image


def pil_to_torch(image, ref_size=512):
    image = np.array(image)
    image = image.transpose((2, 0, 1))
    image = torch.tensor(image).float() / 255
    image = image.unsqueeze(0)

    if ref_size == 256:
        size = (ref_size, ref_size)
    elif image.shape[2] > image.shape[3]:
        size = (ref_size, ref_size * image.shape[3]//image.shape[2])
    else:
        size = (ref_size * image.shape[2]//image.shape[3], ref_size)

    image = torch.nn.functional.interpolate(image, size=size, mode='bilinear')
    return image


def torch_to_pil(image):
    image = image.squeeze(0).cpu().detach().numpy()
    image = image.transpose((1, 2, 0))
    image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
    image = PIL.Image.fromarray(image)
    return image


def image_mod(image, noise_level, denoiser):
    image = pil_to_torch(image, ref_size=256 if denoiser == 'DiffUNet' else 512)
    if denoiser == 'DnCNN':
        den = dinv.models.DnCNN()
        sigma0 = 2/255
        denoiser = lambda x, sigma: den(x*sigma0/sigma)*sigma/sigma0
    elif denoiser == 'MedianFilter':
        denoiser = dinv.models.MedianFilter(kernel_size=5)
    elif denoiser == 'BM3D':
        denoiser = dinv.models.BM3D()
    elif denoiser == 'TV':
        denoiser = dinv.models.TVDenoiser()
    elif denoiser == 'TGV':
        denoiser = dinv.models.TGVDenoiser()
    elif denoiser == 'Wavelets':
        denoiser = dinv.models.WaveletPrior()
    elif denoiser == 'DiffUNet':
        denoiser = dinv.models.DiffUNet()
    elif denoiser == 'DRUNet':
        denoiser = dinv.models.DRUNet()
    else:
        raise ValueError("Invalid denoiser")
    noisy = image + torch.randn_like(image) * noise_level
    estimated = denoiser(noisy, noise_level)
    return torch_to_pil(noisy), torch_to_pil(estimated)


input_image = gr.Image(label='Input Image')
output_images = gr.Image(label='Denoised Image')
noise_image = gr.Image(label='Noisy Image')
input_image_output = gr.Image(label='Input Image')

noise_levels = gr.Dropdown(choices=[0.05, 0.1, 0.2, 0.3, 0.5, 1], value=0.1, label='Noise Level')

denoiser = gr.Dropdown(choices=['DnCNN', 'DRUNet', 'DiffUNet', 'BM3D', 'MedianFilter', 'TV', 'TGV', 'Wavelets'], value='DRUNet', label='Denoiser')

demo = gr.Interface(
    image_mod,
    inputs=[input_image, noise_levels, denoiser],
    examples=[['https://upload.wikimedia.org/wikipedia/commons/b/b4/Lionel-Messi-Argentina-2022-FIFA-World-Cup_%28cropped%29.jpg', 0.1, 'DRUNet']],
    outputs=[noise_image, output_images],
    title="Image Denoising with DeepInverse",
    description="Denoise an image using a variety of denoisers and noise levels using the deepinverse library (https://deepinv.github.io/). We only include lightweight models like DnCNN and MedianFilter as this example is intended to be run on a CPU. We also automatically resize the input image to 512 pixels to reduce the computation time. For more advanced models, please run the code locally.",
)

demo.launch()