File size: 2,545 Bytes
545f79d
 
 
 
159cb3e
 
 
545f79d
 
 
da6c643
545f79d
 
159cb3e
545f79d
159cb3e
545f79d
f603c60
545f79d
159cb3e
 
 
 
 
 
 
 
 
 
 
545f79d
 
 
159cb3e
 
 
 
 
 
 
 
 
 
545f79d
 
e84698c
159cb3e
 
 
545f79d
 
159cb3e
545f79d
159cb3e
 
 
545f79d
159cb3e
 
545f79d
 
 
 
 
 
 
159cb3e
545f79d
 
 
 
159cb3e
b0d6f6a
 
159cb3e
 
545f79d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr 
from PIL import Image
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F

from archs import DarkIR



device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#define some auxiliary functions
pil_to_tensor = transforms.ToTensor()
tensor_to_pil = transforms.ToPILImage()

network = 'DarkIR'

PATH_MODEL = './DarkIR_384.pt'

model = DarkIR(img_channel=3, 
                    width=32, 
                    middle_blk_num_enc=2,
                    middle_blk_num_dec=2, 
                    enc_blk_nums=[1, 2, 3],
                    dec_blk_nums=[3, 1, 1], 
                    dilations=[1, 4, 9],
                    extra_depth_wise=True)

checkpoints = torch.load(PATH_MODEL, map_location=device)
model.load_state_dict(checkpoints['params'])

model = model.to(device)


def pad_tensor(tensor, multiple = 8):
    '''pad the tensor to be multiple of some number'''
    multiple = multiple
    _, _, H, W = tensor.shape
    pad_h = (multiple - H % multiple) % multiple
    pad_w = (multiple - W % multiple) % multiple
    tensor = F.pad(tensor, (0, pad_w, 0, pad_h), value = 0)
    
    return tensor

def process_img(image):
    tensor = pil_to_tensor(image).unsqueeze(0).to(device)
    _, _, H, W = tensor.shape
    
    tensor = pad_tensor(tensor)

    with torch.no_grad():
        output = model(tensor, side_loss=False)

    output = torch.clamp(output, 0., 1.)
    output = output[:,:, :H, :W].squeeze(0)    
    return tensor_to_pil(output)

title = "DarkIR ✏️🖼️ 🤗"
description = ''' ## [ DarkIR: Robust Low-Light Image Restoration](https://github.com/cidautai/DarkIR)

[Daniel Feijoo](https://github.com/danifei)

Fundación Cidaut


> **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
**This demo expects an image with some Low-Light degradations.**

<br>
'''

examples = [['examples/0010.png'],
            ['examples/1001.png'],
            ['examples/1100.png'], 
            ['examples/low00733_low.png'], 
            ["examples/0087.png"]]

css = """
    .image-frame img, .image-container img {
        width: auto;
        height: auto;
        max-width: none;
    }
"""

demo = gr.Interface(
    fn = process_img,
    inputs = [
            gr.Image(type = 'pil', label = 'input')
    ],
    outputs = [gr.Image(type='pil', label = 'output')],
    title = title,
    description = description,
    examples = examples,
    css = css
)

if __name__ == '__main__':
    demo.launch()