diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..edcdc72315a8d2b3bddf830cccdb3cc8866ce8a8 100644 --- a/.gitattributes +++ b/.gitattributes @@ -19,6 +19,7 @@ *.pb filter=lfs diff=lfs merge=lfs -text *.pickle filter=lfs diff=lfs merge=lfs -text *.pkl filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text *.pt filter=lfs diff=lfs merge=lfs -text *.pth filter=lfs diff=lfs merge=lfs -text *.rar filter=lfs diff=lfs merge=lfs -text diff --git a/.gradio/certificate.pem b/.gradio/certificate.pem new file mode 100644 index 0000000000000000000000000000000000000000..b85c8037f6b60976b2546fdbae88312c5246d9a3 --- /dev/null +++ b/.gradio/certificate.pem @@ -0,0 +1,31 @@ +-----BEGIN CERTIFICATE----- +MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw +TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh +cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 +WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu +ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY +MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc +h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ +0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U +A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW +T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH +B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC +B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv +KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn +OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn +jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw +qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI +rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV +HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq +hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL +ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ +3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK +NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5 +ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur +TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC +jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc +oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq +4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA +mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d +emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc= +-----END CERTIFICATE----- diff --git a/__pycache__/inference.cpython-310.pyc b/__pycache__/inference.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68a6c072dd65d1bd86a6f025904df3d59dc4ee2d Binary files /dev/null and b/__pycache__/inference.cpython-310.pyc differ diff --git a/app.py b/app.py index 22df67e3786f851e62d6ed61cea675530042e781..da0ce6b20aebc7c25e6088eab23431e73e281e19 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,45 @@ import gradio as gr +from inference import infer -def greet(name): - return "Hello " + name + "!!" +def greet(image, prompt): + restore_img = infer(img=image, text_prompt=prompt) + return restore_img -demo = gr.Interface(fn=greet, inputs=[gr.components.Image(), "Text Instruction"], outputs=gr.components.Image()) -demo.launch() \ No newline at end of file + +title = "๐Ÿ–ผ๏ธ ICDR ๐Ÿ–ผ๏ธ" +description = ''' ## ICDR: Image Restoration Framework for Composite Degradation following Human Instructions +Our Github : https://github.com/ + +Siwon Kim, Donghyeon Yoon + +Ajou Univ +''' + + +article = "

ICDR

