|
import gradio as gr |
|
import deepinv as dinv |
|
import torch |
|
import numpy as np |
|
import PIL.Image |
|
|
|
|
|
def pil_to_torch(image): |
|
image = np.array(image) |
|
image = image.transpose((2, 0, 1)) |
|
image = torch.tensor(image).float() / 255 |
|
return image.unsqueeze(0) |
|
|
|
|
|
def torch_to_pil(image): |
|
image = image.squeeze(0).cpu().detach().numpy() |
|
image = image.transpose((1, 2, 0)) |
|
image = (torch.clip(image, 0, 1) * 255).astype(np.uint8) |
|
image = PIL.Image.fromarray(image) |
|
return image |
|
|
|
|
|
def image_mod(image, noise_level, denoiser): |
|
image = pil_to_torch(image) |
|
if denoiser == 'DnCNN': |
|
denoiser = dinv.models.DnCNN() |
|
elif denoiser == 'MedianFilter': |
|
denoiser = dinv.models.MedianFilter() |
|
elif denoiser == 'BM3D': |
|
denoiser = dinv.models.BM3D() |
|
elif denoiser == 'DRUNet': |
|
denoiser = dinv.models.DRUNet() |
|
else: |
|
raise ValueError("Invalid denoiser") |
|
noisy = image + torch.randn_like(image) * noise_level |
|
estimated = denoiser(image, noise_level) |
|
return torch_to_pil(noisy), torch_to_pil(estimated) |
|
|
|
|
|
input_image = gr.Image(label='Input Image') |
|
output_images = gr.Image(label='Denoised Image') |
|
noise_image = gr.Image(label='Noisy Image') |
|
input_image_output = gr.Image(label='Input Image') |
|
|
|
noise_levels = gr.Dropdown(choices=[0.1, 0.2, 0.3, 0.4, 0.5], value=0.1, label='Noise Level') |
|
|
|
denoiser = gr.Dropdown(choices=['DnCNN', 'DRUNet', 'BM3D', 'MedianFilter'], value='DnCNN', label='Denoiser') |
|
|
|
demo = gr.Interface( |
|
image_mod, |
|
inputs=[input_image, noise_levels, denoiser], |
|
examples=[['https://deepinv.github.io/deepinv/_static/deepinv_logolarge.png', 0.1, 'DnCNN']], |
|
outputs=[noise_image, output_images], |
|
title="Image Denoising with DeepInverse", |
|
description="Denoise an image using a variety of denoisers and noise levels using the deepinverse library (https://deepinv.github.io/). We only include lightweight models like DnCNN and MedianFilter as this example is intended to be run on a CPU.", |
|
) |
|
|
|
demo.launch() |