Spaces:
Sleeping
Sleeping
File size: 3,003 Bytes
077a255 b6d8eef 0e089a6 c472fe6 077a255 eb570bb d63f7ea 0e089a6 c472fe6 de4a2f9 5a3ed26 d63f7ea 7f13cff 5a3ed26 7f13cff 5a3ed26 de4a2f9 0e089a6 c472fe6 0e089a6 c472fe6 0e089a6 de4a2f9 c472fe6 0e089a6 077a255 c472fe6 eb570bb 7c90ba3 eb570bb 9a82551 e49a0f8 eb570bb 5a3ed26 51daf88 de4a2f9 e49a0f8 7c90ba3 51daf88 eb570bb c472fe6 263948f 0e089a6 d7af0b8 f964bf0 0e089a6 7f13cff c472fe6 7c90ba3 288f88f eb570bb 2f1dbb2 0e089a6 7f13cff 288f88f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
import deepinv as dinv
import torch
import numpy as np
import PIL.Image
def pil_to_torch(image, ref_size=256):
image = np.array(image)
image = image.transpose((2, 0, 1))
image = torch.tensor(image).float() / 255
image = image.unsqueeze(0)
if ref_size == 128:
size = (ref_size, ref_size)
elif image.shape[2] > image.shape[3]:
size = (ref_size, ref_size * image.shape[3]//image.shape[2])
else:
size = (ref_size * image.shape[2]//image.shape[3], ref_size)
image = torch.nn.functional.interpolate(image, size=size, mode='bilinear')
return image
def torch_to_pil(image):
image = image.squeeze(0).cpu().detach().numpy()
image = image.transpose((1, 2, 0))
image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
image = PIL.Image.fromarray(image)
return image
def image_mod(image, noise_level, denoiser):
image = pil_to_torch(image, ref_size=128 if denoiser == 'DiffUNet' else 256)
if denoiser == 'DnCNN':
den = dinv.models.DnCNN()
sigma0 = 2/255
denoiser = lambda x, sigma: den(x*sigma0/sigma)*sigma/sigma0
elif denoiser == 'MedianFilter':
denoiser = dinv.models.MedianFilter(kernel_size=5)
elif denoiser == 'BM3D':
denoiser = dinv.models.BM3D()
elif denoiser == 'TV':
denoiser = dinv.models.TVDenoiser()
elif denoiser == 'TGV':
denoiser = dinv.models.TGVDenoiser()
elif denoiser == 'Wavelets':
denoiser = dinv.models.WaveletPrior()
elif denoiser == 'DiffUNet':
denoiser = dinv.models.DiffUNet()
elif denoiser == 'DRUNet':
denoiser = dinv.models.DRUNet()
else:
raise ValueError("Invalid denoiser")
noisy = image + torch.randn_like(image) * noise_level
estimated = denoiser(noisy, noise_level)
return torch_to_pil(noisy), torch_to_pil(estimated)
input_image = gr.Image(label='Input Image')
output_images = gr.Image(label='Denoised Image')
noise_image = gr.Image(label='Noisy Image')
input_image_output = gr.Image(label='Input Image')
noise_levels = gr.Dropdown(choices=[0.1, 0.2, 0.3, 0.5, 1], value=0.1, label='Noise Level')
denoiser = gr.Dropdown(choices=['DnCNN', 'DRUNet', 'DiffUNet', 'BM3D', 'MedianFilter', 'TV', 'TGV', 'Wavelets'], value='DnCNN', label='Denoiser')
demo = gr.Interface(
image_mod,
inputs=[input_image, noise_levels, denoiser],
examples=[['https://deepinv.github.io/deepinv/_static/deepinv_logolarge.png', 0.1, 'DnCNN']],
outputs=[noise_image, output_images],
title="Image Denoising with DeepInverse",
description="Denoise an image using a variety of denoisers and noise levels using the deepinverse library (https://deepinv.github.io/). We only include lightweight models like DnCNN and MedianFilter as this example is intended to be run on a CPU. We also automatically resize the input image to 256 pixels to reduce the computation time. For more advanced models, please run the code locally.",
)
demo.launch() |