" + +#### Image,Prompts examples +examples = [['input/00010.png', "I love this photo, could you remove the haze and more brighter?"], + ['input/00058.png', "I have to post an emotional shot on Instagram, but it was shot too foggy and too dark. Change it like a sunny day and brighten it up!"]] + +css = """ + .image-frame img, .image-container img { + width: auto; + height: auto; + max-width: none; + } +""" + + +demo = gr.Interface( + fn=greet, + inputs=[gr.Image(type="pil", label="Input"), + gr.Text(label="Prompt") ], + outputs=[gr.Image(type="pil", label="Ouput")], + title=title, + description=description, + article=article, + examples=examples, + css=css, + ) +demo.launch(share=True) \ No newline at end of file diff --git a/ckpt/epoch_287.pth b/ckpt/epoch_287.pth new file mode 100644 index 0000000000000000000000000000000000000000..2c0a593483b6ae92cba8451db3123669b26bcae0 --- /dev/null +++ b/ckpt/epoch_287.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db279692728bd4614759c08a0478d9d07200768e5fb7fa893e78aaa05f3ca707 +size 48705338 diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..e2e3271a9d29ca3e435edd8139379668854a88b8 --- /dev/null +++ b/inference.py @@ -0,0 +1,59 @@ +import argparse +import subprocess +from tqdm import tqdm +import numpy as np + +import torch +from torch.utils.data import DataLoader + +from utils.dataset_utils_CDD import DerainDehazeDataset +from utils.val_utils import AverageMeter, compute_psnr_ssim +from utils.image_io import save_image_tensor + +from text_net.model import AirNet + +def test_Derain_Dehaze(opt, net, dataset, task="derain"): + output_path = opt.output_path + task + '/' + subprocess.check_output(['mkdir', '-p', output_path]) + + # dataset.set_dataset(task) + testloader = DataLoader(dataset, batch_size=1, pin_memory=True, shuffle=False, num_workers=0) + print(len(testloader)) + + with torch.no_grad(): + for ([degraded_name], degradation, degrad_patch, clean_patch, text_prompt) in tqdm(testloader): + degrad_patch, clean_patch = degrad_patch.cuda(), clean_patch.cuda() + restored = net(x_query=degrad_patch, x_key=degrad_patch, text_prompt = text_prompt) + + return save_image_tensor(restored) + + +def infer(text_prompt = "", img=None): + parser = argparse.ArgumentParser() + # Input Parameters + parser.add_argument('--cuda', type=int, default=0) + parser.add_argument('--derain_path', type=str, default="data/Test_prompting/", help='save path of test raining images') + parser.add_argument('--output_path', type=str, default="output/demo11", help='output save path') + parser.add_argument('--ckpt_path', type=str, default="ckpt/epoch_287.pth", help='checkpoint save path') + # parser.add_argument('--text_prompt', type=str, default="derain") + + opt = parser.parse_args() + # opt.text_prompt = text_prompt + + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.set_device(opt.cuda) + + opt.batch_size = 7 + ckpt_path = opt.ckpt_path + + derain_set = DerainDehazeDataset(opt, img=img, text_prompt = text_prompt) + + # Make network + net = AirNet(opt).cuda() + net.eval() + net.load_state_dict(torch.load(ckpt_path, map_location=torch.device(opt.cuda))) + + restored = test_Derain_Dehaze(opt, net, derain_set, task="derain") + + return restored diff --git a/text_net/DGRN.py b/text_net/DGRN.py new file mode 100644 index 0000000000000000000000000000000000000000..8f7593c83d20dabb34bd73621ac070514579190e --- /dev/null +++ b/text_net/DGRN.py @@ -0,0 +1,232 @@ +import torch.nn as nn +import torch +from .deform_conv import DCN_layer +import clip + +clip_model, preprocess = clip.load("ViT-B/32", device='cuda') + +# ๋™์ ์œผ๋กœ ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ์ฐจ์› ๊ฐ€์ ธ์˜ค๊ธฐ +text_embed_dim = clip_model.text_projection.shape[1] + + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias) + + +class DGM(nn.Module): + def __init__(self, channels_in, channels_out, kernel_size): + super(DGM, self).__init__() + self.channels_out = channels_out + self.channels_in = channels_in + self.kernel_size = kernel_size + + self.dcn = DCN_layer(self.channels_in, self.channels_out, kernel_size, + padding=(kernel_size - 1) // 2, bias=False) + self.sft = SFT_layer(self.channels_in, self.channels_out) + + self.relu = nn.LeakyReLU(0.1, True) + + def forward(self, x, inter, text_prompt): + ''' + :param x: feature map: B * C * H * W + :inter: degradation map: B * C * H * W + ''' + dcn_out = self.dcn(x, inter) + sft_out = self.sft(x, inter, text_prompt) + out = dcn_out + sft_out + out = x + out + + return out + +# Projection Head ์ •์˜ +class TextProjectionHead(nn.Module): + def __init__(self, input_dim, output_dim): + super(TextProjectionHead, self).__init__() + self.proj = nn.Sequential( + nn.Linear(input_dim, output_dim), + nn.ReLU(), + nn.Linear(output_dim, output_dim) + ).float() + + def forward(self, x): + return self.proj(x.float()) + + + +class SFT_layer(nn.Module): + def __init__(self, channels_in, channels_out): + super(SFT_layer, self).__init__() + self.conv_gamma = nn.Sequential( + nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=False), + nn.LeakyReLU(0.1, True), + nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False), + ) + self.conv_beta = nn.Sequential( + nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=False), + nn.LeakyReLU(0.1, True), + nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False), + ) + + self.text_proj_head = TextProjectionHead(text_embed_dim, channels_out) + + ''' + self.text_gamma = nn.Sequential( + nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False), + nn.LeakyReLU(0.1, True), + nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False), + ).float() + self.text_beta = nn.Sequential( + nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False), + nn.LeakyReLU(0.1, True), + nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False), + ).float() + ''' + + self.cross_attention = nn.MultiheadAttention(embed_dim=channels_out, num_heads=2) + + + def forward(self, x, inter, text_prompt): + ''' + :param x: degradation representation: B * C + :param inter: degradation intermediate representation map: B * C * H * W + ''' + # img_gamma = self.conv_gamma(inter) + # img_beta = self.conv_beta(inter) + + B, C, H, W = inter.shape #cross attention + + + text_tokens = clip.tokenize(text_prompt).to(x.device) # Tokenize the text prompts (Batch size) + with torch.no_grad(): + text_embed = clip_model.encode_text(text_tokens) + + text_proj = self.text_proj_head(text_embed).float() + + # ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ์ฐจ์› ํ™•์žฅ: (B, C, H, W)๋กœ ๋ณ€๊ฒฝ #concat + # text_proj_expanded = text_proj.unsqueeze(-1).unsqueeze(-1).expand(B, self.conv_gamma[0].out_channels, H, W) + text_proj_expanded = text_proj.unsqueeze(-1).unsqueeze(-1).expand(B, C, H, W) + + # ์ด๋ฏธ์ง€ ์ค‘๊ฐ„ ํ‘œํ˜„๊ณผ ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ๊ฒฐํ•ฉ (concat) + combined = inter * text_proj_expanded + # combined = torch.cat([inter, text_proj_expanded], dim=1) + + # ์ด๋ฏธ์ง€์™€ ํ…์ŠคํŠธ ๊ธฐ๋ฐ˜ gamma์™€ beta ๊ณ„์‚ฐ + img_gamma = self.conv_gamma(combined) + img_beta = self.conv_beta(combined) + + ''' simple concat + text_gamma = self.text_gamma(text_proj.unsqueeze(-1).unsqueeze(-1)) # Reshape to match (B, C, H, W) + text_beta = self.text_beta(text_proj.unsqueeze(-1).unsqueeze(-1)) # Reshape to match (B, C, H, W) + ''' + + ''' + text_proj = text_proj.unsqueeze(1).expand(-1, H*W, -1) # B * (H*W) * C + + # ์ด๋ฏธ์ง€ ์ค‘๊ฐ„ ํ‘œํ˜„ ๋ณ€ํ™˜: B * (H*W) * C๋กœ ๋ณ€๊ฒฝ + inter_flat = inter.view(B, C, -1).permute(2, 0, 1) # (H*W) * B * C + + # Cross-attention ์ ์šฉ + attn_output, _ = self.cross_attention(text_proj.permute(1, 0, 2), inter_flat, inter_flat) + attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W) # B * C * H * W + + # Gamma์™€ Beta ๊ณ„์‚ฐ + img_gamma = self.conv_gamma(attn_output) + img_beta = self.conv_beta(attn_output) +''' + # concat์œผ๋กœ text ๊ฒฐํ•ฉ ์‹คํ—˜ + return x * img_gamma + img_beta + + +class DGB(nn.Module): + def __init__(self, conv, n_feat, kernel_size): + super(DGB, self).__init__() + + # self.da_conv1 = DGM(n_feat, n_feat, kernel_size) + # self.da_conv2 = DGM(n_feat, n_feat, kernel_size) + self.dgm1 = DGM(n_feat, n_feat, kernel_size) + self.dgm2 = DGM(n_feat, n_feat, kernel_size) + self.conv1 = conv(n_feat, n_feat, kernel_size) + self.conv2 = conv(n_feat, n_feat, kernel_size) + + self.relu = nn.LeakyReLU(0.1, True) + + def forward(self, x, inter, text_prompt): + ''' + :param x: feature map: B * C * H * W + :param inter: degradation representation: B * C * H * W + ''' + + out = self.relu(self.dgm1(x, inter, text_prompt)) + out = self.relu(self.conv1(out)) + out = self.relu(self.dgm2(out, inter, text_prompt)) + out = self.conv2(out) + x + + return out + + +class DGG(nn.Module): + def __init__(self, conv, n_feat, kernel_size, n_blocks): + super(DGG, self).__init__() + self.n_blocks = n_blocks + modules_body = [ + DGB(conv, n_feat, kernel_size) \ + for _ in range(n_blocks) + ] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + + self.body = nn.Sequential(*modules_body) + + def forward(self, x, inter, text_prompt): + ''' + :param x: feature map: B * C * H * W + :param inter: degradation representation: B * C * H * W + ''' + res = x + for i in range(self.n_blocks): + res = self.body[i](res, inter, text_prompt) + res = self.body[-1](res) + res = res + x + + return res + + +class DGRN(nn.Module): + def __init__(self, opt, conv=default_conv): + super(DGRN, self).__init__() + + self.n_groups = 5 + n_blocks = 5 + n_feats = 64 + kernel_size = 3 + + # head module + modules_head = [conv(3, n_feats, kernel_size)] + self.head = nn.Sequential(*modules_head) + + # body + modules_body = [ + DGG(default_conv, n_feats, kernel_size, n_blocks) \ + for _ in range(self.n_groups) + ] + modules_body.append(conv(n_feats, n_feats, kernel_size)) + self.body = nn.Sequential(*modules_body) + + # tail + modules_tail = [conv(n_feats, 3, kernel_size)] + self.tail = nn.Sequential(*modules_tail) + + def forward(self, x, inter, text_prompt): + # head + x = self.head(x) + + # body + res = x + for i in range(self.n_groups): + res = self.body[i](res, inter, text_prompt) + res = self.body[-1](res) + res = res + x + + # tail + x = self.tail(res) + + return x diff --git a/text_net/__pycache__/DGRN.cpython-310.pyc b/text_net/__pycache__/DGRN.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..add7bd93939e3a00cd46b9ab8cbd11579f92b40a Binary files /dev/null and b/text_net/__pycache__/DGRN.cpython-310.pyc differ diff --git a/text_net/__pycache__/DGRN.cpython-38.pyc b/text_net/__pycache__/DGRN.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e0f4a38a4455a7476836d22a9f316c75af4d31f Binary files /dev/null and b/text_net/__pycache__/DGRN.cpython-38.pyc differ diff --git a/text_net/__pycache__/deform_conv.cpython-310.pyc b/text_net/__pycache__/deform_conv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e5d101f989f91a4d207e193234511f7f89ff29e Binary files /dev/null and b/text_net/__pycache__/deform_conv.cpython-310.pyc differ diff --git a/text_net/__pycache__/deform_conv.cpython-36.pyc b/text_net/__pycache__/deform_conv.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f08c69c038d6929a76930639daaf07bc54de296 Binary files /dev/null and b/text_net/__pycache__/deform_conv.cpython-36.pyc differ diff --git a/text_net/__pycache__/deform_conv.cpython-38.pyc b/text_net/__pycache__/deform_conv.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cba6ecc3675f79d8a1a5977754cb5a9d9d21893c Binary files /dev/null and b/text_net/__pycache__/deform_conv.cpython-38.pyc differ diff --git a/text_net/__pycache__/encoder.cpython-310.pyc b/text_net/__pycache__/encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f7a4f0cbabd47d03c00b57a6f5de55562358c0a Binary files /dev/null and b/text_net/__pycache__/encoder.cpython-310.pyc differ diff --git a/text_net/__pycache__/encoder.cpython-36.pyc b/text_net/__pycache__/encoder.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9015c707df37f278843a49dd7287f9f3e15d41b Binary files /dev/null and b/text_net/__pycache__/encoder.cpython-36.pyc differ diff --git a/text_net/__pycache__/encoder.cpython-38.pyc b/text_net/__pycache__/encoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a74750768bc18d1575a67f54bafd4c7465bccd9e Binary files /dev/null and b/text_net/__pycache__/encoder.cpython-38.pyc differ diff --git a/text_net/__pycache__/moco.cpython-310.pyc b/text_net/__pycache__/moco.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c190557fca0214e75f96a9295ae9e9beb7ec7cb Binary files /dev/null and b/text_net/__pycache__/moco.cpython-310.pyc differ diff --git a/text_net/__pycache__/moco.cpython-36.pyc b/text_net/__pycache__/moco.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a443457cd055c02a4581d22b845baaf06a6ce836 Binary files /dev/null and b/text_net/__pycache__/moco.cpython-36.pyc differ diff --git a/text_net/__pycache__/moco.cpython-38.pyc b/text_net/__pycache__/moco.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b643c513699e32939fa5ed9dcae4c4bd983387bd Binary files /dev/null and b/text_net/__pycache__/moco.cpython-38.pyc differ diff --git a/text_net/__pycache__/model.cpython-310.pyc b/text_net/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db13fd363630e4c00db4f2468a69eefe20e8de89 Binary files /dev/null and b/text_net/__pycache__/model.cpython-310.pyc differ diff --git a/text_net/__pycache__/model.cpython-36.pyc b/text_net/__pycache__/model.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce82a9823d22eb41c046651b3d7865f49a786c8c Binary files /dev/null and b/text_net/__pycache__/model.cpython-36.pyc differ diff --git a/text_net/__pycache__/model.cpython-38.pyc b/text_net/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7c0ffe5bb44052b986374e66c05fc9b9414a3fc Binary files /dev/null and b/text_net/__pycache__/model.cpython-38.pyc differ diff --git a/text_net/deform_conv.py b/text_net/deform_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..314ae232a28ee5e060a400a6dd5ac513c40a3875 --- /dev/null +++ b/text_net/deform_conv.py @@ -0,0 +1,65 @@ +import math + +import torch +import torch.nn as nn +from torch.nn.modules.utils import _pair + +from mmcv.ops import modulated_deform_conv2d + + +class DCN_layer(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, + groups=1, deformable_groups=1, bias=True, extra_offset_mask=True): + super(DCN_layer, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deformable_groups = deformable_groups + self.with_bias = bias + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + + self.extra_offset_mask = extra_offset_mask + self.conv_offset_mask = nn.Conv2d( + self.in_channels * 2, + self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding), + bias=True + ) + + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + + self.init_offset() + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def init_offset(self): + self.conv_offset_mask.weight.data.zero_() + self.conv_offset_mask.bias.data.zero_() + + def forward(self, input_feat, inter): + feat_degradation = torch.cat([input_feat, inter], dim=1) + + out = self.conv_offset_mask(feat_degradation) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + + return modulated_deform_conv2d(input_feat.contiguous(), offset, mask, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups, self.deformable_groups) diff --git a/text_net/encoder.py b/text_net/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4610be2421e4e0e91f9eff777d660f28649dfcab --- /dev/null +++ b/text_net/encoder.py @@ -0,0 +1,67 @@ +from torch import nn +from text_net.moco import MoCo + + +class ResBlock(nn.Module): + def __init__(self, in_feat, out_feat, stride=1): + super(ResBlock, self).__init__() + self.backbone = nn.Sequential( + nn.Conv2d(in_feat, out_feat, kernel_size=3, stride=stride, padding=1, bias=False), + nn.BatchNorm2d(out_feat), + nn.LeakyReLU(0.1, True), + nn.Conv2d(out_feat, out_feat, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_feat), + ) + self.shortcut = nn.Sequential( + nn.Conv2d(in_feat, out_feat, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_feat) + ) + + def forward(self, x): + return nn.LeakyReLU(0.1, True)(self.backbone(x) + self.shortcut(x)) + + +class ResEncoder(nn.Module): + def __init__(self): + super(ResEncoder, self).__init__() + + self.E_pre = ResBlock(in_feat=3, out_feat=64, stride=1) + self.E = nn.Sequential( + ResBlock(in_feat=64, out_feat=128, stride=2), + ResBlock(in_feat=128, out_feat=256, stride=2), + nn.AdaptiveAvgPool2d(1) + ) + + self.mlp = nn.Sequential( + nn.Linear(256, 256), + nn.LeakyReLU(0.1, True), + nn.Linear(256, 256), + ) + + def forward(self, x): + inter = self.E_pre(x) + fea = self.E(inter).squeeze(-1).squeeze(-1) + out = self.mlp(fea) + + return fea, out, inter + + +class CBDE(nn.Module): + def __init__(self, opt): + super(CBDE, self).__init__() + + dim = 256 + + # Encoder + self.E = MoCo(base_encoder=ResEncoder, dim=dim, K=opt.batch_size * dim) + + def forward(self, x_query, x_key): + if self.training: + # degradation-aware represenetion learning + fea, logits, labels, inter = self.E(x_query, x_key) + + return fea, logits, labels, inter + else: + # degradation-aware represenetion learning + fea, inter = self.E(x_query, x_query) + return fea, inter diff --git a/text_net/moco.py b/text_net/moco.py new file mode 100644 index 0000000000000000000000000000000000000000..07eee315ea2d02d3d64b5b9c2103b8e5f43ca192 --- /dev/null +++ b/text_net/moco.py @@ -0,0 +1,166 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import torch +import torch.nn as nn + + +class MoCo(nn.Module): + """ + Build a MoCo model with: a query encoder, a key encoder, and a queue + https://arxiv.org/abs/1911.05722 + """ + def __init__(self, base_encoder, dim=256, K=3*256, m=0.999, T=0.07, mlp=False): + """ + dim: feature dimension (default: 128) + K: queue size; number of negative keys (default: 65536) + m: moco momentum of updating key encoder (default: 0.999) + T: softmax temperature (default: 0.07) + """ + super(MoCo, self).__init__() + + self.K = K + self.m = m + self.T = T + + # create the encoders + # num_classes is the output fc dimension + self.encoder_q = base_encoder() + self.encoder_k = base_encoder() + + for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): + param_k.data.copy_(param_q.data) # initialize + param_k.requires_grad = False # not update by gradient + + # create the queue + self.register_buffer("queue", torch.randn(dim, K)) + self.queue = nn.functional.normalize(self.queue, dim=0) + + self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) + + @torch.no_grad() + def _momentum_update_key_encoder(self): + """ + Momentum update of the key encoder + """ + for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): + param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) + + @torch.no_grad() + def _dequeue_and_enqueue(self, keys): + # gather keys before updating queue + # keys = concat_all_gather(keys) + batch_size = keys.shape[0] + + ptr = int(self.queue_ptr) + assert self.K % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1) + ptr = (ptr + batch_size) % self.K # move pointer + + self.queue_ptr[0] = ptr + + @torch.no_grad() + def _batch_shuffle_ddp(self, x): + """ + Batch shuffle, for making use of BatchNorm. + *** Only support DistributedDataParallel (DDP) model. *** + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = concat_all_gather(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # random shuffle index + idx_shuffle = torch.randperm(batch_size_all).cuda() + + # broadcast to all gpus + torch.distributed.broadcast(idx_shuffle, src=0) + + # index for restoring + idx_unshuffle = torch.argsort(idx_shuffle) + + # shuffled index for this gpu + gpu_idx = torch.distributed.get_rank() + idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this], idx_unshuffle + + @torch.no_grad() + def _batch_unshuffle_ddp(self, x, idx_unshuffle): + """ + Undo batch shuffle. + *** Only support DistributedDataParallel (DDP) model. *** + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = concat_all_gather(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # restored index for this gpu + gpu_idx = torch.distributed.get_rank() + idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this] + + def forward(self, im_q, im_k): + """ + Input: + im_q: a batch of query images + im_k: a batch of key images + Output: + logits, targets + """ + if self.training: + # compute query features + embedding, q, inter = self.encoder_q(im_q) # queries: NxC + q = nn.functional.normalize(q, dim=1) + + # compute key features + with torch.no_grad(): # no gradient to keys + self._momentum_update_key_encoder() # update the key encoder + + _, k, _ = self.encoder_k(im_k) # keys: NxC + k = nn.functional.normalize(k, dim=1) + + # compute logits + # Einstein sum is more intuitive + # positive logits: Nx1 + l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) + # negative logits: NxK + l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) + + # logits: Nx(1+K) + logits = torch.cat([l_pos, l_neg], dim=1) + + # apply temperature + logits /= self.T + + # labels: positive key indicators + labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() + # dequeue and enqueue + self._dequeue_and_enqueue(k) + + return embedding, logits, labels, inter + else: + embedding, _, inter = self.encoder_q(im_q) + + return embedding, inter + + +# utils +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [torch.ones_like(tensor) + for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output diff --git a/text_net/model.py b/text_net/model.py new file mode 100644 index 0000000000000000000000000000000000000000..bbefbc44cd1da3c31b8f444b70a1df2ac6c36647 --- /dev/null +++ b/text_net/model.py @@ -0,0 +1,29 @@ +from torch import nn + +from text_net.encoder import CBDE +from text_net.DGRN import DGRN + + +class AirNet(nn.Module): + def __init__(self, opt): + super(AirNet, self).__init__() + + # Restorer + self.R = DGRN(opt) + + # Encoder + self.E = CBDE(opt) + + def forward(self, x_query, x_key, text_prompt): + if self.training: + fea, logits, labels, inter = self.E(x_query, x_key) + + restored = self.R(x_query, inter, text_prompt) + + return restored, logits, labels + else: + fea, inter = self.E(x_query, x_query) + + restored = self.R(x_query, inter, text_prompt) + + return restored diff --git a/utils/.DS_Store b/utils/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..abe3c7b57d03002bebd1ed83b3245a1358ce0071 Binary files /dev/null and b/utils/.DS_Store differ diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/__pycache__/__init__.cpython-310.pyc b/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad418447b6c8b2cceb2c58332bcbdccd923fa816 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/utils/__pycache__/__init__.cpython-36.pyc b/utils/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea2f732be63e554911d2ebe3c7cd2854885f63d1 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-36.pyc differ diff --git a/utils/__pycache__/__init__.cpython-38.pyc b/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2fba75b5ed418d60b601a7562fa8b3e0bf3eaea Binary files /dev/null and b/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/utils/__pycache__/dataset_utils.cpython-310.pyc b/utils/__pycache__/dataset_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8f4d2e59eef15ec0af60c1a04146de466289e1f Binary files /dev/null and b/utils/__pycache__/dataset_utils.cpython-310.pyc differ diff --git a/utils/__pycache__/dataset_utils.cpython-36.pyc b/utils/__pycache__/dataset_utils.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bef65204ee758781e483ebd48f0c06d7cf413de7 Binary files /dev/null and b/utils/__pycache__/dataset_utils.cpython-36.pyc differ diff --git a/utils/__pycache__/dataset_utils.cpython-38.pyc b/utils/__pycache__/dataset_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb58e1ba2dabd5d1660cedeba9924756291b66b4 Binary files /dev/null and b/utils/__pycache__/dataset_utils.cpython-38.pyc differ diff --git a/utils/__pycache__/dataset_utils_CDD.cpython-310.pyc b/utils/__pycache__/dataset_utils_CDD.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c348fff445900cca2c1af2c229176f50306861c Binary files /dev/null and b/utils/__pycache__/dataset_utils_CDD.cpython-310.pyc differ diff --git a/utils/__pycache__/degradation_utils.cpython-310.pyc b/utils/__pycache__/degradation_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9b7bc1944eac6485bd39f36fec12fbe04477821 Binary files /dev/null and b/utils/__pycache__/degradation_utils.cpython-310.pyc differ diff --git a/utils/__pycache__/degradation_utils.cpython-36.pyc b/utils/__pycache__/degradation_utils.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9f47525b4ab570ef6abbce8fb958a227c73ba4a Binary files /dev/null and b/utils/__pycache__/degradation_utils.cpython-36.pyc differ diff --git a/utils/__pycache__/degradation_utils.cpython-38.pyc b/utils/__pycache__/degradation_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34ec2c750e9c30b11711112d43eb0fdf3629ffdc Binary files /dev/null and b/utils/__pycache__/degradation_utils.cpython-38.pyc differ diff --git a/utils/__pycache__/image_io.cpython-310.pyc b/utils/__pycache__/image_io.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03408b6f959a20c91d0a8dd6b16a5769ebd75d4e Binary files /dev/null and b/utils/__pycache__/image_io.cpython-310.pyc differ diff --git a/utils/__pycache__/image_io.cpython-36.pyc b/utils/__pycache__/image_io.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c24cc0678c513d1d6554f44f233cc9501fd09e08 Binary files /dev/null and b/utils/__pycache__/image_io.cpython-36.pyc differ diff --git a/utils/__pycache__/image_io.cpython-38.pyc b/utils/__pycache__/image_io.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a4ab0fd70d454f85a1528d4cc3460cd05d78441 Binary files /dev/null and b/utils/__pycache__/image_io.cpython-38.pyc differ diff --git a/utils/__pycache__/image_utils.cpython-310.pyc b/utils/__pycache__/image_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..556d4405112890a150b4b743e72d4ca22105d88b Binary files /dev/null and b/utils/__pycache__/image_utils.cpython-310.pyc differ diff --git a/utils/__pycache__/image_utils.cpython-36.pyc b/utils/__pycache__/image_utils.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..530ff420619815f5d8850c8364199df237071857 Binary files /dev/null and b/utils/__pycache__/image_utils.cpython-36.pyc differ diff --git a/utils/__pycache__/image_utils.cpython-38.pyc b/utils/__pycache__/image_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98f60f51a707e4c7b43cc6e282e2495eb17feb9a Binary files /dev/null and b/utils/__pycache__/image_utils.cpython-38.pyc differ diff --git a/utils/__pycache__/imresize.cpython-36.pyc b/utils/__pycache__/imresize.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..568d96f012f822f4c89092c4e15c3df90d4a721b Binary files /dev/null and b/utils/__pycache__/imresize.cpython-36.pyc differ diff --git a/utils/__pycache__/imresize.cpython-38.pyc b/utils/__pycache__/imresize.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cf5f53968c7cab1d6837de25c5918fb58152f0b Binary files /dev/null and b/utils/__pycache__/imresize.cpython-38.pyc differ diff --git a/utils/__pycache__/loss_utils.cpython-38.pyc b/utils/__pycache__/loss_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae9e8913dd6174eedf01a89b0473c2d1e74e8c08 Binary files /dev/null and b/utils/__pycache__/loss_utils.cpython-38.pyc differ diff --git a/utils/__pycache__/val_utils.cpython-310.pyc b/utils/__pycache__/val_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3880b9cb0800ba357e51ac11fabdaa14f968e87 Binary files /dev/null and b/utils/__pycache__/val_utils.cpython-310.pyc differ diff --git a/utils/__pycache__/val_utils.cpython-36.pyc b/utils/__pycache__/val_utils.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c2c587b7f349ba4f3c6518b0715d1da6de43f3f Binary files /dev/null and b/utils/__pycache__/val_utils.cpython-36.pyc differ diff --git a/utils/__pycache__/val_utils.cpython-38.pyc b/utils/__pycache__/val_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04fba9e79e5b43beb5b9165188232e03fe39a491 Binary files /dev/null and b/utils/__pycache__/val_utils.cpython-38.pyc differ diff --git a/utils/dataset_utils.py b/utils/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..872c3b295710181fe4632358338e1468ce805b52 --- /dev/null +++ b/utils/dataset_utils.py @@ -0,0 +1,309 @@ +import os +import random +import copy +from PIL import Image +import numpy as np + +from torch.utils.data import Dataset +from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor + +from utils.image_utils import random_augmentation, crop_img +from utils.degradation_utils import Degradation + + +class TrainDataset(Dataset): + def __init__(self, args): + super(TrainDataset, self).__init__() + self.args = args + self.rs_ids = [] + self.hazy_ids = [] + self.D = Degradation(args) + self.de_temp = 0 + self.de_type = self.args.de_type + self.image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif'] + + self.de_dict = {'denoise_15': 0, 'denoise_25': 1, 'denoise_50': 2, 'derain': 3, 'dehaze': 4} + + self._init_ids() + + self.crop_transform = Compose([ + ToPILImage(), + RandomCrop(args.patch_size), + ]) + + self.toTensor = ToTensor() + + def _init_ids(self): + if 'denoise_15' in self.de_type or 'denoise_25' in self.de_type or 'denoise_50' in self.de_type: + self._init_clean_ids() + if 'derain' in self.de_type: + self._init_rs_ids() + if 'dehaze' in self.de_type: + self._init_hazy_ids() + + random.shuffle(self.de_type) + + def _init_clean_ids(self): + clean_ids = [] + # ํŒŒ์ผ ๋ชฉ๋ก ์ค‘ ์ด๋ฏธ์ง€ ํŒŒ์ผ๋งŒ ํ•„ํ„ฐ๋ง + name_list = os.listdir(self.args.denoise_dir) + name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions] + + clean_ids += [self.args.denoise_dir + id_ for id_ in name_list] + + if 'denoise_15' in self.de_type: + self.s15_ids = copy.deepcopy(clean_ids) + random.shuffle(self.s15_ids) + self.s15_counter = 0 + if 'denoise_25' in self.de_type: + self.s25_ids = copy.deepcopy(clean_ids) + random.shuffle(self.s25_ids) + self.s25_counter = 0 + if 'denoise_50' in self.de_type: + self.s50_ids = copy.deepcopy(clean_ids) + random.shuffle(self.s50_ids) + self.s50_counter = 0 + + # print(clean_ids) + + self.num_clean = len(clean_ids) + + def _init_hazy_ids(self): + # ํŒŒ์ผ ๋ชฉ๋ก ์ค‘ ์ด๋ฏธ์ง€ ํŒŒ์ผ๋งŒ ํ•„ํ„ฐ๋ง + dehaze_ids = [] + name_list = os.listdir(self.args.dehaze_dir) + name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions] + dehaze_ids += [self.args.dehaze_dir + id_ for id_ in name_list] + self.hazy_ids = dehaze_ids + + self.hazy_counter = 0 + self.num_hazy = len(self.hazy_ids) + + def _init_rs_ids(self): + # ํŒŒ์ผ ๋ชฉ๋ก ์ค‘ ์ด๋ฏธ์ง€ ํŒŒ์ผ๋งŒ ํ•„ํ„ฐ๋ง + derain_ids = [] + name_list = os.listdir(self.args.derain_dir) + name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions] + derain_ids += [self.args.derain_dir + id_ for id_ in name_list] + self.rs_ids = derain_ids + + self.rl_counter = 0 + # print(derain_ids) + + self.num_rl = len(self.rs_ids) + + def _crop_patch(self, img_1, img_2): + H = img_1.shape[0] + W = img_1.shape[1] + ind_H = random.randint(0, H - self.args.patch_size) + ind_W = random.randint(0, W - self.args.patch_size) + + patch_1 = img_1[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size] + patch_2 = img_2[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size] + + return patch_1, patch_2 + + def _get_gt_name(self, rainy_name): + gt_name = 'data/' + 'Target/Derain/norain-' + rainy_name.split('rain-')[-1] + return gt_name + + def _get_nonhazy_name(self, hazy_name): + gt_name = 'data/' + 'Target/Dehaze/nohaze-' + rainy_name.split('haze-')[-1] + return gt_name + + def __getitem__(self, _): + de_id = self.de_dict[self.de_type[self.de_temp]] + + if de_id < 3: + if de_id == 0: + clean_id = self.s15_ids[self.s15_counter] + self.s15_counter = (self.s15_counter + 1) % self.num_clean + if self.s15_counter == 0: + random.shuffle(self.s15_ids) + elif de_id == 1: + clean_id = self.s25_ids[self.s25_counter] + self.s25_counter = (self.s25_counter + 1) % self.num_clean + if self.s25_counter == 0: + random.shuffle(self.s25_ids) + elif de_id == 2: + clean_id = self.s50_ids[self.s50_counter] + self.s50_counter = (self.s50_counter + 1) % self.num_clean + if self.s50_counter == 0: + random.shuffle(self.s50_ids) + + # clean_id = random.randint(0, len(self.clean_ids) - 1) + clean_img = crop_img(np.array(Image.open(clean_id).convert('RGB')), base=16) + clean_patch_1, clean_patch_2 = self.crop_transform(clean_img), self.crop_transform(clean_img) + clean_patch_1, clean_patch_2 = np.array(clean_patch_1), np.array(clean_patch_2) + + # clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0] + clean_name = clean_id.split("/")[-1].split('.')[0] + + clean_patch_1, clean_patch_2 = random_augmentation(clean_patch_1, clean_patch_2) + degrad_patch_1, degrad_patch_2 = self.D.degrade(clean_patch_1, clean_patch_2, de_id) + else: + if de_id == 3: + # Rain Streak Removal + # rl_id = random.randint(0, len(self.rl_ids) - 1) + degrad_img = crop_img(np.array(Image.open(self.rs_ids[self.rl_counter]).convert('RGB')), base=16) + clean_name = self._get_gt_name(self.rs_ids[self.rl_counter]) + clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16) + + self.rl_counter = (self.rl_counter + 1) % self.num_rl + if self.rl_counter == 0: + random.shuffle(self.rs_ids) + elif de_id == 4: + # Dehazing with SOTS outdoor training set + # hazy_id = random.randint(0, len(self.hazy_ids) - 1) + degrad_img = crop_img(np.array(Image.open(self.hazy_ids[self.hazy_counter]).convert('RGB')), base=16) + clean_name = self._get_nonhazy_name(self.hazy_ids[self.hazy_counter]) + clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16) + + self.hazy_counter = (self.hazy_counter + 1) % self.num_hazy + if self.hazy_counter == 0: + random.shuffle(self.hazy_ids) + degrad_patch_1, clean_patch_1 = random_augmentation(*self._crop_patch(degrad_img, clean_img)) + degrad_patch_2, clean_patch_2 = random_augmentation(*self._crop_patch(degrad_img, clean_img)) + + clean_patch_1, clean_patch_2 = self.toTensor(clean_patch_1), self.toTensor(clean_patch_2) + degrad_patch_1, degrad_patch_2 = self.toTensor(degrad_patch_1), self.toTensor(degrad_patch_2) + + self.de_temp = (self.de_temp + 1) % len(self.de_type) + if self.de_temp == 0: + random.shuffle(self.de_type) + + return [clean_name, de_id], degrad_patch_1, degrad_patch_2, clean_patch_1, clean_patch_2 + + def __len__(self): + return 400 * len(self.args.de_type) + + +class DenoiseTestDataset(Dataset): + def __init__(self, args): + super(DenoiseTestDataset, self).__init__() + self.args = args + self.clean_ids = [] + self.sigma = 15 + self.image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif'] + + self._init_clean_ids() + + self.toTensor = ToTensor() + + def _init_clean_ids(self): + clean_ids = [] + # ํŒŒ์ผ ๋ชฉ๋ก ์ค‘ ์ด๋ฏธ์ง€ ํŒŒ์ผ๋งŒ ํ•„ํ„ฐ๋ง + name_list = os.listdir(self.args.denoise_path) + name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions] + self.clean_ids += [self.args.denoise_path + id_ for id_ in name_list] + + self.num_clean = len(self.clean_ids) + + def _add_gaussian_noise(self, clean_patch): + noise = np.random.randn(*clean_patch.shape) + noisy_patch = np.clip(clean_patch + noise * self.sigma, 0, 255).astype(np.uint8) + return noisy_patch, clean_patch + + def set_sigma(self, sigma): + self.sigma = sigma + + def __getitem__(self, clean_id): + clean_img = crop_img(np.array(Image.open(self.clean_ids[clean_id]).convert('RGB')), base=16) + clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0] + + noisy_img, _ = self._add_gaussian_noise(clean_img) + clean_img, noisy_img = self.toTensor(clean_img), self.toTensor(noisy_img) + + return [clean_name], noisy_img, clean_img + + def __len__(self): + return self.num_clean + + +class DerainDehazeDataset(Dataset): + def __init__(self, args, task="derain"): + super(DerainDehazeDataset, self).__init__() + self.ids = [] + self.task_idx = 0 + self.args = args + self.image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif'] + + self.task_dict = {'derain': 0, 'dehaze': 1} + self.toTensor = ToTensor() + + self.set_dataset(task) + + def _init_input_ids(self): + if self.task_idx == 0: + self.ids = [] + # ํŒŒ์ผ ๋ชฉ๋ก ์ค‘ ์ด๋ฏธ์ง€ ํŒŒ์ผ๋งŒ ํ•„ํ„ฐ๋ง + name_list = os.listdir(self.args.derain_path + 'input/') + name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions] + self.ids += [self.args.derain_path + 'input/' + id_ for id_ in name_list] + elif self.task_idx == 1: + self.ids = [] + # ํŒŒ์ผ ๋ชฉ๋ก ์ค‘ ์ด๋ฏธ์ง€ ํŒŒ์ผ๋งŒ ํ•„ํ„ฐ๋ง + name_list = os.listdir(self.args.dehaze_path + 'input/') + name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions] + self.ids += [self.args.dehaze_path + 'input/' + id_ for id_ in name_list] + + self.length = len(self.ids) + + def _get_gt_path(self, degraded_name): + if self.task_idx == 0: + gt_name = '/'.join(degraded_name.replace("input", "target").split('/')[:-1] + degraded_name.replace("input", "target").replace("rain", "norain").split('/')[-1:]) + print(gt_name) + elif self.task_idx == 1: + dir_name = degraded_name.split("input")[0] + 'target/' + name = degraded_name.split('/')[-1].split('_')[0] + '.png' + gt_name = dir_name + name + return gt_name + + def set_dataset(self, task): + self.task_idx = self.task_dict[task] + self._init_input_ids() + + def __getitem__(self, idx): + degraded_path = self.ids[idx] + clean_path = self._get_gt_path(degraded_path) + + degraded_img = crop_img(np.array(Image.open(degraded_path).convert('RGB')), base=16) + clean_img = crop_img(np.array(Image.open(clean_path).convert('RGB')), base=16) + + clean_img, degraded_img = self.toTensor(clean_img), self.toTensor(degraded_img) + degraded_name = degraded_path.split('/')[-1][:-4] + + return [degraded_name], degraded_img, clean_img + + def __len__(self): + return self.length + + +class TestSpecificDataset(Dataset): + def __init__(self, args): + super(TestSpecificDataset, self).__init__() + self.args = args + self.degraded_ids = [] + self._init_clean_ids(args.test_path) + + self.toTensor = ToTensor() + + def _init_clean_ids(self, root): + degraded_ids = [] + # ํŒŒ์ผ ๋ชฉ๋ก ์ค‘ ์ด๋ฏธ์ง€ ํŒŒ์ผ๋งŒ ํ•„ํ„ฐ๋ง + name_list = os.listdir(root) + name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions] + self.degraded_ids += [root + id_ for id_ in name_list] + + self.num_img = len(self.degraded_ids) + + def __getitem__(self, idx): + degraded_img = crop_img(np.array(Image.open(self.degraded_ids[idx]).convert('RGB')), base=16) + name = self.degraded_ids[idx].split('/')[-1][:-4] + + degraded_img = self.toTensor(degraded_img) + + return [name], degraded_img + + def __len__(self): + return self.num_img diff --git a/utils/dataset_utils_CDD.py b/utils/dataset_utils_CDD.py new file mode 100644 index 0000000000000000000000000000000000000000..7297549511feb17b2945147a64a6621c6bac48a8 --- /dev/null +++ b/utils/dataset_utils_CDD.py @@ -0,0 +1,39 @@ +import os +import random +import copy +from PIL import Image +import numpy as np +import json + +from torch.utils.data import Dataset +from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor + +from utils.image_utils import random_augmentation, crop_img +from utils.degradation_utils import Degradation + + +class DerainDehazeDataset(Dataset): + def __init__(self, args, img, text_prompt, task="derain"): + super(DerainDehazeDataset, self).__init__() + self.args = args + self.toTensor = ToTensor() + self.img = img + self.text_prompt = text_prompt + + def __getitem__(self, idx): + degraded_inp = self.img + clean_path = "" + degradation = "" + + text_prompt = self.text_prompt + + degraded_img = crop_img(np.array(degraded_inp.convert('RGB')), base=16) + clean_img = crop_img(np.array(degraded_inp.convert('RGB')), base=16) + + clean_img, degraded_img = self.toTensor(clean_img), self.toTensor(degraded_img) + degraded_name = [""] + + return [degraded_name], degradation, degraded_img, clean_img, text_prompt + + def __len__(self): + return 1 \ No newline at end of file diff --git a/utils/degradation_utils.py b/utils/degradation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a50ddeccf6f8316f5abf8b3bb6e36942aed1bbff --- /dev/null +++ b/utils/degradation_utils.py @@ -0,0 +1,50 @@ +import torch +from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor, Grayscale + +from PIL import Image +import random +import numpy as np + +from utils.image_utils import crop_img + + +class Degradation(object): + def __init__(self, args): + super(Degradation, self).__init__() + self.args = args + self.toTensor = ToTensor() + self.crop_transform = Compose([ + ToPILImage(), + RandomCrop(args.patch_size), + ]) + + def _add_gaussian_noise(self, clean_patch, sigma): + # noise = torch.randn(*(clean_patch.shape)) + # clean_patch = self.toTensor(clean_patch) + noise = np.random.randn(*clean_patch.shape) + noisy_patch = np.clip(clean_patch + noise * sigma, 0, 255).astype(np.uint8) + # noisy_patch = torch.clamp(clean_patch + noise * sigma, 0, 255).type(torch.int32) + return noisy_patch, clean_patch + + def _degrade_by_type(self, clean_patch, degrade_type): + if degrade_type == 0: + # denoise sigma=15 + degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=15) + elif degrade_type == 1: + # denoise sigma=25 + degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=25) + elif degrade_type == 2: + # denoise sigma=50 + degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=50) + + return degraded_patch, clean_patch + + def degrade(self, clean_patch_1, clean_patch_2, degrade_type=None): + if degrade_type == None: + degrade_type = random.randint(0, 3) + else: + degrade_type = degrade_type + + degrad_patch_1, _ = self._degrade_by_type(clean_patch_1, degrade_type) + degrad_patch_2, _ = self._degrade_by_type(clean_patch_2, degrade_type) + return degrad_patch_1, degrad_patch_2 diff --git a/utils/image_io.py b/utils/image_io.py new file mode 100644 index 0000000000000000000000000000000000000000..d2b8d406e16eaf0e6552746687683d2bf4c8887d --- /dev/null +++ b/utils/image_io.py @@ -0,0 +1,414 @@ +import glob + +import torch +import torchvision +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +from PIL import Image + +# import skvideo.io + +matplotlib.use('agg') + + +def prepare_hazy_image(file_name): + img_pil = crop_image(get_image(file_name, -1)[0], d=32) + return pil_to_np(img_pil) + + +def prepare_gt_img(file_name, SOTS=True): + if SOTS: + img_pil = crop_image(crop_a_image(get_image(file_name, -1)[0], d=10), d=32) + else: + img_pil = crop_image(get_image(file_name, -1)[0], d=32) + + return pil_to_np(img_pil) + + +def crop_a_image(img, d=10): + bbox = [ + int((d)), + int((d)), + int((img.size[0] - d)), + int((img.size[1] - d)), + ] + img_cropped = img.crop(bbox) + return img_cropped + + +def crop_image(img, d=32): + """ + Make dimensions divisible by d + + :param pil img: + :param d: + :return: + """ + + new_size = (img.size[0] - img.size[0] % d, + img.size[1] - img.size[1] % d) + + bbox = [ + int((img.size[0] - new_size[0]) / 2), + int((img.size[1] - new_size[1]) / 2), + int((img.size[0] + new_size[0]) / 2), + int((img.size[1] + new_size[1]) / 2), + ] + + img_cropped = img.crop(bbox) + return img_cropped + + +def crop_np_image(img_np, d=32): + return torch_to_np(crop_torch_image(np_to_torch(img_np), d)) + + +def crop_torch_image(img, d=32): + """ + Make dimensions divisible by d + image is [1, 3, W, H] or [3, W, H] + :param pil img: + :param d: + :return: + """ + new_size = (img.shape[-2] - img.shape[-2] % d, + img.shape[-1] - img.shape[-1] % d) + pad = ((img.shape[-2] - new_size[-2]) // 2, (img.shape[-1] - new_size[-1]) // 2) + + if len(img.shape) == 4: + return img[:, :, pad[-2]: pad[-2] + new_size[-2], pad[-1]: pad[-1] + new_size[-1]] + assert len(img.shape) == 3 + return img[:, pad[-2]: pad[-2] + new_size[-2], pad[-1]: pad[-1] + new_size[-1]] + + +def get_params(opt_over, net, net_input, downsampler=None): + """ + Returns parameters that we want to optimize over. + :param opt_over: comma separated list, e.g. "net,input" or "net" + :param net: network + :param net_input: torch.Tensor that stores input `z` + :param downsampler: + :return: + """ + + opt_over_list = opt_over.split(',') + params = [] + + for opt in opt_over_list: + + if opt == 'net': + params += [x for x in net.parameters()] + elif opt == 'down': + assert downsampler is not None + params = [x for x in downsampler.parameters()] + elif opt == 'input': + net_input.requires_grad = True + params += [net_input] + else: + assert False, 'what is it?' + + return params + + +def get_image_grid(images_np, nrow=8): + """ + Creates a grid from a list of images by concatenating them. + :param images_np: + :param nrow: + :return: + """ + images_torch = [torch.from_numpy(x).type(torch.FloatTensor) for x in images_np] + torch_grid = torchvision.utils.make_grid(images_torch, nrow) + + return torch_grid.numpy() + + +def plot_image_grid(name, images_np, interpolation='lanczos', output_path="output/"): + """ + Draws images in a grid + + Args: + images_np: list of images, each image is np.array of size 3xHxW or 1xHxW + nrow: how many images will be in one row + interpolation: interpolation used in plt.imshow + """ + assert len(images_np) == 2 + n_channels = max(x.shape[0] for x in images_np) + assert (n_channels == 3) or (n_channels == 1), "images should have 1 or 3 channels" + + images_np = [x if (x.shape[0] == n_channels) else np.concatenate([x, x, x], axis=0) for x in images_np] + + grid = get_image_grid(images_np, 2) + + if images_np[0].shape[0] == 1: + plt.imshow(grid[0], cmap='gray', interpolation=interpolation) + else: + plt.imshow(grid.transpose(1, 2, 0), interpolation=interpolation) + + plt.savefig(output_path + "{}.png".format(name)) + + +def save_image_np(name, image_np, output_path="output/"): + p = np_to_pil(image_np) + p.save(output_path + "{}.png".format(name)) + + +def save_image_tensor(image_tensor, output_path="output/"): + image_np = torch_to_np(image_tensor) + # print(image_np.shape) + p = np_to_pil(image_np) + return p + + +def video_to_images(file_name, name): + video = prepare_video(file_name) + for i, f in enumerate(video): + save_image(name + "_{0:03d}".format(i), f) + + +def images_to_video(images_dir, name, gray=True): + num = len(glob.glob(images_dir + "/*.jpg")) + c = [] + for i in range(num): + if gray: + img = prepare_gray_image(images_dir + "/" + name + "_{}.jpg".format(i)) + else: + img = prepare_image(images_dir + "/" + name + "_{}.jpg".format(i)) + print(img.shape) + c.append(img) + save_video(name, np.array(c)) + + +def save_heatmap(name, image_np): + cmap = plt.get_cmap('jet') + + rgba_img = cmap(image_np) + rgb_img = np.delete(rgba_img, 3, 2) + save_image(name, rgb_img.transpose(2, 0, 1)) + + +def save_graph(name, graph_list, output_path="output/"): + plt.clf() + plt.plot(graph_list) + plt.savefig(output_path + name + ".png") + + +def create_augmentations(np_image): + """ + convention: original, left, upside-down, right, rot1, rot2, rot3 + :param np_image: + :return: + """ + aug = [np_image.copy(), np.rot90(np_image, 1, (1, 2)).copy(), + np.rot90(np_image, 2, (1, 2)).copy(), np.rot90(np_image, 3, (1, 2)).copy()] + flipped = np_image[:, ::-1, :].copy() + aug += [flipped.copy(), np.rot90(flipped, 1, (1, 2)).copy(), np.rot90(flipped, 2, (1, 2)).copy(), + np.rot90(flipped, 3, (1, 2)).copy()] + return aug + + +def create_video_augmentations(np_video): + """ + convention: original, left, upside-down, right, rot1, rot2, rot3 + :param np_video: + :return: + """ + aug = [np_video.copy(), np.rot90(np_video, 1, (2, 3)).copy(), + np.rot90(np_video, 2, (2, 3)).copy(), np.rot90(np_video, 3, (2, 3)).copy()] + flipped = np_video[:, :, ::-1, :].copy() + aug += [flipped.copy(), np.rot90(flipped, 1, (2, 3)).copy(), np.rot90(flipped, 2, (2, 3)).copy(), + np.rot90(flipped, 3, (2, 3)).copy()] + return aug + + +def save_graphs(name, graph_dict, output_path="output/"): + """ + + :param name: + :param dict graph_dict: a dict from the name of the list to the list itself. + :return: + """ + plt.clf() + fig, ax = plt.subplots() + for k, v in graph_dict.items(): + ax.plot(v, label=k) + # ax.semilogy(v, label=k) + ax.set_xlabel('iterations') + # ax.set_ylabel(name) + ax.set_ylabel('MSE-loss') + # ax.set_ylabel('PSNR') + plt.legend() + plt.savefig(output_path + name + ".png") + + +def load(path): + """Load PIL image.""" + img = Image.open(path) + return img + + +def get_image(path, imsize=-1): + """Load an image and resize to a cpecific size. + + Args: + path: path to image + imsize: tuple or scalar with dimensions; -1 for `no resize` + """ + img = load(path) + if isinstance(imsize, int): + imsize = (imsize, imsize) + + if imsize[0] != -1 and img.size != imsize: + if imsize[0] > img.size[0]: + img = img.resize(imsize, Image.BICUBIC) + else: + img = img.resize(imsize, Image.ANTIALIAS) + + img_np = pil_to_np(img) + # 3*460*620 + # print(np.shape(img_np)) + + return img, img_np + + +def prepare_gt(file_name): + """ + loads makes it divisible + :param file_name: + :return: the numpy representation of the image + """ + img = get_image(file_name, -1) + # print(img[0].size) + + img_pil = img[0].crop([10, 10, img[0].size[0] - 10, img[0].size[1] - 10]) + + img_pil = crop_image(img_pil, d=32) + + # img_pil = get_image(file_name, -1)[0] + # print(img_pil.size) + return pil_to_np(img_pil) + + +def prepare_image(file_name): + """ + loads makes it divisible + :param file_name: + :return: the numpy representation of the image + """ + img = get_image(file_name, -1) + # print(img[0].size) + # img_pil = img[0] + img_pil = crop_image(img[0], d=16) + # img_pil = get_image(file_name, -1)[0] + # print(img_pil.size) + return pil_to_np(img_pil) + + +# def prepare_video(file_name, folder="output/"): +# data = skvideo.io.vread(folder + file_name) +# return crop_torch_image(data.transpose(0, 3, 1, 2).astype(np.float32) / 255.)[:35] +# +# +# def save_video(name, video_np, output_path="output/"): +# outputdata = video_np * 255 +# outputdata = outputdata.astype(np.uint8) +# skvideo.io.vwrite(output_path + "{}.mp4".format(name), outputdata.transpose(0, 2, 3, 1)) + + +def prepare_gray_image(file_name): + img = prepare_image(file_name) + return np.array([np.mean(img, axis=0)]) + + +def pil_to_np(img_PIL, with_transpose=True): + """ + Converts image in PIL format to np.array. + + From W x H x C [0...255] to C x W x H [0..1] + """ + ar = np.array(img_PIL) + if len(ar.shape) == 3 and ar.shape[-1] == 4: + ar = ar[:, :, :3] + # this is alpha channel + if with_transpose: + if len(ar.shape) == 3: + ar = ar.transpose(2, 0, 1) + else: + ar = ar[None, ...] + + return ar.astype(np.float32) / 255. + + +def median(img_np_list): + """ + assumes C x W x H [0..1] + :param img_np_list: + :return: + """ + assert len(img_np_list) > 0 + l = len(img_np_list) + shape = img_np_list[0].shape + result = np.zeros(shape) + for c in range(shape[0]): + for w in range(shape[1]): + for h in range(shape[2]): + result[c, w, h] = sorted(i[c, w, h] for i in img_np_list)[l // 2] + return result + + +def average(img_np_list): + """ + assumes C x W x H [0..1] + :param img_np_list: + :return: + """ + assert len(img_np_list) > 0 + l = len(img_np_list) + shape = img_np_list[0].shape + result = np.zeros(shape) + for i in img_np_list: + result += i + return result / l + + +def np_to_pil(img_np): + """ + Converts image in np.array format to PIL image. + + From C x W x H [0..1] to W x H x C [0...255] + :param img_np: + :return: + """ + ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) + + if img_np.shape[0] == 1: + ar = ar[0] + else: + assert img_np.shape[0] == 3, img_np.shape + ar = ar.transpose(1, 2, 0) + + return Image.fromarray(ar) + + +def np_to_torch(img_np): + """ + Converts image in numpy.array to torch.Tensor. + + From C x W x H [0..1] to C x W x H [0..1] + + :param img_np: + :return: + """ + return torch.from_numpy(img_np)[None, :] + + +def torch_to_np(img_var): + """ + Converts an image in torch.Tensor format to np.array. + + From 1 x C x W x H [0..1] to C x W x H [0..1] + :param img_var: + :return: + """ + return img_var.detach().cpu().numpy()[0] diff --git a/utils/image_utils.py b/utils/image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0e327b36bc5194364da9df4ab28b8053943d4b81 --- /dev/null +++ b/utils/image_utils.py @@ -0,0 +1,303 @@ +""" +Created on 2020/9/8 + +@author: Boyun Li +""" +import os +import numpy as np +import torch +import random +import torch.nn as nn +from torch.nn import init +from PIL import Image + +class EdgeComputation(nn.Module): + def __init__(self, test=False): + super(EdgeComputation, self).__init__() + self.test = test + def forward(self, x): + if self.test: + x_diffx = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]) + x_diffy = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]) + + # y = torch.Tensor(x.size()).cuda() + y = torch.Tensor(x.size()) + y.fill_(0) + y[:, :, :, 1:] += x_diffx + y[:, :, :, :-1] += x_diffx + y[:, :, 1:, :] += x_diffy + y[:, :, :-1, :] += x_diffy + y = torch.sum(y, 1, keepdim=True) / 3 + y /= 4 + return y + else: + x_diffx = torch.abs(x[:, :, 1:] - x[:, :, :-1]) + x_diffy = torch.abs(x[:, 1:, :] - x[:, :-1, :]) + + y = torch.Tensor(x.size()) + y.fill_(0) + y[:, :, 1:] += x_diffx + y[:, :, :-1] += x_diffx + y[:, 1:, :] += x_diffy + y[:, :-1, :] += x_diffy + y = torch.sum(y, 0) / 3 + y /= 4 + return y.unsqueeze(0) + + +# randomly crop a patch from image +def crop_patch(im, pch_size): + H = im.shape[0] + W = im.shape[1] + ind_H = random.randint(0, H - pch_size) + ind_W = random.randint(0, W - pch_size) + pch = im[ind_H:ind_H + pch_size, ind_W:ind_W + pch_size] + return pch + + +# crop an image to the multiple of base +def crop_img(image, base=64): + h = image.shape[0] + w = image.shape[1] + crop_h = h % base + crop_w = w % base + return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :] + + +# image (H, W, C) -> patches (B, H, W, C) +def slice_image2patches(image, patch_size=64, overlap=0): + assert image.shape[0] % patch_size == 0 and image.shape[1] % patch_size == 0 + H = image.shape[0] + W = image.shape[1] + patches = [] + image_padding = np.pad(image, ((overlap, overlap), (overlap, overlap), (0, 0)), mode='edge') + for h in range(H // patch_size): + for w in range(W // patch_size): + idx_h = [h * patch_size, (h + 1) * patch_size + overlap] + idx_w = [w * patch_size, (w + 1) * patch_size + overlap] + patches.append(np.expand_dims(image_padding[idx_h[0]:idx_h[1], idx_w[0]:idx_w[1], :], axis=0)) + return np.concatenate(patches, axis=0) + + +# patches (B, H, W, C) -> image (H, W, C) +def splice_patches2image(patches, image_size, overlap=0): + assert len(image_size) > 1 + assert patches.shape[-3] == patches.shape[-2] + H = image_size[0] + W = image_size[1] + patch_size = patches.shape[-2] - overlap + image = np.zeros(image_size) + idx = 0 + for h in range(H // patch_size): + for w in range(W // patch_size): + image[h * patch_size:(h + 1) * patch_size, w * patch_size:(w + 1) * patch_size, :] = patches[idx, + overlap:patch_size + overlap, + overlap:patch_size + overlap, + :] + idx += 1 + return image + + +# def data_augmentation(image, mode): +# if mode == 0: +# # original +# out = image.numpy() +# elif mode == 1: +# # flip up and down +# out = np.flipud(image) +# elif mode == 2: +# # rotate counterwise 90 degree +# out = np.rot90(image, axes=(1, 2)) +# elif mode == 3: +# # rotate 90 degree and flip up and down +# out = np.rot90(image, axes=(1, 2)) +# out = np.flipud(out) +# elif mode == 4: +# # rotate 180 degree +# out = np.rot90(image, k=2, axes=(1, 2)) +# elif mode == 5: +# # rotate 180 degree and flip +# out = np.rot90(image, k=2, axes=(1, 2)) +# out = np.flipud(out) +# elif mode == 6: +# # rotate 270 degree +# out = np.rot90(image, k=3, axes=(1, 2)) +# elif mode == 7: +# # rotate 270 degree and flip +# out = np.rot90(image, k=3, axes=(1, 2)) +# out = np.flipud(out) +# else: +# raise Exception('Invalid choice of image transformation') +# return out + +def data_augmentation(image, mode): + if mode == 0: + # original + out = image.numpy() + elif mode == 1: + # flip up and down + out = np.flipud(image) + elif mode == 2: + # rotate counterwise 90 degree + out = np.rot90(image) + elif mode == 3: + # rotate 90 degree and flip up and down + out = np.rot90(image) + out = np.flipud(out) + elif mode == 4: + # rotate 180 degree + out = np.rot90(image, k=2) + elif mode == 5: + # rotate 180 degree and flip + out = np.rot90(image, k=2) + out = np.flipud(out) + elif mode == 6: + # rotate 270 degree + out = np.rot90(image, k=3) + elif mode == 7: + # rotate 270 degree and flip + out = np.rot90(image, k=3) + out = np.flipud(out) + else: + raise Exception('Invalid choice of image transformation') + return out + + +# def random_augmentation(*args): +# out = [] +# if random.randint(0, 1) == 1: +# flag_aug = random.randint(1, 7) +# for data in args: +# out.append(data_augmentation(data, flag_aug).copy()) +# else: +# for data in args: +# out.append(data) +# return out + +def random_augmentation(*args): + out = [] + flag_aug = random.randint(1, 7) + for data in args: + out.append(data_augmentation(data, flag_aug).copy()) + return out + + +def weights_init_normal_(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + init.uniform(m.weight.data, 0.0, 0.02) + elif classname.find('Linear') != -1: + init.uniform(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + init.uniform(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def weights_init_normal(m): + classname = m.__class__.__name__ + if classname.find('Conv2d') != -1: + m.apply(weights_init_normal_) + elif classname.find('Linear') != -1: + init.uniform(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + init.uniform(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def weights_init_xavier(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + init.xavier_normal(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + init.xavier_normal(m.weight.data, gain=1) + elif classname.find('BatchNorm2d') != -1: + init.uniform(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def weights_init_kaiming(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + elif classname.find('Linear') != -1: + init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + elif classname.find('BatchNorm2d') != -1: + init.uniform(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def weights_init_orthogonal(m): + classname = m.__class__.__name__ + print(classname) + if classname.find('Conv') != -1: + init.orthogonal(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + init.orthogonal(m.weight.data, gain=1) + elif classname.find('BatchNorm2d') != -1: + init.uniform(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def init_weights(net, init_type='normal'): + print('initialization method [%s]' % init_type) + if init_type == 'normal': + net.apply(weights_init_normal) + elif init_type == 'xavier': + net.apply(weights_init_xavier) + elif init_type == 'kaiming': + net.apply(weights_init_kaiming) + elif init_type == 'orthogonal': + net.apply(weights_init_orthogonal) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + + +def np_to_torch(img_np): + """ + Converts image in numpy.array to torch.Tensor. + + From C x W x H [0..1] to C x W x H [0..1] + + :param img_np: + :return: + """ + return torch.from_numpy(img_np)[None, :] + + +def torch_to_np(img_var): + """ + Converts an image in torch.Tensor format to np.array. + + From 1 x C x W x H [0..1] to C x W x H [0..1] + :param img_var: + :return: + """ + return img_var.detach().cpu().numpy() + # return img_var.detach().cpu().numpy()[0] + + +def save_image(name, image_np, output_path="output/normal/"): + if not os.path.exists(output_path): + os.mkdir(output_path) + + p = np_to_pil(image_np) + p.save(output_path + "{}.png".format(name)) + + +def np_to_pil(img_np): + """ + Converts image in np.array format to PIL image. + + From C x W x H [0..1] to W x H x C [0...255] + :param img_np: + :return: + """ + ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) + + if img_np.shape[0] == 1: + ar = ar[0] + else: + assert img_np.shape[0] == 3, img_np.shape + ar = ar.transpose(1, 2, 0) + + return Image.fromarray(ar) \ No newline at end of file diff --git a/utils/imresize.py b/utils/imresize.py new file mode 100644 index 0000000000000000000000000000000000000000..80cb1158b988fa496dab27054d14ad19ab28ccda --- /dev/null +++ b/utils/imresize.py @@ -0,0 +1,232 @@ +import numpy as np +from scipy.ndimage import filters, measurements, interpolation +from math import pi + + +def imresize(im, scale_factor=None, output_shape=None, kernel=None, antialiasing=True, kernel_shift_flag=False): + # First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa + scale_factor, output_shape = fix_scale_and_size(im.shape, output_shape, scale_factor) + + # For a given numeric kernel case, just do convolution and sub-sampling (downscaling only) + if type(kernel) == np.ndarray and scale_factor[0] <= 1: + return numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag) + + # Choose interpolation method, each method has the matching kernel size + method, kernel_width = { + "cubic": (cubic, 4.0), + "lanczos2": (lanczos2, 4.0), + "lanczos3": (lanczos3, 6.0), + "box": (box, 1.0), + "linear": (linear, 2.0), + None: (cubic, 4.0) # set default interpolation method as cubic + }.get(kernel) + + # Antialiasing is only used when downscaling + antialiasing *= (scale_factor[0] < 1) + + # Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient + sorted_dims = np.argsort(np.array(scale_factor)).tolist() + + # Iterate over dimensions to calculate local weights for resizing and resize each time in one direction + out_im = np.copy(im) + for dim in sorted_dims: + # No point doing calculations for scale-factor 1. nothing will happen anyway + if scale_factor[dim] == 1.0: + continue + + # for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the + # weights that multiply the values there to get its result. + weights, field_of_view = contributions(im.shape[dim], output_shape[dim], scale_factor[dim], + method, kernel_width, antialiasing) + + # Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim + out_im = resize_along_dim(out_im, dim, weights, field_of_view) + + return out_im + + +def fix_scale_and_size(input_shape, output_shape, scale_factor): + # First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the + # same size as the number of input dimensions) + if scale_factor is not None: + # By default, if scale-factor is a scalar we assume 2d resizing and duplicate it. + if np.isscalar(scale_factor): + scale_factor = [scale_factor, scale_factor] + + # We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales + scale_factor = list(scale_factor) + scale_factor.extend([1] * (len(input_shape) - len(scale_factor))) + + # Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size + # to all the unspecified dimensions + if output_shape is not None: + output_shape = list(np.uint(np.array(output_shape))) + list(input_shape[len(output_shape):]) + + # Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is + # sub-optimal, because there can be different scales to the same output-shape. + if scale_factor is None: + scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape) + + # Dealing with missing output-shape. calculating according to scale-factor + if output_shape is None: + output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor))) + + return scale_factor, output_shape + + +def contributions(in_length, out_length, scale, kernel, kernel_width, antialiasing): + # This function calculates a set of 'filters' and a set of field_of_view that will later on be applied + # such that each position from the field_of_view will be multiplied with a matching filter from the + # 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers + # around it. This is only done for one dimension of the image. + + # When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of + # 1/sf. this means filtering is more 'low-pass filter'. + fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel + kernel_width *= 1.0 / scale if antialiasing else 1.0 + + # These are the coordinates of the output image + out_coordinates = np.arange(1, out_length+1) + + # These are the matching positions of the output-coordinates on the input image coordinates. + # Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels: + # [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel. + # The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to + # the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big + # one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor). + # So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is + # at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means: + # (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf) + match_coordinates = 1.0 * out_coordinates / scale + 0.5 * (1 - 1.0 / scale) + + # This is the left boundary to start multiplying the filter from, it depends on the size of the filter + left_boundary = np.floor(match_coordinates - kernel_width / 2) + + # Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers + # of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them) + expanded_kernel_width = np.ceil(kernel_width) + 2 + + # Determine a set of field_of_view for each each output position, these are the pixels in the input image + # that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the + # vertical dim is the pixels it 'sees' (kernel_size + 2) + field_of_view = np.squeeze(np.uint(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1)) + + # Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the + # vertical dim is a list of weights matching to the pixel in the field of view (that are specified in + # 'field_of_view') + weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1) + + # Normalize weights to sum up to 1. be careful from dividing by 0 + sum_weights = np.sum(weights, axis=1) + sum_weights[sum_weights == 0] = 1.0 + weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1) + + # We use this mirror structure as a trick for reflection padding at the boundaries + mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1)))) + field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])] + + # Get rid of weights and pixel positions that are of zero weight + non_zero_out_pixels = np.nonzero(np.any(weights, axis=0)) + weights = np.squeeze(weights[:, non_zero_out_pixels]) + field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels]) + + # Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size + return weights, field_of_view + + +def resize_along_dim(im, dim, weights, field_of_view): + # To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize + tmp_im = np.swapaxes(im, dim, 0) + + # We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for + # tmp_im[field_of_view.T], (bsxfun style) + weights = np.reshape(weights.T, list(weights.T.shape) + (np.ndim(im) - 1) * [1]) + + # This is a bit of a complicated multiplication: tmp_im[field_of_view.T] is a tensor of order image_dims+1. + # for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim + # only, this is why it only adds 1 dim to the shape). We then multiply, for each pixel, its set of positions with + # the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style: + # matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the + # same number + tmp_out_im = np.sum(tmp_im[field_of_view.T] * weights, axis=0) + + # Finally we swap back the axes to the original order + return np.swapaxes(tmp_out_im, dim, 0) + + +def numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag): + # See kernel_shift function to understand what this is + if kernel_shift_flag: + kernel = kernel_shift(kernel, scale_factor) + + # First run a correlation (convolution with flipped kernel) + out_im = np.zeros_like(im) + for channel in range(np.ndim(im)): + out_im[:, :, channel] = filters.correlate(im[:, :, channel], kernel) + + # Then subsample and return + return out_im[np.round(np.linspace(0, im.shape[0] - 1 / scale_factor[0], output_shape[0])).astype(int)[:, None], + np.round(np.linspace(0, im.shape[1] - 1 / scale_factor[1], output_shape[1])).astype(int), :] + + +def kernel_shift(kernel, sf): + # There are two reasons for shifting the kernel: + # 1. Center of mass is not in the center of the kernel which creates ambiguity. There is no possible way to know + # the degradation process included shifting so we always assume center of mass is center of the kernel. + # 2. We further shift kernel center so that top left result pixel corresponds to the middle of the sfXsf first + # pixels. Default is for odd size to be in the middle of the first pixel and for even sized kernel to be at the + # top left corner of the first pixel. that is why different shift size needed between od and even size. + # Given that these two conditions are fulfilled, we are happy and aligned, the way to test it is as follows: + # The input image, when interpolated (regular bicubic) is exactly aligned with ground truth. + + # First calculate the current center of mass for the kernel + current_center_of_mass = measurements.center_of_mass(kernel) + + # The second ("+ 0.5 * ....") is for applying condition 2 from the comments above + wanted_center_of_mass = np.array(kernel.shape) / 2 + 0.5 * (sf - (kernel.shape[0] % 2)) + + # Define the shift vector for the kernel shifting (x,y) + shift_vec = wanted_center_of_mass - current_center_of_mass + + # Before applying the shift, we first pad the kernel so that nothing is lost due to the shift + # (biggest shift among dims + 1 for safety) + kernel = np.pad(kernel, np.int(np.ceil(np.max(shift_vec))) + 1, 'constant') + + # Finally shift the kernel and return + return interpolation.shift(kernel, shift_vec) + + +# These next functions are all interpolation methods. x is the distance from the left pixel center + + +def cubic(x): + absx = np.abs(x) + absx2 = absx ** 2 + absx3 = absx ** 3 + return ((1.5*absx3 - 2.5*absx2 + 1) * (absx <= 1) + + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * ((1 < absx) & (absx <= 2))) + + +def lanczos2(x): + return (((np.sin(pi*x) * np.sin(pi*x/2) + np.finfo(np.float32).eps) / + ((pi**2 * x**2 / 2) + np.finfo(np.float32).eps)) + * (abs(x) < 2)) + + +def box(x): + return ((-0.5 <= x) & (x < 0.5)) * 1.0 + + +def lanczos3(x): + return (((np.sin(pi*x) * np.sin(pi*x/3) + np.finfo(np.float32).eps) / + ((pi**2 * x**2 / 3) + np.finfo(np.float32).eps)) + * (abs(x) < 3)) + + +def linear(x): + return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1)) + + +def np_imresize(im, scale_factor=None, output_shape=None, kernel=None, antialiasing=True, kernel_shift_flag=False): + return np.clip(imresize(im.transpose(1, 2, 0), scale_factor, output_shape, kernel, antialiasing, + kernel_shift_flag).transpose(2, 0, 1), 0, 1) \ No newline at end of file diff --git a/utils/loss_utils.py b/utils/loss_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ac2b21c6e428300f8c0ddd4a5fab9b90668554a5 --- /dev/null +++ b/utils/loss_utils.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn +from torch.nn.functional import mse_loss + + +class GANLoss(nn.Module): + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, + tensor=torch.FloatTensor): + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_var = None + self.fake_label_var = None + self.Tensor = tensor + if use_lsgan: + self.loss = nn.MSELoss() + else: + self.loss = nn.BCELoss() + + def get_target_tensor(self, input, target_is_real): + target_tensor = None + if target_is_real: + create_label = ((self.real_label_var is None) or(self.real_label_var.numel() != input.numel())) + # pdb.set_trace() + if create_label: + real_tensor = self.Tensor(input.size()).fill_(self.real_label) + # self.real_label_var = Variable(real_tensor, requires_grad=False) + # self.real_label_var = torch.Tensor(real_tensor) + self.real_label_var = real_tensor + target_tensor = self.real_label_var + else: + # pdb.set_trace() + create_label = ((self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel())) + if create_label: + fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) + # self.fake_label_var = Variable(fake_tensor, requires_grad=False) + # self.fake_label_var = torch.Tensor(fake_tensor) + self.fake_label_var = fake_tensor + target_tensor = self.fake_label_var + return target_tensor + + def __call__(self, input, target_is_real): + target_tensor = self.get_target_tensor(input, target_is_real) + # pdb.set_trace() + return self.loss(input, target_tensor) + diff --git a/utils/pytorch_ssim/__init__.py b/utils/pytorch_ssim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b24134a65a803e405632e2a68198a8cdd14fcad6 --- /dev/null +++ b/utils/pytorch_ssim/__init__.py @@ -0,0 +1,78 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from math import exp + +# Matlab style 1D gaussian filter. +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) + return gauss/gauss.sum() + +# Matlab style n_D gaussian filter. +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + +def _ssim(img1, img2, window, window_size, channel, size_average = True): + mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) + mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1*mu2 + + sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq + sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq + sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) + + + # I added this for sm +# ssim_map = torch.exp(1 + ssim_map) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + +class SSIM(torch.nn.Module): + def __init__(self, window_size = 11, size_average = True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + +def ssim(img1, img2, window_size = 11, size_average = True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) diff --git a/utils/pytorch_ssim/__init__.pyc b/utils/pytorch_ssim/__init__.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e710ebc50971e08c2fed4bde40a5ba833bbeec0f Binary files /dev/null and b/utils/pytorch_ssim/__init__.pyc differ diff --git a/utils/pytorch_ssim/__pycache__/__init__.cpython-36.pyc b/utils/pytorch_ssim/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36140813b96330811d16726e82f430e09837218d Binary files /dev/null and b/utils/pytorch_ssim/__pycache__/__init__.cpython-36.pyc differ diff --git a/utils/pytorch_ssim/__pycache__/__init__.cpython-38.pyc b/utils/pytorch_ssim/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9adfdd57b44597751120b6e088e9c7169e20e5dc Binary files /dev/null and b/utils/pytorch_ssim/__pycache__/__init__.cpython-38.pyc differ diff --git a/utils/val_utils.py b/utils/val_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..afa21746f744927343bac03fc8cdc7aae4f54a17 --- /dev/null +++ b/utils/val_utils.py @@ -0,0 +1,97 @@ + +import time +import numpy as np +from skimage.metrics import peak_signal_noise_ratio, structural_similarity +from skvideo.measure import niqe + + +class AverageMeter(): + """ Computes and stores the average and current value """ + + def __init__(self): + self.reset() + + def reset(self): + """ Reset all statistics """ + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + """ Update statistics """ + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def accuracy(output, target, topk=(1,)): + """ Computes the precision@k for the specified values of k """ + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + # one-hot case + if target.ndimension() > 1: + target = target.max(1)[1] + + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(1.0 / batch_size)) + + return res + + +def compute_psnr_ssim(recoverd, clean): + assert recoverd.shape == clean.shape + recoverd = np.clip(recoverd.detach().cpu().numpy(), 0, 1) + clean = np.clip(clean.detach().cpu().numpy(), 0, 1) + + recoverd = recoverd.transpose(0, 2, 3, 1) + clean = clean.transpose(0, 2, 3, 1) + psnr = 0 + ssim = 0 + for i in range(recoverd.shape[0]): + print(f"Clean patch size: {clean[i].shape}, Restored size: {recoverd[i].shape}") + # psnr_val += compare_psnr(clean[i], recoverd[i]) + # ssim += compare_ssim(clean[i], recoverd[i], multichannel=True) + psnr += peak_signal_noise_ratio(clean[i], recoverd[i], data_range=1) + ssim += structural_similarity(clean[i], recoverd[i], data_range=1, multichannel=True, win_size=3) + + return psnr / recoverd.shape[0], ssim / recoverd.shape[0], recoverd.shape[0] + + +def compute_niqe(image): + image = np.clip(image.detach().cpu().numpy(), 0, 1) + image = image.transpose(0, 2, 3, 1) + niqe_val = niqe(image) + + return niqe_val.mean() + +class timer(): + def __init__(self): + self.acc = 0 + self.tic() + + def tic(self): + self.t0 = time.time() + + def toc(self): + return time.time() - self.t0 + + def hold(self): + self.acc += self.toc() + + def release(self): + ret = self.acc + self.acc = 0 + + return ret + + def reset(self): + self.acc = 0 \ No newline at end of file