Spaces:
Running
on
Zero
Running
on
Zero
gokaygokay
commited on
Commit
•
2f4febc
1
Parent(s):
f17a2ad
full_files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +146 -201
- configs/inference/controlnet_c_3b_canny.yaml +14 -0
- configs/inference/controlnet_c_3b_identity.yaml +17 -0
- configs/inference/controlnet_c_3b_inpainting.yaml +15 -0
- configs/inference/controlnet_c_3b_sr.yaml +15 -0
- configs/inference/lora_c_3b.yaml +15 -0
- configs/inference/stage_b_1b.yaml +13 -0
- configs/inference/stage_b_3b.yaml +13 -0
- configs/inference/stage_c_1b.yaml +7 -0
- configs/inference/stage_c_3b.yaml +7 -0
- configs/training/cfg_control_lr.yaml +47 -0
- configs/training/lora_personalization.yaml +37 -0
- configs/training/t2i.yaml +29 -0
- core/__init__.py +372 -0
- core/data/__init__.py +69 -0
- core/data/bucketeer.py +88 -0
- core/data/bucketeer_deg.py +91 -0
- core/data/deg_kair_utils/utils_alignfaces.py +263 -0
- core/data/deg_kair_utils/utils_blindsr.py +631 -0
- core/data/deg_kair_utils/utils_bnorm.py +91 -0
- core/data/deg_kair_utils/utils_deblur.py +655 -0
- core/data/deg_kair_utils/utils_dist.py +201 -0
- core/data/deg_kair_utils/utils_googledownload.py +93 -0
- core/data/deg_kair_utils/utils_image.py +1016 -0
- core/data/deg_kair_utils/utils_lmdb.py +205 -0
- core/data/deg_kair_utils/utils_logger.py +66 -0
- core/data/deg_kair_utils/utils_mat.py +88 -0
- core/data/deg_kair_utils/utils_matconvnet.py +197 -0
- core/data/deg_kair_utils/utils_model.py +330 -0
- core/data/deg_kair_utils/utils_modelsummary.py +485 -0
- core/data/deg_kair_utils/utils_option.py +255 -0
- core/data/deg_kair_utils/utils_params.py +135 -0
- core/data/deg_kair_utils/utils_receptivefield.py +62 -0
- core/data/deg_kair_utils/utils_regularizers.py +104 -0
- core/data/deg_kair_utils/utils_sisr.py +848 -0
- core/data/deg_kair_utils/utils_video.py +493 -0
- core/data/deg_kair_utils/utils_videoio.py +555 -0
- core/scripts/__init__.py +0 -0
- core/scripts/cli.py +41 -0
- core/templates/__init__.py +1 -0
- core/templates/diffusion.py +236 -0
- core/utils/__init__.py +9 -0
- core/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- core/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- core/utils/__pycache__/base_dto.cpython-310.pyc +0 -0
- core/utils/__pycache__/base_dto.cpython-39.pyc +0 -0
- core/utils/__pycache__/save_and_load.cpython-310.pyc +0 -0
- core/utils/__pycache__/save_and_load.cpython-39.pyc +0 -0
- core/utils/base_dto.py +56 -0
- core/utils/save_and_load.py +59 -0
app.py
CHANGED
@@ -1,213 +1,158 @@
|
|
1 |
import spaces
|
2 |
-
import json
|
3 |
-
import subprocess
|
4 |
import os
|
5 |
-
import
|
6 |
-
|
7 |
-
|
8 |
-
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
9 |
-
output, error = process.communicate()
|
10 |
-
if process.returncode != 0:
|
11 |
-
print(f"Error executing command: {command}")
|
12 |
-
print(error.decode('utf-8'))
|
13 |
-
exit(1)
|
14 |
-
return output.decode('utf-8')
|
15 |
-
|
16 |
-
# Download CUDA installer
|
17 |
-
download_command = "wget https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
|
18 |
-
result = run_command(download_command)
|
19 |
-
if result is None:
|
20 |
-
print("Failed to download CUDA installer.")
|
21 |
-
exit(1)
|
22 |
-
|
23 |
-
# Run CUDA installer in silent mode
|
24 |
-
install_command = "sh cuda_12.2.0_535.54.03_linux.run --silent --toolkit --samples --override"
|
25 |
-
result = run_command(install_command)
|
26 |
-
if result is None:
|
27 |
-
print("Failed to run CUDA installer.")
|
28 |
-
exit(1)
|
29 |
-
|
30 |
-
print("CUDA installation process completed.")
|
31 |
-
|
32 |
-
def install_packages():
|
33 |
-
|
34 |
-
# Clone the repository with submodules
|
35 |
-
run_command("git clone --recurse-submodules https://github.com/abetlen/llama-cpp-python.git")
|
36 |
-
|
37 |
-
# Change to the cloned directory
|
38 |
-
os.chdir("llama-cpp-python")
|
39 |
-
|
40 |
-
# Checkout the specific commit in the llama.cpp submodule
|
41 |
-
os.chdir("vendor/llama.cpp")
|
42 |
-
run_command("git checkout 50e0535")
|
43 |
-
os.chdir("../..")
|
44 |
-
|
45 |
-
# Upgrade pip
|
46 |
-
run_command("pip install --upgrade pip")
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
# Install all optional dependencies with CUDA support
|
51 |
-
run_command('CMAKE_ARGS="-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DCUDA_PATH=/usr/local/cuda-12.2 -DCUDAToolkit_ROOT=/usr/local/cuda-12.2 -DCUDAToolkit_INCLUDE_DIR=/usr/local/cuda-12/include -DCUDAToolkit_LIBRARY_DIR=/usr/local/cuda-12.2/lib64" FORCE_CMAKE=1 pip install -e .')
|
52 |
-
|
53 |
-
run_command("make clean && GGML_OPENBLAS=1 make -j")
|
54 |
-
|
55 |
-
# Reinstall the package with CUDA support
|
56 |
-
run_command('CMAKE_ARGS="-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DCUDA_PATH=/usr/local/cuda-12.2 -DCUDAToolkit_ROOT=/usr/local/cuda-12.2 -DCUDAToolkit_INCLUDE_DIR=/usr/local/cuda-12/include -DCUDAToolkit_LIBRARY_DIR=/usr/local/cuda-12.2/lib64" FORCE_CMAKE=1 pip install -e .')
|
57 |
-
|
58 |
-
# Install llama-cpp-agent
|
59 |
-
run_command("pip install llama-cpp-agent")
|
60 |
-
|
61 |
-
run_command("export PYTHONPATH=$PYTHONPATH:$(pwd)")
|
62 |
-
|
63 |
-
print("Installation complete!")
|
64 |
-
|
65 |
-
try:
|
66 |
-
install_packages()
|
67 |
-
|
68 |
-
# Add a delay to allow for package registration
|
69 |
-
import time
|
70 |
-
time.sleep(5)
|
71 |
-
|
72 |
-
# Force Python to reload the site packages
|
73 |
-
import site
|
74 |
-
import importlib
|
75 |
-
importlib.reload(site)
|
76 |
-
|
77 |
-
# Now try to import the libraries
|
78 |
-
from llama_cpp import Llama
|
79 |
-
from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
|
80 |
-
from llama_cpp_agent.providers import LlamaCppPythonProvider
|
81 |
-
from llama_cpp_agent.chat_history import BasicChatHistory
|
82 |
-
from llama_cpp_agent.chat_history.messages import Roles
|
83 |
-
|
84 |
-
print("Libraries imported successfully!")
|
85 |
-
except Exception as e:
|
86 |
-
print(f"Installation failed or libraries couldn't be imported: {str(e)}")
|
87 |
-
sys.exit(1)
|
88 |
-
|
89 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
from huggingface_hub import hf_hub_download
|
91 |
|
92 |
-
hf_hub_download(
|
93 |
-
repo_id="MaziyarPanahi/Mistral-Nemo-Instruct-2407-GGUF",
|
94 |
-
filename="Mistral-Nemo-Instruct-2407.Q5_K_M.gguf",
|
95 |
-
local_dir="./models"
|
96 |
-
)
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
)
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
-
|
130 |
-
|
131 |
-
settings.top_k = top_k
|
132 |
-
settings.top_p = top_p
|
133 |
-
settings.max_tokens = max_tokens
|
134 |
-
settings.repeat_penalty = repeat_penalty
|
135 |
-
settings.stream = True
|
136 |
-
|
137 |
-
messages = BasicChatHistory()
|
138 |
-
|
139 |
-
for msn in history:
|
140 |
-
user = {
|
141 |
-
'role': Roles.user,
|
142 |
-
'content': msn[0]
|
143 |
-
}
|
144 |
-
assistant = {
|
145 |
-
'role': Roles.assistant,
|
146 |
-
'content': msn[1]
|
147 |
-
}
|
148 |
-
messages.add_message(user)
|
149 |
-
messages.add_message(assistant)
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
gr.Textbox(
|
175 |
-
gr.Slider(minimum=
|
176 |
-
gr.Slider(minimum=
|
177 |
-
gr.
|
178 |
-
minimum=0.1,
|
179 |
-
maximum=1.0,
|
180 |
-
value=0.95,
|
181 |
-
step=0.05,
|
182 |
-
label="Top-p",
|
183 |
-
),
|
184 |
-
gr.Slider(
|
185 |
-
minimum=0,
|
186 |
-
maximum=100,
|
187 |
-
value=40,
|
188 |
-
step=1,
|
189 |
-
label="Top-k",
|
190 |
-
),
|
191 |
-
gr.Slider(
|
192 |
-
minimum=0.0,
|
193 |
-
maximum=2.0,
|
194 |
-
value=1.1,
|
195 |
-
step=0.1,
|
196 |
-
label="Repetition penalty",
|
197 |
-
),
|
198 |
],
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
title="Chat with Mistral-NeMo using llama.cpp",
|
204 |
-
description=description,
|
205 |
-
chatbot=gr.Chatbot(
|
206 |
-
scale=1,
|
207 |
-
likeable=False,
|
208 |
-
show_copy_button=True
|
209 |
-
)
|
210 |
)
|
211 |
|
212 |
-
|
213 |
-
demo.launch(debug=True)
|
|
|
1 |
import spaces
|
|
|
|
|
2 |
import os
|
3 |
+
import requests
|
4 |
+
import yaml
|
5 |
+
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import gradio as gr
|
7 |
+
from PIL import Image
|
8 |
+
import sys
|
9 |
+
sys.path.append(os.path.abspath('./'))
|
10 |
+
from inference.utils import *
|
11 |
+
from core.utils import load_or_fail
|
12 |
+
from train import WurstCoreB
|
13 |
+
from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
|
14 |
+
from train import WurstCore_t2i as WurstCoreC
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from core.utils import load_or_fail
|
17 |
+
import numpy as np
|
18 |
+
import random
|
19 |
+
import math
|
20 |
+
from einops import rearrange
|
21 |
from huggingface_hub import hf_hub_download
|
22 |
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
def download_file(url, folder_path, filename):
|
25 |
+
if not os.path.exists(folder_path):
|
26 |
+
os.makedirs(folder_path)
|
27 |
+
file_path = os.path.join(folder_path, filename)
|
28 |
+
|
29 |
+
if os.path.isfile(file_path):
|
30 |
+
print(f"File already exists: {file_path}")
|
31 |
+
else:
|
32 |
+
response = requests.get(url, stream=True)
|
33 |
+
if response.status_code == 200:
|
34 |
+
with open(file_path, 'wb') as file:
|
35 |
+
for chunk in response.iter_content(chunk_size=1024):
|
36 |
+
file.write(chunk)
|
37 |
+
print(f"File successfully downloaded and saved: {file_path}")
|
38 |
+
else:
|
39 |
+
print(f"Error downloading the file. Status code: {response.status_code}")
|
40 |
+
|
41 |
+
def download_models():
|
42 |
+
models = {
|
43 |
+
"STABLEWURST_A": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors?download=true", "models/StableWurst", "stage_a.safetensors"),
|
44 |
+
"STABLEWURST_PREVIEWER": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors?download=true", "models/StableWurst", "previewer.safetensors"),
|
45 |
+
"STABLEWURST_EFFNET": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors?download=true", "models/StableWurst", "effnet_encoder.safetensors"),
|
46 |
+
"STABLEWURST_B_LITE": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors?download=true", "models/StableWurst", "stage_b_lite_bf16.safetensors"),
|
47 |
+
"STABLEWURST_C": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors?download=true", "models/StableWurst", "stage_c_bf16.safetensors"),
|
48 |
+
"ULTRAPIXEL_T2I": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/ultrapixel_t2i.safetensors?download=true", "models/UltraPixel", "ultrapixel_t2i.safetensors"),
|
49 |
+
"ULTRAPIXEL_LORA_CAT": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/lora_cat.safetensors?download=true", "models/UltraPixel", "lora_cat.safetensors"),
|
50 |
+
}
|
51 |
+
|
52 |
+
for model, (url, folder, filename) in models.items():
|
53 |
+
download_file(url, folder, filename)
|
54 |
+
|
55 |
+
download_models()
|
56 |
+
|
57 |
+
# Global variables
|
58 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
59 |
+
dtype = torch.bfloat16
|
60 |
+
|
61 |
+
# Load configs and setup models
|
62 |
+
with open("configs/training/t2i.yaml", "r", encoding="utf-8") as file:
|
63 |
+
config_c = yaml.safe_load(file)
|
64 |
+
|
65 |
+
with open("configs/inference/stage_b_1b.yaml", "r", encoding="utf-8") as file:
|
66 |
+
config_b = yaml.safe_load(file)
|
67 |
+
|
68 |
+
core = WurstCoreC(config_dict=config_c, device=device, training=False)
|
69 |
+
core_b = WurstCoreB(config_dict=config_b, device=device, training=False)
|
70 |
+
|
71 |
+
extras = core.setup_extras_pre()
|
72 |
+
models = core.setup_models(extras)
|
73 |
+
models.generator.eval().requires_grad_(False)
|
74 |
+
|
75 |
+
extras_b = core_b.setup_extras_pre()
|
76 |
+
models_b = core_b.setup_models(extras_b, skip_clip=True)
|
77 |
+
models_b = WurstCoreB.Models(
|
78 |
+
**{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
|
79 |
)
|
80 |
+
models_b.generator.bfloat16().eval().requires_grad_(False)
|
81 |
+
|
82 |
+
# Load pretrained model
|
83 |
+
pretrained_path = "models/ultrapixel_t2i.safetensors"
|
84 |
+
sdd = torch.load(pretrained_path, map_location='cpu')
|
85 |
+
collect_sd = {k[7:]: v for k, v in sdd.items()}
|
86 |
+
models.train_norm.load_state_dict(collect_sd)
|
87 |
+
models.generator.eval()
|
88 |
+
models.train_norm.eval()
|
89 |
+
|
90 |
+
# Set up sampling configurations
|
91 |
+
extras.sampling_configs.update({
|
92 |
+
'cfg': 4,
|
93 |
+
'shift': 1,
|
94 |
+
'timesteps': 20,
|
95 |
+
't_start': 1.0,
|
96 |
+
'sampler': DDPMSampler(extras.gdf)
|
97 |
+
})
|
98 |
+
|
99 |
+
extras_b.sampling_configs.update({
|
100 |
+
'cfg': 1.1,
|
101 |
+
'shift': 1,
|
102 |
+
'timesteps': 10,
|
103 |
+
't_start': 1.0
|
104 |
+
})
|
105 |
+
|
106 |
+
@spaces.GPU
|
107 |
+
def generate_image(prompt, height, width, seed):
|
108 |
+
torch.manual_seed(seed)
|
109 |
+
random.seed(seed)
|
110 |
+
np.random.seed(seed)
|
111 |
+
|
112 |
+
batch_size = 1
|
113 |
+
height_lr, width_lr = get_target_lr_size(height / width, std_size=32)
|
114 |
+
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
|
115 |
+
stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size)
|
116 |
+
|
117 |
+
batch = {'captions': [prompt] * batch_size}
|
118 |
+
conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
|
119 |
+
unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
|
120 |
|
121 |
+
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
|
122 |
+
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
+
with torch.no_grad():
|
125 |
+
models.generator.cuda()
|
126 |
+
with torch.cuda.amp.autocast(dtype=dtype):
|
127 |
+
sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device)
|
128 |
+
|
129 |
+
models.generator.cpu()
|
130 |
+
torch.cuda.empty_cache()
|
131 |
+
|
132 |
+
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
|
133 |
+
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
|
134 |
+
conditions_b['effnet'] = sampled_c
|
135 |
+
unconditions_b['effnet'] = torch.zeros_like(sampled_c)
|
136 |
+
|
137 |
+
with torch.cuda.amp.autocast(dtype=dtype):
|
138 |
+
sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=True)
|
139 |
+
|
140 |
+
torch.cuda.empty_cache()
|
141 |
+
imgs = show_images(sampled)
|
142 |
+
return imgs[0]
|
143 |
+
|
144 |
+
iface = gr.Interface(
|
145 |
+
fn=generate_image,
|
146 |
+
inputs=[
|
147 |
+
gr.Textbox(label="Prompt"),
|
148 |
+
gr.Slider(minimum=256, maximum=2560, step=32, label="Height", value=1024),
|
149 |
+
gr.Slider(minimum=256, maximum=5120, step=32, label="Width", value=1024),
|
150 |
+
gr.Number(label="Seed", value=42)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
],
|
152 |
+
outputs=gr.Image(type="pil"),
|
153 |
+
title="UltraPixel Image Generation",
|
154 |
+
description="Generate high-resolution images using UltraPixel model.",
|
155 |
+
theme='bethecloud/storj_theme'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
)
|
157 |
|
158 |
+
iface.launch()
|
|
configs/inference/controlnet_c_3b_canny.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
model_version: 3.6B
|
3 |
+
dtype: bfloat16
|
4 |
+
|
5 |
+
# ControlNet specific
|
6 |
+
controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
|
7 |
+
controlnet_filter: CannyFilter
|
8 |
+
controlnet_filter_params:
|
9 |
+
resize: 224
|
10 |
+
|
11 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
12 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
13 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
14 |
+
controlnet_checkpoint_path: models/canny.safetensors
|
configs/inference/controlnet_c_3b_identity.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
model_version: 3.6B
|
3 |
+
dtype: bfloat16
|
4 |
+
|
5 |
+
# ControlNet specific
|
6 |
+
controlnet_bottleneck_mode: 'simple'
|
7 |
+
controlnet_blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
|
8 |
+
controlnet_filter: IdentityFilter
|
9 |
+
controlnet_filter_params:
|
10 |
+
max_faces: 4
|
11 |
+
p_drop: 0.00
|
12 |
+
p_full: 0.0
|
13 |
+
|
14 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
15 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
16 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
17 |
+
controlnet_checkpoint_path:
|
configs/inference/controlnet_c_3b_inpainting.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
model_version: 3.6B
|
3 |
+
dtype: bfloat16
|
4 |
+
|
5 |
+
# ControlNet specific
|
6 |
+
controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
|
7 |
+
controlnet_filter: InpaintFilter
|
8 |
+
controlnet_filter_params:
|
9 |
+
thresold: [0.04, 0.4]
|
10 |
+
p_outpaint: 0.4
|
11 |
+
|
12 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
13 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
14 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
15 |
+
controlnet_checkpoint_path: models/inpainting.safetensors
|
configs/inference/controlnet_c_3b_sr.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
model_version: 3.6B
|
3 |
+
dtype: bfloat16
|
4 |
+
|
5 |
+
# ControlNet specific
|
6 |
+
controlnet_bottleneck_mode: 'large'
|
7 |
+
controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
|
8 |
+
controlnet_filter: SREffnetFilter
|
9 |
+
controlnet_filter_params:
|
10 |
+
scale_factor: 0.5
|
11 |
+
|
12 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
13 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
14 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
15 |
+
controlnet_checkpoint_path: models/super_resolution.safetensors
|
configs/inference/lora_c_3b.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
model_version: 3.6B
|
3 |
+
dtype: bfloat16
|
4 |
+
|
5 |
+
# LoRA specific
|
6 |
+
module_filters: ['.attn']
|
7 |
+
rank: 4
|
8 |
+
train_tokens:
|
9 |
+
# - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
|
10 |
+
- ['[fernando]', '^dog</w>'] # custom token [snail], initialize as avg of snail & snails
|
11 |
+
|
12 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
13 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
14 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
15 |
+
lora_checkpoint_path: models/lora_fernando_10k.safetensors
|
configs/inference/stage_b_1b.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
model_version: 700M
|
3 |
+
dtype: bfloat16
|
4 |
+
|
5 |
+
# For demonstration purposes in reconstruct_images.ipynb
|
6 |
+
webdataset_path: path to your dataset
|
7 |
+
batch_size: 1
|
8 |
+
image_size: 2048
|
9 |
+
grad_accum_steps: 1
|
10 |
+
|
11 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
12 |
+
stage_a_checkpoint_path: models/stage_a.safetensors
|
13 |
+
generator_checkpoint_path: models/stage_b_lite_bf16.safetensors
|
configs/inference/stage_b_3b.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
model_version: 3B
|
3 |
+
dtype: bfloat16
|
4 |
+
|
5 |
+
# For demonstration purposes in reconstruct_images.ipynb
|
6 |
+
webdataset_path: path to your dataset
|
7 |
+
batch_size: 4
|
8 |
+
image_size: 1024
|
9 |
+
grad_accum_steps: 1
|
10 |
+
|
11 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
12 |
+
stage_a_checkpoint_path: models/stage_a.safetensors
|
13 |
+
generator_checkpoint_path: models/stage_b_lite_bf16.safetensors
|
configs/inference/stage_c_1b.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
model_version: 1B
|
3 |
+
dtype: bfloat16
|
4 |
+
|
5 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
6 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
7 |
+
generator_checkpoint_path: models/stage_c_lite_bf16.safetensors
|
configs/inference/stage_c_3b.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
model_version: 3.6B
|
3 |
+
dtype: bfloat16
|
4 |
+
|
5 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
6 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
7 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
configs/training/cfg_control_lr.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
experiment_id: Ultrapixel_controlnet
|
3 |
+
|
4 |
+
checkpoint_path: checkpoint output path
|
5 |
+
output_path: visual results output path
|
6 |
+
model_version: 3.6B
|
7 |
+
dtype: float32
|
8 |
+
# # WandB
|
9 |
+
# wandb_project: StableCascade
|
10 |
+
# wandb_entity: wandb_username
|
11 |
+
#module_filters: ['.depthwise', '.mapper', '.attn', '.channelwise' ]
|
12 |
+
#rank: 32
|
13 |
+
# TRAINING PARAMS
|
14 |
+
lr: 1.0e-4
|
15 |
+
batch_size: 12
|
16 |
+
#image_size: [1536, 2048, 2560, 3072, 4096]
|
17 |
+
image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608]
|
18 |
+
#image_size: [ 1024, 1536, 2048, 2560, 3072, 3584, 3840, 4096, 4608]
|
19 |
+
#image_size: [ 1024, 1280]
|
20 |
+
multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
|
21 |
+
grad_accum_steps: 2
|
22 |
+
updates: 40000
|
23 |
+
backup_every: 5000
|
24 |
+
save_every: 256
|
25 |
+
warmup_updates: 1
|
26 |
+
use_fsdp: True
|
27 |
+
|
28 |
+
# ControlNet specific
|
29 |
+
controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
|
30 |
+
controlnet_filter: CannyFilter
|
31 |
+
controlnet_filter_params:
|
32 |
+
resize: 224
|
33 |
+
# offset_noise: 0.1
|
34 |
+
|
35 |
+
# GDF
|
36 |
+
adaptive_loss_weight: True
|
37 |
+
|
38 |
+
ema_start_iters: 10
|
39 |
+
ema_iters: 50
|
40 |
+
ema_beta: 0.9
|
41 |
+
|
42 |
+
webdataset_path: path to your training dataset
|
43 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
44 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
45 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
46 |
+
controlnet_checkpoint_path: pretrained controlnet path
|
47 |
+
|
configs/training/lora_personalization.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
experiment_id: roubao_cat_personalized
|
3 |
+
|
4 |
+
checkpoint_path: checkpoint output path
|
5 |
+
output_path: visual results output path
|
6 |
+
model_version: 3.6B
|
7 |
+
dtype: float32
|
8 |
+
|
9 |
+
module_filters: [ '.attn']
|
10 |
+
rank: 4
|
11 |
+
train_tokens:
|
12 |
+
# - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
|
13 |
+
- ['[roubaobao]', '^cat</w>'] # custom token [snail], initialize as avg of snail & snails
|
14 |
+
# TRAINING PARAMS
|
15 |
+
lr: 1.0e-4
|
16 |
+
batch_size: 4
|
17 |
+
|
18 |
+
image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608]
|
19 |
+
multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
|
20 |
+
grad_accum_steps: 2
|
21 |
+
updates: 40000
|
22 |
+
backup_every: 5000
|
23 |
+
save_every: 512
|
24 |
+
warmup_updates: 1
|
25 |
+
use_ddp: True
|
26 |
+
|
27 |
+
# GDF
|
28 |
+
adaptive_loss_weight: True
|
29 |
+
|
30 |
+
|
31 |
+
tmp_prompt: a photo of a cat [roubaobao]
|
32 |
+
webdataset_path: path to your personalized training dataset
|
33 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
34 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
35 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
36 |
+
ultrapixel_path: models/ultrapixel_t2i.safetensors
|
37 |
+
|
configs/training/t2i.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
experiment_id: ultrapixel_t2i
|
3 |
+
#strc_fixlrt_norm3_lite_1024_hrft_newdata
|
4 |
+
checkpoint_path: checkpoint output path #output model directory
|
5 |
+
output_path: visual results output path #experiment output directory
|
6 |
+
model_version: 3.6B # finetune large stage c model of stablecascade
|
7 |
+
dtype: float32
|
8 |
+
|
9 |
+
|
10 |
+
# TRAINING PARAMS
|
11 |
+
lr: 1.0e-4
|
12 |
+
batch_size: 4 # gpu_number * num_per_gpu * grad_accum_steps
|
13 |
+
image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608] # possible image resolution
|
14 |
+
multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
|
15 |
+
grad_accum_steps: 2
|
16 |
+
updates: 40000
|
17 |
+
backup_every: 5000
|
18 |
+
save_every: 256
|
19 |
+
warmup_updates: 1
|
20 |
+
use_ddp: True
|
21 |
+
|
22 |
+
# GDF
|
23 |
+
adaptive_loss_weight: True
|
24 |
+
|
25 |
+
|
26 |
+
webdataset_path: path to your personalized training dataset
|
27 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
28 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
29 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
core/__init__.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import wandb
|
6 |
+
import json
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from torch.utils.data import Dataset, DataLoader
|
10 |
+
|
11 |
+
from torch.distributed import init_process_group, destroy_process_group, barrier
|
12 |
+
from torch.distributed.fsdp import (
|
13 |
+
FullyShardedDataParallel as FSDP,
|
14 |
+
FullStateDictConfig,
|
15 |
+
MixedPrecision,
|
16 |
+
ShardingStrategy,
|
17 |
+
StateDictType
|
18 |
+
)
|
19 |
+
|
20 |
+
from .utils import Base, EXPECTED, EXPECTED_TRAIN
|
21 |
+
from .utils import create_folder_if_necessary, safe_save, load_or_fail
|
22 |
+
|
23 |
+
# pylint: disable=unused-argument
|
24 |
+
class WarpCore(ABC):
|
25 |
+
@dataclass(frozen=True)
|
26 |
+
class Config(Base):
|
27 |
+
experiment_id: str = EXPECTED_TRAIN
|
28 |
+
checkpoint_path: str = EXPECTED_TRAIN
|
29 |
+
output_path: str = EXPECTED_TRAIN
|
30 |
+
checkpoint_extension: str = "safetensors"
|
31 |
+
dist_file_subfolder: str = ""
|
32 |
+
allow_tf32: bool = True
|
33 |
+
|
34 |
+
wandb_project: str = None
|
35 |
+
wandb_entity: str = None
|
36 |
+
|
37 |
+
@dataclass() # not frozen, means that fields are mutable
|
38 |
+
class Info(): # not inheriting from Base, because we don't want to enforce the default fields
|
39 |
+
wandb_run_id: str = None
|
40 |
+
total_steps: int = 0
|
41 |
+
iter: int = 0
|
42 |
+
|
43 |
+
@dataclass(frozen=True)
|
44 |
+
class Data(Base):
|
45 |
+
dataset: Dataset = EXPECTED
|
46 |
+
dataloader: DataLoader = EXPECTED
|
47 |
+
iterator: any = EXPECTED
|
48 |
+
|
49 |
+
@dataclass(frozen=True)
|
50 |
+
class Models(Base):
|
51 |
+
pass
|
52 |
+
|
53 |
+
@dataclass(frozen=True)
|
54 |
+
class Optimizers(Base):
|
55 |
+
pass
|
56 |
+
|
57 |
+
@dataclass(frozen=True)
|
58 |
+
class Schedulers(Base):
|
59 |
+
pass
|
60 |
+
|
61 |
+
@dataclass(frozen=True)
|
62 |
+
class Extras(Base):
|
63 |
+
pass
|
64 |
+
# ---------------------------------------
|
65 |
+
info: Info
|
66 |
+
config: Config
|
67 |
+
|
68 |
+
# FSDP stuff
|
69 |
+
fsdp_defaults = {
|
70 |
+
"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
|
71 |
+
"cpu_offload": None,
|
72 |
+
"mixed_precision": MixedPrecision(
|
73 |
+
param_dtype=torch.bfloat16,
|
74 |
+
reduce_dtype=torch.bfloat16,
|
75 |
+
buffer_dtype=torch.bfloat16,
|
76 |
+
),
|
77 |
+
"limit_all_gathers": True,
|
78 |
+
}
|
79 |
+
fsdp_fullstate_save_policy = FullStateDictConfig(
|
80 |
+
offload_to_cpu=True, rank0_only=True
|
81 |
+
)
|
82 |
+
# ------------
|
83 |
+
|
84 |
+
# OVERRIDEABLE METHODS
|
85 |
+
|
86 |
+
# [optionally] setup extra stuff, will be called BEFORE the models & optimizers are setup
|
87 |
+
def setup_extras_pre(self) -> Extras:
|
88 |
+
return self.Extras()
|
89 |
+
|
90 |
+
# setup dataset & dataloader, return a dict contained dataser, dataloader and/or iterator
|
91 |
+
@abstractmethod
|
92 |
+
def setup_data(self, extras: Extras) -> Data:
|
93 |
+
raise NotImplementedError("This method needs to be overriden")
|
94 |
+
|
95 |
+
# return a dict with all models that are going to be used in the training
|
96 |
+
@abstractmethod
|
97 |
+
def setup_models(self, extras: Extras) -> Models:
|
98 |
+
raise NotImplementedError("This method needs to be overriden")
|
99 |
+
|
100 |
+
# return a dict with all optimizers that are going to be used in the training
|
101 |
+
@abstractmethod
|
102 |
+
def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers:
|
103 |
+
raise NotImplementedError("This method needs to be overriden")
|
104 |
+
|
105 |
+
# [optionally] return a dict with all schedulers that are going to be used in the training
|
106 |
+
def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers:
|
107 |
+
return self.Schedulers()
|
108 |
+
|
109 |
+
# [optionally] setup extra stuff, will be called AFTER the models & optimizers are setup
|
110 |
+
def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers) -> Extras:
|
111 |
+
return self.Extras.from_dict(extras.to_dict())
|
112 |
+
|
113 |
+
# perform the training here
|
114 |
+
@abstractmethod
|
115 |
+
def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
|
116 |
+
raise NotImplementedError("This method needs to be overriden")
|
117 |
+
# ------------
|
118 |
+
|
119 |
+
def setup_info(self, full_path=None) -> Info:
|
120 |
+
if full_path is None:
|
121 |
+
full_path = (f"{self.config.checkpoint_path}/{self.config.experiment_id}/info.json")
|
122 |
+
info_dict = load_or_fail(full_path, wandb_run_id=None) or {}
|
123 |
+
info_dto = self.Info(**info_dict)
|
124 |
+
if info_dto.total_steps > 0 and self.is_main_node:
|
125 |
+
print(">>> RESUMING TRAINING FROM ITER ", info_dto.total_steps)
|
126 |
+
return info_dto
|
127 |
+
|
128 |
+
def setup_config(self, config_file_path=None, config_dict=None, training=True) -> Config:
|
129 |
+
if config_file_path is not None:
|
130 |
+
if config_file_path.endswith(".yml") or config_file_path.endswith(".yaml"):
|
131 |
+
with open(config_file_path, "r", encoding="utf-8") as file:
|
132 |
+
loaded_config = yaml.safe_load(file)
|
133 |
+
elif config_file_path.endswith(".json"):
|
134 |
+
with open(config_file_path, "r", encoding="utf-8") as file:
|
135 |
+
loaded_config = json.load(file)
|
136 |
+
else:
|
137 |
+
raise ValueError("Config file must be either a .yml|.yaml or .json file")
|
138 |
+
return self.Config.from_dict({**loaded_config, 'training': training})
|
139 |
+
if config_dict is not None:
|
140 |
+
return self.Config.from_dict({**config_dict, 'training': training})
|
141 |
+
return self.Config(training=training)
|
142 |
+
|
143 |
+
def setup_ddp(self, experiment_id, single_gpu=False):
|
144 |
+
if not single_gpu:
|
145 |
+
local_rank = int(os.environ.get("SLURM_LOCALID"))
|
146 |
+
process_id = int(os.environ.get("SLURM_PROCID"))
|
147 |
+
world_size = int(os.environ.get("SLURM_NNODES")) * torch.cuda.device_count()
|
148 |
+
|
149 |
+
self.process_id = process_id
|
150 |
+
self.is_main_node = process_id == 0
|
151 |
+
self.device = torch.device(local_rank)
|
152 |
+
self.world_size = world_size
|
153 |
+
|
154 |
+
dist_file_path = f"{os.getcwd()}/{self.config.dist_file_subfolder}dist_file_{experiment_id}"
|
155 |
+
# if os.path.exists(dist_file_path) and self.is_main_node:
|
156 |
+
# os.remove(dist_file_path)
|
157 |
+
|
158 |
+
torch.cuda.set_device(local_rank)
|
159 |
+
init_process_group(
|
160 |
+
backend="nccl",
|
161 |
+
rank=process_id,
|
162 |
+
world_size=world_size,
|
163 |
+
init_method=f"file://{dist_file_path}",
|
164 |
+
)
|
165 |
+
print(f"[GPU {process_id}] READY")
|
166 |
+
else:
|
167 |
+
print("Running in single thread, DDP not enabled.")
|
168 |
+
|
169 |
+
def setup_wandb(self):
|
170 |
+
if self.is_main_node and self.config.wandb_project is not None:
|
171 |
+
self.info.wandb_run_id = self.info.wandb_run_id or wandb.util.generate_id()
|
172 |
+
wandb.init(project=self.config.wandb_project, entity=self.config.wandb_entity, name=self.config.experiment_id, id=self.info.wandb_run_id, resume="allow", config=self.config.to_dict())
|
173 |
+
|
174 |
+
if self.info.total_steps > 0:
|
175 |
+
wandb.alert(title=f"Training {self.info.wandb_run_id} resumed", text=f"Training {self.info.wandb_run_id} resumed from step {self.info.total_steps}")
|
176 |
+
else:
|
177 |
+
wandb.alert(title=f"Training {self.info.wandb_run_id} started", text=f"Training {self.info.wandb_run_id} started")
|
178 |
+
|
179 |
+
# LOAD UTILITIES ----------
|
180 |
+
def load_model(self, model, model_id=None, full_path=None, strict=True):
|
181 |
+
print('in line 181 load model', type(model), model_id, full_path, strict)
|
182 |
+
if model_id is not None and full_path is None:
|
183 |
+
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
|
184 |
+
elif full_path is None and model_id is None:
|
185 |
+
raise ValueError(
|
186 |
+
"This method expects either 'model_id' or 'full_path' to be defined"
|
187 |
+
)
|
188 |
+
|
189 |
+
checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
|
190 |
+
if checkpoint is not None:
|
191 |
+
model.load_state_dict(checkpoint, strict=strict)
|
192 |
+
del checkpoint
|
193 |
+
|
194 |
+
return model
|
195 |
+
|
196 |
+
def load_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
|
197 |
+
if optim_id is not None and full_path is None:
|
198 |
+
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
|
199 |
+
elif full_path is None and optim_id is None:
|
200 |
+
raise ValueError(
|
201 |
+
"This method expects either 'optim_id' or 'full_path' to be defined"
|
202 |
+
)
|
203 |
+
|
204 |
+
checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
|
205 |
+
if checkpoint is not None:
|
206 |
+
try:
|
207 |
+
if fsdp_model is not None:
|
208 |
+
sharded_optimizer_state_dict = (
|
209 |
+
FSDP.scatter_full_optim_state_dict( # <---- FSDP
|
210 |
+
checkpoint
|
211 |
+
if (
|
212 |
+
self.is_main_node
|
213 |
+
or self.fsdp_defaults["sharding_strategy"]
|
214 |
+
== ShardingStrategy.NO_SHARD
|
215 |
+
)
|
216 |
+
else None,
|
217 |
+
fsdp_model,
|
218 |
+
)
|
219 |
+
)
|
220 |
+
optim.load_state_dict(sharded_optimizer_state_dict)
|
221 |
+
del checkpoint, sharded_optimizer_state_dict
|
222 |
+
else:
|
223 |
+
optim.load_state_dict(checkpoint)
|
224 |
+
# pylint: disable=broad-except
|
225 |
+
except Exception as e:
|
226 |
+
print("!!! Failed loading optimizer, skipping... Exception:", e)
|
227 |
+
|
228 |
+
return optim
|
229 |
+
|
230 |
+
# SAVE UTILITIES ----------
|
231 |
+
def save_info(self, info, suffix=""):
|
232 |
+
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/info{suffix}.json"
|
233 |
+
create_folder_if_necessary(full_path)
|
234 |
+
if self.is_main_node:
|
235 |
+
safe_save(vars(self.info), full_path)
|
236 |
+
|
237 |
+
def save_model(self, model, model_id=None, full_path=None, is_fsdp=False):
|
238 |
+
if model_id is not None and full_path is None:
|
239 |
+
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
|
240 |
+
elif full_path is None and model_id is None:
|
241 |
+
raise ValueError(
|
242 |
+
"This method expects either 'model_id' or 'full_path' to be defined"
|
243 |
+
)
|
244 |
+
create_folder_if_necessary(full_path)
|
245 |
+
if is_fsdp:
|
246 |
+
with FSDP.summon_full_params(model):
|
247 |
+
pass
|
248 |
+
with FSDP.state_dict_type(
|
249 |
+
model, StateDictType.FULL_STATE_DICT, self.fsdp_fullstate_save_policy
|
250 |
+
):
|
251 |
+
checkpoint = model.state_dict()
|
252 |
+
if self.is_main_node:
|
253 |
+
safe_save(checkpoint, full_path)
|
254 |
+
del checkpoint
|
255 |
+
else:
|
256 |
+
if self.is_main_node:
|
257 |
+
checkpoint = model.state_dict()
|
258 |
+
safe_save(checkpoint, full_path)
|
259 |
+
del checkpoint
|
260 |
+
|
261 |
+
def save_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
|
262 |
+
if optim_id is not None and full_path is None:
|
263 |
+
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
|
264 |
+
elif full_path is None and optim_id is None:
|
265 |
+
raise ValueError(
|
266 |
+
"This method expects either 'optim_id' or 'full_path' to be defined"
|
267 |
+
)
|
268 |
+
create_folder_if_necessary(full_path)
|
269 |
+
if fsdp_model is not None:
|
270 |
+
optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim)
|
271 |
+
if self.is_main_node:
|
272 |
+
safe_save(optim_statedict, full_path)
|
273 |
+
del optim_statedict
|
274 |
+
else:
|
275 |
+
if self.is_main_node:
|
276 |
+
checkpoint = optim.state_dict()
|
277 |
+
safe_save(checkpoint, full_path)
|
278 |
+
del checkpoint
|
279 |
+
# -----
|
280 |
+
|
281 |
+
def __init__(self, config_file_path=None, config_dict=None, device="cpu", training=True):
|
282 |
+
# Temporary setup, will be overriden by setup_ddp if required
|
283 |
+
self.device = device
|
284 |
+
self.process_id = 0
|
285 |
+
self.is_main_node = True
|
286 |
+
self.world_size = 1
|
287 |
+
# ----
|
288 |
+
|
289 |
+
self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
|
290 |
+
self.info: self.Info = self.setup_info()
|
291 |
+
|
292 |
+
def __call__(self, single_gpu=False):
|
293 |
+
self.setup_ddp(self.config.experiment_id, single_gpu=single_gpu) # this will change the device to the CUDA rank
|
294 |
+
self.setup_wandb()
|
295 |
+
if self.config.allow_tf32:
|
296 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
297 |
+
torch.backends.cudnn.allow_tf32 = True
|
298 |
+
|
299 |
+
if self.is_main_node:
|
300 |
+
print()
|
301 |
+
print("**STARTIG JOB WITH CONFIG:**")
|
302 |
+
print(yaml.dump(self.config.to_dict(), default_flow_style=False))
|
303 |
+
print("------------------------------------")
|
304 |
+
print()
|
305 |
+
print("**INFO:**")
|
306 |
+
print(yaml.dump(vars(self.info), default_flow_style=False))
|
307 |
+
print("------------------------------------")
|
308 |
+
print()
|
309 |
+
|
310 |
+
# SETUP STUFF
|
311 |
+
extras = self.setup_extras_pre()
|
312 |
+
assert extras is not None, "setup_extras_pre() must return a DTO"
|
313 |
+
|
314 |
+
data = self.setup_data(extras)
|
315 |
+
assert data is not None, "setup_data() must return a DTO"
|
316 |
+
if self.is_main_node:
|
317 |
+
print("**DATA:**")
|
318 |
+
print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
|
319 |
+
print("------------------------------------")
|
320 |
+
print()
|
321 |
+
|
322 |
+
models = self.setup_models(extras)
|
323 |
+
assert models is not None, "setup_models() must return a DTO"
|
324 |
+
if self.is_main_node:
|
325 |
+
print("**MODELS:**")
|
326 |
+
print(yaml.dump({
|
327 |
+
k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
|
328 |
+
}, default_flow_style=False))
|
329 |
+
print("------------------------------------")
|
330 |
+
print()
|
331 |
+
|
332 |
+
optimizers = self.setup_optimizers(extras, models)
|
333 |
+
assert optimizers is not None, "setup_optimizers() must return a DTO"
|
334 |
+
if self.is_main_node:
|
335 |
+
print("**OPTIMIZERS:**")
|
336 |
+
print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
|
337 |
+
print("------------------------------------")
|
338 |
+
print()
|
339 |
+
|
340 |
+
schedulers = self.setup_schedulers(extras, models, optimizers)
|
341 |
+
assert schedulers is not None, "setup_schedulers() must return a DTO"
|
342 |
+
if self.is_main_node:
|
343 |
+
print("**SCHEDULERS:**")
|
344 |
+
print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
|
345 |
+
print("------------------------------------")
|
346 |
+
print()
|
347 |
+
|
348 |
+
post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
|
349 |
+
assert post_extras is not None, "setup_extras_post() must return a DTO"
|
350 |
+
extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
|
351 |
+
if self.is_main_node:
|
352 |
+
print("**EXTRAS:**")
|
353 |
+
print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
|
354 |
+
print("------------------------------------")
|
355 |
+
print()
|
356 |
+
# -------
|
357 |
+
|
358 |
+
# TRAIN
|
359 |
+
if self.is_main_node:
|
360 |
+
print("**TRAINING STARTING...**")
|
361 |
+
self.train(data, extras, models, optimizers, schedulers)
|
362 |
+
|
363 |
+
if single_gpu is False:
|
364 |
+
barrier()
|
365 |
+
destroy_process_group()
|
366 |
+
if self.is_main_node:
|
367 |
+
print()
|
368 |
+
print("------------------------------------")
|
369 |
+
print()
|
370 |
+
print("**TRAINING COMPLETE**")
|
371 |
+
if self.config.wandb_project is not None:
|
372 |
+
wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished")
|
core/data/__init__.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import subprocess
|
3 |
+
import yaml
|
4 |
+
import os
|
5 |
+
from .bucketeer import Bucketeer
|
6 |
+
|
7 |
+
class MultiFilter():
|
8 |
+
def __init__(self, rules, default=False):
|
9 |
+
self.rules = rules
|
10 |
+
self.default = default
|
11 |
+
|
12 |
+
def __call__(self, x):
|
13 |
+
try:
|
14 |
+
x_json = x['json']
|
15 |
+
if isinstance(x_json, bytes):
|
16 |
+
x_json = json.loads(x_json)
|
17 |
+
validations = []
|
18 |
+
for k, r in self.rules.items():
|
19 |
+
if isinstance(k, tuple):
|
20 |
+
v = r(*[x_json[kv] for kv in k])
|
21 |
+
else:
|
22 |
+
v = r(x_json[k])
|
23 |
+
validations.append(v)
|
24 |
+
return all(validations)
|
25 |
+
except Exception:
|
26 |
+
return False
|
27 |
+
|
28 |
+
class MultiGetter():
|
29 |
+
def __init__(self, rules):
|
30 |
+
self.rules = rules
|
31 |
+
|
32 |
+
def __call__(self, x_json):
|
33 |
+
if isinstance(x_json, bytes):
|
34 |
+
x_json = json.loads(x_json)
|
35 |
+
outputs = []
|
36 |
+
for k, r in self.rules.items():
|
37 |
+
if isinstance(k, tuple):
|
38 |
+
v = r(*[x_json[kv] for kv in k])
|
39 |
+
else:
|
40 |
+
v = r(x_json[k])
|
41 |
+
outputs.append(v)
|
42 |
+
if len(outputs) == 1:
|
43 |
+
outputs = outputs[0]
|
44 |
+
return outputs
|
45 |
+
|
46 |
+
def setup_webdataset_path(paths, cache_path=None):
|
47 |
+
if cache_path is None or not os.path.exists(cache_path):
|
48 |
+
tar_paths = []
|
49 |
+
if isinstance(paths, str):
|
50 |
+
paths = [paths]
|
51 |
+
for path in paths:
|
52 |
+
if path.strip().endswith(".tar"):
|
53 |
+
# Avoid looking up s3 if we already have a tar file
|
54 |
+
tar_paths.append(path)
|
55 |
+
continue
|
56 |
+
bucket = "/".join(path.split("/")[:3])
|
57 |
+
result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True)
|
58 |
+
files = result.stdout.decode('utf-8').split()
|
59 |
+
files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")]
|
60 |
+
tar_paths += files
|
61 |
+
|
62 |
+
with open(cache_path, 'w', encoding='utf-8') as outfile:
|
63 |
+
yaml.dump(tar_paths, outfile, default_flow_style=False)
|
64 |
+
else:
|
65 |
+
with open(cache_path, 'r', encoding='utf-8') as file:
|
66 |
+
tar_paths = yaml.safe_load(file)
|
67 |
+
|
68 |
+
tar_paths_str = ",".join([f"{p}" for p in tar_paths])
|
69 |
+
return f"pipe:aws s3 cp {{ {tar_paths_str} }} -"
|
core/data/bucketeer.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
import numpy as np
|
4 |
+
from torchtools.transforms import SmartCrop
|
5 |
+
import math
|
6 |
+
|
7 |
+
class Bucketeer():
|
8 |
+
def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
|
9 |
+
assert crop_mode in ['center', 'random', 'smart']
|
10 |
+
self.crop_mode = crop_mode
|
11 |
+
self.ratios = ratios
|
12 |
+
if reverse_list:
|
13 |
+
for r in list(ratios):
|
14 |
+
if 1/r not in self.ratios:
|
15 |
+
self.ratios.append(1/r)
|
16 |
+
self.sizes = {}
|
17 |
+
for dd in density:
|
18 |
+
self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios]
|
19 |
+
|
20 |
+
self.batch_size = dataloader.batch_size
|
21 |
+
self.iterator = iter(dataloader)
|
22 |
+
all_sizes = []
|
23 |
+
for k, vs in self.sizes.items():
|
24 |
+
all_sizes += vs
|
25 |
+
self.buckets = {s: [] for s in all_sizes}
|
26 |
+
self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
|
27 |
+
self.p_random_ratio = p_random_ratio
|
28 |
+
self.interpolate_nearest = interpolate_nearest
|
29 |
+
|
30 |
+
def get_available_batch(self):
|
31 |
+
for b in self.buckets:
|
32 |
+
if len(self.buckets[b]) >= self.batch_size:
|
33 |
+
batch = self.buckets[b][:self.batch_size]
|
34 |
+
self.buckets[b] = self.buckets[b][self.batch_size:]
|
35 |
+
return batch
|
36 |
+
return None
|
37 |
+
|
38 |
+
def get_closest_size(self, x):
|
39 |
+
w, h = x.size(-1), x.size(-2)
|
40 |
+
|
41 |
+
|
42 |
+
best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
|
43 |
+
find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()}
|
44 |
+
min_ = find_dict[list(find_dict.keys())[0]]
|
45 |
+
find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx]
|
46 |
+
for dd, val in find_dict.items():
|
47 |
+
if val < min_:
|
48 |
+
min_ = val
|
49 |
+
find_size = self.sizes[dd][best_size_idx]
|
50 |
+
|
51 |
+
return find_size
|
52 |
+
|
53 |
+
def get_resize_size(self, orig_size, tgt_size):
|
54 |
+
if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
|
55 |
+
alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
|
56 |
+
resize_size = max(alt_min, min(tgt_size))
|
57 |
+
else:
|
58 |
+
alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
|
59 |
+
resize_size = max(alt_max, max(tgt_size))
|
60 |
+
|
61 |
+
return resize_size
|
62 |
+
|
63 |
+
def __next__(self):
|
64 |
+
batch = self.get_available_batch()
|
65 |
+
while batch is None:
|
66 |
+
elements = next(self.iterator)
|
67 |
+
for dct in elements:
|
68 |
+
img = dct['images']
|
69 |
+
size = self.get_closest_size(img)
|
70 |
+
resize_size = self.get_resize_size(img.shape[-2:], size)
|
71 |
+
|
72 |
+
if self.interpolate_nearest:
|
73 |
+
img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
|
74 |
+
else:
|
75 |
+
img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
|
76 |
+
if self.crop_mode == 'center':
|
77 |
+
img = torchvision.transforms.functional.center_crop(img, size)
|
78 |
+
elif self.crop_mode == 'random':
|
79 |
+
img = torchvision.transforms.RandomCrop(size)(img)
|
80 |
+
elif self.crop_mode == 'smart':
|
81 |
+
self.smartcrop.output_size = size
|
82 |
+
img = self.smartcrop(img)
|
83 |
+
|
84 |
+
self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
|
85 |
+
batch = self.get_available_batch()
|
86 |
+
|
87 |
+
out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
|
88 |
+
return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
|
core/data/bucketeer_deg.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
import numpy as np
|
4 |
+
from torchtools.transforms import SmartCrop
|
5 |
+
import math
|
6 |
+
|
7 |
+
class Bucketeer():
|
8 |
+
def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
|
9 |
+
assert crop_mode in ['center', 'random', 'smart']
|
10 |
+
self.crop_mode = crop_mode
|
11 |
+
self.ratios = ratios
|
12 |
+
if reverse_list:
|
13 |
+
for r in list(ratios):
|
14 |
+
if 1/r not in self.ratios:
|
15 |
+
self.ratios.append(1/r)
|
16 |
+
self.sizes = {}
|
17 |
+
for dd in density:
|
18 |
+
self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios]
|
19 |
+
print('in line 17 buckteer', self.sizes)
|
20 |
+
self.batch_size = dataloader.batch_size
|
21 |
+
self.iterator = iter(dataloader)
|
22 |
+
all_sizes = []
|
23 |
+
for k, vs in self.sizes.items():
|
24 |
+
all_sizes += vs
|
25 |
+
self.buckets = {s: [] for s in all_sizes}
|
26 |
+
self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
|
27 |
+
self.p_random_ratio = p_random_ratio
|
28 |
+
self.interpolate_nearest = interpolate_nearest
|
29 |
+
|
30 |
+
def get_available_batch(self):
|
31 |
+
for b in self.buckets:
|
32 |
+
if len(self.buckets[b]) >= self.batch_size:
|
33 |
+
batch = self.buckets[b][:self.batch_size]
|
34 |
+
self.buckets[b] = self.buckets[b][self.batch_size:]
|
35 |
+
return batch
|
36 |
+
return None
|
37 |
+
|
38 |
+
def get_closest_size(self, x):
|
39 |
+
w, h = x.size(-1), x.size(-2)
|
40 |
+
#if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio:
|
41 |
+
# best_size_idx = np.random.randint(len(self.ratios))
|
42 |
+
#print('in line 41 get closes size', best_size_idx, x.shape, self.p_random_ratio)
|
43 |
+
#else:
|
44 |
+
|
45 |
+
best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
|
46 |
+
find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()}
|
47 |
+
min_ = find_dict[list(find_dict.keys())[0]]
|
48 |
+
find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx]
|
49 |
+
for dd, val in find_dict.items():
|
50 |
+
if val < min_:
|
51 |
+
min_ = val
|
52 |
+
find_size = self.sizes[dd][best_size_idx]
|
53 |
+
|
54 |
+
return find_size
|
55 |
+
|
56 |
+
def get_resize_size(self, orig_size, tgt_size):
|
57 |
+
if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
|
58 |
+
alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
|
59 |
+
resize_size = max(alt_min, min(tgt_size))
|
60 |
+
else:
|
61 |
+
alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
|
62 |
+
resize_size = max(alt_max, max(tgt_size))
|
63 |
+
#print('in line 50', orig_size, tgt_size, resize_size)
|
64 |
+
return resize_size
|
65 |
+
|
66 |
+
def __next__(self):
|
67 |
+
batch = self.get_available_batch()
|
68 |
+
while batch is None:
|
69 |
+
elements = next(self.iterator)
|
70 |
+
for dct in elements:
|
71 |
+
img = dct['images']
|
72 |
+
size = self.get_closest_size(img)
|
73 |
+
resize_size = self.get_resize_size(img.shape[-2:], size)
|
74 |
+
#print('in line 74', img.size(), resize_size)
|
75 |
+
if self.interpolate_nearest:
|
76 |
+
img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
|
77 |
+
else:
|
78 |
+
img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
|
79 |
+
if self.crop_mode == 'center':
|
80 |
+
img = torchvision.transforms.functional.center_crop(img, size)
|
81 |
+
elif self.crop_mode == 'random':
|
82 |
+
img = torchvision.transforms.RandomCrop(size)(img)
|
83 |
+
elif self.crop_mode == 'smart':
|
84 |
+
self.smartcrop.output_size = size
|
85 |
+
img = self.smartcrop(img)
|
86 |
+
print('in line 86 bucketeer', type(img), img.shape, torch.max(img), torch.min(img))
|
87 |
+
self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
|
88 |
+
batch = self.get_available_batch()
|
89 |
+
|
90 |
+
out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
|
91 |
+
return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
|
core/data/deg_kair_utils/utils_alignfaces.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Mon Apr 24 15:43:29 2017
|
4 |
+
@author: zhaoy
|
5 |
+
"""
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
from skimage import transform as trans
|
9 |
+
|
10 |
+
# reference facial points, a list of coordinates (x,y)
|
11 |
+
REFERENCE_FACIAL_POINTS = [
|
12 |
+
[30.29459953, 51.69630051],
|
13 |
+
[65.53179932, 51.50139999],
|
14 |
+
[48.02519989, 71.73660278],
|
15 |
+
[33.54930115, 92.3655014],
|
16 |
+
[62.72990036, 92.20410156]
|
17 |
+
]
|
18 |
+
|
19 |
+
DEFAULT_CROP_SIZE = (96, 112)
|
20 |
+
|
21 |
+
|
22 |
+
def _umeyama(src, dst, estimate_scale=True, scale=1.0):
|
23 |
+
"""Estimate N-D similarity transformation with or without scaling.
|
24 |
+
Parameters
|
25 |
+
----------
|
26 |
+
src : (M, N) array
|
27 |
+
Source coordinates.
|
28 |
+
dst : (M, N) array
|
29 |
+
Destination coordinates.
|
30 |
+
estimate_scale : bool
|
31 |
+
Whether to estimate scaling factor.
|
32 |
+
Returns
|
33 |
+
-------
|
34 |
+
T : (N + 1, N + 1)
|
35 |
+
The homogeneous similarity transformation matrix. The matrix contains
|
36 |
+
NaN values only if the problem is not well-conditioned.
|
37 |
+
References
|
38 |
+
----------
|
39 |
+
.. [1] "Least-squares estimation of transformation parameters between two
|
40 |
+
point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573`
|
41 |
+
"""
|
42 |
+
|
43 |
+
num = src.shape[0]
|
44 |
+
dim = src.shape[1]
|
45 |
+
|
46 |
+
# Compute mean of src and dst.
|
47 |
+
src_mean = src.mean(axis=0)
|
48 |
+
dst_mean = dst.mean(axis=0)
|
49 |
+
|
50 |
+
# Subtract mean from src and dst.
|
51 |
+
src_demean = src - src_mean
|
52 |
+
dst_demean = dst - dst_mean
|
53 |
+
|
54 |
+
# Eq. (38).
|
55 |
+
A = dst_demean.T @ src_demean / num
|
56 |
+
|
57 |
+
# Eq. (39).
|
58 |
+
d = np.ones((dim,), dtype=np.double)
|
59 |
+
if np.linalg.det(A) < 0:
|
60 |
+
d[dim - 1] = -1
|
61 |
+
|
62 |
+
T = np.eye(dim + 1, dtype=np.double)
|
63 |
+
|
64 |
+
U, S, V = np.linalg.svd(A)
|
65 |
+
|
66 |
+
# Eq. (40) and (43).
|
67 |
+
rank = np.linalg.matrix_rank(A)
|
68 |
+
if rank == 0:
|
69 |
+
return np.nan * T
|
70 |
+
elif rank == dim - 1:
|
71 |
+
if np.linalg.det(U) * np.linalg.det(V) > 0:
|
72 |
+
T[:dim, :dim] = U @ V
|
73 |
+
else:
|
74 |
+
s = d[dim - 1]
|
75 |
+
d[dim - 1] = -1
|
76 |
+
T[:dim, :dim] = U @ np.diag(d) @ V
|
77 |
+
d[dim - 1] = s
|
78 |
+
else:
|
79 |
+
T[:dim, :dim] = U @ np.diag(d) @ V
|
80 |
+
|
81 |
+
if estimate_scale:
|
82 |
+
# Eq. (41) and (42).
|
83 |
+
scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d)
|
84 |
+
else:
|
85 |
+
scale = scale
|
86 |
+
|
87 |
+
T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T)
|
88 |
+
T[:dim, :dim] *= scale
|
89 |
+
|
90 |
+
return T, scale
|
91 |
+
|
92 |
+
|
93 |
+
class FaceWarpException(Exception):
|
94 |
+
def __str__(self):
|
95 |
+
return 'In File {}:{}'.format(
|
96 |
+
__file__, super.__str__(self))
|
97 |
+
|
98 |
+
|
99 |
+
def get_reference_facial_points(output_size=None,
|
100 |
+
inner_padding_factor=0.0,
|
101 |
+
outer_padding=(0, 0),
|
102 |
+
default_square=False):
|
103 |
+
tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
|
104 |
+
tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
|
105 |
+
|
106 |
+
# 0) make the inner region a square
|
107 |
+
if default_square:
|
108 |
+
size_diff = max(tmp_crop_size) - tmp_crop_size
|
109 |
+
tmp_5pts += size_diff / 2
|
110 |
+
tmp_crop_size += size_diff
|
111 |
+
|
112 |
+
if (output_size and
|
113 |
+
output_size[0] == tmp_crop_size[0] and
|
114 |
+
output_size[1] == tmp_crop_size[1]):
|
115 |
+
print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
|
116 |
+
return tmp_5pts
|
117 |
+
|
118 |
+
if (inner_padding_factor == 0 and
|
119 |
+
outer_padding == (0, 0)):
|
120 |
+
if output_size is None:
|
121 |
+
print('No paddings to do: return default reference points')
|
122 |
+
return tmp_5pts
|
123 |
+
else:
|
124 |
+
raise FaceWarpException(
|
125 |
+
'No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
|
126 |
+
|
127 |
+
# check output size
|
128 |
+
if not (0 <= inner_padding_factor <= 1.0):
|
129 |
+
raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
|
130 |
+
|
131 |
+
if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0)
|
132 |
+
and output_size is None):
|
133 |
+
output_size = tmp_crop_size * \
|
134 |
+
(1 + inner_padding_factor * 2).astype(np.int32)
|
135 |
+
output_size += np.array(outer_padding)
|
136 |
+
print(' deduced from paddings, output_size = ', output_size)
|
137 |
+
|
138 |
+
if not (outer_padding[0] < output_size[0]
|
139 |
+
and outer_padding[1] < output_size[1]):
|
140 |
+
raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
|
141 |
+
'and outer_padding[1] < output_size[1])')
|
142 |
+
|
143 |
+
# 1) pad the inner region according inner_padding_factor
|
144 |
+
# print('---> STEP1: pad the inner region according inner_padding_factor')
|
145 |
+
if inner_padding_factor > 0:
|
146 |
+
size_diff = tmp_crop_size * inner_padding_factor * 2
|
147 |
+
tmp_5pts += size_diff / 2
|
148 |
+
tmp_crop_size += np.round(size_diff).astype(np.int32)
|
149 |
+
|
150 |
+
# print(' crop_size = ', tmp_crop_size)
|
151 |
+
# print(' reference_5pts = ', tmp_5pts)
|
152 |
+
|
153 |
+
# 2) resize the padded inner region
|
154 |
+
# print('---> STEP2: resize the padded inner region')
|
155 |
+
size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
|
156 |
+
# print(' crop_size = ', tmp_crop_size)
|
157 |
+
# print(' size_bf_outer_pad = ', size_bf_outer_pad)
|
158 |
+
|
159 |
+
if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
|
160 |
+
raise FaceWarpException('Must have (output_size - outer_padding)'
|
161 |
+
'= some_scale * (crop_size * (1.0 + inner_padding_factor)')
|
162 |
+
|
163 |
+
scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
|
164 |
+
# print(' resize scale_factor = ', scale_factor)
|
165 |
+
tmp_5pts = tmp_5pts * scale_factor
|
166 |
+
# size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
|
167 |
+
# tmp_5pts = tmp_5pts + size_diff / 2
|
168 |
+
tmp_crop_size = size_bf_outer_pad
|
169 |
+
# print(' crop_size = ', tmp_crop_size)
|
170 |
+
# print(' reference_5pts = ', tmp_5pts)
|
171 |
+
|
172 |
+
# 3) add outer_padding to make output_size
|
173 |
+
reference_5point = tmp_5pts + np.array(outer_padding)
|
174 |
+
tmp_crop_size = output_size
|
175 |
+
# print('---> STEP3: add outer_padding to make output_size')
|
176 |
+
# print(' crop_size = ', tmp_crop_size)
|
177 |
+
# print(' reference_5pts = ', tmp_5pts)
|
178 |
+
#
|
179 |
+
# print('===> end get_reference_facial_points\n')
|
180 |
+
|
181 |
+
return reference_5point
|
182 |
+
|
183 |
+
|
184 |
+
def get_affine_transform_matrix(src_pts, dst_pts):
|
185 |
+
tfm = np.float32([[1, 0, 0], [0, 1, 0]])
|
186 |
+
n_pts = src_pts.shape[0]
|
187 |
+
ones = np.ones((n_pts, 1), src_pts.dtype)
|
188 |
+
src_pts_ = np.hstack([src_pts, ones])
|
189 |
+
dst_pts_ = np.hstack([dst_pts, ones])
|
190 |
+
|
191 |
+
A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
|
192 |
+
|
193 |
+
if rank == 3:
|
194 |
+
tfm = np.float32([
|
195 |
+
[A[0, 0], A[1, 0], A[2, 0]],
|
196 |
+
[A[0, 1], A[1, 1], A[2, 1]]
|
197 |
+
])
|
198 |
+
elif rank == 2:
|
199 |
+
tfm = np.float32([
|
200 |
+
[A[0, 0], A[1, 0], 0],
|
201 |
+
[A[0, 1], A[1, 1], 0]
|
202 |
+
])
|
203 |
+
|
204 |
+
return tfm
|
205 |
+
|
206 |
+
|
207 |
+
def warp_and_crop_face(src_img,
|
208 |
+
facial_pts,
|
209 |
+
reference_pts=None,
|
210 |
+
crop_size=(96, 112),
|
211 |
+
align_type='smilarity'): #smilarity cv2_affine affine
|
212 |
+
if reference_pts is None:
|
213 |
+
if crop_size[0] == 96 and crop_size[1] == 112:
|
214 |
+
reference_pts = REFERENCE_FACIAL_POINTS
|
215 |
+
else:
|
216 |
+
default_square = False
|
217 |
+
inner_padding_factor = 0
|
218 |
+
outer_padding = (0, 0)
|
219 |
+
output_size = crop_size
|
220 |
+
|
221 |
+
reference_pts = get_reference_facial_points(output_size,
|
222 |
+
inner_padding_factor,
|
223 |
+
outer_padding,
|
224 |
+
default_square)
|
225 |
+
|
226 |
+
ref_pts = np.float32(reference_pts)
|
227 |
+
ref_pts_shp = ref_pts.shape
|
228 |
+
if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
|
229 |
+
raise FaceWarpException(
|
230 |
+
'reference_pts.shape must be (K,2) or (2,K) and K>2')
|
231 |
+
|
232 |
+
if ref_pts_shp[0] == 2:
|
233 |
+
ref_pts = ref_pts.T
|
234 |
+
|
235 |
+
src_pts = np.float32(facial_pts)
|
236 |
+
src_pts_shp = src_pts.shape
|
237 |
+
if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
|
238 |
+
raise FaceWarpException(
|
239 |
+
'facial_pts.shape must be (K,2) or (2,K) and K>2')
|
240 |
+
|
241 |
+
if src_pts_shp[0] == 2:
|
242 |
+
src_pts = src_pts.T
|
243 |
+
|
244 |
+
if src_pts.shape != ref_pts.shape:
|
245 |
+
raise FaceWarpException(
|
246 |
+
'facial_pts and reference_pts must have the same shape')
|
247 |
+
|
248 |
+
if align_type is 'cv2_affine':
|
249 |
+
tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
|
250 |
+
tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3])
|
251 |
+
elif align_type is 'affine':
|
252 |
+
tfm = get_affine_transform_matrix(src_pts, ref_pts)
|
253 |
+
tfm_inv = get_affine_transform_matrix(ref_pts, src_pts)
|
254 |
+
else:
|
255 |
+
params, scale = _umeyama(src_pts, ref_pts)
|
256 |
+
tfm = params[:2, :]
|
257 |
+
|
258 |
+
params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0/scale)
|
259 |
+
tfm_inv = params[:2, :]
|
260 |
+
|
261 |
+
face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3)
|
262 |
+
|
263 |
+
return face_img, tfm_inv
|
core/data/deg_kair_utils/utils_blindsr.py
ADDED
@@ -0,0 +1,631 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from core.data.deg_kair_utils import utils_image as util
|
7 |
+
|
8 |
+
import random
|
9 |
+
from scipy import ndimage
|
10 |
+
import scipy
|
11 |
+
import scipy.stats as ss
|
12 |
+
from scipy.interpolate import interp2d
|
13 |
+
from scipy.linalg import orth
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
"""
|
19 |
+
# --------------------------------------------
|
20 |
+
# Super-Resolution
|
21 |
+
# --------------------------------------------
|
22 |
+
#
|
23 |
+
# Kai Zhang ([email protected])
|
24 |
+
# https://github.com/cszn
|
25 |
+
# From 2019/03--2021/08
|
26 |
+
# --------------------------------------------
|
27 |
+
"""
|
28 |
+
|
29 |
+
def modcrop_np(img, sf):
|
30 |
+
'''
|
31 |
+
Args:
|
32 |
+
img: numpy image, WxH or WxHxC
|
33 |
+
sf: scale factor
|
34 |
+
|
35 |
+
Return:
|
36 |
+
cropped image
|
37 |
+
'''
|
38 |
+
w, h = img.shape[:2]
|
39 |
+
im = np.copy(img)
|
40 |
+
return im[:w - w % sf, :h - h % sf, ...]
|
41 |
+
|
42 |
+
|
43 |
+
"""
|
44 |
+
# --------------------------------------------
|
45 |
+
# anisotropic Gaussian kernels
|
46 |
+
# --------------------------------------------
|
47 |
+
"""
|
48 |
+
def analytic_kernel(k):
|
49 |
+
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
50 |
+
k_size = k.shape[0]
|
51 |
+
# Calculate the big kernels size
|
52 |
+
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
53 |
+
# Loop over the small kernel to fill the big one
|
54 |
+
for r in range(k_size):
|
55 |
+
for c in range(k_size):
|
56 |
+
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
|
57 |
+
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
58 |
+
crop = k_size // 2
|
59 |
+
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
60 |
+
# Normalize to 1
|
61 |
+
return cropped_big_k / cropped_big_k.sum()
|
62 |
+
|
63 |
+
|
64 |
+
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
65 |
+
""" generate an anisotropic Gaussian kernel
|
66 |
+
Args:
|
67 |
+
ksize : e.g., 15, kernel size
|
68 |
+
theta : [0, pi], rotation angle range
|
69 |
+
l1 : [0.1,50], scaling of eigenvalues
|
70 |
+
l2 : [0.1,l1], scaling of eigenvalues
|
71 |
+
If l1 = l2, will get an isotropic Gaussian kernel.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
k : kernel
|
75 |
+
"""
|
76 |
+
|
77 |
+
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
|
78 |
+
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
79 |
+
D = np.array([[l1, 0], [0, l2]])
|
80 |
+
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
81 |
+
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
82 |
+
|
83 |
+
return k
|
84 |
+
|
85 |
+
|
86 |
+
def gm_blur_kernel(mean, cov, size=15):
|
87 |
+
center = size / 2.0 + 0.5
|
88 |
+
k = np.zeros([size, size])
|
89 |
+
for y in range(size):
|
90 |
+
for x in range(size):
|
91 |
+
cy = y - center + 1
|
92 |
+
cx = x - center + 1
|
93 |
+
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
94 |
+
|
95 |
+
k = k / np.sum(k)
|
96 |
+
return k
|
97 |
+
|
98 |
+
|
99 |
+
def shift_pixel(x, sf, upper_left=True):
|
100 |
+
"""shift pixel for super-resolution with different scale factors
|
101 |
+
Args:
|
102 |
+
x: WxHxC or WxH
|
103 |
+
sf: scale factor
|
104 |
+
upper_left: shift direction
|
105 |
+
"""
|
106 |
+
h, w = x.shape[:2]
|
107 |
+
shift = (sf-1)*0.5
|
108 |
+
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
109 |
+
if upper_left:
|
110 |
+
x1 = xv + shift
|
111 |
+
y1 = yv + shift
|
112 |
+
else:
|
113 |
+
x1 = xv - shift
|
114 |
+
y1 = yv - shift
|
115 |
+
|
116 |
+
x1 = np.clip(x1, 0, w-1)
|
117 |
+
y1 = np.clip(y1, 0, h-1)
|
118 |
+
|
119 |
+
if x.ndim == 2:
|
120 |
+
x = interp2d(xv, yv, x)(x1, y1)
|
121 |
+
if x.ndim == 3:
|
122 |
+
for i in range(x.shape[-1]):
|
123 |
+
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
124 |
+
|
125 |
+
return x
|
126 |
+
|
127 |
+
|
128 |
+
def blur(x, k):
|
129 |
+
'''
|
130 |
+
x: image, NxcxHxW
|
131 |
+
k: kernel, Nx1xhxw
|
132 |
+
'''
|
133 |
+
n, c = x.shape[:2]
|
134 |
+
p1, p2 = (k.shape[-2]-1)//2, (k.shape[-1]-1)//2
|
135 |
+
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
|
136 |
+
k = k.repeat(1,c,1,1)
|
137 |
+
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
138 |
+
x = x.view(1, -1, x.shape[2], x.shape[3])
|
139 |
+
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n*c)
|
140 |
+
x = x.view(n, c, x.shape[2], x.shape[3])
|
141 |
+
|
142 |
+
return x
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
|
147 |
+
""""
|
148 |
+
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
149 |
+
# Kai Zhang
|
150 |
+
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
151 |
+
# max_var = 2.5 * sf
|
152 |
+
"""
|
153 |
+
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
154 |
+
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
155 |
+
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
156 |
+
theta = np.random.rand() * np.pi # random theta
|
157 |
+
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
158 |
+
|
159 |
+
# Set COV matrix using Lambdas and Theta
|
160 |
+
LAMBDA = np.diag([lambda_1, lambda_2])
|
161 |
+
Q = np.array([[np.cos(theta), -np.sin(theta)],
|
162 |
+
[np.sin(theta), np.cos(theta)]])
|
163 |
+
SIGMA = Q @ LAMBDA @ Q.T
|
164 |
+
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
165 |
+
|
166 |
+
# Set expectation position (shifting kernel for aligned image)
|
167 |
+
MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
168 |
+
MU = MU[None, None, :, None]
|
169 |
+
|
170 |
+
# Create meshgrid for Gaussian
|
171 |
+
[X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
172 |
+
Z = np.stack([X, Y], 2)[:, :, :, None]
|
173 |
+
|
174 |
+
# Calcualte Gaussian for every pixel of the kernel
|
175 |
+
ZZ = Z-MU
|
176 |
+
ZZ_t = ZZ.transpose(0,1,3,2)
|
177 |
+
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
178 |
+
|
179 |
+
# shift the kernel so it will be centered
|
180 |
+
#raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
181 |
+
|
182 |
+
# Normalize the kernel and return
|
183 |
+
#kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
184 |
+
kernel = raw_kernel / np.sum(raw_kernel)
|
185 |
+
return kernel
|
186 |
+
|
187 |
+
|
188 |
+
def fspecial_gaussian(hsize, sigma):
|
189 |
+
hsize = [hsize, hsize]
|
190 |
+
siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0]
|
191 |
+
std = sigma
|
192 |
+
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1))
|
193 |
+
arg = -(x*x + y*y)/(2*std*std)
|
194 |
+
h = np.exp(arg)
|
195 |
+
h[h < scipy.finfo(float).eps * h.max()] = 0
|
196 |
+
sumh = h.sum()
|
197 |
+
if sumh != 0:
|
198 |
+
h = h/sumh
|
199 |
+
return h
|
200 |
+
|
201 |
+
|
202 |
+
def fspecial_laplacian(alpha):
|
203 |
+
alpha = max([0, min([alpha,1])])
|
204 |
+
h1 = alpha/(alpha+1)
|
205 |
+
h2 = (1-alpha)/(alpha+1)
|
206 |
+
h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]]
|
207 |
+
h = np.array(h)
|
208 |
+
return h
|
209 |
+
|
210 |
+
|
211 |
+
def fspecial(filter_type, *args, **kwargs):
|
212 |
+
'''
|
213 |
+
python code from:
|
214 |
+
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
215 |
+
'''
|
216 |
+
if filter_type == 'gaussian':
|
217 |
+
return fspecial_gaussian(*args, **kwargs)
|
218 |
+
if filter_type == 'laplacian':
|
219 |
+
return fspecial_laplacian(*args, **kwargs)
|
220 |
+
|
221 |
+
"""
|
222 |
+
# --------------------------------------------
|
223 |
+
# degradation models
|
224 |
+
# --------------------------------------------
|
225 |
+
"""
|
226 |
+
|
227 |
+
|
228 |
+
def bicubic_degradation(x, sf=3):
|
229 |
+
'''
|
230 |
+
Args:
|
231 |
+
x: HxWxC image, [0, 1]
|
232 |
+
sf: down-scale factor
|
233 |
+
|
234 |
+
Return:
|
235 |
+
bicubicly downsampled LR image
|
236 |
+
'''
|
237 |
+
x = util.imresize_np(x, scale=1/sf)
|
238 |
+
return x
|
239 |
+
|
240 |
+
|
241 |
+
def srmd_degradation(x, k, sf=3):
|
242 |
+
''' blur + bicubic downsampling
|
243 |
+
|
244 |
+
Args:
|
245 |
+
x: HxWxC image, [0, 1]
|
246 |
+
k: hxw, double
|
247 |
+
sf: down-scale factor
|
248 |
+
|
249 |
+
Return:
|
250 |
+
downsampled LR image
|
251 |
+
|
252 |
+
Reference:
|
253 |
+
@inproceedings{zhang2018learning,
|
254 |
+
title={Learning a single convolutional super-resolution network for multiple degradations},
|
255 |
+
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
256 |
+
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
257 |
+
pages={3262--3271},
|
258 |
+
year={2018}
|
259 |
+
}
|
260 |
+
'''
|
261 |
+
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
|
262 |
+
x = bicubic_degradation(x, sf=sf)
|
263 |
+
return x
|
264 |
+
|
265 |
+
|
266 |
+
def dpsr_degradation(x, k, sf=3):
|
267 |
+
|
268 |
+
''' bicubic downsampling + blur
|
269 |
+
|
270 |
+
Args:
|
271 |
+
x: HxWxC image, [0, 1]
|
272 |
+
k: hxw, double
|
273 |
+
sf: down-scale factor
|
274 |
+
|
275 |
+
Return:
|
276 |
+
downsampled LR image
|
277 |
+
|
278 |
+
Reference:
|
279 |
+
@inproceedings{zhang2019deep,
|
280 |
+
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
281 |
+
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
282 |
+
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
283 |
+
pages={1671--1681},
|
284 |
+
year={2019}
|
285 |
+
}
|
286 |
+
'''
|
287 |
+
x = bicubic_degradation(x, sf=sf)
|
288 |
+
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
289 |
+
return x
|
290 |
+
|
291 |
+
|
292 |
+
def classical_degradation(x, k, sf=3):
|
293 |
+
''' blur + downsampling
|
294 |
+
|
295 |
+
Args:
|
296 |
+
x: HxWxC image, [0, 1]/[0, 255]
|
297 |
+
k: hxw, double
|
298 |
+
sf: down-scale factor
|
299 |
+
|
300 |
+
Return:
|
301 |
+
downsampled LR image
|
302 |
+
'''
|
303 |
+
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
304 |
+
#x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
305 |
+
st = 0
|
306 |
+
return x[st::sf, st::sf, ...]
|
307 |
+
|
308 |
+
|
309 |
+
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
310 |
+
"""USM sharpening. borrowed from real-ESRGAN
|
311 |
+
Input image: I; Blurry image: B.
|
312 |
+
1. K = I + weight * (I - B)
|
313 |
+
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
314 |
+
3. Blur mask:
|
315 |
+
4. Out = Mask * K + (1 - Mask) * I
|
316 |
+
Args:
|
317 |
+
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
318 |
+
weight (float): Sharp weight. Default: 1.
|
319 |
+
radius (float): Kernel size of Gaussian blur. Default: 50.
|
320 |
+
threshold (int):
|
321 |
+
"""
|
322 |
+
if radius % 2 == 0:
|
323 |
+
radius += 1
|
324 |
+
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
325 |
+
residual = img - blur
|
326 |
+
mask = np.abs(residual) * 255 > threshold
|
327 |
+
mask = mask.astype('float32')
|
328 |
+
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
329 |
+
|
330 |
+
K = img + weight * residual
|
331 |
+
K = np.clip(K, 0, 1)
|
332 |
+
return soft_mask * K + (1 - soft_mask) * img
|
333 |
+
|
334 |
+
|
335 |
+
def add_blur(img, sf=4):
|
336 |
+
wd2 = 4.0 + sf
|
337 |
+
wd = 2.0 + 0.2*sf
|
338 |
+
if random.random() < 0.5:
|
339 |
+
l1 = wd2*random.random()
|
340 |
+
l2 = wd2*random.random()
|
341 |
+
k = anisotropic_Gaussian(ksize=2*random.randint(2,11)+3, theta=random.random()*np.pi, l1=l1, l2=l2)
|
342 |
+
else:
|
343 |
+
k = fspecial('gaussian', 2*random.randint(2,11)+3, wd*random.random())
|
344 |
+
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
|
345 |
+
|
346 |
+
return img
|
347 |
+
|
348 |
+
|
349 |
+
def add_resize(img, sf=4):
|
350 |
+
rnum = np.random.rand()
|
351 |
+
if rnum > 0.8: # up
|
352 |
+
sf1 = random.uniform(1, 2)
|
353 |
+
elif rnum < 0.7: # down
|
354 |
+
sf1 = random.uniform(0.5/sf, 1)
|
355 |
+
else:
|
356 |
+
sf1 = 1.0
|
357 |
+
img = cv2.resize(img, (int(sf1*img.shape[1]), int(sf1*img.shape[0])), interpolation=random.choice([1, 2, 3]))
|
358 |
+
img = np.clip(img, 0.0, 1.0)
|
359 |
+
|
360 |
+
return img
|
361 |
+
|
362 |
+
|
363 |
+
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
364 |
+
noise_level = random.randint(noise_level1, noise_level2)
|
365 |
+
rnum = np.random.rand()
|
366 |
+
if rnum > 0.6: # add color Gaussian noise
|
367 |
+
img += np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32)
|
368 |
+
elif rnum < 0.4: # add grayscale Gaussian noise
|
369 |
+
img += np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32)
|
370 |
+
else: # add noise
|
371 |
+
L = noise_level2/255.
|
372 |
+
D = np.diag(np.random.rand(3))
|
373 |
+
U = orth(np.random.rand(3,3))
|
374 |
+
conv = np.dot(np.dot(np.transpose(U), D), U)
|
375 |
+
img += np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32)
|
376 |
+
img = np.clip(img, 0.0, 1.0)
|
377 |
+
return img
|
378 |
+
|
379 |
+
|
380 |
+
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
381 |
+
noise_level = random.randint(noise_level1, noise_level2)
|
382 |
+
img = np.clip(img, 0.0, 1.0)
|
383 |
+
rnum = random.random()
|
384 |
+
if rnum > 0.6:
|
385 |
+
img += img*np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32)
|
386 |
+
elif rnum < 0.4:
|
387 |
+
img += img*np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32)
|
388 |
+
else:
|
389 |
+
L = noise_level2/255.
|
390 |
+
D = np.diag(np.random.rand(3))
|
391 |
+
U = orth(np.random.rand(3,3))
|
392 |
+
conv = np.dot(np.dot(np.transpose(U), D), U)
|
393 |
+
img += img*np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32)
|
394 |
+
img = np.clip(img, 0.0, 1.0)
|
395 |
+
return img
|
396 |
+
|
397 |
+
|
398 |
+
def add_Poisson_noise(img):
|
399 |
+
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
400 |
+
vals = 10**(2*random.random()+2.0) # [2, 4]
|
401 |
+
if random.random() < 0.5:
|
402 |
+
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
403 |
+
else:
|
404 |
+
img_gray = np.dot(img[...,:3], [0.299, 0.587, 0.114])
|
405 |
+
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
|
406 |
+
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
407 |
+
img += noise_gray[:, :, np.newaxis]
|
408 |
+
img = np.clip(img, 0.0, 1.0)
|
409 |
+
return img
|
410 |
+
|
411 |
+
|
412 |
+
def add_JPEG_noise(img):
|
413 |
+
quality_factor = random.randint(30, 95)
|
414 |
+
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
415 |
+
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
416 |
+
img = cv2.imdecode(encimg, 1)
|
417 |
+
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
418 |
+
return img
|
419 |
+
|
420 |
+
|
421 |
+
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
422 |
+
h, w = lq.shape[:2]
|
423 |
+
rnd_h = random.randint(0, h-lq_patchsize)
|
424 |
+
rnd_w = random.randint(0, w-lq_patchsize)
|
425 |
+
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
|
426 |
+
|
427 |
+
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
428 |
+
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize*sf, rnd_w_H:rnd_w_H + lq_patchsize*sf, :]
|
429 |
+
return lq, hq
|
430 |
+
|
431 |
+
|
432 |
+
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
433 |
+
"""
|
434 |
+
This is the degradation model of BSRGAN from the paper
|
435 |
+
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
436 |
+
----------
|
437 |
+
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
438 |
+
sf: scale factor
|
439 |
+
isp_model: camera ISP model
|
440 |
+
|
441 |
+
Returns
|
442 |
+
-------
|
443 |
+
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
444 |
+
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
445 |
+
"""
|
446 |
+
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
447 |
+
sf_ori = sf
|
448 |
+
|
449 |
+
h1, w1 = img.shape[:2]
|
450 |
+
img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop
|
451 |
+
h, w = img.shape[:2]
|
452 |
+
|
453 |
+
if h < lq_patchsize*sf or w < lq_patchsize*sf:
|
454 |
+
raise ValueError(f'img size ({h1}X{w1}) is too small!')
|
455 |
+
|
456 |
+
hq = img.copy()
|
457 |
+
|
458 |
+
if sf == 4 and random.random() < scale2_prob: # downsample1
|
459 |
+
if np.random.rand() < 0.5:
|
460 |
+
img = cv2.resize(img, (int(1/2*img.shape[1]), int(1/2*img.shape[0])), interpolation=random.choice([1,2,3]))
|
461 |
+
else:
|
462 |
+
img = util.imresize_np(img, 1/2, True)
|
463 |
+
img = np.clip(img, 0.0, 1.0)
|
464 |
+
sf = 2
|
465 |
+
|
466 |
+
shuffle_order = random.sample(range(7), 7)
|
467 |
+
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
468 |
+
if idx1 > idx2: # keep downsample3 last
|
469 |
+
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
470 |
+
|
471 |
+
for i in shuffle_order:
|
472 |
+
|
473 |
+
if i == 0:
|
474 |
+
img = add_blur(img, sf=sf)
|
475 |
+
|
476 |
+
elif i == 1:
|
477 |
+
img = add_blur(img, sf=sf)
|
478 |
+
|
479 |
+
elif i == 2:
|
480 |
+
a, b = img.shape[1], img.shape[0]
|
481 |
+
# downsample2
|
482 |
+
if random.random() < 0.75:
|
483 |
+
sf1 = random.uniform(1,2*sf)
|
484 |
+
img = cv2.resize(img, (int(1/sf1*img.shape[1]), int(1/sf1*img.shape[0])), interpolation=random.choice([1,2,3]))
|
485 |
+
else:
|
486 |
+
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6*sf))
|
487 |
+
k_shifted = shift_pixel(k, sf)
|
488 |
+
k_shifted = k_shifted/k_shifted.sum() # blur with shifted kernel
|
489 |
+
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
490 |
+
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
491 |
+
img = np.clip(img, 0.0, 1.0)
|
492 |
+
|
493 |
+
elif i == 3:
|
494 |
+
# downsample3
|
495 |
+
img = cv2.resize(img, (int(1/sf*a), int(1/sf*b)), interpolation=random.choice([1,2,3]))
|
496 |
+
img = np.clip(img, 0.0, 1.0)
|
497 |
+
|
498 |
+
elif i == 4:
|
499 |
+
# add Gaussian noise
|
500 |
+
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
501 |
+
|
502 |
+
elif i == 5:
|
503 |
+
# add JPEG noise
|
504 |
+
if random.random() < jpeg_prob:
|
505 |
+
img = add_JPEG_noise(img)
|
506 |
+
|
507 |
+
elif i == 6:
|
508 |
+
# add processed camera sensor noise
|
509 |
+
if random.random() < isp_prob and isp_model is not None:
|
510 |
+
with torch.no_grad():
|
511 |
+
img, hq = isp_model.forward(img.copy(), hq)
|
512 |
+
|
513 |
+
# add final JPEG compression noise
|
514 |
+
img = add_JPEG_noise(img)
|
515 |
+
|
516 |
+
# random crop
|
517 |
+
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
518 |
+
|
519 |
+
return img, hq
|
520 |
+
|
521 |
+
|
522 |
+
|
523 |
+
|
524 |
+
def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=False, lq_patchsize=64, isp_model=None):
|
525 |
+
"""
|
526 |
+
This is an extended degradation model by combining
|
527 |
+
the degradation models of BSRGAN and Real-ESRGAN
|
528 |
+
----------
|
529 |
+
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
530 |
+
sf: scale factor
|
531 |
+
use_shuffle: the degradation shuffle
|
532 |
+
use_sharp: sharpening the img
|
533 |
+
|
534 |
+
Returns
|
535 |
+
-------
|
536 |
+
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
537 |
+
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
538 |
+
"""
|
539 |
+
|
540 |
+
h1, w1 = img.shape[:2]
|
541 |
+
img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop
|
542 |
+
h, w = img.shape[:2]
|
543 |
+
|
544 |
+
if h < lq_patchsize*sf or w < lq_patchsize*sf:
|
545 |
+
raise ValueError(f'img size ({h1}X{w1}) is too small!')
|
546 |
+
|
547 |
+
if use_sharp:
|
548 |
+
img = add_sharpening(img)
|
549 |
+
hq = img.copy()
|
550 |
+
|
551 |
+
if random.random() < shuffle_prob:
|
552 |
+
shuffle_order = random.sample(range(13), 13)
|
553 |
+
else:
|
554 |
+
shuffle_order = list(range(13))
|
555 |
+
# local shuffle for noise, JPEG is always the last one
|
556 |
+
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
|
557 |
+
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
|
558 |
+
|
559 |
+
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
|
560 |
+
|
561 |
+
for i in shuffle_order:
|
562 |
+
if i == 0:
|
563 |
+
img = add_blur(img, sf=sf)
|
564 |
+
elif i == 1:
|
565 |
+
img = add_resize(img, sf=sf)
|
566 |
+
elif i == 2:
|
567 |
+
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
568 |
+
elif i == 3:
|
569 |
+
if random.random() < poisson_prob:
|
570 |
+
img = add_Poisson_noise(img)
|
571 |
+
elif i == 4:
|
572 |
+
if random.random() < speckle_prob:
|
573 |
+
img = add_speckle_noise(img)
|
574 |
+
elif i == 5:
|
575 |
+
if random.random() < isp_prob and isp_model is not None:
|
576 |
+
with torch.no_grad():
|
577 |
+
img, hq = isp_model.forward(img.copy(), hq)
|
578 |
+
elif i == 6:
|
579 |
+
img = add_JPEG_noise(img)
|
580 |
+
elif i == 7:
|
581 |
+
img = add_blur(img, sf=sf)
|
582 |
+
elif i == 8:
|
583 |
+
img = add_resize(img, sf=sf)
|
584 |
+
elif i == 9:
|
585 |
+
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
586 |
+
elif i == 10:
|
587 |
+
if random.random() < poisson_prob:
|
588 |
+
img = add_Poisson_noise(img)
|
589 |
+
elif i == 11:
|
590 |
+
if random.random() < speckle_prob:
|
591 |
+
img = add_speckle_noise(img)
|
592 |
+
elif i == 12:
|
593 |
+
if random.random() < isp_prob and isp_model is not None:
|
594 |
+
with torch.no_grad():
|
595 |
+
img, hq = isp_model.forward(img.copy(), hq)
|
596 |
+
else:
|
597 |
+
print('check the shuffle!')
|
598 |
+
|
599 |
+
# resize to desired size
|
600 |
+
img = cv2.resize(img, (int(1/sf*hq.shape[1]), int(1/sf*hq.shape[0])), interpolation=random.choice([1, 2, 3]))
|
601 |
+
|
602 |
+
# add final JPEG compression noise
|
603 |
+
img = add_JPEG_noise(img)
|
604 |
+
|
605 |
+
# random crop
|
606 |
+
img, hq = random_crop(img, hq, sf, lq_patchsize)
|
607 |
+
|
608 |
+
return img, hq
|
609 |
+
|
610 |
+
|
611 |
+
|
612 |
+
if __name__ == '__main__':
|
613 |
+
img = util.imread_uint('utils/test.png', 3)
|
614 |
+
img = util.uint2single(img)
|
615 |
+
sf = 4
|
616 |
+
|
617 |
+
for i in range(20):
|
618 |
+
img_lq, img_hq = degradation_bsrgan(img, sf=sf, lq_patchsize=72)
|
619 |
+
print(i)
|
620 |
+
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0)
|
621 |
+
img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1)
|
622 |
+
util.imsave(img_concat, str(i)+'.png')
|
623 |
+
|
624 |
+
# for i in range(10):
|
625 |
+
# img_lq, img_hq = degradation_bsrgan_plus(img, sf=sf, shuffle_prob=0.1, use_sharp=True, lq_patchsize=64)
|
626 |
+
# print(i)
|
627 |
+
# lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0)
|
628 |
+
# img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1)
|
629 |
+
# util.imsave(img_concat, str(i)+'.png')
|
630 |
+
|
631 |
+
# run utils/utils_blindsr.py
|
core/data/deg_kair_utils/utils_bnorm.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
"""
|
6 |
+
# --------------------------------------------
|
7 |
+
# Batch Normalization
|
8 |
+
# --------------------------------------------
|
9 |
+
|
10 |
+
# Kai Zhang ([email protected])
|
11 |
+
# https://github.com/cszn
|
12 |
+
# 01/Jan/2019
|
13 |
+
# --------------------------------------------
|
14 |
+
"""
|
15 |
+
|
16 |
+
|
17 |
+
# --------------------------------------------
|
18 |
+
# remove/delete specified layer
|
19 |
+
# --------------------------------------------
|
20 |
+
def deleteLayer(model, layer_type=nn.BatchNorm2d):
|
21 |
+
''' Kai Zhang, 11/Jan/2019.
|
22 |
+
'''
|
23 |
+
for k, m in list(model.named_children()):
|
24 |
+
if isinstance(m, layer_type):
|
25 |
+
del model._modules[k]
|
26 |
+
deleteLayer(m, layer_type)
|
27 |
+
|
28 |
+
|
29 |
+
# --------------------------------------------
|
30 |
+
# merge bn, "conv+bn" --> "conv"
|
31 |
+
# --------------------------------------------
|
32 |
+
def merge_bn(model):
|
33 |
+
''' Kai Zhang, 11/Jan/2019.
|
34 |
+
merge all 'Conv+BN' (or 'TConv+BN') into 'Conv' (or 'TConv')
|
35 |
+
based on https://github.com/pytorch/pytorch/pull/901
|
36 |
+
'''
|
37 |
+
prev_m = None
|
38 |
+
for k, m in list(model.named_children()):
|
39 |
+
if (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)) and (isinstance(prev_m, nn.Conv2d) or isinstance(prev_m, nn.Linear) or isinstance(prev_m, nn.ConvTranspose2d)):
|
40 |
+
|
41 |
+
w = prev_m.weight.data
|
42 |
+
|
43 |
+
if prev_m.bias is None:
|
44 |
+
zeros = torch.Tensor(prev_m.out_channels).zero_().type(w.type())
|
45 |
+
prev_m.bias = nn.Parameter(zeros)
|
46 |
+
b = prev_m.bias.data
|
47 |
+
|
48 |
+
invstd = m.running_var.clone().add_(m.eps).pow_(-0.5)
|
49 |
+
if isinstance(prev_m, nn.ConvTranspose2d):
|
50 |
+
w.mul_(invstd.view(1, w.size(1), 1, 1).expand_as(w))
|
51 |
+
else:
|
52 |
+
w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w))
|
53 |
+
b.add_(-m.running_mean).mul_(invstd)
|
54 |
+
if m.affine:
|
55 |
+
if isinstance(prev_m, nn.ConvTranspose2d):
|
56 |
+
w.mul_(m.weight.data.view(1, w.size(1), 1, 1).expand_as(w))
|
57 |
+
else:
|
58 |
+
w.mul_(m.weight.data.view(w.size(0), 1, 1, 1).expand_as(w))
|
59 |
+
b.mul_(m.weight.data).add_(m.bias.data)
|
60 |
+
|
61 |
+
del model._modules[k]
|
62 |
+
prev_m = m
|
63 |
+
merge_bn(m)
|
64 |
+
|
65 |
+
|
66 |
+
# --------------------------------------------
|
67 |
+
# add bn, "conv" --> "conv+bn"
|
68 |
+
# --------------------------------------------
|
69 |
+
def add_bn(model):
|
70 |
+
''' Kai Zhang, 11/Jan/2019.
|
71 |
+
'''
|
72 |
+
for k, m in list(model.named_children()):
|
73 |
+
if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d)):
|
74 |
+
b = nn.BatchNorm2d(m.out_channels, momentum=0.1, affine=True)
|
75 |
+
b.weight.data.fill_(1)
|
76 |
+
new_m = nn.Sequential(model._modules[k], b)
|
77 |
+
model._modules[k] = new_m
|
78 |
+
add_bn(m)
|
79 |
+
|
80 |
+
|
81 |
+
# --------------------------------------------
|
82 |
+
# tidy model after removing bn
|
83 |
+
# --------------------------------------------
|
84 |
+
def tidy_sequential(model):
|
85 |
+
''' Kai Zhang, 11/Jan/2019.
|
86 |
+
'''
|
87 |
+
for k, m in list(model.named_children()):
|
88 |
+
if isinstance(m, nn.Sequential):
|
89 |
+
if m.__len__() == 1:
|
90 |
+
model._modules[k] = m.__getitem__(0)
|
91 |
+
tidy_sequential(m)
|
core/data/deg_kair_utils/utils_deblur.py
ADDED
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import numpy as np
|
3 |
+
import scipy
|
4 |
+
from scipy import fftpack
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from math import cos, sin
|
8 |
+
from numpy import zeros, ones, prod, array, pi, log, min, mod, arange, sum, mgrid, exp, pad, round
|
9 |
+
from numpy.random import randn, rand
|
10 |
+
from scipy.signal import convolve2d
|
11 |
+
import cv2
|
12 |
+
import random
|
13 |
+
# import utils_image as util
|
14 |
+
|
15 |
+
'''
|
16 |
+
modified by Kai Zhang (github: https://github.com/cszn)
|
17 |
+
03/03/2019
|
18 |
+
'''
|
19 |
+
|
20 |
+
|
21 |
+
def get_uperleft_denominator(img, kernel):
|
22 |
+
'''
|
23 |
+
img: HxWxC
|
24 |
+
kernel: hxw
|
25 |
+
denominator: HxWx1
|
26 |
+
upperleft: HxWxC
|
27 |
+
'''
|
28 |
+
V = psf2otf(kernel, img.shape[:2])
|
29 |
+
denominator = np.expand_dims(np.abs(V)**2, axis=2)
|
30 |
+
upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1])
|
31 |
+
return upperleft, denominator
|
32 |
+
|
33 |
+
|
34 |
+
def get_uperleft_denominator_pytorch(img, kernel):
|
35 |
+
'''
|
36 |
+
img: NxCxHxW
|
37 |
+
kernel: Nx1xhxw
|
38 |
+
denominator: Nx1xHxW
|
39 |
+
upperleft: NxCxHxWx2
|
40 |
+
'''
|
41 |
+
V = p2o(kernel, img.shape[-2:]) # Nx1xHxWx2
|
42 |
+
denominator = V[..., 0]**2+V[..., 1]**2 # Nx1xHxW
|
43 |
+
upperleft = cmul(cconj(V), rfft(img)) # Nx1xHxWx2 * NxCxHxWx2
|
44 |
+
return upperleft, denominator
|
45 |
+
|
46 |
+
|
47 |
+
def c2c(x):
|
48 |
+
return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
|
49 |
+
|
50 |
+
|
51 |
+
def r2c(x):
|
52 |
+
return torch.stack([x, torch.zeros_like(x)], -1)
|
53 |
+
|
54 |
+
|
55 |
+
def cdiv(x, y):
|
56 |
+
a, b = x[..., 0], x[..., 1]
|
57 |
+
c, d = y[..., 0], y[..., 1]
|
58 |
+
cd2 = c**2 + d**2
|
59 |
+
return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
|
60 |
+
|
61 |
+
|
62 |
+
def cabs(x):
|
63 |
+
return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
|
64 |
+
|
65 |
+
|
66 |
+
def cmul(t1, t2):
|
67 |
+
'''
|
68 |
+
complex multiplication
|
69 |
+
t1: NxCxHxWx2
|
70 |
+
output: NxCxHxWx2
|
71 |
+
'''
|
72 |
+
real1, imag1 = t1[..., 0], t1[..., 1]
|
73 |
+
real2, imag2 = t2[..., 0], t2[..., 1]
|
74 |
+
return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
|
75 |
+
|
76 |
+
|
77 |
+
def cconj(t, inplace=False):
|
78 |
+
'''
|
79 |
+
# complex's conjugation
|
80 |
+
t: NxCxHxWx2
|
81 |
+
output: NxCxHxWx2
|
82 |
+
'''
|
83 |
+
c = t.clone() if not inplace else t
|
84 |
+
c[..., 1] *= -1
|
85 |
+
return c
|
86 |
+
|
87 |
+
|
88 |
+
def rfft(t):
|
89 |
+
return torch.rfft(t, 2, onesided=False)
|
90 |
+
|
91 |
+
|
92 |
+
def irfft(t):
|
93 |
+
return torch.irfft(t, 2, onesided=False)
|
94 |
+
|
95 |
+
|
96 |
+
def fft(t):
|
97 |
+
return torch.fft(t, 2)
|
98 |
+
|
99 |
+
|
100 |
+
def ifft(t):
|
101 |
+
return torch.ifft(t, 2)
|
102 |
+
|
103 |
+
|
104 |
+
def p2o(psf, shape):
|
105 |
+
'''
|
106 |
+
# psf: NxCxhxw
|
107 |
+
# shape: [H,W]
|
108 |
+
# otf: NxCxHxWx2
|
109 |
+
'''
|
110 |
+
otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
|
111 |
+
otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
|
112 |
+
for axis, axis_size in enumerate(psf.shape[2:]):
|
113 |
+
otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
|
114 |
+
otf = torch.rfft(otf, 2, onesided=False)
|
115 |
+
n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
|
116 |
+
otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
|
117 |
+
return otf
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
# otf2psf: not sure where I got this one from. Maybe translated from Octave source code or whatever. It's just math.
|
122 |
+
def otf2psf(otf, outsize=None):
|
123 |
+
insize = np.array(otf.shape)
|
124 |
+
psf = np.fft.ifftn(otf, axes=(0, 1))
|
125 |
+
for axis, axis_size in enumerate(insize):
|
126 |
+
psf = np.roll(psf, np.floor(axis_size / 2).astype(int), axis=axis)
|
127 |
+
if type(outsize) != type(None):
|
128 |
+
insize = np.array(otf.shape)
|
129 |
+
outsize = np.array(outsize)
|
130 |
+
n = max(np.size(outsize), np.size(insize))
|
131 |
+
# outsize = postpad(outsize(:), n, 1);
|
132 |
+
# insize = postpad(insize(:) , n, 1);
|
133 |
+
colvec_out = outsize.flatten().reshape((np.size(outsize), 1))
|
134 |
+
colvec_in = insize.flatten().reshape((np.size(insize), 1))
|
135 |
+
outsize = np.pad(colvec_out, ((0, max(0, n - np.size(colvec_out))), (0, 0)), mode="constant")
|
136 |
+
insize = np.pad(colvec_in, ((0, max(0, n - np.size(colvec_in))), (0, 0)), mode="constant")
|
137 |
+
|
138 |
+
pad = (insize - outsize) / 2
|
139 |
+
if np.any(pad < 0):
|
140 |
+
print("otf2psf error: OUTSIZE must be smaller than or equal than OTF size")
|
141 |
+
prepad = np.floor(pad)
|
142 |
+
postpad = np.ceil(pad)
|
143 |
+
dims_start = prepad.astype(int)
|
144 |
+
dims_end = (insize - postpad).astype(int)
|
145 |
+
for i in range(len(dims_start.shape)):
|
146 |
+
psf = np.take(psf, range(dims_start[i][0], dims_end[i][0]), axis=i)
|
147 |
+
n_ops = np.sum(otf.size * np.log2(otf.shape))
|
148 |
+
psf = np.real_if_close(psf, tol=n_ops)
|
149 |
+
return psf
|
150 |
+
|
151 |
+
|
152 |
+
# psf2otf copied/modified from https://github.com/aboucaud/pypher/blob/master/pypher/pypher.py
|
153 |
+
def psf2otf(psf, shape=None):
|
154 |
+
"""
|
155 |
+
Convert point-spread function to optical transfer function.
|
156 |
+
Compute the Fast Fourier Transform (FFT) of the point-spread
|
157 |
+
function (PSF) array and creates the optical transfer function (OTF)
|
158 |
+
array that is not influenced by the PSF off-centering.
|
159 |
+
By default, the OTF array is the same size as the PSF array.
|
160 |
+
To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF
|
161 |
+
post-pads the PSF array (down or to the right) with zeros to match
|
162 |
+
dimensions specified in OUTSIZE, then circularly shifts the values of
|
163 |
+
the PSF array up (or to the left) until the central pixel reaches (1,1)
|
164 |
+
position.
|
165 |
+
Parameters
|
166 |
+
----------
|
167 |
+
psf : `numpy.ndarray`
|
168 |
+
PSF array
|
169 |
+
shape : int
|
170 |
+
Output shape of the OTF array
|
171 |
+
Returns
|
172 |
+
-------
|
173 |
+
otf : `numpy.ndarray`
|
174 |
+
OTF array
|
175 |
+
Notes
|
176 |
+
-----
|
177 |
+
Adapted from MATLAB psf2otf function
|
178 |
+
"""
|
179 |
+
if type(shape) == type(None):
|
180 |
+
shape = psf.shape
|
181 |
+
shape = np.array(shape)
|
182 |
+
if np.all(psf == 0):
|
183 |
+
# return np.zeros_like(psf)
|
184 |
+
return np.zeros(shape)
|
185 |
+
if len(psf.shape) == 1:
|
186 |
+
psf = psf.reshape((1, psf.shape[0]))
|
187 |
+
inshape = psf.shape
|
188 |
+
psf = zero_pad(psf, shape, position='corner')
|
189 |
+
for axis, axis_size in enumerate(inshape):
|
190 |
+
psf = np.roll(psf, -int(axis_size / 2), axis=axis)
|
191 |
+
# Compute the OTF
|
192 |
+
otf = np.fft.fft2(psf, axes=(0, 1))
|
193 |
+
# Estimate the rough number of operations involved in the FFT
|
194 |
+
# and discard the PSF imaginary part if within roundoff error
|
195 |
+
# roundoff error = machine epsilon = sys.float_info.epsilon
|
196 |
+
# or np.finfo().eps
|
197 |
+
n_ops = np.sum(psf.size * np.log2(psf.shape))
|
198 |
+
otf = np.real_if_close(otf, tol=n_ops)
|
199 |
+
return otf
|
200 |
+
|
201 |
+
|
202 |
+
def zero_pad(image, shape, position='corner'):
|
203 |
+
"""
|
204 |
+
Extends image to a certain size with zeros
|
205 |
+
Parameters
|
206 |
+
----------
|
207 |
+
image: real 2d `numpy.ndarray`
|
208 |
+
Input image
|
209 |
+
shape: tuple of int
|
210 |
+
Desired output shape of the image
|
211 |
+
position : str, optional
|
212 |
+
The position of the input image in the output one:
|
213 |
+
* 'corner'
|
214 |
+
top-left corner (default)
|
215 |
+
* 'center'
|
216 |
+
centered
|
217 |
+
Returns
|
218 |
+
-------
|
219 |
+
padded_img: real `numpy.ndarray`
|
220 |
+
The zero-padded image
|
221 |
+
"""
|
222 |
+
shape = np.asarray(shape, dtype=int)
|
223 |
+
imshape = np.asarray(image.shape, dtype=int)
|
224 |
+
if np.alltrue(imshape == shape):
|
225 |
+
return image
|
226 |
+
if np.any(shape <= 0):
|
227 |
+
raise ValueError("ZERO_PAD: null or negative shape given")
|
228 |
+
dshape = shape - imshape
|
229 |
+
if np.any(dshape < 0):
|
230 |
+
raise ValueError("ZERO_PAD: target size smaller than source one")
|
231 |
+
pad_img = np.zeros(shape, dtype=image.dtype)
|
232 |
+
idx, idy = np.indices(imshape)
|
233 |
+
if position == 'center':
|
234 |
+
if np.any(dshape % 2 != 0):
|
235 |
+
raise ValueError("ZERO_PAD: source and target shapes "
|
236 |
+
"have different parity.")
|
237 |
+
offx, offy = dshape // 2
|
238 |
+
else:
|
239 |
+
offx, offy = (0, 0)
|
240 |
+
pad_img[idx + offx, idy + offy] = image
|
241 |
+
return pad_img
|
242 |
+
|
243 |
+
|
244 |
+
'''
|
245 |
+
Reducing boundary artifacts
|
246 |
+
'''
|
247 |
+
|
248 |
+
|
249 |
+
def opt_fft_size(n):
|
250 |
+
'''
|
251 |
+
Kai Zhang (github: https://github.com/cszn)
|
252 |
+
03/03/2019
|
253 |
+
# opt_fft_size.m
|
254 |
+
# compute an optimal data length for Fourier transforms
|
255 |
+
# written by Sunghyun Cho ([email protected])
|
256 |
+
# persistent opt_fft_size_LUT;
|
257 |
+
'''
|
258 |
+
|
259 |
+
LUT_size = 2048
|
260 |
+
# print("generate opt_fft_size_LUT")
|
261 |
+
opt_fft_size_LUT = np.zeros(LUT_size)
|
262 |
+
|
263 |
+
e2 = 1
|
264 |
+
while e2 <= LUT_size:
|
265 |
+
e3 = e2
|
266 |
+
while e3 <= LUT_size:
|
267 |
+
e5 = e3
|
268 |
+
while e5 <= LUT_size:
|
269 |
+
e7 = e5
|
270 |
+
while e7 <= LUT_size:
|
271 |
+
if e7 <= LUT_size:
|
272 |
+
opt_fft_size_LUT[e7-1] = e7
|
273 |
+
if e7*11 <= LUT_size:
|
274 |
+
opt_fft_size_LUT[e7*11-1] = e7*11
|
275 |
+
if e7*13 <= LUT_size:
|
276 |
+
opt_fft_size_LUT[e7*13-1] = e7*13
|
277 |
+
e7 = e7 * 7
|
278 |
+
e5 = e5 * 5
|
279 |
+
e3 = e3 * 3
|
280 |
+
e2 = e2 * 2
|
281 |
+
|
282 |
+
nn = 0
|
283 |
+
for i in range(LUT_size, 0, -1):
|
284 |
+
if opt_fft_size_LUT[i-1] != 0:
|
285 |
+
nn = i-1
|
286 |
+
else:
|
287 |
+
opt_fft_size_LUT[i-1] = nn+1
|
288 |
+
|
289 |
+
m = np.zeros(len(n))
|
290 |
+
for c in range(len(n)):
|
291 |
+
nn = n[c]
|
292 |
+
if nn <= LUT_size:
|
293 |
+
m[c] = opt_fft_size_LUT[nn-1]
|
294 |
+
else:
|
295 |
+
m[c] = -1
|
296 |
+
return m
|
297 |
+
|
298 |
+
|
299 |
+
def wrap_boundary_liu(img, img_size):
|
300 |
+
|
301 |
+
"""
|
302 |
+
Reducing boundary artifacts in image deconvolution
|
303 |
+
Renting Liu, Jiaya Jia
|
304 |
+
ICIP 2008
|
305 |
+
"""
|
306 |
+
if img.ndim == 2:
|
307 |
+
ret = wrap_boundary(img, img_size)
|
308 |
+
elif img.ndim == 3:
|
309 |
+
ret = [wrap_boundary(img[:, :, i], img_size) for i in range(3)]
|
310 |
+
ret = np.stack(ret, 2)
|
311 |
+
return ret
|
312 |
+
|
313 |
+
|
314 |
+
def wrap_boundary(img, img_size):
|
315 |
+
|
316 |
+
"""
|
317 |
+
python code from:
|
318 |
+
https://github.com/ys-koshelev/nla_deblur/blob/90fe0ab98c26c791dcbdf231fe6f938fca80e2a0/boundaries.py
|
319 |
+
Reducing boundary artifacts in image deconvolution
|
320 |
+
Renting Liu, Jiaya Jia
|
321 |
+
ICIP 2008
|
322 |
+
"""
|
323 |
+
(H, W) = np.shape(img)
|
324 |
+
H_w = int(img_size[0]) - H
|
325 |
+
W_w = int(img_size[1]) - W
|
326 |
+
|
327 |
+
# ret = np.zeros((img_size[0], img_size[1]));
|
328 |
+
alpha = 1
|
329 |
+
HG = img[:, :]
|
330 |
+
|
331 |
+
r_A = np.zeros((alpha*2+H_w, W))
|
332 |
+
r_A[:alpha, :] = HG[-alpha:, :]
|
333 |
+
r_A[-alpha:, :] = HG[:alpha, :]
|
334 |
+
a = np.arange(H_w)/(H_w-1)
|
335 |
+
# r_A(alpha+1:end-alpha, 1) = (1-a)*r_A(alpha,1) + a*r_A(end-alpha+1,1)
|
336 |
+
r_A[alpha:-alpha, 0] = (1-a)*r_A[alpha-1, 0] + a*r_A[-alpha, 0]
|
337 |
+
# r_A(alpha+1:end-alpha, end) = (1-a)*r_A(alpha,end) + a*r_A(end-alpha+1,end)
|
338 |
+
r_A[alpha:-alpha, -1] = (1-a)*r_A[alpha-1, -1] + a*r_A[-alpha, -1]
|
339 |
+
|
340 |
+
r_B = np.zeros((H, alpha*2+W_w))
|
341 |
+
r_B[:, :alpha] = HG[:, -alpha:]
|
342 |
+
r_B[:, -alpha:] = HG[:, :alpha]
|
343 |
+
a = np.arange(W_w)/(W_w-1)
|
344 |
+
r_B[0, alpha:-alpha] = (1-a)*r_B[0, alpha-1] + a*r_B[0, -alpha]
|
345 |
+
r_B[-1, alpha:-alpha] = (1-a)*r_B[-1, alpha-1] + a*r_B[-1, -alpha]
|
346 |
+
|
347 |
+
if alpha == 1:
|
348 |
+
A2 = solve_min_laplacian(r_A[alpha-1:, :])
|
349 |
+
B2 = solve_min_laplacian(r_B[:, alpha-1:])
|
350 |
+
r_A[alpha-1:, :] = A2
|
351 |
+
r_B[:, alpha-1:] = B2
|
352 |
+
else:
|
353 |
+
A2 = solve_min_laplacian(r_A[alpha-1:-alpha+1, :])
|
354 |
+
r_A[alpha-1:-alpha+1, :] = A2
|
355 |
+
B2 = solve_min_laplacian(r_B[:, alpha-1:-alpha+1])
|
356 |
+
r_B[:, alpha-1:-alpha+1] = B2
|
357 |
+
A = r_A
|
358 |
+
B = r_B
|
359 |
+
|
360 |
+
r_C = np.zeros((alpha*2+H_w, alpha*2+W_w))
|
361 |
+
r_C[:alpha, :] = B[-alpha:, :]
|
362 |
+
r_C[-alpha:, :] = B[:alpha, :]
|
363 |
+
r_C[:, :alpha] = A[:, -alpha:]
|
364 |
+
r_C[:, -alpha:] = A[:, :alpha]
|
365 |
+
|
366 |
+
if alpha == 1:
|
367 |
+
C2 = C2 = solve_min_laplacian(r_C[alpha-1:, alpha-1:])
|
368 |
+
r_C[alpha-1:, alpha-1:] = C2
|
369 |
+
else:
|
370 |
+
C2 = solve_min_laplacian(r_C[alpha-1:-alpha+1, alpha-1:-alpha+1])
|
371 |
+
r_C[alpha-1:-alpha+1, alpha-1:-alpha+1] = C2
|
372 |
+
C = r_C
|
373 |
+
# return C
|
374 |
+
A = A[alpha-1:-alpha-1, :]
|
375 |
+
B = B[:, alpha:-alpha]
|
376 |
+
C = C[alpha:-alpha, alpha:-alpha]
|
377 |
+
ret = np.vstack((np.hstack((img, B)), np.hstack((A, C))))
|
378 |
+
return ret
|
379 |
+
|
380 |
+
|
381 |
+
def solve_min_laplacian(boundary_image):
|
382 |
+
(H, W) = np.shape(boundary_image)
|
383 |
+
|
384 |
+
# Laplacian
|
385 |
+
f = np.zeros((H, W))
|
386 |
+
# boundary image contains image intensities at boundaries
|
387 |
+
boundary_image[1:-1, 1:-1] = 0
|
388 |
+
j = np.arange(2, H)-1
|
389 |
+
k = np.arange(2, W)-1
|
390 |
+
f_bp = np.zeros((H, W))
|
391 |
+
f_bp[np.ix_(j, k)] = -4*boundary_image[np.ix_(j, k)] + boundary_image[np.ix_(j, k+1)] + boundary_image[np.ix_(j, k-1)] + boundary_image[np.ix_(j-1, k)] + boundary_image[np.ix_(j+1, k)]
|
392 |
+
|
393 |
+
del(j, k)
|
394 |
+
f1 = f - f_bp # subtract boundary points contribution
|
395 |
+
del(f_bp, f)
|
396 |
+
|
397 |
+
# DST Sine Transform algo starts here
|
398 |
+
f2 = f1[1:-1,1:-1]
|
399 |
+
del(f1)
|
400 |
+
|
401 |
+
# compute sine tranform
|
402 |
+
if f2.shape[1] == 1:
|
403 |
+
tt = fftpack.dst(f2, type=1, axis=0)/2
|
404 |
+
else:
|
405 |
+
tt = fftpack.dst(f2, type=1)/2
|
406 |
+
|
407 |
+
if tt.shape[0] == 1:
|
408 |
+
f2sin = np.transpose(fftpack.dst(np.transpose(tt), type=1, axis=0)/2)
|
409 |
+
else:
|
410 |
+
f2sin = np.transpose(fftpack.dst(np.transpose(tt), type=1)/2)
|
411 |
+
del(f2)
|
412 |
+
|
413 |
+
# compute Eigen Values
|
414 |
+
[x, y] = np.meshgrid(np.arange(1, W-1), np.arange(1, H-1))
|
415 |
+
denom = (2*np.cos(np.pi*x/(W-1))-2) + (2*np.cos(np.pi*y/(H-1)) - 2)
|
416 |
+
|
417 |
+
# divide
|
418 |
+
f3 = f2sin/denom
|
419 |
+
del(f2sin, x, y)
|
420 |
+
|
421 |
+
# compute Inverse Sine Transform
|
422 |
+
if f3.shape[0] == 1:
|
423 |
+
tt = fftpack.idst(f3*2, type=1, axis=1)/(2*(f3.shape[1]+1))
|
424 |
+
else:
|
425 |
+
tt = fftpack.idst(f3*2, type=1, axis=0)/(2*(f3.shape[0]+1))
|
426 |
+
del(f3)
|
427 |
+
if tt.shape[1] == 1:
|
428 |
+
img_tt = np.transpose(fftpack.idst(np.transpose(tt)*2, type=1)/(2*(tt.shape[0]+1)))
|
429 |
+
else:
|
430 |
+
img_tt = np.transpose(fftpack.idst(np.transpose(tt)*2, type=1, axis=0)/(2*(tt.shape[1]+1)))
|
431 |
+
del(tt)
|
432 |
+
|
433 |
+
# put solution in inner points; outer points obtained from boundary image
|
434 |
+
img_direct = boundary_image
|
435 |
+
img_direct[1:-1, 1:-1] = 0
|
436 |
+
img_direct[1:-1, 1:-1] = img_tt
|
437 |
+
return img_direct
|
438 |
+
|
439 |
+
|
440 |
+
"""
|
441 |
+
Created on Thu Jan 18 15:36:32 2018
|
442 |
+
@author: italo
|
443 |
+
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
444 |
+
"""
|
445 |
+
|
446 |
+
"""
|
447 |
+
Syntax
|
448 |
+
h = fspecial(type)
|
449 |
+
h = fspecial('average',hsize)
|
450 |
+
h = fspecial('disk',radius)
|
451 |
+
h = fspecial('gaussian',hsize,sigma)
|
452 |
+
h = fspecial('laplacian',alpha)
|
453 |
+
h = fspecial('log',hsize,sigma)
|
454 |
+
h = fspecial('motion',len,theta)
|
455 |
+
h = fspecial('prewitt')
|
456 |
+
h = fspecial('sobel')
|
457 |
+
"""
|
458 |
+
|
459 |
+
|
460 |
+
def fspecial_average(hsize=3):
|
461 |
+
"""Smoothing filter"""
|
462 |
+
return np.ones((hsize, hsize))/hsize**2
|
463 |
+
|
464 |
+
|
465 |
+
def fspecial_disk(radius):
|
466 |
+
"""Disk filter"""
|
467 |
+
raise(NotImplemented)
|
468 |
+
rad = 0.6
|
469 |
+
crad = np.ceil(rad-0.5)
|
470 |
+
[x, y] = np.meshgrid(np.arange(-crad, crad+1), np.arange(-crad, crad+1))
|
471 |
+
maxxy = np.zeros(x.shape)
|
472 |
+
maxxy[abs(x) >= abs(y)] = abs(x)[abs(x) >= abs(y)]
|
473 |
+
maxxy[abs(y) >= abs(x)] = abs(y)[abs(y) >= abs(x)]
|
474 |
+
minxy = np.zeros(x.shape)
|
475 |
+
minxy[abs(x) <= abs(y)] = abs(x)[abs(x) <= abs(y)]
|
476 |
+
minxy[abs(y) <= abs(x)] = abs(y)[abs(y) <= abs(x)]
|
477 |
+
m1 = (rad**2 < (maxxy+0.5)**2 + (minxy-0.5)**2)*(minxy-0.5) +\
|
478 |
+
(rad**2 >= (maxxy+0.5)**2 + (minxy-0.5)**2)*\
|
479 |
+
np.sqrt((rad**2 + 0j) - (maxxy + 0.5)**2)
|
480 |
+
m2 = (rad**2 > (maxxy-0.5)**2 + (minxy+0.5)**2)*(minxy+0.5) +\
|
481 |
+
(rad**2 <= (maxxy-0.5)**2 + (minxy+0.5)**2)*\
|
482 |
+
np.sqrt((rad**2 + 0j) - (maxxy - 0.5)**2)
|
483 |
+
h = None
|
484 |
+
return h
|
485 |
+
|
486 |
+
|
487 |
+
def fspecial_gaussian(hsize, sigma):
|
488 |
+
hsize = [hsize, hsize]
|
489 |
+
siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0]
|
490 |
+
std = sigma
|
491 |
+
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1))
|
492 |
+
arg = -(x*x + y*y)/(2*std*std)
|
493 |
+
h = np.exp(arg)
|
494 |
+
h[h < scipy.finfo(float).eps * h.max()] = 0
|
495 |
+
sumh = h.sum()
|
496 |
+
if sumh != 0:
|
497 |
+
h = h/sumh
|
498 |
+
return h
|
499 |
+
|
500 |
+
|
501 |
+
def fspecial_laplacian(alpha):
|
502 |
+
alpha = max([0, min([alpha,1])])
|
503 |
+
h1 = alpha/(alpha+1)
|
504 |
+
h2 = (1-alpha)/(alpha+1)
|
505 |
+
h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]]
|
506 |
+
h = np.array(h)
|
507 |
+
return h
|
508 |
+
|
509 |
+
|
510 |
+
def fspecial_log(hsize, sigma):
|
511 |
+
raise(NotImplemented)
|
512 |
+
|
513 |
+
|
514 |
+
def fspecial_motion(motion_len, theta):
|
515 |
+
raise(NotImplemented)
|
516 |
+
|
517 |
+
|
518 |
+
def fspecial_prewitt():
|
519 |
+
return np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]])
|
520 |
+
|
521 |
+
|
522 |
+
def fspecial_sobel():
|
523 |
+
return np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
|
524 |
+
|
525 |
+
|
526 |
+
def fspecial(filter_type, *args, **kwargs):
|
527 |
+
'''
|
528 |
+
python code from:
|
529 |
+
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
530 |
+
'''
|
531 |
+
if filter_type == 'average':
|
532 |
+
return fspecial_average(*args, **kwargs)
|
533 |
+
if filter_type == 'disk':
|
534 |
+
return fspecial_disk(*args, **kwargs)
|
535 |
+
if filter_type == 'gaussian':
|
536 |
+
return fspecial_gaussian(*args, **kwargs)
|
537 |
+
if filter_type == 'laplacian':
|
538 |
+
return fspecial_laplacian(*args, **kwargs)
|
539 |
+
if filter_type == 'log':
|
540 |
+
return fspecial_log(*args, **kwargs)
|
541 |
+
if filter_type == 'motion':
|
542 |
+
return fspecial_motion(*args, **kwargs)
|
543 |
+
if filter_type == 'prewitt':
|
544 |
+
return fspecial_prewitt(*args, **kwargs)
|
545 |
+
if filter_type == 'sobel':
|
546 |
+
return fspecial_sobel(*args, **kwargs)
|
547 |
+
|
548 |
+
|
549 |
+
def fspecial_gauss(size, sigma):
|
550 |
+
x, y = mgrid[-size // 2 + 1 : size // 2 + 1, -size // 2 + 1 : size // 2 + 1]
|
551 |
+
g = exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2)))
|
552 |
+
return g / g.sum()
|
553 |
+
|
554 |
+
|
555 |
+
def blurkernel_synthesis(h=37, w=None):
|
556 |
+
# https://github.com/tkkcc/prior/blob/879a0b6c117c810776d8cc6b63720bf29f7d0cc4/util/gen_kernel.py
|
557 |
+
w = h if w is None else w
|
558 |
+
kdims = [h, w]
|
559 |
+
x = randomTrajectory(250)
|
560 |
+
k = None
|
561 |
+
while k is None:
|
562 |
+
k = kernelFromTrajectory(x)
|
563 |
+
|
564 |
+
# center pad to kdims
|
565 |
+
pad_width = ((kdims[0] - k.shape[0]) // 2, (kdims[1] - k.shape[1]) // 2)
|
566 |
+
pad_width = [(pad_width[0],), (pad_width[1],)]
|
567 |
+
|
568 |
+
if pad_width[0][0]<0 or pad_width[1][0]<0:
|
569 |
+
k = k[0:h, 0:h]
|
570 |
+
else:
|
571 |
+
k = pad(k, pad_width, "constant")
|
572 |
+
x1,x2 = k.shape
|
573 |
+
if np.random.randint(0, 4) == 1:
|
574 |
+
k = cv2.resize(k, (random.randint(x1, 5*x1), random.randint(x2, 5*x2)), interpolation=cv2.INTER_LINEAR)
|
575 |
+
y1, y2 = k.shape
|
576 |
+
k = k[(y1-x1)//2: (y1-x1)//2+x1, (y2-x2)//2: (y2-x2)//2+x2]
|
577 |
+
|
578 |
+
if sum(k)<0.1:
|
579 |
+
k = fspecial_gaussian(h, 0.1+6*np.random.rand(1))
|
580 |
+
k = k / sum(k)
|
581 |
+
# import matplotlib.pyplot as plt
|
582 |
+
# plt.imshow(k, interpolation="nearest", cmap="gray")
|
583 |
+
# plt.show()
|
584 |
+
return k
|
585 |
+
|
586 |
+
|
587 |
+
def kernelFromTrajectory(x):
|
588 |
+
h = 5 - log(rand()) / 0.15
|
589 |
+
h = round(min([h, 27])).astype(int)
|
590 |
+
h = h + 1 - h % 2
|
591 |
+
w = h
|
592 |
+
k = zeros((h, w))
|
593 |
+
|
594 |
+
xmin = min(x[0])
|
595 |
+
xmax = max(x[0])
|
596 |
+
ymin = min(x[1])
|
597 |
+
ymax = max(x[1])
|
598 |
+
xthr = arange(xmin, xmax, (xmax - xmin) / w)
|
599 |
+
ythr = arange(ymin, ymax, (ymax - ymin) / h)
|
600 |
+
|
601 |
+
for i in range(1, xthr.size):
|
602 |
+
for j in range(1, ythr.size):
|
603 |
+
idx = (
|
604 |
+
(x[0, :] >= xthr[i - 1])
|
605 |
+
& (x[0, :] < xthr[i])
|
606 |
+
& (x[1, :] >= ythr[j - 1])
|
607 |
+
& (x[1, :] < ythr[j])
|
608 |
+
)
|
609 |
+
k[i - 1, j - 1] = sum(idx)
|
610 |
+
if sum(k) == 0:
|
611 |
+
return
|
612 |
+
k = k / sum(k)
|
613 |
+
k = convolve2d(k, fspecial_gauss(3, 1), "same")
|
614 |
+
k = k / sum(k)
|
615 |
+
return k
|
616 |
+
|
617 |
+
|
618 |
+
def randomTrajectory(T):
|
619 |
+
x = zeros((3, T))
|
620 |
+
v = randn(3, T)
|
621 |
+
r = zeros((3, T))
|
622 |
+
trv = 1 / 1
|
623 |
+
trr = 2 * pi / T
|
624 |
+
for t in range(1, T):
|
625 |
+
F_rot = randn(3) / (t + 1) + r[:, t - 1]
|
626 |
+
F_trans = randn(3) / (t + 1)
|
627 |
+
r[:, t] = r[:, t - 1] + trr * F_rot
|
628 |
+
v[:, t] = v[:, t - 1] + trv * F_trans
|
629 |
+
st = v[:, t]
|
630 |
+
st = rot3D(st, r[:, t])
|
631 |
+
x[:, t] = x[:, t - 1] + st
|
632 |
+
return x
|
633 |
+
|
634 |
+
|
635 |
+
def rot3D(x, r):
|
636 |
+
Rx = array([[1, 0, 0], [0, cos(r[0]), -sin(r[0])], [0, sin(r[0]), cos(r[0])]])
|
637 |
+
Ry = array([[cos(r[1]), 0, sin(r[1])], [0, 1, 0], [-sin(r[1]), 0, cos(r[1])]])
|
638 |
+
Rz = array([[cos(r[2]), -sin(r[2]), 0], [sin(r[2]), cos(r[2]), 0], [0, 0, 1]])
|
639 |
+
R = Rz @ Ry @ Rx
|
640 |
+
x = R @ x
|
641 |
+
return x
|
642 |
+
|
643 |
+
|
644 |
+
if __name__ == '__main__':
|
645 |
+
a = opt_fft_size([111])
|
646 |
+
print(a)
|
647 |
+
|
648 |
+
print(fspecial('gaussian', 5, 1))
|
649 |
+
|
650 |
+
print(p2o(torch.zeros(1,1,4,4).float(),(14,14)).shape)
|
651 |
+
|
652 |
+
k = blurkernel_synthesis(11)
|
653 |
+
import matplotlib.pyplot as plt
|
654 |
+
plt.imshow(k, interpolation="nearest", cmap="gray")
|
655 |
+
plt.show()
|
core/data/deg_kair_utils/utils_dist.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
import torch.multiprocessing as mp
|
8 |
+
|
9 |
+
|
10 |
+
# ----------------------------------
|
11 |
+
# init
|
12 |
+
# ----------------------------------
|
13 |
+
def init_dist(launcher, backend='nccl', **kwargs):
|
14 |
+
if mp.get_start_method(allow_none=True) is None:
|
15 |
+
mp.set_start_method('spawn')
|
16 |
+
if launcher == 'pytorch':
|
17 |
+
_init_dist_pytorch(backend, **kwargs)
|
18 |
+
elif launcher == 'slurm':
|
19 |
+
_init_dist_slurm(backend, **kwargs)
|
20 |
+
else:
|
21 |
+
raise ValueError(f'Invalid launcher type: {launcher}')
|
22 |
+
|
23 |
+
|
24 |
+
def _init_dist_pytorch(backend, **kwargs):
|
25 |
+
rank = int(os.environ['RANK'])
|
26 |
+
num_gpus = torch.cuda.device_count()
|
27 |
+
torch.cuda.set_device(rank % num_gpus)
|
28 |
+
dist.init_process_group(backend=backend, **kwargs)
|
29 |
+
|
30 |
+
|
31 |
+
def _init_dist_slurm(backend, port=None):
|
32 |
+
"""Initialize slurm distributed training environment.
|
33 |
+
If argument ``port`` is not specified, then the master port will be system
|
34 |
+
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
|
35 |
+
environment variable, then a default port ``29500`` will be used.
|
36 |
+
Args:
|
37 |
+
backend (str): Backend of torch.distributed.
|
38 |
+
port (int, optional): Master port. Defaults to None.
|
39 |
+
"""
|
40 |
+
proc_id = int(os.environ['SLURM_PROCID'])
|
41 |
+
ntasks = int(os.environ['SLURM_NTASKS'])
|
42 |
+
node_list = os.environ['SLURM_NODELIST']
|
43 |
+
num_gpus = torch.cuda.device_count()
|
44 |
+
torch.cuda.set_device(proc_id % num_gpus)
|
45 |
+
addr = subprocess.getoutput(
|
46 |
+
f'scontrol show hostname {node_list} | head -n1')
|
47 |
+
# specify master port
|
48 |
+
if port is not None:
|
49 |
+
os.environ['MASTER_PORT'] = str(port)
|
50 |
+
elif 'MASTER_PORT' in os.environ:
|
51 |
+
pass # use MASTER_PORT in the environment variable
|
52 |
+
else:
|
53 |
+
# 29500 is torch.distributed default port
|
54 |
+
os.environ['MASTER_PORT'] = '29500'
|
55 |
+
os.environ['MASTER_ADDR'] = addr
|
56 |
+
os.environ['WORLD_SIZE'] = str(ntasks)
|
57 |
+
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
58 |
+
os.environ['RANK'] = str(proc_id)
|
59 |
+
dist.init_process_group(backend=backend)
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
# ----------------------------------
|
64 |
+
# get rank and world_size
|
65 |
+
# ----------------------------------
|
66 |
+
def get_dist_info():
|
67 |
+
if dist.is_available():
|
68 |
+
initialized = dist.is_initialized()
|
69 |
+
else:
|
70 |
+
initialized = False
|
71 |
+
if initialized:
|
72 |
+
rank = dist.get_rank()
|
73 |
+
world_size = dist.get_world_size()
|
74 |
+
else:
|
75 |
+
rank = 0
|
76 |
+
world_size = 1
|
77 |
+
return rank, world_size
|
78 |
+
|
79 |
+
|
80 |
+
def get_rank():
|
81 |
+
if not dist.is_available():
|
82 |
+
return 0
|
83 |
+
|
84 |
+
if not dist.is_initialized():
|
85 |
+
return 0
|
86 |
+
|
87 |
+
return dist.get_rank()
|
88 |
+
|
89 |
+
|
90 |
+
def get_world_size():
|
91 |
+
if not dist.is_available():
|
92 |
+
return 1
|
93 |
+
|
94 |
+
if not dist.is_initialized():
|
95 |
+
return 1
|
96 |
+
|
97 |
+
return dist.get_world_size()
|
98 |
+
|
99 |
+
|
100 |
+
def master_only(func):
|
101 |
+
|
102 |
+
@functools.wraps(func)
|
103 |
+
def wrapper(*args, **kwargs):
|
104 |
+
rank, _ = get_dist_info()
|
105 |
+
if rank == 0:
|
106 |
+
return func(*args, **kwargs)
|
107 |
+
|
108 |
+
return wrapper
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
# ----------------------------------
|
116 |
+
# operation across ranks
|
117 |
+
# ----------------------------------
|
118 |
+
def reduce_sum(tensor):
|
119 |
+
if not dist.is_available():
|
120 |
+
return tensor
|
121 |
+
|
122 |
+
if not dist.is_initialized():
|
123 |
+
return tensor
|
124 |
+
|
125 |
+
tensor = tensor.clone()
|
126 |
+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
127 |
+
|
128 |
+
return tensor
|
129 |
+
|
130 |
+
|
131 |
+
def gather_grad(params):
|
132 |
+
world_size = get_world_size()
|
133 |
+
|
134 |
+
if world_size == 1:
|
135 |
+
return
|
136 |
+
|
137 |
+
for param in params:
|
138 |
+
if param.grad is not None:
|
139 |
+
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
|
140 |
+
param.grad.data.div_(world_size)
|
141 |
+
|
142 |
+
|
143 |
+
def all_gather(data):
|
144 |
+
world_size = get_world_size()
|
145 |
+
|
146 |
+
if world_size == 1:
|
147 |
+
return [data]
|
148 |
+
|
149 |
+
buffer = pickle.dumps(data)
|
150 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
151 |
+
tensor = torch.ByteTensor(storage).to('cuda')
|
152 |
+
|
153 |
+
local_size = torch.IntTensor([tensor.numel()]).to('cuda')
|
154 |
+
size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
|
155 |
+
dist.all_gather(size_list, local_size)
|
156 |
+
size_list = [int(size.item()) for size in size_list]
|
157 |
+
max_size = max(size_list)
|
158 |
+
|
159 |
+
tensor_list = []
|
160 |
+
for _ in size_list:
|
161 |
+
tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
|
162 |
+
|
163 |
+
if local_size != max_size:
|
164 |
+
padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
|
165 |
+
tensor = torch.cat((tensor, padding), 0)
|
166 |
+
|
167 |
+
dist.all_gather(tensor_list, tensor)
|
168 |
+
|
169 |
+
data_list = []
|
170 |
+
|
171 |
+
for size, tensor in zip(size_list, tensor_list):
|
172 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
173 |
+
data_list.append(pickle.loads(buffer))
|
174 |
+
|
175 |
+
return data_list
|
176 |
+
|
177 |
+
|
178 |
+
def reduce_loss_dict(loss_dict):
|
179 |
+
world_size = get_world_size()
|
180 |
+
|
181 |
+
if world_size < 2:
|
182 |
+
return loss_dict
|
183 |
+
|
184 |
+
with torch.no_grad():
|
185 |
+
keys = []
|
186 |
+
losses = []
|
187 |
+
|
188 |
+
for k in sorted(loss_dict.keys()):
|
189 |
+
keys.append(k)
|
190 |
+
losses.append(loss_dict[k])
|
191 |
+
|
192 |
+
losses = torch.stack(losses, 0)
|
193 |
+
dist.reduce(losses, dst=0)
|
194 |
+
|
195 |
+
if dist.get_rank() == 0:
|
196 |
+
losses /= world_size
|
197 |
+
|
198 |
+
reduced_losses = {k: v for k, v in zip(keys, losses)}
|
199 |
+
|
200 |
+
return reduced_losses
|
201 |
+
|
core/data/deg_kair_utils/utils_googledownload.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import requests
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
|
6 |
+
'''
|
7 |
+
borrowed from
|
8 |
+
https://github.com/xinntao/BasicSR/blob/28883e15eedc3381d23235ff3cf7c454c4be87e6/basicsr/utils/download_util.py
|
9 |
+
'''
|
10 |
+
|
11 |
+
|
12 |
+
def sizeof_fmt(size, suffix='B'):
|
13 |
+
"""Get human readable file size.
|
14 |
+
Args:
|
15 |
+
size (int): File size.
|
16 |
+
suffix (str): Suffix. Default: 'B'.
|
17 |
+
Return:
|
18 |
+
str: Formated file siz.
|
19 |
+
"""
|
20 |
+
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
|
21 |
+
if abs(size) < 1024.0:
|
22 |
+
return f'{size:3.1f} {unit}{suffix}'
|
23 |
+
size /= 1024.0
|
24 |
+
return f'{size:3.1f} Y{suffix}'
|
25 |
+
|
26 |
+
|
27 |
+
def download_file_from_google_drive(file_id, save_path):
|
28 |
+
"""Download files from google drive.
|
29 |
+
Ref:
|
30 |
+
https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
|
31 |
+
Args:
|
32 |
+
file_id (str): File id.
|
33 |
+
save_path (str): Save path.
|
34 |
+
"""
|
35 |
+
|
36 |
+
session = requests.Session()
|
37 |
+
URL = 'https://docs.google.com/uc?export=download'
|
38 |
+
params = {'id': file_id}
|
39 |
+
|
40 |
+
response = session.get(URL, params=params, stream=True)
|
41 |
+
token = get_confirm_token(response)
|
42 |
+
if token:
|
43 |
+
params['confirm'] = token
|
44 |
+
response = session.get(URL, params=params, stream=True)
|
45 |
+
|
46 |
+
# get file size
|
47 |
+
response_file_size = session.get(
|
48 |
+
URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
|
49 |
+
if 'Content-Range' in response_file_size.headers:
|
50 |
+
file_size = int(
|
51 |
+
response_file_size.headers['Content-Range'].split('/')[1])
|
52 |
+
else:
|
53 |
+
file_size = None
|
54 |
+
|
55 |
+
save_response_content(response, save_path, file_size)
|
56 |
+
|
57 |
+
|
58 |
+
def get_confirm_token(response):
|
59 |
+
for key, value in response.cookies.items():
|
60 |
+
if key.startswith('download_warning'):
|
61 |
+
return value
|
62 |
+
return None
|
63 |
+
|
64 |
+
|
65 |
+
def save_response_content(response,
|
66 |
+
destination,
|
67 |
+
file_size=None,
|
68 |
+
chunk_size=32768):
|
69 |
+
if file_size is not None:
|
70 |
+
pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
|
71 |
+
|
72 |
+
readable_file_size = sizeof_fmt(file_size)
|
73 |
+
else:
|
74 |
+
pbar = None
|
75 |
+
|
76 |
+
with open(destination, 'wb') as f:
|
77 |
+
downloaded_size = 0
|
78 |
+
for chunk in response.iter_content(chunk_size):
|
79 |
+
downloaded_size += chunk_size
|
80 |
+
if pbar is not None:
|
81 |
+
pbar.update(1)
|
82 |
+
pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
|
83 |
+
f'/ {readable_file_size}')
|
84 |
+
if chunk: # filter out keep-alive new chunks
|
85 |
+
f.write(chunk)
|
86 |
+
if pbar is not None:
|
87 |
+
pbar.close()
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
file_id = '1WNULM1e8gRNvsngVscsQ8tpaOqJ4mYtv'
|
92 |
+
save_path = 'BSRGAN.pth'
|
93 |
+
download_file_from_google_drive(file_id, save_path)
|
core/data/deg_kair_utils/utils_image.py
ADDED
@@ -0,0 +1,1016 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import cv2
|
7 |
+
from torchvision.utils import make_grid
|
8 |
+
from datetime import datetime
|
9 |
+
# import torchvision.transforms as transforms
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from mpl_toolkits.mplot3d import Axes3D
|
12 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
13 |
+
|
14 |
+
|
15 |
+
'''
|
16 |
+
# --------------------------------------------
|
17 |
+
# Kai Zhang (github: https://github.com/cszn)
|
18 |
+
# 03/Mar/2019
|
19 |
+
# --------------------------------------------
|
20 |
+
# https://github.com/twhui/SRGAN-pyTorch
|
21 |
+
# https://github.com/xinntao/BasicSR
|
22 |
+
# --------------------------------------------
|
23 |
+
'''
|
24 |
+
|
25 |
+
|
26 |
+
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
|
27 |
+
|
28 |
+
|
29 |
+
def is_image_file(filename):
|
30 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
31 |
+
|
32 |
+
|
33 |
+
def get_timestamp():
|
34 |
+
return datetime.now().strftime('%y%m%d-%H%M%S')
|
35 |
+
|
36 |
+
|
37 |
+
def imshow(x, title=None, cbar=False, figsize=None):
|
38 |
+
plt.figure(figsize=figsize)
|
39 |
+
plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
|
40 |
+
if title:
|
41 |
+
plt.title(title)
|
42 |
+
if cbar:
|
43 |
+
plt.colorbar()
|
44 |
+
plt.show()
|
45 |
+
|
46 |
+
|
47 |
+
def surf(Z, cmap='rainbow', figsize=None):
|
48 |
+
plt.figure(figsize=figsize)
|
49 |
+
ax3 = plt.axes(projection='3d')
|
50 |
+
|
51 |
+
w, h = Z.shape[:2]
|
52 |
+
xx = np.arange(0,w,1)
|
53 |
+
yy = np.arange(0,h,1)
|
54 |
+
X, Y = np.meshgrid(xx, yy)
|
55 |
+
ax3.plot_surface(X,Y,Z,cmap=cmap)
|
56 |
+
#ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
|
57 |
+
plt.show()
|
58 |
+
|
59 |
+
|
60 |
+
'''
|
61 |
+
# --------------------------------------------
|
62 |
+
# get image pathes
|
63 |
+
# --------------------------------------------
|
64 |
+
'''
|
65 |
+
|
66 |
+
|
67 |
+
def get_image_paths(dataroot):
|
68 |
+
paths = None # return None if dataroot is None
|
69 |
+
if isinstance(dataroot, str):
|
70 |
+
paths = sorted(_get_paths_from_images(dataroot))
|
71 |
+
elif isinstance(dataroot, list):
|
72 |
+
paths = []
|
73 |
+
for i in dataroot:
|
74 |
+
paths += sorted(_get_paths_from_images(i))
|
75 |
+
return paths
|
76 |
+
|
77 |
+
|
78 |
+
def _get_paths_from_images(path):
|
79 |
+
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
|
80 |
+
images = []
|
81 |
+
for dirpath, _, fnames in sorted(os.walk(path)):
|
82 |
+
for fname in sorted(fnames):
|
83 |
+
if is_image_file(fname):
|
84 |
+
img_path = os.path.join(dirpath, fname)
|
85 |
+
images.append(img_path)
|
86 |
+
assert images, '{:s} has no valid image file'.format(path)
|
87 |
+
return images
|
88 |
+
|
89 |
+
|
90 |
+
'''
|
91 |
+
# --------------------------------------------
|
92 |
+
# split large images into small images
|
93 |
+
# --------------------------------------------
|
94 |
+
'''
|
95 |
+
|
96 |
+
|
97 |
+
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
|
98 |
+
w, h = img.shape[:2]
|
99 |
+
patches = []
|
100 |
+
if w > p_max and h > p_max:
|
101 |
+
w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
|
102 |
+
h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
|
103 |
+
w1.append(w-p_size)
|
104 |
+
h1.append(h-p_size)
|
105 |
+
# print(w1)
|
106 |
+
# print(h1)
|
107 |
+
for i in w1:
|
108 |
+
for j in h1:
|
109 |
+
patches.append(img[i:i+p_size, j:j+p_size,:])
|
110 |
+
else:
|
111 |
+
patches.append(img)
|
112 |
+
|
113 |
+
return patches
|
114 |
+
|
115 |
+
|
116 |
+
def imssave(imgs, img_path):
|
117 |
+
"""
|
118 |
+
imgs: list, N images of size WxHxC
|
119 |
+
"""
|
120 |
+
img_name, ext = os.path.splitext(os.path.basename(img_path))
|
121 |
+
for i, img in enumerate(imgs):
|
122 |
+
if img.ndim == 3:
|
123 |
+
img = img[:, :, [2, 1, 0]]
|
124 |
+
new_path = os.path.join(os.path.dirname(img_path), img_name+str('_{:04d}'.format(i))+'.png')
|
125 |
+
cv2.imwrite(new_path, img)
|
126 |
+
|
127 |
+
|
128 |
+
def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=512, p_overlap=96, p_max=800):
|
129 |
+
"""
|
130 |
+
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
|
131 |
+
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
|
132 |
+
will be splitted.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
original_dataroot:
|
136 |
+
taget_dataroot:
|
137 |
+
p_size: size of small images
|
138 |
+
p_overlap: patch size in training is a good choice
|
139 |
+
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
|
140 |
+
"""
|
141 |
+
paths = get_image_paths(original_dataroot)
|
142 |
+
for img_path in paths:
|
143 |
+
# img_name, ext = os.path.splitext(os.path.basename(img_path))
|
144 |
+
img = imread_uint(img_path, n_channels=n_channels)
|
145 |
+
patches = patches_from_image(img, p_size, p_overlap, p_max)
|
146 |
+
imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path)))
|
147 |
+
#if original_dataroot == taget_dataroot:
|
148 |
+
#del img_path
|
149 |
+
|
150 |
+
'''
|
151 |
+
# --------------------------------------------
|
152 |
+
# makedir
|
153 |
+
# --------------------------------------------
|
154 |
+
'''
|
155 |
+
|
156 |
+
|
157 |
+
def mkdir(path):
|
158 |
+
if not os.path.exists(path):
|
159 |
+
os.makedirs(path)
|
160 |
+
|
161 |
+
|
162 |
+
def mkdirs(paths):
|
163 |
+
if isinstance(paths, str):
|
164 |
+
mkdir(paths)
|
165 |
+
else:
|
166 |
+
for path in paths:
|
167 |
+
mkdir(path)
|
168 |
+
|
169 |
+
|
170 |
+
def mkdir_and_rename(path):
|
171 |
+
if os.path.exists(path):
|
172 |
+
new_name = path + '_archived_' + get_timestamp()
|
173 |
+
print('Path already exists. Rename it to [{:s}]'.format(new_name))
|
174 |
+
os.rename(path, new_name)
|
175 |
+
os.makedirs(path)
|
176 |
+
|
177 |
+
|
178 |
+
'''
|
179 |
+
# --------------------------------------------
|
180 |
+
# read image from path
|
181 |
+
# opencv is fast, but read BGR numpy image
|
182 |
+
# --------------------------------------------
|
183 |
+
'''
|
184 |
+
|
185 |
+
|
186 |
+
# --------------------------------------------
|
187 |
+
# get uint8 image of size HxWxn_channles (RGB)
|
188 |
+
# --------------------------------------------
|
189 |
+
def imread_uint(path, n_channels=3):
|
190 |
+
# input: path
|
191 |
+
# output: HxWx3(RGB or GGG), or HxWx1 (G)
|
192 |
+
if n_channels == 1:
|
193 |
+
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
|
194 |
+
img = np.expand_dims(img, axis=2) # HxWx1
|
195 |
+
elif n_channels == 3:
|
196 |
+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
|
197 |
+
if img.ndim == 2:
|
198 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
|
199 |
+
else:
|
200 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
|
201 |
+
return img
|
202 |
+
|
203 |
+
|
204 |
+
# --------------------------------------------
|
205 |
+
# matlab's imwrite
|
206 |
+
# --------------------------------------------
|
207 |
+
def imsave(img, img_path):
|
208 |
+
img = np.squeeze(img)
|
209 |
+
if img.ndim == 3:
|
210 |
+
img = img[:, :, [2, 1, 0]]
|
211 |
+
cv2.imwrite(img_path, img)
|
212 |
+
|
213 |
+
def imwrite(img, img_path):
|
214 |
+
img = np.squeeze(img)
|
215 |
+
if img.ndim == 3:
|
216 |
+
img = img[:, :, [2, 1, 0]]
|
217 |
+
cv2.imwrite(img_path, img)
|
218 |
+
|
219 |
+
|
220 |
+
|
221 |
+
# --------------------------------------------
|
222 |
+
# get single image of size HxWxn_channles (BGR)
|
223 |
+
# --------------------------------------------
|
224 |
+
def read_img(path):
|
225 |
+
# read image by cv2
|
226 |
+
# return: Numpy float32, HWC, BGR, [0,1]
|
227 |
+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
|
228 |
+
img = img.astype(np.float32) / 255.
|
229 |
+
if img.ndim == 2:
|
230 |
+
img = np.expand_dims(img, axis=2)
|
231 |
+
# some images have 4 channels
|
232 |
+
if img.shape[2] > 3:
|
233 |
+
img = img[:, :, :3]
|
234 |
+
return img
|
235 |
+
|
236 |
+
|
237 |
+
'''
|
238 |
+
# --------------------------------------------
|
239 |
+
# image format conversion
|
240 |
+
# --------------------------------------------
|
241 |
+
# numpy(single) <---> numpy(uint)
|
242 |
+
# numpy(single) <---> tensor
|
243 |
+
# numpy(uint) <---> tensor
|
244 |
+
# --------------------------------------------
|
245 |
+
'''
|
246 |
+
|
247 |
+
|
248 |
+
# --------------------------------------------
|
249 |
+
# numpy(single) [0, 1] <---> numpy(uint)
|
250 |
+
# --------------------------------------------
|
251 |
+
|
252 |
+
|
253 |
+
def uint2single(img):
|
254 |
+
|
255 |
+
return np.float32(img/255.)
|
256 |
+
|
257 |
+
|
258 |
+
def single2uint(img):
|
259 |
+
|
260 |
+
return np.uint8((img.clip(0, 1)*255.).round())
|
261 |
+
|
262 |
+
|
263 |
+
def uint162single(img):
|
264 |
+
|
265 |
+
return np.float32(img/65535.)
|
266 |
+
|
267 |
+
|
268 |
+
def single2uint16(img):
|
269 |
+
|
270 |
+
return np.uint16((img.clip(0, 1)*65535.).round())
|
271 |
+
|
272 |
+
|
273 |
+
# --------------------------------------------
|
274 |
+
# numpy(uint) (HxWxC or HxW) <---> tensor
|
275 |
+
# --------------------------------------------
|
276 |
+
|
277 |
+
|
278 |
+
# convert uint to 4-dimensional torch tensor
|
279 |
+
def uint2tensor4(img):
|
280 |
+
if img.ndim == 2:
|
281 |
+
img = np.expand_dims(img, axis=2)
|
282 |
+
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
|
283 |
+
|
284 |
+
|
285 |
+
# convert uint to 3-dimensional torch tensor
|
286 |
+
def uint2tensor3(img):
|
287 |
+
if img.ndim == 2:
|
288 |
+
img = np.expand_dims(img, axis=2)
|
289 |
+
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
|
290 |
+
|
291 |
+
|
292 |
+
# convert 2/3/4-dimensional torch tensor to uint
|
293 |
+
def tensor2uint(img):
|
294 |
+
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
|
295 |
+
if img.ndim == 3:
|
296 |
+
img = np.transpose(img, (1, 2, 0))
|
297 |
+
return np.uint8((img*255.0).round())
|
298 |
+
|
299 |
+
|
300 |
+
# --------------------------------------------
|
301 |
+
# numpy(single) (HxWxC) <---> tensor
|
302 |
+
# --------------------------------------------
|
303 |
+
|
304 |
+
|
305 |
+
# convert single (HxWxC) to 3-dimensional torch tensor
|
306 |
+
def single2tensor3(img):
|
307 |
+
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
|
308 |
+
|
309 |
+
|
310 |
+
# convert single (HxWxC) to 4-dimensional torch tensor
|
311 |
+
def single2tensor4(img):
|
312 |
+
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
|
313 |
+
|
314 |
+
|
315 |
+
# convert torch tensor to single
|
316 |
+
def tensor2single(img):
|
317 |
+
img = img.data.squeeze().float().cpu().numpy()
|
318 |
+
if img.ndim == 3:
|
319 |
+
img = np.transpose(img, (1, 2, 0))
|
320 |
+
|
321 |
+
return img
|
322 |
+
|
323 |
+
# convert torch tensor to single
|
324 |
+
def tensor2single3(img):
|
325 |
+
img = img.data.squeeze().float().cpu().numpy()
|
326 |
+
if img.ndim == 3:
|
327 |
+
img = np.transpose(img, (1, 2, 0))
|
328 |
+
elif img.ndim == 2:
|
329 |
+
img = np.expand_dims(img, axis=2)
|
330 |
+
return img
|
331 |
+
|
332 |
+
|
333 |
+
def single2tensor5(img):
|
334 |
+
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
|
335 |
+
|
336 |
+
|
337 |
+
def single32tensor5(img):
|
338 |
+
return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
|
339 |
+
|
340 |
+
|
341 |
+
def single42tensor4(img):
|
342 |
+
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
|
343 |
+
|
344 |
+
|
345 |
+
# from skimage.io import imread, imsave
|
346 |
+
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
|
347 |
+
'''
|
348 |
+
Converts a torch Tensor into an image Numpy array of BGR channel order
|
349 |
+
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
|
350 |
+
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
|
351 |
+
'''
|
352 |
+
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
|
353 |
+
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
|
354 |
+
n_dim = tensor.dim()
|
355 |
+
if n_dim == 4:
|
356 |
+
n_img = len(tensor)
|
357 |
+
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
|
358 |
+
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
359 |
+
elif n_dim == 3:
|
360 |
+
img_np = tensor.numpy()
|
361 |
+
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
362 |
+
elif n_dim == 2:
|
363 |
+
img_np = tensor.numpy()
|
364 |
+
else:
|
365 |
+
raise TypeError(
|
366 |
+
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
|
367 |
+
if out_type == np.uint8:
|
368 |
+
img_np = (img_np * 255.0).round()
|
369 |
+
# Important. Unlike matlab, numpy.uint8() WILL NOT round by default.
|
370 |
+
return img_np.astype(out_type)
|
371 |
+
|
372 |
+
|
373 |
+
'''
|
374 |
+
# --------------------------------------------
|
375 |
+
# Augmentation, flipe and/or rotate
|
376 |
+
# --------------------------------------------
|
377 |
+
# The following two are enough.
|
378 |
+
# (1) augmet_img: numpy image of WxHxC or WxH
|
379 |
+
# (2) augment_img_tensor4: tensor image 1xCxWxH
|
380 |
+
# --------------------------------------------
|
381 |
+
'''
|
382 |
+
|
383 |
+
|
384 |
+
def augment_img(img, mode=0):
|
385 |
+
'''Kai Zhang (github: https://github.com/cszn)
|
386 |
+
'''
|
387 |
+
if mode == 0:
|
388 |
+
return img
|
389 |
+
elif mode == 1:
|
390 |
+
return np.flipud(np.rot90(img))
|
391 |
+
elif mode == 2:
|
392 |
+
return np.flipud(img)
|
393 |
+
elif mode == 3:
|
394 |
+
return np.rot90(img, k=3)
|
395 |
+
elif mode == 4:
|
396 |
+
return np.flipud(np.rot90(img, k=2))
|
397 |
+
elif mode == 5:
|
398 |
+
return np.rot90(img)
|
399 |
+
elif mode == 6:
|
400 |
+
return np.rot90(img, k=2)
|
401 |
+
elif mode == 7:
|
402 |
+
return np.flipud(np.rot90(img, k=3))
|
403 |
+
|
404 |
+
|
405 |
+
def augment_img_tensor4(img, mode=0):
|
406 |
+
'''Kai Zhang (github: https://github.com/cszn)
|
407 |
+
'''
|
408 |
+
if mode == 0:
|
409 |
+
return img
|
410 |
+
elif mode == 1:
|
411 |
+
return img.rot90(1, [2, 3]).flip([2])
|
412 |
+
elif mode == 2:
|
413 |
+
return img.flip([2])
|
414 |
+
elif mode == 3:
|
415 |
+
return img.rot90(3, [2, 3])
|
416 |
+
elif mode == 4:
|
417 |
+
return img.rot90(2, [2, 3]).flip([2])
|
418 |
+
elif mode == 5:
|
419 |
+
return img.rot90(1, [2, 3])
|
420 |
+
elif mode == 6:
|
421 |
+
return img.rot90(2, [2, 3])
|
422 |
+
elif mode == 7:
|
423 |
+
return img.rot90(3, [2, 3]).flip([2])
|
424 |
+
|
425 |
+
|
426 |
+
def augment_img_tensor(img, mode=0):
|
427 |
+
'''Kai Zhang (github: https://github.com/cszn)
|
428 |
+
'''
|
429 |
+
img_size = img.size()
|
430 |
+
img_np = img.data.cpu().numpy()
|
431 |
+
if len(img_size) == 3:
|
432 |
+
img_np = np.transpose(img_np, (1, 2, 0))
|
433 |
+
elif len(img_size) == 4:
|
434 |
+
img_np = np.transpose(img_np, (2, 3, 1, 0))
|
435 |
+
img_np = augment_img(img_np, mode=mode)
|
436 |
+
img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
|
437 |
+
if len(img_size) == 3:
|
438 |
+
img_tensor = img_tensor.permute(2, 0, 1)
|
439 |
+
elif len(img_size) == 4:
|
440 |
+
img_tensor = img_tensor.permute(3, 2, 0, 1)
|
441 |
+
|
442 |
+
return img_tensor.type_as(img)
|
443 |
+
|
444 |
+
|
445 |
+
def augment_img_np3(img, mode=0):
|
446 |
+
if mode == 0:
|
447 |
+
return img
|
448 |
+
elif mode == 1:
|
449 |
+
return img.transpose(1, 0, 2)
|
450 |
+
elif mode == 2:
|
451 |
+
return img[::-1, :, :]
|
452 |
+
elif mode == 3:
|
453 |
+
img = img[::-1, :, :]
|
454 |
+
img = img.transpose(1, 0, 2)
|
455 |
+
return img
|
456 |
+
elif mode == 4:
|
457 |
+
return img[:, ::-1, :]
|
458 |
+
elif mode == 5:
|
459 |
+
img = img[:, ::-1, :]
|
460 |
+
img = img.transpose(1, 0, 2)
|
461 |
+
return img
|
462 |
+
elif mode == 6:
|
463 |
+
img = img[:, ::-1, :]
|
464 |
+
img = img[::-1, :, :]
|
465 |
+
return img
|
466 |
+
elif mode == 7:
|
467 |
+
img = img[:, ::-1, :]
|
468 |
+
img = img[::-1, :, :]
|
469 |
+
img = img.transpose(1, 0, 2)
|
470 |
+
return img
|
471 |
+
|
472 |
+
|
473 |
+
def augment_imgs(img_list, hflip=True, rot=True):
|
474 |
+
# horizontal flip OR rotate
|
475 |
+
hflip = hflip and random.random() < 0.5
|
476 |
+
vflip = rot and random.random() < 0.5
|
477 |
+
rot90 = rot and random.random() < 0.5
|
478 |
+
|
479 |
+
def _augment(img):
|
480 |
+
if hflip:
|
481 |
+
img = img[:, ::-1, :]
|
482 |
+
if vflip:
|
483 |
+
img = img[::-1, :, :]
|
484 |
+
if rot90:
|
485 |
+
img = img.transpose(1, 0, 2)
|
486 |
+
return img
|
487 |
+
|
488 |
+
return [_augment(img) for img in img_list]
|
489 |
+
|
490 |
+
|
491 |
+
'''
|
492 |
+
# --------------------------------------------
|
493 |
+
# modcrop and shave
|
494 |
+
# --------------------------------------------
|
495 |
+
'''
|
496 |
+
|
497 |
+
|
498 |
+
def modcrop(img_in, scale):
|
499 |
+
# img_in: Numpy, HWC or HW
|
500 |
+
img = np.copy(img_in)
|
501 |
+
if img.ndim == 2:
|
502 |
+
H, W = img.shape
|
503 |
+
H_r, W_r = H % scale, W % scale
|
504 |
+
img = img[:H - H_r, :W - W_r]
|
505 |
+
elif img.ndim == 3:
|
506 |
+
H, W, C = img.shape
|
507 |
+
H_r, W_r = H % scale, W % scale
|
508 |
+
img = img[:H - H_r, :W - W_r, :]
|
509 |
+
else:
|
510 |
+
raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
|
511 |
+
return img
|
512 |
+
|
513 |
+
|
514 |
+
def shave(img_in, border=0):
|
515 |
+
# img_in: Numpy, HWC or HW
|
516 |
+
img = np.copy(img_in)
|
517 |
+
h, w = img.shape[:2]
|
518 |
+
img = img[border:h-border, border:w-border]
|
519 |
+
return img
|
520 |
+
|
521 |
+
|
522 |
+
'''
|
523 |
+
# --------------------------------------------
|
524 |
+
# image processing process on numpy image
|
525 |
+
# channel_convert(in_c, tar_type, img_list):
|
526 |
+
# rgb2ycbcr(img, only_y=True):
|
527 |
+
# bgr2ycbcr(img, only_y=True):
|
528 |
+
# ycbcr2rgb(img):
|
529 |
+
# --------------------------------------------
|
530 |
+
'''
|
531 |
+
|
532 |
+
|
533 |
+
def rgb2ycbcr(img, only_y=True):
|
534 |
+
'''same as matlab rgb2ycbcr
|
535 |
+
only_y: only return Y channel
|
536 |
+
Input:
|
537 |
+
uint8, [0, 255]
|
538 |
+
float, [0, 1]
|
539 |
+
'''
|
540 |
+
in_img_type = img.dtype
|
541 |
+
img.astype(np.float32)
|
542 |
+
if in_img_type != np.uint8:
|
543 |
+
img *= 255.
|
544 |
+
# convert
|
545 |
+
if only_y:
|
546 |
+
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
547 |
+
else:
|
548 |
+
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
|
549 |
+
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
|
550 |
+
if in_img_type == np.uint8:
|
551 |
+
rlt = rlt.round()
|
552 |
+
else:
|
553 |
+
rlt /= 255.
|
554 |
+
return rlt.astype(in_img_type)
|
555 |
+
|
556 |
+
|
557 |
+
def ycbcr2rgb(img):
|
558 |
+
'''same as matlab ycbcr2rgb
|
559 |
+
Input:
|
560 |
+
uint8, [0, 255]
|
561 |
+
float, [0, 1]
|
562 |
+
'''
|
563 |
+
in_img_type = img.dtype
|
564 |
+
img.astype(np.float32)
|
565 |
+
if in_img_type != np.uint8:
|
566 |
+
img *= 255.
|
567 |
+
# convert
|
568 |
+
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
569 |
+
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
|
570 |
+
rlt = np.clip(rlt, 0, 255)
|
571 |
+
if in_img_type == np.uint8:
|
572 |
+
rlt = rlt.round()
|
573 |
+
else:
|
574 |
+
rlt /= 255.
|
575 |
+
return rlt.astype(in_img_type)
|
576 |
+
|
577 |
+
|
578 |
+
def bgr2ycbcr(img, only_y=True):
|
579 |
+
'''bgr version of rgb2ycbcr
|
580 |
+
only_y: only return Y channel
|
581 |
+
Input:
|
582 |
+
uint8, [0, 255]
|
583 |
+
float, [0, 1]
|
584 |
+
'''
|
585 |
+
in_img_type = img.dtype
|
586 |
+
img.astype(np.float32)
|
587 |
+
if in_img_type != np.uint8:
|
588 |
+
img *= 255.
|
589 |
+
# convert
|
590 |
+
if only_y:
|
591 |
+
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
592 |
+
else:
|
593 |
+
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
594 |
+
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
|
595 |
+
if in_img_type == np.uint8:
|
596 |
+
rlt = rlt.round()
|
597 |
+
else:
|
598 |
+
rlt /= 255.
|
599 |
+
return rlt.astype(in_img_type)
|
600 |
+
|
601 |
+
|
602 |
+
def channel_convert(in_c, tar_type, img_list):
|
603 |
+
# conversion among BGR, gray and y
|
604 |
+
if in_c == 3 and tar_type == 'gray': # BGR to gray
|
605 |
+
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
|
606 |
+
return [np.expand_dims(img, axis=2) for img in gray_list]
|
607 |
+
elif in_c == 3 and tar_type == 'y': # BGR to y
|
608 |
+
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
|
609 |
+
return [np.expand_dims(img, axis=2) for img in y_list]
|
610 |
+
elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
|
611 |
+
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
|
612 |
+
else:
|
613 |
+
return img_list
|
614 |
+
|
615 |
+
|
616 |
+
'''
|
617 |
+
# --------------------------------------------
|
618 |
+
# metric, PSNR, SSIM and PSNRB
|
619 |
+
# --------------------------------------------
|
620 |
+
'''
|
621 |
+
|
622 |
+
|
623 |
+
# --------------------------------------------
|
624 |
+
# PSNR
|
625 |
+
# --------------------------------------------
|
626 |
+
def calculate_psnr(img1, img2, border=0):
|
627 |
+
# img1 and img2 have range [0, 255]
|
628 |
+
#img1 = img1.squeeze()
|
629 |
+
#img2 = img2.squeeze()
|
630 |
+
if not img1.shape == img2.shape:
|
631 |
+
raise ValueError('Input images must have the same dimensions.')
|
632 |
+
h, w = img1.shape[:2]
|
633 |
+
img1 = img1[border:h-border, border:w-border]
|
634 |
+
img2 = img2[border:h-border, border:w-border]
|
635 |
+
|
636 |
+
img1 = img1.astype(np.float64)
|
637 |
+
img2 = img2.astype(np.float64)
|
638 |
+
mse = np.mean((img1 - img2)**2)
|
639 |
+
if mse == 0:
|
640 |
+
return float('inf')
|
641 |
+
return 20 * math.log10(255.0 / math.sqrt(mse))
|
642 |
+
|
643 |
+
|
644 |
+
# --------------------------------------------
|
645 |
+
# SSIM
|
646 |
+
# --------------------------------------------
|
647 |
+
def calculate_ssim(img1, img2, border=0):
|
648 |
+
'''calculate SSIM
|
649 |
+
the same outputs as MATLAB's
|
650 |
+
img1, img2: [0, 255]
|
651 |
+
'''
|
652 |
+
#img1 = img1.squeeze()
|
653 |
+
#img2 = img2.squeeze()
|
654 |
+
if not img1.shape == img2.shape:
|
655 |
+
raise ValueError('Input images must have the same dimensions.')
|
656 |
+
h, w = img1.shape[:2]
|
657 |
+
img1 = img1[border:h-border, border:w-border]
|
658 |
+
img2 = img2[border:h-border, border:w-border]
|
659 |
+
|
660 |
+
if img1.ndim == 2:
|
661 |
+
return ssim(img1, img2)
|
662 |
+
elif img1.ndim == 3:
|
663 |
+
if img1.shape[2] == 3:
|
664 |
+
ssims = []
|
665 |
+
for i in range(3):
|
666 |
+
ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
|
667 |
+
return np.array(ssims).mean()
|
668 |
+
elif img1.shape[2] == 1:
|
669 |
+
return ssim(np.squeeze(img1), np.squeeze(img2))
|
670 |
+
else:
|
671 |
+
raise ValueError('Wrong input image dimensions.')
|
672 |
+
|
673 |
+
|
674 |
+
def ssim(img1, img2):
|
675 |
+
C1 = (0.01 * 255)**2
|
676 |
+
C2 = (0.03 * 255)**2
|
677 |
+
|
678 |
+
img1 = img1.astype(np.float64)
|
679 |
+
img2 = img2.astype(np.float64)
|
680 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
681 |
+
window = np.outer(kernel, kernel.transpose())
|
682 |
+
|
683 |
+
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
684 |
+
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
685 |
+
mu1_sq = mu1**2
|
686 |
+
mu2_sq = mu2**2
|
687 |
+
mu1_mu2 = mu1 * mu2
|
688 |
+
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
689 |
+
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
690 |
+
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
691 |
+
|
692 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
693 |
+
(sigma1_sq + sigma2_sq + C2))
|
694 |
+
return ssim_map.mean()
|
695 |
+
|
696 |
+
|
697 |
+
def _blocking_effect_factor(im):
|
698 |
+
block_size = 8
|
699 |
+
|
700 |
+
block_horizontal_positions = torch.arange(7, im.shape[3] - 1, 8)
|
701 |
+
block_vertical_positions = torch.arange(7, im.shape[2] - 1, 8)
|
702 |
+
|
703 |
+
horizontal_block_difference = (
|
704 |
+
(im[:, :, :, block_horizontal_positions] - im[:, :, :, block_horizontal_positions + 1]) ** 2).sum(
|
705 |
+
3).sum(2).sum(1)
|
706 |
+
vertical_block_difference = (
|
707 |
+
(im[:, :, block_vertical_positions, :] - im[:, :, block_vertical_positions + 1, :]) ** 2).sum(3).sum(
|
708 |
+
2).sum(1)
|
709 |
+
|
710 |
+
nonblock_horizontal_positions = np.setdiff1d(torch.arange(0, im.shape[3] - 1), block_horizontal_positions)
|
711 |
+
nonblock_vertical_positions = np.setdiff1d(torch.arange(0, im.shape[2] - 1), block_vertical_positions)
|
712 |
+
|
713 |
+
horizontal_nonblock_difference = (
|
714 |
+
(im[:, :, :, nonblock_horizontal_positions] - im[:, :, :, nonblock_horizontal_positions + 1]) ** 2).sum(
|
715 |
+
3).sum(2).sum(1)
|
716 |
+
vertical_nonblock_difference = (
|
717 |
+
(im[:, :, nonblock_vertical_positions, :] - im[:, :, nonblock_vertical_positions + 1, :]) ** 2).sum(
|
718 |
+
3).sum(2).sum(1)
|
719 |
+
|
720 |
+
n_boundary_horiz = im.shape[2] * (im.shape[3] // block_size - 1)
|
721 |
+
n_boundary_vert = im.shape[3] * (im.shape[2] // block_size - 1)
|
722 |
+
boundary_difference = (horizontal_block_difference + vertical_block_difference) / (
|
723 |
+
n_boundary_horiz + n_boundary_vert)
|
724 |
+
|
725 |
+
n_nonboundary_horiz = im.shape[2] * (im.shape[3] - 1) - n_boundary_horiz
|
726 |
+
n_nonboundary_vert = im.shape[3] * (im.shape[2] - 1) - n_boundary_vert
|
727 |
+
nonboundary_difference = (horizontal_nonblock_difference + vertical_nonblock_difference) / (
|
728 |
+
n_nonboundary_horiz + n_nonboundary_vert)
|
729 |
+
|
730 |
+
scaler = np.log2(block_size) / np.log2(min([im.shape[2], im.shape[3]]))
|
731 |
+
bef = scaler * (boundary_difference - nonboundary_difference)
|
732 |
+
|
733 |
+
bef[boundary_difference <= nonboundary_difference] = 0
|
734 |
+
return bef
|
735 |
+
|
736 |
+
|
737 |
+
def calculate_psnrb(img1, img2, border=0):
|
738 |
+
"""Calculate PSNR-B (Peak Signal-to-Noise Ratio).
|
739 |
+
Ref: Quality assessment of deblocked images, for JPEG image deblocking evaluation
|
740 |
+
# https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
|
741 |
+
Args:
|
742 |
+
img1 (ndarray): Images with range [0, 255].
|
743 |
+
img2 (ndarray): Images with range [0, 255].
|
744 |
+
border (int): Cropped pixels in each edge of an image. These
|
745 |
+
pixels are not involved in the PSNR calculation.
|
746 |
+
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
747 |
+
Returns:
|
748 |
+
float: psnr result.
|
749 |
+
"""
|
750 |
+
|
751 |
+
if not img1.shape == img2.shape:
|
752 |
+
raise ValueError('Input images must have the same dimensions.')
|
753 |
+
|
754 |
+
if img1.ndim == 2:
|
755 |
+
img1, img2 = np.expand_dims(img1, 2), np.expand_dims(img2, 2)
|
756 |
+
|
757 |
+
h, w = img1.shape[:2]
|
758 |
+
img1 = img1[border:h-border, border:w-border]
|
759 |
+
img2 = img2[border:h-border, border:w-border]
|
760 |
+
|
761 |
+
img1 = img1.astype(np.float64)
|
762 |
+
img2 = img2.astype(np.float64)
|
763 |
+
|
764 |
+
# follow https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
|
765 |
+
img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0) / 255.
|
766 |
+
img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0) / 255.
|
767 |
+
|
768 |
+
total = 0
|
769 |
+
for c in range(img1.shape[1]):
|
770 |
+
mse = torch.nn.functional.mse_loss(img1[:, c:c + 1, :, :], img2[:, c:c + 1, :, :], reduction='none')
|
771 |
+
bef = _blocking_effect_factor(img1[:, c:c + 1, :, :])
|
772 |
+
|
773 |
+
mse = mse.view(mse.shape[0], -1).mean(1)
|
774 |
+
total += 10 * torch.log10(1 / (mse + bef))
|
775 |
+
|
776 |
+
return float(total) / img1.shape[1]
|
777 |
+
|
778 |
+
'''
|
779 |
+
# --------------------------------------------
|
780 |
+
# matlab's bicubic imresize (numpy and torch) [0, 1]
|
781 |
+
# --------------------------------------------
|
782 |
+
'''
|
783 |
+
|
784 |
+
|
785 |
+
# matlab 'imresize' function, now only support 'bicubic'
|
786 |
+
def cubic(x):
|
787 |
+
absx = torch.abs(x)
|
788 |
+
absx2 = absx**2
|
789 |
+
absx3 = absx**3
|
790 |
+
return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
|
791 |
+
(-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
|
792 |
+
|
793 |
+
|
794 |
+
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
795 |
+
if (scale < 1) and (antialiasing):
|
796 |
+
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
|
797 |
+
kernel_width = kernel_width / scale
|
798 |
+
|
799 |
+
# Output-space coordinates
|
800 |
+
x = torch.linspace(1, out_length, out_length)
|
801 |
+
|
802 |
+
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
803 |
+
# in output space maps to 0.5 in input space, and 0.5+scale in output
|
804 |
+
# space maps to 1.5 in input space.
|
805 |
+
u = x / scale + 0.5 * (1 - 1 / scale)
|
806 |
+
|
807 |
+
# What is the left-most pixel that can be involved in the computation?
|
808 |
+
left = torch.floor(u - kernel_width / 2)
|
809 |
+
|
810 |
+
# What is the maximum number of pixels that can be involved in the
|
811 |
+
# computation? Note: it's OK to use an extra pixel here; if the
|
812 |
+
# corresponding weights are all zero, it will be eliminated at the end
|
813 |
+
# of this function.
|
814 |
+
P = math.ceil(kernel_width) + 2
|
815 |
+
|
816 |
+
# The indices of the input pixels involved in computing the k-th output
|
817 |
+
# pixel are in row k of the indices matrix.
|
818 |
+
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
|
819 |
+
1, P).expand(out_length, P)
|
820 |
+
|
821 |
+
# The weights used to compute the k-th output pixel are in row k of the
|
822 |
+
# weights matrix.
|
823 |
+
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
|
824 |
+
# apply cubic kernel
|
825 |
+
if (scale < 1) and (antialiasing):
|
826 |
+
weights = scale * cubic(distance_to_center * scale)
|
827 |
+
else:
|
828 |
+
weights = cubic(distance_to_center)
|
829 |
+
# Normalize the weights matrix so that each row sums to 1.
|
830 |
+
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
831 |
+
weights = weights / weights_sum.expand(out_length, P)
|
832 |
+
|
833 |
+
# If a column in weights is all zero, get rid of it. only consider the first and last column.
|
834 |
+
weights_zero_tmp = torch.sum((weights == 0), 0)
|
835 |
+
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
836 |
+
indices = indices.narrow(1, 1, P - 2)
|
837 |
+
weights = weights.narrow(1, 1, P - 2)
|
838 |
+
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
839 |
+
indices = indices.narrow(1, 0, P - 2)
|
840 |
+
weights = weights.narrow(1, 0, P - 2)
|
841 |
+
weights = weights.contiguous()
|
842 |
+
indices = indices.contiguous()
|
843 |
+
sym_len_s = -indices.min() + 1
|
844 |
+
sym_len_e = indices.max() - in_length
|
845 |
+
indices = indices + sym_len_s - 1
|
846 |
+
return weights, indices, int(sym_len_s), int(sym_len_e)
|
847 |
+
|
848 |
+
|
849 |
+
# --------------------------------------------
|
850 |
+
# imresize for tensor image [0, 1]
|
851 |
+
# --------------------------------------------
|
852 |
+
def imresize(img, scale, antialiasing=True):
|
853 |
+
# Now the scale should be the same for H and W
|
854 |
+
# input: img: pytorch tensor, CHW or HW [0,1]
|
855 |
+
# output: CHW or HW [0,1] w/o round
|
856 |
+
need_squeeze = True if img.dim() == 2 else False
|
857 |
+
if need_squeeze:
|
858 |
+
img.unsqueeze_(0)
|
859 |
+
in_C, in_H, in_W = img.size()
|
860 |
+
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
861 |
+
kernel_width = 4
|
862 |
+
kernel = 'cubic'
|
863 |
+
|
864 |
+
# Return the desired dimension order for performing the resize. The
|
865 |
+
# strategy is to perform the resize first along the dimension with the
|
866 |
+
# smallest scale factor.
|
867 |
+
# Now we do not support this.
|
868 |
+
|
869 |
+
# get weights and indices
|
870 |
+
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
871 |
+
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
872 |
+
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
873 |
+
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
874 |
+
# process H dimension
|
875 |
+
# symmetric copying
|
876 |
+
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
|
877 |
+
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
|
878 |
+
|
879 |
+
sym_patch = img[:, :sym_len_Hs, :]
|
880 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
881 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
882 |
+
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
|
883 |
+
|
884 |
+
sym_patch = img[:, -sym_len_He:, :]
|
885 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
886 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
887 |
+
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
888 |
+
|
889 |
+
out_1 = torch.FloatTensor(in_C, out_H, in_W)
|
890 |
+
kernel_width = weights_H.size(1)
|
891 |
+
for i in range(out_H):
|
892 |
+
idx = int(indices_H[i][0])
|
893 |
+
for j in range(out_C):
|
894 |
+
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
895 |
+
|
896 |
+
# process W dimension
|
897 |
+
# symmetric copying
|
898 |
+
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
|
899 |
+
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
|
900 |
+
|
901 |
+
sym_patch = out_1[:, :, :sym_len_Ws]
|
902 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
903 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
904 |
+
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
|
905 |
+
|
906 |
+
sym_patch = out_1[:, :, -sym_len_We:]
|
907 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
908 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
909 |
+
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
910 |
+
|
911 |
+
out_2 = torch.FloatTensor(in_C, out_H, out_W)
|
912 |
+
kernel_width = weights_W.size(1)
|
913 |
+
for i in range(out_W):
|
914 |
+
idx = int(indices_W[i][0])
|
915 |
+
for j in range(out_C):
|
916 |
+
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
|
917 |
+
if need_squeeze:
|
918 |
+
out_2.squeeze_()
|
919 |
+
return out_2
|
920 |
+
|
921 |
+
|
922 |
+
# --------------------------------------------
|
923 |
+
# imresize for numpy image [0, 1]
|
924 |
+
# --------------------------------------------
|
925 |
+
def imresize_np(img, scale, antialiasing=True):
|
926 |
+
# Now the scale should be the same for H and W
|
927 |
+
# input: img: Numpy, HWC or HW [0,1]
|
928 |
+
# output: HWC or HW [0,1] w/o round
|
929 |
+
img = torch.from_numpy(img)
|
930 |
+
need_squeeze = True if img.dim() == 2 else False
|
931 |
+
if need_squeeze:
|
932 |
+
img.unsqueeze_(2)
|
933 |
+
|
934 |
+
in_H, in_W, in_C = img.size()
|
935 |
+
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
936 |
+
kernel_width = 4
|
937 |
+
kernel = 'cubic'
|
938 |
+
|
939 |
+
# Return the desired dimension order for performing the resize. The
|
940 |
+
# strategy is to perform the resize first along the dimension with the
|
941 |
+
# smallest scale factor.
|
942 |
+
# Now we do not support this.
|
943 |
+
|
944 |
+
# get weights and indices
|
945 |
+
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
946 |
+
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
947 |
+
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
948 |
+
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
949 |
+
# process H dimension
|
950 |
+
# symmetric copying
|
951 |
+
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
|
952 |
+
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
|
953 |
+
|
954 |
+
sym_patch = img[:sym_len_Hs, :, :]
|
955 |
+
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
956 |
+
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
957 |
+
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
|
958 |
+
|
959 |
+
sym_patch = img[-sym_len_He:, :, :]
|
960 |
+
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
961 |
+
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
962 |
+
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
963 |
+
|
964 |
+
out_1 = torch.FloatTensor(out_H, in_W, in_C)
|
965 |
+
kernel_width = weights_H.size(1)
|
966 |
+
for i in range(out_H):
|
967 |
+
idx = int(indices_H[i][0])
|
968 |
+
for j in range(out_C):
|
969 |
+
out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
|
970 |
+
|
971 |
+
# process W dimension
|
972 |
+
# symmetric copying
|
973 |
+
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
|
974 |
+
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
|
975 |
+
|
976 |
+
sym_patch = out_1[:, :sym_len_Ws, :]
|
977 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
978 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
979 |
+
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
|
980 |
+
|
981 |
+
sym_patch = out_1[:, -sym_len_We:, :]
|
982 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
983 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
984 |
+
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
985 |
+
|
986 |
+
out_2 = torch.FloatTensor(out_H, out_W, in_C)
|
987 |
+
kernel_width = weights_W.size(1)
|
988 |
+
for i in range(out_W):
|
989 |
+
idx = int(indices_W[i][0])
|
990 |
+
for j in range(out_C):
|
991 |
+
out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
|
992 |
+
if need_squeeze:
|
993 |
+
out_2.squeeze_()
|
994 |
+
|
995 |
+
return out_2.numpy()
|
996 |
+
|
997 |
+
|
998 |
+
if __name__ == '__main__':
|
999 |
+
img = imread_uint('test.bmp', 3)
|
1000 |
+
# img = uint2single(img)
|
1001 |
+
# img_bicubic = imresize_np(img, 1/4)
|
1002 |
+
# imshow(single2uint(img_bicubic))
|
1003 |
+
#
|
1004 |
+
# img_tensor = single2tensor4(img)
|
1005 |
+
# for i in range(8):
|
1006 |
+
# imshow(np.concatenate((augment_img(img, i), tensor2single(augment_img_tensor4(img_tensor, i))), 1))
|
1007 |
+
|
1008 |
+
# patches = patches_from_image(img, p_size=128, p_overlap=0, p_max=200)
|
1009 |
+
# imssave(patches,'a.png')
|
1010 |
+
|
1011 |
+
|
1012 |
+
|
1013 |
+
|
1014 |
+
|
1015 |
+
|
1016 |
+
|
core/data/deg_kair_utils/utils_lmdb.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import lmdb
|
3 |
+
import sys
|
4 |
+
from multiprocessing import Pool
|
5 |
+
from os import path as osp
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
|
9 |
+
def make_lmdb_from_imgs(data_path,
|
10 |
+
lmdb_path,
|
11 |
+
img_path_list,
|
12 |
+
keys,
|
13 |
+
batch=5000,
|
14 |
+
compress_level=1,
|
15 |
+
multiprocessing_read=False,
|
16 |
+
n_thread=40,
|
17 |
+
map_size=None):
|
18 |
+
"""Make lmdb from images.
|
19 |
+
|
20 |
+
Contents of lmdb. The file structure is:
|
21 |
+
example.lmdb
|
22 |
+
├── data.mdb
|
23 |
+
├── lock.mdb
|
24 |
+
├── meta_info.txt
|
25 |
+
|
26 |
+
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
27 |
+
https://lmdb.readthedocs.io/en/release/ for more details.
|
28 |
+
|
29 |
+
The meta_info.txt is a specified txt file to record the meta information
|
30 |
+
of our datasets. It will be automatically created when preparing
|
31 |
+
datasets by our provided dataset tools.
|
32 |
+
Each line in the txt file records 1)image name (with extension),
|
33 |
+
2)image shape, and 3)compression level, separated by a white space.
|
34 |
+
|
35 |
+
For example, the meta information could be:
|
36 |
+
`000_00000000.png (720,1280,3) 1`, which means:
|
37 |
+
1) image name (with extension): 000_00000000.png;
|
38 |
+
2) image shape: (720,1280,3);
|
39 |
+
3) compression level: 1
|
40 |
+
|
41 |
+
We use the image name without extension as the lmdb key.
|
42 |
+
|
43 |
+
If `multiprocessing_read` is True, it will read all the images to memory
|
44 |
+
using multiprocessing. Thus, your server needs to have enough memory.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
data_path (str): Data path for reading images.
|
48 |
+
lmdb_path (str): Lmdb save path.
|
49 |
+
img_path_list (str): Image path list.
|
50 |
+
keys (str): Used for lmdb keys.
|
51 |
+
batch (int): After processing batch images, lmdb commits.
|
52 |
+
Default: 5000.
|
53 |
+
compress_level (int): Compress level when encoding images. Default: 1.
|
54 |
+
multiprocessing_read (bool): Whether use multiprocessing to read all
|
55 |
+
the images to memory. Default: False.
|
56 |
+
n_thread (int): For multiprocessing.
|
57 |
+
map_size (int | None): Map size for lmdb env. If None, use the
|
58 |
+
estimated size from images. Default: None
|
59 |
+
"""
|
60 |
+
|
61 |
+
assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
|
62 |
+
f'but got {len(img_path_list)} and {len(keys)}')
|
63 |
+
print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
|
64 |
+
print(f'Totoal images: {len(img_path_list)}')
|
65 |
+
if not lmdb_path.endswith('.lmdb'):
|
66 |
+
raise ValueError("lmdb_path must end with '.lmdb'.")
|
67 |
+
if osp.exists(lmdb_path):
|
68 |
+
print(f'Folder {lmdb_path} already exists. Exit.')
|
69 |
+
sys.exit(1)
|
70 |
+
|
71 |
+
if multiprocessing_read:
|
72 |
+
# read all the images to memory (multiprocessing)
|
73 |
+
dataset = {} # use dict to keep the order for multiprocessing
|
74 |
+
shapes = {}
|
75 |
+
print(f'Read images with multiprocessing, #thread: {n_thread} ...')
|
76 |
+
pbar = tqdm(total=len(img_path_list), unit='image')
|
77 |
+
|
78 |
+
def callback(arg):
|
79 |
+
"""get the image data and update pbar."""
|
80 |
+
key, dataset[key], shapes[key] = arg
|
81 |
+
pbar.update(1)
|
82 |
+
pbar.set_description(f'Read {key}')
|
83 |
+
|
84 |
+
pool = Pool(n_thread)
|
85 |
+
for path, key in zip(img_path_list, keys):
|
86 |
+
pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
|
87 |
+
pool.close()
|
88 |
+
pool.join()
|
89 |
+
pbar.close()
|
90 |
+
print(f'Finish reading {len(img_path_list)} images.')
|
91 |
+
|
92 |
+
# create lmdb environment
|
93 |
+
if map_size is None:
|
94 |
+
# obtain data size for one image
|
95 |
+
img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
|
96 |
+
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
97 |
+
data_size_per_img = img_byte.nbytes
|
98 |
+
print('Data size per image is: ', data_size_per_img)
|
99 |
+
data_size = data_size_per_img * len(img_path_list)
|
100 |
+
map_size = data_size * 10
|
101 |
+
|
102 |
+
env = lmdb.open(lmdb_path, map_size=map_size)
|
103 |
+
|
104 |
+
# write data to lmdb
|
105 |
+
pbar = tqdm(total=len(img_path_list), unit='chunk')
|
106 |
+
txn = env.begin(write=True)
|
107 |
+
txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
|
108 |
+
for idx, (path, key) in enumerate(zip(img_path_list, keys)):
|
109 |
+
pbar.update(1)
|
110 |
+
pbar.set_description(f'Write {key}')
|
111 |
+
key_byte = key.encode('ascii')
|
112 |
+
if multiprocessing_read:
|
113 |
+
img_byte = dataset[key]
|
114 |
+
h, w, c = shapes[key]
|
115 |
+
else:
|
116 |
+
_, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
|
117 |
+
h, w, c = img_shape
|
118 |
+
|
119 |
+
txn.put(key_byte, img_byte)
|
120 |
+
# write meta information
|
121 |
+
txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
|
122 |
+
if idx % batch == 0:
|
123 |
+
txn.commit()
|
124 |
+
txn = env.begin(write=True)
|
125 |
+
pbar.close()
|
126 |
+
txn.commit()
|
127 |
+
env.close()
|
128 |
+
txt_file.close()
|
129 |
+
print('\nFinish writing lmdb.')
|
130 |
+
|
131 |
+
|
132 |
+
def read_img_worker(path, key, compress_level):
|
133 |
+
"""Read image worker.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
path (str): Image path.
|
137 |
+
key (str): Image key.
|
138 |
+
compress_level (int): Compress level when encoding images.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
str: Image key.
|
142 |
+
byte: Image byte.
|
143 |
+
tuple[int]: Image shape.
|
144 |
+
"""
|
145 |
+
|
146 |
+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
147 |
+
# deal with `libpng error: Read Error`
|
148 |
+
if img is None:
|
149 |
+
print(f'To deal with `libpng error: Read Error`, use PIL to load {path}')
|
150 |
+
from PIL import Image
|
151 |
+
import numpy as np
|
152 |
+
img = Image.open(path)
|
153 |
+
img = np.asanyarray(img)
|
154 |
+
img = img[:, :, [2, 1, 0]]
|
155 |
+
|
156 |
+
if img.ndim == 2:
|
157 |
+
h, w = img.shape
|
158 |
+
c = 1
|
159 |
+
else:
|
160 |
+
h, w, c = img.shape
|
161 |
+
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
162 |
+
return (key, img_byte, (h, w, c))
|
163 |
+
|
164 |
+
|
165 |
+
class LmdbMaker():
|
166 |
+
"""LMDB Maker.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
lmdb_path (str): Lmdb save path.
|
170 |
+
map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
|
171 |
+
batch (int): After processing batch images, lmdb commits.
|
172 |
+
Default: 5000.
|
173 |
+
compress_level (int): Compress level when encoding images. Default: 1.
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
|
177 |
+
if not lmdb_path.endswith('.lmdb'):
|
178 |
+
raise ValueError("lmdb_path must end with '.lmdb'.")
|
179 |
+
if osp.exists(lmdb_path):
|
180 |
+
print(f'Folder {lmdb_path} already exists. Exit.')
|
181 |
+
sys.exit(1)
|
182 |
+
|
183 |
+
self.lmdb_path = lmdb_path
|
184 |
+
self.batch = batch
|
185 |
+
self.compress_level = compress_level
|
186 |
+
self.env = lmdb.open(lmdb_path, map_size=map_size)
|
187 |
+
self.txn = self.env.begin(write=True)
|
188 |
+
self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
|
189 |
+
self.counter = 0
|
190 |
+
|
191 |
+
def put(self, img_byte, key, img_shape):
|
192 |
+
self.counter += 1
|
193 |
+
key_byte = key.encode('ascii')
|
194 |
+
self.txn.put(key_byte, img_byte)
|
195 |
+
# write meta information
|
196 |
+
h, w, c = img_shape
|
197 |
+
self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
|
198 |
+
if self.counter % self.batch == 0:
|
199 |
+
self.txn.commit()
|
200 |
+
self.txn = self.env.begin(write=True)
|
201 |
+
|
202 |
+
def close(self):
|
203 |
+
self.txn.commit()
|
204 |
+
self.env.close()
|
205 |
+
self.txt_file.close()
|
core/data/deg_kair_utils/utils_logger.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import datetime
|
3 |
+
import logging
|
4 |
+
|
5 |
+
|
6 |
+
'''
|
7 |
+
# --------------------------------------------
|
8 |
+
# Kai Zhang (github: https://github.com/cszn)
|
9 |
+
# 03/Mar/2019
|
10 |
+
# --------------------------------------------
|
11 |
+
# https://github.com/xinntao/BasicSR
|
12 |
+
# --------------------------------------------
|
13 |
+
'''
|
14 |
+
|
15 |
+
|
16 |
+
def log(*args, **kwargs):
|
17 |
+
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)
|
18 |
+
|
19 |
+
|
20 |
+
'''
|
21 |
+
# --------------------------------------------
|
22 |
+
# logger
|
23 |
+
# --------------------------------------------
|
24 |
+
'''
|
25 |
+
|
26 |
+
|
27 |
+
def logger_info(logger_name, log_path='default_logger.log'):
|
28 |
+
''' set up logger
|
29 |
+
modified by Kai Zhang (github: https://github.com/cszn)
|
30 |
+
'''
|
31 |
+
log = logging.getLogger(logger_name)
|
32 |
+
if log.hasHandlers():
|
33 |
+
print('LogHandlers exist!')
|
34 |
+
else:
|
35 |
+
print('LogHandlers setup!')
|
36 |
+
level = logging.INFO
|
37 |
+
formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S')
|
38 |
+
fh = logging.FileHandler(log_path, mode='a')
|
39 |
+
fh.setFormatter(formatter)
|
40 |
+
log.setLevel(level)
|
41 |
+
log.addHandler(fh)
|
42 |
+
# print(len(log.handlers))
|
43 |
+
|
44 |
+
sh = logging.StreamHandler()
|
45 |
+
sh.setFormatter(formatter)
|
46 |
+
log.addHandler(sh)
|
47 |
+
|
48 |
+
|
49 |
+
'''
|
50 |
+
# --------------------------------------------
|
51 |
+
# print to file and std_out simultaneously
|
52 |
+
# --------------------------------------------
|
53 |
+
'''
|
54 |
+
|
55 |
+
|
56 |
+
class logger_print(object):
|
57 |
+
def __init__(self, log_path="default.log"):
|
58 |
+
self.terminal = sys.stdout
|
59 |
+
self.log = open(log_path, 'a')
|
60 |
+
|
61 |
+
def write(self, message):
|
62 |
+
self.terminal.write(message)
|
63 |
+
self.log.write(message) # write the message
|
64 |
+
|
65 |
+
def flush(self):
|
66 |
+
pass
|
core/data/deg_kair_utils/utils_mat.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import scipy.io as spio
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
|
7 |
+
def loadmat(filename):
|
8 |
+
'''
|
9 |
+
this function should be called instead of direct spio.loadmat
|
10 |
+
as it cures the problem of not properly recovering python dictionaries
|
11 |
+
from mat files. It calls the function check keys to cure all entries
|
12 |
+
which are still mat-objects
|
13 |
+
'''
|
14 |
+
data = spio.loadmat(filename, struct_as_record=False, squeeze_me=True)
|
15 |
+
return dict_to_nonedict(_check_keys(data))
|
16 |
+
|
17 |
+
def _check_keys(dict):
|
18 |
+
'''
|
19 |
+
checks if entries in dictionary are mat-objects. If yes
|
20 |
+
todict is called to change them to nested dictionaries
|
21 |
+
'''
|
22 |
+
for key in dict:
|
23 |
+
if isinstance(dict[key], spio.matlab.mio5_params.mat_struct):
|
24 |
+
dict[key] = _todict(dict[key])
|
25 |
+
return dict
|
26 |
+
|
27 |
+
def _todict(matobj):
|
28 |
+
'''
|
29 |
+
A recursive function which constructs from matobjects nested dictionaries
|
30 |
+
'''
|
31 |
+
dict = {}
|
32 |
+
for strg in matobj._fieldnames:
|
33 |
+
elem = matobj.__dict__[strg]
|
34 |
+
if isinstance(elem, spio.matlab.mio5_params.mat_struct):
|
35 |
+
dict[strg] = _todict(elem)
|
36 |
+
else:
|
37 |
+
dict[strg] = elem
|
38 |
+
return dict
|
39 |
+
|
40 |
+
|
41 |
+
def dict_to_nonedict(opt):
|
42 |
+
if isinstance(opt, dict):
|
43 |
+
new_opt = dict()
|
44 |
+
for key, sub_opt in opt.items():
|
45 |
+
new_opt[key] = dict_to_nonedict(sub_opt)
|
46 |
+
return NoneDict(**new_opt)
|
47 |
+
elif isinstance(opt, list):
|
48 |
+
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
|
49 |
+
else:
|
50 |
+
return opt
|
51 |
+
|
52 |
+
|
53 |
+
class NoneDict(dict):
|
54 |
+
def __missing__(self, key):
|
55 |
+
return None
|
56 |
+
|
57 |
+
|
58 |
+
def mat2json(mat_path=None, filepath = None):
|
59 |
+
"""
|
60 |
+
Converts .mat file to .json and writes new file
|
61 |
+
Parameters
|
62 |
+
----------
|
63 |
+
mat_path: Str
|
64 |
+
path/filename .mat存放路径
|
65 |
+
filepath: Str
|
66 |
+
如果需要保存成json, 添加这一路径. 否则不保存
|
67 |
+
Returns
|
68 |
+
返回转化的字典
|
69 |
+
-------
|
70 |
+
None
|
71 |
+
Examples
|
72 |
+
--------
|
73 |
+
>>> mat2json(blah blah)
|
74 |
+
"""
|
75 |
+
|
76 |
+
matlabFile = loadmat(mat_path)
|
77 |
+
#pop all those dumb fields that don't let you jsonize file
|
78 |
+
matlabFile.pop('__header__')
|
79 |
+
matlabFile.pop('__version__')
|
80 |
+
matlabFile.pop('__globals__')
|
81 |
+
#jsonize the file - orientation is 'index'
|
82 |
+
matlabFile = pd.Series(matlabFile).to_json()
|
83 |
+
|
84 |
+
if filepath:
|
85 |
+
json_path = os.path.splitext(os.path.split(mat_path)[1])[0] + '.json'
|
86 |
+
with open(json_path, 'w') as f:
|
87 |
+
f.write(matlabFile)
|
88 |
+
return matlabFile
|
core/data/deg_kair_utils/utils_matconvnet.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
# import scipy.io as io
|
7 |
+
import hdf5storage
|
8 |
+
|
9 |
+
"""
|
10 |
+
# --------------------------------------------
|
11 |
+
# Convert matconvnet SimpleNN model into pytorch model
|
12 |
+
# --------------------------------------------
|
13 |
+
# Kai Zhang ([email protected])
|
14 |
+
# https://github.com/cszn
|
15 |
+
# 28/Nov/2019
|
16 |
+
# --------------------------------------------
|
17 |
+
"""
|
18 |
+
|
19 |
+
|
20 |
+
def weights2tensor(x, squeeze=False, in_features=None, out_features=None):
|
21 |
+
"""Modified version of https://github.com/albanie/pytorch-mcn
|
22 |
+
Adjust memory layout and load weights as torch tensor
|
23 |
+
Args:
|
24 |
+
x (ndaray): a numpy array, corresponding to a set of network weights
|
25 |
+
stored in column major order
|
26 |
+
squeeze (bool) [False]: whether to squeeze the tensor (i.e. remove
|
27 |
+
singletons from the trailing dimensions. So after converting to
|
28 |
+
pytorch layout (C_out, C_in, H, W), if the shape is (A, B, 1, 1)
|
29 |
+
it will be reshaped to a matrix with shape (A,B).
|
30 |
+
in_features (int :: None): used to reshape weights for a linear block.
|
31 |
+
out_features (int :: None): used to reshape weights for a linear block.
|
32 |
+
Returns:
|
33 |
+
torch.tensor: a permuted sets of weights, matching the pytorch layout
|
34 |
+
convention
|
35 |
+
"""
|
36 |
+
if x.ndim == 4:
|
37 |
+
x = x.transpose((3, 2, 0, 1))
|
38 |
+
# for FFDNet, pixel-shuffle layer
|
39 |
+
# if x.shape[1]==13:
|
40 |
+
# x=x[:,[0,2,1,3, 4,6,5,7, 8,10,9,11, 12],:,:]
|
41 |
+
# if x.shape[0]==12:
|
42 |
+
# x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:]
|
43 |
+
# if x.shape[1]==5:
|
44 |
+
# x=x[:,[0,2,1,3, 4],:,:]
|
45 |
+
# if x.shape[0]==4:
|
46 |
+
# x=x[[0,2,1,3],:,:,:]
|
47 |
+
## for SRMD, pixel-shuffle layer
|
48 |
+
# if x.shape[0]==12:
|
49 |
+
# x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:]
|
50 |
+
# if x.shape[0]==27:
|
51 |
+
# x=x[[0,3,6,1,4,7,2,5,8, 0+9,3+9,6+9,1+9,4+9,7+9,2+9,5+9,8+9, 0+18,3+18,6+18,1+18,4+18,7+18,2+18,5+18,8+18],:,:,:]
|
52 |
+
# if x.shape[0]==48:
|
53 |
+
# x=x[[0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15, 0+16,4+16,8+16,12+16,1+16,5+16,9+16,13+16,2+16,6+16,10+16,14+16,3+16,7+16,11+16,15+16, 0+32,4+32,8+32,12+32,1+32,5+32,9+32,13+32,2+32,6+32,10+32,14+32,3+32,7+32,11+32,15+32],:,:,:]
|
54 |
+
|
55 |
+
elif x.ndim == 3: # add by Kai
|
56 |
+
x = x[:,:,:,None]
|
57 |
+
x = x.transpose((3, 2, 0, 1))
|
58 |
+
elif x.ndim == 2:
|
59 |
+
if x.shape[1] == 1:
|
60 |
+
x = x.flatten()
|
61 |
+
if squeeze:
|
62 |
+
if in_features and out_features:
|
63 |
+
x = x.reshape((out_features, in_features))
|
64 |
+
x = np.squeeze(x)
|
65 |
+
return torch.from_numpy(np.ascontiguousarray(x))
|
66 |
+
|
67 |
+
|
68 |
+
def save_model(network, save_path):
|
69 |
+
state_dict = network.state_dict()
|
70 |
+
for key, param in state_dict.items():
|
71 |
+
state_dict[key] = param.cpu()
|
72 |
+
torch.save(state_dict, save_path)
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == '__main__':
|
76 |
+
|
77 |
+
|
78 |
+
# from utils import utils_logger
|
79 |
+
# import logging
|
80 |
+
# utils_logger.logger_info('a', 'a.log')
|
81 |
+
# logger = logging.getLogger('a')
|
82 |
+
#
|
83 |
+
# mcn = hdf5storage.loadmat('/model_zoo/matfile/FFDNet_Clip_gray.mat')
|
84 |
+
mcn = hdf5storage.loadmat('models/modelcolor.mat')
|
85 |
+
|
86 |
+
|
87 |
+
#logger.info(mcn['CNNdenoiser'][0][0][0][1][0][0][0][0])
|
88 |
+
|
89 |
+
mat_net = OrderedDict()
|
90 |
+
for idx in range(25):
|
91 |
+
mat_net[str(idx)] = OrderedDict()
|
92 |
+
count = -1
|
93 |
+
|
94 |
+
print(idx)
|
95 |
+
for i in range(13):
|
96 |
+
|
97 |
+
if mcn['CNNdenoiser'][0][idx][0][i][0][0][0][0] == 'conv':
|
98 |
+
|
99 |
+
count += 1
|
100 |
+
w = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][0]
|
101 |
+
# print(w.shape)
|
102 |
+
w = weights2tensor(w)
|
103 |
+
# print(w.shape)
|
104 |
+
|
105 |
+
b = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][1]
|
106 |
+
b = weights2tensor(b)
|
107 |
+
print(b.shape)
|
108 |
+
|
109 |
+
mat_net[str(idx)]['model.{:d}.weight'.format(count*2)] = w
|
110 |
+
mat_net[str(idx)]['model.{:d}.bias'.format(count*2)] = b
|
111 |
+
|
112 |
+
torch.save(mat_net, 'model_zoo/modelcolor.pth')
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
# from models.network_dncnn import IRCNN as net
|
117 |
+
# network = net(in_nc=3, out_nc=3, nc=64)
|
118 |
+
# state_dict = network.state_dict()
|
119 |
+
#
|
120 |
+
# #show_kv(state_dict)
|
121 |
+
#
|
122 |
+
# for i in range(len(mcn['net'][0][0][0])):
|
123 |
+
# print(mcn['net'][0][0][0][i][0][0][0][0])
|
124 |
+
#
|
125 |
+
# count = -1
|
126 |
+
# mat_net = OrderedDict()
|
127 |
+
# for i in range(len(mcn['net'][0][0][0])):
|
128 |
+
# if mcn['net'][0][0][0][i][0][0][0][0] == 'conv':
|
129 |
+
#
|
130 |
+
# count += 1
|
131 |
+
# w = mcn['net'][0][0][0][i][0][1][0][0]
|
132 |
+
# print(w.shape)
|
133 |
+
# w = weights2tensor(w)
|
134 |
+
# print(w.shape)
|
135 |
+
#
|
136 |
+
# b = mcn['net'][0][0][0][i][0][1][0][1]
|
137 |
+
# b = weights2tensor(b)
|
138 |
+
# print(b.shape)
|
139 |
+
#
|
140 |
+
# mat_net['model.{:d}.weight'.format(count*2)] = w
|
141 |
+
# mat_net['model.{:d}.bias'.format(count*2)] = b
|
142 |
+
#
|
143 |
+
# torch.save(mat_net, 'E:/pytorch/KAIR_ongoing/model_zoo/ffdnet_gray_clip.pth')
|
144 |
+
#
|
145 |
+
#
|
146 |
+
#
|
147 |
+
# crt_net = torch.load('E:/pytorch/KAIR_ongoing/model_zoo/imdn_x4.pth')
|
148 |
+
# def show_kv(net):
|
149 |
+
# for k, v in net.items():
|
150 |
+
# print(k)
|
151 |
+
#
|
152 |
+
# show_kv(crt_net)
|
153 |
+
|
154 |
+
|
155 |
+
# from models.network_dncnn import DnCNN as net
|
156 |
+
# network = net(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R')
|
157 |
+
|
158 |
+
# from models.network_srmd import SRMD as net
|
159 |
+
# #network = net(in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R')
|
160 |
+
# network = net(in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
|
161 |
+
#
|
162 |
+
# from models.network_rrdb import RRDB as net
|
163 |
+
# network = net(in_nc=3, out_nc=3, nc=64, nb=23, gc=32, upscale=4, act_mode='L', upsample_mode='upconv')
|
164 |
+
#
|
165 |
+
# state_dict = network.state_dict()
|
166 |
+
# for key, param in state_dict.items():
|
167 |
+
# print(key)
|
168 |
+
# from models.network_imdn import IMDN as net
|
169 |
+
# network = net(in_nc=3, out_nc=3, nc=64, nb=8, upscale=4, act_mode='L', upsample_mode='pixelshuffle')
|
170 |
+
# state_dict = network.state_dict()
|
171 |
+
# mat_net = OrderedDict()
|
172 |
+
# for ((key, param),(key2, param2)) in zip(state_dict.items(), crt_net.items()):
|
173 |
+
# mat_net[key] = param2
|
174 |
+
# torch.save(mat_net, 'model_zoo/imdn_x4_1.pth')
|
175 |
+
#
|
176 |
+
|
177 |
+
# net_old = torch.load('net_old.pth')
|
178 |
+
# def show_kv(net):
|
179 |
+
# for k, v in net.items():
|
180 |
+
# print(k)
|
181 |
+
#
|
182 |
+
# show_kv(net_old)
|
183 |
+
# from models.network_dpsr import MSRResNet_prior as net
|
184 |
+
# model = net(in_nc=4, out_nc=3, nc=96, nb=16, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
|
185 |
+
# state_dict = network.state_dict()
|
186 |
+
# net_new = OrderedDict()
|
187 |
+
# for ((key, param),(key_old, param_old)) in zip(state_dict.items(), net_old.items()):
|
188 |
+
# net_new[key] = param_old
|
189 |
+
# torch.save(net_new, 'net_new.pth')
|
190 |
+
|
191 |
+
|
192 |
+
# print(key)
|
193 |
+
# print(param.size())
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
# run utils/utils_matconvnet.py
|
core/data/deg_kair_utils/utils_model.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from utils import utils_image as util
|
5 |
+
import re
|
6 |
+
import glob
|
7 |
+
import os
|
8 |
+
|
9 |
+
|
10 |
+
'''
|
11 |
+
# --------------------------------------------
|
12 |
+
# Model
|
13 |
+
# --------------------------------------------
|
14 |
+
# Kai Zhang (github: https://github.com/cszn)
|
15 |
+
# 03/Mar/2019
|
16 |
+
# --------------------------------------------
|
17 |
+
'''
|
18 |
+
|
19 |
+
|
20 |
+
def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
|
21 |
+
"""
|
22 |
+
# ---------------------------------------
|
23 |
+
# Kai Zhang (github: https://github.com/cszn)
|
24 |
+
# 03/Mar/2019
|
25 |
+
# ---------------------------------------
|
26 |
+
Args:
|
27 |
+
save_dir: model folder
|
28 |
+
net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
|
29 |
+
pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
|
30 |
+
|
31 |
+
Return:
|
32 |
+
init_iter: iteration number
|
33 |
+
init_path: model path
|
34 |
+
# ---------------------------------------
|
35 |
+
"""
|
36 |
+
|
37 |
+
file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
|
38 |
+
if file_list:
|
39 |
+
iter_exist = []
|
40 |
+
for file_ in file_list:
|
41 |
+
iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
|
42 |
+
iter_exist.append(int(iter_current[0]))
|
43 |
+
init_iter = max(iter_exist)
|
44 |
+
init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
|
45 |
+
else:
|
46 |
+
init_iter = 0
|
47 |
+
init_path = pretrained_path
|
48 |
+
return init_iter, init_path
|
49 |
+
|
50 |
+
|
51 |
+
def test_mode(model, L, mode=0, refield=32, min_size=256, sf=1, modulo=1):
|
52 |
+
'''
|
53 |
+
# ---------------------------------------
|
54 |
+
# Kai Zhang (github: https://github.com/cszn)
|
55 |
+
# 03/Mar/2019
|
56 |
+
# ---------------------------------------
|
57 |
+
Args:
|
58 |
+
model: trained model
|
59 |
+
L: input Low-quality image
|
60 |
+
mode:
|
61 |
+
(0) normal: test(model, L)
|
62 |
+
(1) pad: test_pad(model, L, modulo=16)
|
63 |
+
(2) split: test_split(model, L, refield=32, min_size=256, sf=1, modulo=1)
|
64 |
+
(3) x8: test_x8(model, L, modulo=1) ^_^
|
65 |
+
(4) split and x8: test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1)
|
66 |
+
refield: effective receptive filed of the network, 32 is enough
|
67 |
+
useful when split, i.e., mode=2, 4
|
68 |
+
min_size: min_sizeXmin_size image, e.g., 256X256 image
|
69 |
+
useful when split, i.e., mode=2, 4
|
70 |
+
sf: scale factor for super-resolution, otherwise 1
|
71 |
+
modulo: 1 if split
|
72 |
+
useful when pad, i.e., mode=1
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
E: estimated image
|
76 |
+
# ---------------------------------------
|
77 |
+
'''
|
78 |
+
if mode == 0:
|
79 |
+
E = test(model, L)
|
80 |
+
elif mode == 1:
|
81 |
+
E = test_pad(model, L, modulo, sf)
|
82 |
+
elif mode == 2:
|
83 |
+
E = test_split(model, L, refield, min_size, sf, modulo)
|
84 |
+
elif mode == 3:
|
85 |
+
E = test_x8(model, L, modulo, sf)
|
86 |
+
elif mode == 4:
|
87 |
+
E = test_split_x8(model, L, refield, min_size, sf, modulo)
|
88 |
+
return E
|
89 |
+
|
90 |
+
|
91 |
+
'''
|
92 |
+
# --------------------------------------------
|
93 |
+
# normal (0)
|
94 |
+
# --------------------------------------------
|
95 |
+
'''
|
96 |
+
|
97 |
+
|
98 |
+
def test(model, L):
|
99 |
+
E = model(L)
|
100 |
+
return E
|
101 |
+
|
102 |
+
|
103 |
+
'''
|
104 |
+
# --------------------------------------------
|
105 |
+
# pad (1)
|
106 |
+
# --------------------------------------------
|
107 |
+
'''
|
108 |
+
|
109 |
+
|
110 |
+
def test_pad(model, L, modulo=16, sf=1):
|
111 |
+
h, w = L.size()[-2:]
|
112 |
+
paddingBottom = int(np.ceil(h/modulo)*modulo-h)
|
113 |
+
paddingRight = int(np.ceil(w/modulo)*modulo-w)
|
114 |
+
L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L)
|
115 |
+
E = model(L)
|
116 |
+
E = E[..., :h*sf, :w*sf]
|
117 |
+
return E
|
118 |
+
|
119 |
+
|
120 |
+
'''
|
121 |
+
# --------------------------------------------
|
122 |
+
# split (function)
|
123 |
+
# --------------------------------------------
|
124 |
+
'''
|
125 |
+
|
126 |
+
|
127 |
+
def test_split_fn(model, L, refield=32, min_size=256, sf=1, modulo=1):
|
128 |
+
"""
|
129 |
+
Args:
|
130 |
+
model: trained model
|
131 |
+
L: input Low-quality image
|
132 |
+
refield: effective receptive filed of the network, 32 is enough
|
133 |
+
min_size: min_sizeXmin_size image, e.g., 256X256 image
|
134 |
+
sf: scale factor for super-resolution, otherwise 1
|
135 |
+
modulo: 1 if split
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
E: estimated result
|
139 |
+
"""
|
140 |
+
h, w = L.size()[-2:]
|
141 |
+
if h*w <= min_size**2:
|
142 |
+
L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L)
|
143 |
+
E = model(L)
|
144 |
+
E = E[..., :h*sf, :w*sf]
|
145 |
+
else:
|
146 |
+
top = slice(0, (h//2//refield+1)*refield)
|
147 |
+
bottom = slice(h - (h//2//refield+1)*refield, h)
|
148 |
+
left = slice(0, (w//2//refield+1)*refield)
|
149 |
+
right = slice(w - (w//2//refield+1)*refield, w)
|
150 |
+
Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]]
|
151 |
+
|
152 |
+
if h * w <= 4*(min_size**2):
|
153 |
+
Es = [model(Ls[i]) for i in range(4)]
|
154 |
+
else:
|
155 |
+
Es = [test_split_fn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)]
|
156 |
+
|
157 |
+
b, c = Es[0].size()[:2]
|
158 |
+
E = torch.zeros(b, c, sf * h, sf * w).type_as(L)
|
159 |
+
|
160 |
+
E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf]
|
161 |
+
E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:]
|
162 |
+
E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf]
|
163 |
+
E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:]
|
164 |
+
return E
|
165 |
+
|
166 |
+
|
167 |
+
'''
|
168 |
+
# --------------------------------------------
|
169 |
+
# split (2)
|
170 |
+
# --------------------------------------------
|
171 |
+
'''
|
172 |
+
|
173 |
+
|
174 |
+
def test_split(model, L, refield=32, min_size=256, sf=1, modulo=1):
|
175 |
+
E = test_split_fn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo)
|
176 |
+
return E
|
177 |
+
|
178 |
+
|
179 |
+
'''
|
180 |
+
# --------------------------------------------
|
181 |
+
# x8 (3)
|
182 |
+
# --------------------------------------------
|
183 |
+
'''
|
184 |
+
|
185 |
+
|
186 |
+
def test_x8(model, L, modulo=1, sf=1):
|
187 |
+
E_list = [test_pad(model, util.augment_img_tensor4(L, mode=i), modulo=modulo, sf=sf) for i in range(8)]
|
188 |
+
for i in range(len(E_list)):
|
189 |
+
if i == 3 or i == 5:
|
190 |
+
E_list[i] = util.augment_img_tensor4(E_list[i], mode=8 - i)
|
191 |
+
else:
|
192 |
+
E_list[i] = util.augment_img_tensor4(E_list[i], mode=i)
|
193 |
+
output_cat = torch.stack(E_list, dim=0)
|
194 |
+
E = output_cat.mean(dim=0, keepdim=False)
|
195 |
+
return E
|
196 |
+
|
197 |
+
|
198 |
+
'''
|
199 |
+
# --------------------------------------------
|
200 |
+
# split and x8 (4)
|
201 |
+
# --------------------------------------------
|
202 |
+
'''
|
203 |
+
|
204 |
+
|
205 |
+
def test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1):
|
206 |
+
E_list = [test_split_fn(model, util.augment_img_tensor4(L, mode=i), refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(8)]
|
207 |
+
for k, i in enumerate(range(len(E_list))):
|
208 |
+
if i==3 or i==5:
|
209 |
+
E_list[k] = util.augment_img_tensor4(E_list[k], mode=8-i)
|
210 |
+
else:
|
211 |
+
E_list[k] = util.augment_img_tensor4(E_list[k], mode=i)
|
212 |
+
output_cat = torch.stack(E_list, dim=0)
|
213 |
+
E = output_cat.mean(dim=0, keepdim=False)
|
214 |
+
return E
|
215 |
+
|
216 |
+
|
217 |
+
'''
|
218 |
+
# ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-
|
219 |
+
# _^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^
|
220 |
+
# ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-
|
221 |
+
'''
|
222 |
+
|
223 |
+
|
224 |
+
'''
|
225 |
+
# --------------------------------------------
|
226 |
+
# print
|
227 |
+
# --------------------------------------------
|
228 |
+
'''
|
229 |
+
|
230 |
+
|
231 |
+
# --------------------------------------------
|
232 |
+
# print model
|
233 |
+
# --------------------------------------------
|
234 |
+
def print_model(model):
|
235 |
+
msg = describe_model(model)
|
236 |
+
print(msg)
|
237 |
+
|
238 |
+
|
239 |
+
# --------------------------------------------
|
240 |
+
# print params
|
241 |
+
# --------------------------------------------
|
242 |
+
def print_params(model):
|
243 |
+
msg = describe_params(model)
|
244 |
+
print(msg)
|
245 |
+
|
246 |
+
|
247 |
+
'''
|
248 |
+
# --------------------------------------------
|
249 |
+
# information
|
250 |
+
# --------------------------------------------
|
251 |
+
'''
|
252 |
+
|
253 |
+
|
254 |
+
# --------------------------------------------
|
255 |
+
# model inforation
|
256 |
+
# --------------------------------------------
|
257 |
+
def info_model(model):
|
258 |
+
msg = describe_model(model)
|
259 |
+
return msg
|
260 |
+
|
261 |
+
|
262 |
+
# --------------------------------------------
|
263 |
+
# params inforation
|
264 |
+
# --------------------------------------------
|
265 |
+
def info_params(model):
|
266 |
+
msg = describe_params(model)
|
267 |
+
return msg
|
268 |
+
|
269 |
+
|
270 |
+
'''
|
271 |
+
# --------------------------------------------
|
272 |
+
# description
|
273 |
+
# --------------------------------------------
|
274 |
+
'''
|
275 |
+
|
276 |
+
|
277 |
+
# --------------------------------------------
|
278 |
+
# model name and total number of parameters
|
279 |
+
# --------------------------------------------
|
280 |
+
def describe_model(model):
|
281 |
+
if isinstance(model, torch.nn.DataParallel):
|
282 |
+
model = model.module
|
283 |
+
msg = '\n'
|
284 |
+
msg += 'models name: {}'.format(model.__class__.__name__) + '\n'
|
285 |
+
msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), model.parameters()))) + '\n'
|
286 |
+
msg += 'Net structure:\n{}'.format(str(model)) + '\n'
|
287 |
+
return msg
|
288 |
+
|
289 |
+
|
290 |
+
# --------------------------------------------
|
291 |
+
# parameters description
|
292 |
+
# --------------------------------------------
|
293 |
+
def describe_params(model):
|
294 |
+
if isinstance(model, torch.nn.DataParallel):
|
295 |
+
model = model.module
|
296 |
+
msg = '\n'
|
297 |
+
msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n'
|
298 |
+
for name, param in model.state_dict().items():
|
299 |
+
if not 'num_batches_tracked' in name:
|
300 |
+
v = param.data.clone().float()
|
301 |
+
msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n'
|
302 |
+
return msg
|
303 |
+
|
304 |
+
|
305 |
+
if __name__ == '__main__':
|
306 |
+
|
307 |
+
class Net(torch.nn.Module):
|
308 |
+
def __init__(self, in_channels=3, out_channels=3):
|
309 |
+
super(Net, self).__init__()
|
310 |
+
self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)
|
311 |
+
|
312 |
+
def forward(self, x):
|
313 |
+
x = self.conv(x)
|
314 |
+
return x
|
315 |
+
|
316 |
+
start = torch.cuda.Event(enable_timing=True)
|
317 |
+
end = torch.cuda.Event(enable_timing=True)
|
318 |
+
|
319 |
+
model = Net()
|
320 |
+
model = model.eval()
|
321 |
+
print_model(model)
|
322 |
+
print_params(model)
|
323 |
+
x = torch.randn((2,3,401,401))
|
324 |
+
torch.cuda.empty_cache()
|
325 |
+
with torch.no_grad():
|
326 |
+
for mode in range(5):
|
327 |
+
y = test_mode(model, x, mode, refield=32, min_size=256, sf=1, modulo=1)
|
328 |
+
print(y.shape)
|
329 |
+
|
330 |
+
# run utils/utils_model.py
|
core/data/deg_kair_utils/utils_modelsummary.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
'''
|
6 |
+
---- 1) FLOPs: floating point operations
|
7 |
+
---- 2) #Activations: the number of elements of all ‘Conv2d’ outputs
|
8 |
+
---- 3) #Conv2d: the number of ‘Conv2d’ layers
|
9 |
+
# --------------------------------------------
|
10 |
+
# Kai Zhang (github: https://github.com/cszn)
|
11 |
+
# 21/July/2020
|
12 |
+
# --------------------------------------------
|
13 |
+
# Reference
|
14 |
+
https://github.com/sovrasov/flops-counter.pytorch.git
|
15 |
+
|
16 |
+
# If you use this code, please consider the following citation:
|
17 |
+
|
18 |
+
@inproceedings{zhang2020aim, %
|
19 |
+
title={AIM 2020 Challenge on Efficient Super-Resolution: Methods and Results},
|
20 |
+
author={Kai Zhang and Martin Danelljan and Yawei Li and Radu Timofte and others},
|
21 |
+
booktitle={European Conference on Computer Vision Workshops},
|
22 |
+
year={2020}
|
23 |
+
}
|
24 |
+
# --------------------------------------------
|
25 |
+
'''
|
26 |
+
|
27 |
+
def get_model_flops(model, input_res, print_per_layer_stat=True,
|
28 |
+
input_constructor=None):
|
29 |
+
assert type(input_res) is tuple, 'Please provide the size of the input image.'
|
30 |
+
assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
|
31 |
+
flops_model = add_flops_counting_methods(model)
|
32 |
+
flops_model.eval().start_flops_count()
|
33 |
+
if input_constructor:
|
34 |
+
input = input_constructor(input_res)
|
35 |
+
_ = flops_model(**input)
|
36 |
+
else:
|
37 |
+
device = list(flops_model.parameters())[-1].device
|
38 |
+
batch = torch.FloatTensor(1, *input_res).to(device)
|
39 |
+
_ = flops_model(batch)
|
40 |
+
|
41 |
+
if print_per_layer_stat:
|
42 |
+
print_model_with_flops(flops_model)
|
43 |
+
flops_count = flops_model.compute_average_flops_cost()
|
44 |
+
flops_model.stop_flops_count()
|
45 |
+
|
46 |
+
return flops_count
|
47 |
+
|
48 |
+
def get_model_activation(model, input_res, input_constructor=None):
|
49 |
+
assert type(input_res) is tuple, 'Please provide the size of the input image.'
|
50 |
+
assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
|
51 |
+
activation_model = add_activation_counting_methods(model)
|
52 |
+
activation_model.eval().start_activation_count()
|
53 |
+
if input_constructor:
|
54 |
+
input = input_constructor(input_res)
|
55 |
+
_ = activation_model(**input)
|
56 |
+
else:
|
57 |
+
device = list(activation_model.parameters())[-1].device
|
58 |
+
batch = torch.FloatTensor(1, *input_res).to(device)
|
59 |
+
_ = activation_model(batch)
|
60 |
+
|
61 |
+
activation_count, num_conv = activation_model.compute_average_activation_cost()
|
62 |
+
activation_model.stop_activation_count()
|
63 |
+
|
64 |
+
return activation_count, num_conv
|
65 |
+
|
66 |
+
|
67 |
+
def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True,
|
68 |
+
input_constructor=None):
|
69 |
+
assert type(input_res) is tuple
|
70 |
+
assert len(input_res) >= 3
|
71 |
+
flops_model = add_flops_counting_methods(model)
|
72 |
+
flops_model.eval().start_flops_count()
|
73 |
+
if input_constructor:
|
74 |
+
input = input_constructor(input_res)
|
75 |
+
_ = flops_model(**input)
|
76 |
+
else:
|
77 |
+
batch = torch.FloatTensor(1, *input_res)
|
78 |
+
_ = flops_model(batch)
|
79 |
+
|
80 |
+
if print_per_layer_stat:
|
81 |
+
print_model_with_flops(flops_model)
|
82 |
+
flops_count = flops_model.compute_average_flops_cost()
|
83 |
+
params_count = get_model_parameters_number(flops_model)
|
84 |
+
flops_model.stop_flops_count()
|
85 |
+
|
86 |
+
if as_strings:
|
87 |
+
return flops_to_string(flops_count), params_to_string(params_count)
|
88 |
+
|
89 |
+
return flops_count, params_count
|
90 |
+
|
91 |
+
|
92 |
+
def flops_to_string(flops, units='GMac', precision=2):
|
93 |
+
if units is None:
|
94 |
+
if flops // 10**9 > 0:
|
95 |
+
return str(round(flops / 10.**9, precision)) + ' GMac'
|
96 |
+
elif flops // 10**6 > 0:
|
97 |
+
return str(round(flops / 10.**6, precision)) + ' MMac'
|
98 |
+
elif flops // 10**3 > 0:
|
99 |
+
return str(round(flops / 10.**3, precision)) + ' KMac'
|
100 |
+
else:
|
101 |
+
return str(flops) + ' Mac'
|
102 |
+
else:
|
103 |
+
if units == 'GMac':
|
104 |
+
return str(round(flops / 10.**9, precision)) + ' ' + units
|
105 |
+
elif units == 'MMac':
|
106 |
+
return str(round(flops / 10.**6, precision)) + ' ' + units
|
107 |
+
elif units == 'KMac':
|
108 |
+
return str(round(flops / 10.**3, precision)) + ' ' + units
|
109 |
+
else:
|
110 |
+
return str(flops) + ' Mac'
|
111 |
+
|
112 |
+
|
113 |
+
def params_to_string(params_num):
|
114 |
+
if params_num // 10 ** 6 > 0:
|
115 |
+
return str(round(params_num / 10 ** 6, 2)) + ' M'
|
116 |
+
elif params_num // 10 ** 3:
|
117 |
+
return str(round(params_num / 10 ** 3, 2)) + ' k'
|
118 |
+
else:
|
119 |
+
return str(params_num)
|
120 |
+
|
121 |
+
|
122 |
+
def print_model_with_flops(model, units='GMac', precision=3):
|
123 |
+
total_flops = model.compute_average_flops_cost()
|
124 |
+
|
125 |
+
def accumulate_flops(self):
|
126 |
+
if is_supported_instance(self):
|
127 |
+
return self.__flops__ / model.__batch_counter__
|
128 |
+
else:
|
129 |
+
sum = 0
|
130 |
+
for m in self.children():
|
131 |
+
sum += m.accumulate_flops()
|
132 |
+
return sum
|
133 |
+
|
134 |
+
def flops_repr(self):
|
135 |
+
accumulated_flops_cost = self.accumulate_flops()
|
136 |
+
return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision),
|
137 |
+
'{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
|
138 |
+
self.original_extra_repr()])
|
139 |
+
|
140 |
+
def add_extra_repr(m):
|
141 |
+
m.accumulate_flops = accumulate_flops.__get__(m)
|
142 |
+
flops_extra_repr = flops_repr.__get__(m)
|
143 |
+
if m.extra_repr != flops_extra_repr:
|
144 |
+
m.original_extra_repr = m.extra_repr
|
145 |
+
m.extra_repr = flops_extra_repr
|
146 |
+
assert m.extra_repr != m.original_extra_repr
|
147 |
+
|
148 |
+
def del_extra_repr(m):
|
149 |
+
if hasattr(m, 'original_extra_repr'):
|
150 |
+
m.extra_repr = m.original_extra_repr
|
151 |
+
del m.original_extra_repr
|
152 |
+
if hasattr(m, 'accumulate_flops'):
|
153 |
+
del m.accumulate_flops
|
154 |
+
|
155 |
+
model.apply(add_extra_repr)
|
156 |
+
print(model)
|
157 |
+
model.apply(del_extra_repr)
|
158 |
+
|
159 |
+
|
160 |
+
def get_model_parameters_number(model):
|
161 |
+
params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
162 |
+
return params_num
|
163 |
+
|
164 |
+
|
165 |
+
def add_flops_counting_methods(net_main_module):
|
166 |
+
# adding additional methods to the existing module object,
|
167 |
+
# this is done this way so that each function has access to self object
|
168 |
+
# embed()
|
169 |
+
net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
|
170 |
+
net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
|
171 |
+
net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
|
172 |
+
net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module)
|
173 |
+
|
174 |
+
net_main_module.reset_flops_count()
|
175 |
+
return net_main_module
|
176 |
+
|
177 |
+
|
178 |
+
def compute_average_flops_cost(self):
|
179 |
+
"""
|
180 |
+
A method that will be available after add_flops_counting_methods() is called
|
181 |
+
on a desired net object.
|
182 |
+
|
183 |
+
Returns current mean flops consumption per image.
|
184 |
+
|
185 |
+
"""
|
186 |
+
|
187 |
+
flops_sum = 0
|
188 |
+
for module in self.modules():
|
189 |
+
if is_supported_instance(module):
|
190 |
+
flops_sum += module.__flops__
|
191 |
+
|
192 |
+
return flops_sum
|
193 |
+
|
194 |
+
|
195 |
+
def start_flops_count(self):
|
196 |
+
"""
|
197 |
+
A method that will be available after add_flops_counting_methods() is called
|
198 |
+
on a desired net object.
|
199 |
+
|
200 |
+
Activates the computation of mean flops consumption per image.
|
201 |
+
Call it before you run the network.
|
202 |
+
|
203 |
+
"""
|
204 |
+
self.apply(add_flops_counter_hook_function)
|
205 |
+
|
206 |
+
|
207 |
+
def stop_flops_count(self):
|
208 |
+
"""
|
209 |
+
A method that will be available after add_flops_counting_methods() is called
|
210 |
+
on a desired net object.
|
211 |
+
|
212 |
+
Stops computing the mean flops consumption per image.
|
213 |
+
Call whenever you want to pause the computation.
|
214 |
+
|
215 |
+
"""
|
216 |
+
self.apply(remove_flops_counter_hook_function)
|
217 |
+
|
218 |
+
|
219 |
+
def reset_flops_count(self):
|
220 |
+
"""
|
221 |
+
A method that will be available after add_flops_counting_methods() is called
|
222 |
+
on a desired net object.
|
223 |
+
|
224 |
+
Resets statistics computed so far.
|
225 |
+
|
226 |
+
"""
|
227 |
+
self.apply(add_flops_counter_variable_or_reset)
|
228 |
+
|
229 |
+
|
230 |
+
def add_flops_counter_hook_function(module):
|
231 |
+
if is_supported_instance(module):
|
232 |
+
if hasattr(module, '__flops_handle__'):
|
233 |
+
return
|
234 |
+
|
235 |
+
if isinstance(module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)):
|
236 |
+
handle = module.register_forward_hook(conv_flops_counter_hook)
|
237 |
+
elif isinstance(module, (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)):
|
238 |
+
handle = module.register_forward_hook(relu_flops_counter_hook)
|
239 |
+
elif isinstance(module, nn.Linear):
|
240 |
+
handle = module.register_forward_hook(linear_flops_counter_hook)
|
241 |
+
elif isinstance(module, (nn.BatchNorm2d)):
|
242 |
+
handle = module.register_forward_hook(bn_flops_counter_hook)
|
243 |
+
else:
|
244 |
+
handle = module.register_forward_hook(empty_flops_counter_hook)
|
245 |
+
module.__flops_handle__ = handle
|
246 |
+
|
247 |
+
|
248 |
+
def remove_flops_counter_hook_function(module):
|
249 |
+
if is_supported_instance(module):
|
250 |
+
if hasattr(module, '__flops_handle__'):
|
251 |
+
module.__flops_handle__.remove()
|
252 |
+
del module.__flops_handle__
|
253 |
+
|
254 |
+
|
255 |
+
def add_flops_counter_variable_or_reset(module):
|
256 |
+
if is_supported_instance(module):
|
257 |
+
module.__flops__ = 0
|
258 |
+
|
259 |
+
|
260 |
+
# ---- Internal functions
|
261 |
+
def is_supported_instance(module):
|
262 |
+
if isinstance(module,
|
263 |
+
(
|
264 |
+
nn.Conv2d, nn.ConvTranspose2d,
|
265 |
+
nn.BatchNorm2d,
|
266 |
+
nn.Linear,
|
267 |
+
nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6,
|
268 |
+
)):
|
269 |
+
return True
|
270 |
+
|
271 |
+
return False
|
272 |
+
|
273 |
+
|
274 |
+
def conv_flops_counter_hook(conv_module, input, output):
|
275 |
+
# Can have multiple inputs, getting the first one
|
276 |
+
# input = input[0]
|
277 |
+
|
278 |
+
batch_size = output.shape[0]
|
279 |
+
output_dims = list(output.shape[2:])
|
280 |
+
|
281 |
+
kernel_dims = list(conv_module.kernel_size)
|
282 |
+
in_channels = conv_module.in_channels
|
283 |
+
out_channels = conv_module.out_channels
|
284 |
+
groups = conv_module.groups
|
285 |
+
|
286 |
+
filters_per_channel = out_channels // groups
|
287 |
+
conv_per_position_flops = np.prod(kernel_dims) * in_channels * filters_per_channel
|
288 |
+
|
289 |
+
active_elements_count = batch_size * np.prod(output_dims)
|
290 |
+
overall_conv_flops = int(conv_per_position_flops) * int(active_elements_count)
|
291 |
+
|
292 |
+
# overall_flops = overall_conv_flops
|
293 |
+
|
294 |
+
conv_module.__flops__ += int(overall_conv_flops)
|
295 |
+
# conv_module.__output_dims__ = output_dims
|
296 |
+
|
297 |
+
|
298 |
+
def relu_flops_counter_hook(module, input, output):
|
299 |
+
active_elements_count = output.numel()
|
300 |
+
module.__flops__ += int(active_elements_count)
|
301 |
+
# print(module.__flops__, id(module))
|
302 |
+
# print(module)
|
303 |
+
|
304 |
+
|
305 |
+
def linear_flops_counter_hook(module, input, output):
|
306 |
+
input = input[0]
|
307 |
+
if len(input.shape) == 1:
|
308 |
+
batch_size = 1
|
309 |
+
module.__flops__ += int(batch_size * input.shape[0] * output.shape[0])
|
310 |
+
else:
|
311 |
+
batch_size = input.shape[0]
|
312 |
+
module.__flops__ += int(batch_size * input.shape[1] * output.shape[1])
|
313 |
+
|
314 |
+
|
315 |
+
def bn_flops_counter_hook(module, input, output):
|
316 |
+
# input = input[0]
|
317 |
+
# TODO: need to check here
|
318 |
+
# batch_flops = np.prod(input.shape)
|
319 |
+
# if module.affine:
|
320 |
+
# batch_flops *= 2
|
321 |
+
# module.__flops__ += int(batch_flops)
|
322 |
+
batch = output.shape[0]
|
323 |
+
output_dims = output.shape[2:]
|
324 |
+
channels = module.num_features
|
325 |
+
batch_flops = batch * channels * np.prod(output_dims)
|
326 |
+
if module.affine:
|
327 |
+
batch_flops *= 2
|
328 |
+
module.__flops__ += int(batch_flops)
|
329 |
+
|
330 |
+
|
331 |
+
# ---- Count the number of convolutional layers and the activation
|
332 |
+
def add_activation_counting_methods(net_main_module):
|
333 |
+
# adding additional methods to the existing module object,
|
334 |
+
# this is done this way so that each function has access to self object
|
335 |
+
# embed()
|
336 |
+
net_main_module.start_activation_count = start_activation_count.__get__(net_main_module)
|
337 |
+
net_main_module.stop_activation_count = stop_activation_count.__get__(net_main_module)
|
338 |
+
net_main_module.reset_activation_count = reset_activation_count.__get__(net_main_module)
|
339 |
+
net_main_module.compute_average_activation_cost = compute_average_activation_cost.__get__(net_main_module)
|
340 |
+
|
341 |
+
net_main_module.reset_activation_count()
|
342 |
+
return net_main_module
|
343 |
+
|
344 |
+
|
345 |
+
def compute_average_activation_cost(self):
|
346 |
+
"""
|
347 |
+
A method that will be available after add_activation_counting_methods() is called
|
348 |
+
on a desired net object.
|
349 |
+
|
350 |
+
Returns current mean activation consumption per image.
|
351 |
+
|
352 |
+
"""
|
353 |
+
|
354 |
+
activation_sum = 0
|
355 |
+
num_conv = 0
|
356 |
+
for module in self.modules():
|
357 |
+
if is_supported_instance_for_activation(module):
|
358 |
+
activation_sum += module.__activation__
|
359 |
+
num_conv += module.__num_conv__
|
360 |
+
return activation_sum, num_conv
|
361 |
+
|
362 |
+
|
363 |
+
def start_activation_count(self):
|
364 |
+
"""
|
365 |
+
A method that will be available after add_activation_counting_methods() is called
|
366 |
+
on a desired net object.
|
367 |
+
|
368 |
+
Activates the computation of mean activation consumption per image.
|
369 |
+
Call it before you run the network.
|
370 |
+
|
371 |
+
"""
|
372 |
+
self.apply(add_activation_counter_hook_function)
|
373 |
+
|
374 |
+
|
375 |
+
def stop_activation_count(self):
|
376 |
+
"""
|
377 |
+
A method that will be available after add_activation_counting_methods() is called
|
378 |
+
on a desired net object.
|
379 |
+
|
380 |
+
Stops computing the mean activation consumption per image.
|
381 |
+
Call whenever you want to pause the computation.
|
382 |
+
|
383 |
+
"""
|
384 |
+
self.apply(remove_activation_counter_hook_function)
|
385 |
+
|
386 |
+
|
387 |
+
def reset_activation_count(self):
|
388 |
+
"""
|
389 |
+
A method that will be available after add_activation_counting_methods() is called
|
390 |
+
on a desired net object.
|
391 |
+
|
392 |
+
Resets statistics computed so far.
|
393 |
+
|
394 |
+
"""
|
395 |
+
self.apply(add_activation_counter_variable_or_reset)
|
396 |
+
|
397 |
+
|
398 |
+
def add_activation_counter_hook_function(module):
|
399 |
+
if is_supported_instance_for_activation(module):
|
400 |
+
if hasattr(module, '__activation_handle__'):
|
401 |
+
return
|
402 |
+
|
403 |
+
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
|
404 |
+
handle = module.register_forward_hook(conv_activation_counter_hook)
|
405 |
+
module.__activation_handle__ = handle
|
406 |
+
|
407 |
+
|
408 |
+
def remove_activation_counter_hook_function(module):
|
409 |
+
if is_supported_instance_for_activation(module):
|
410 |
+
if hasattr(module, '__activation_handle__'):
|
411 |
+
module.__activation_handle__.remove()
|
412 |
+
del module.__activation_handle__
|
413 |
+
|
414 |
+
|
415 |
+
def add_activation_counter_variable_or_reset(module):
|
416 |
+
if is_supported_instance_for_activation(module):
|
417 |
+
module.__activation__ = 0
|
418 |
+
module.__num_conv__ = 0
|
419 |
+
|
420 |
+
|
421 |
+
def is_supported_instance_for_activation(module):
|
422 |
+
if isinstance(module,
|
423 |
+
(
|
424 |
+
nn.Conv2d, nn.ConvTranspose2d,
|
425 |
+
)):
|
426 |
+
return True
|
427 |
+
|
428 |
+
return False
|
429 |
+
|
430 |
+
def conv_activation_counter_hook(module, input, output):
|
431 |
+
"""
|
432 |
+
Calculate the activations in the convolutional operation.
|
433 |
+
Reference: Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár, Designing Network Design Spaces.
|
434 |
+
:param module:
|
435 |
+
:param input:
|
436 |
+
:param output:
|
437 |
+
:return:
|
438 |
+
"""
|
439 |
+
module.__activation__ += output.numel()
|
440 |
+
module.__num_conv__ += 1
|
441 |
+
|
442 |
+
|
443 |
+
def empty_flops_counter_hook(module, input, output):
|
444 |
+
module.__flops__ += 0
|
445 |
+
|
446 |
+
|
447 |
+
def upsample_flops_counter_hook(module, input, output):
|
448 |
+
output_size = output[0]
|
449 |
+
batch_size = output_size.shape[0]
|
450 |
+
output_elements_count = batch_size
|
451 |
+
for val in output_size.shape[1:]:
|
452 |
+
output_elements_count *= val
|
453 |
+
module.__flops__ += int(output_elements_count)
|
454 |
+
|
455 |
+
|
456 |
+
def pool_flops_counter_hook(module, input, output):
|
457 |
+
input = input[0]
|
458 |
+
module.__flops__ += int(np.prod(input.shape))
|
459 |
+
|
460 |
+
|
461 |
+
def dconv_flops_counter_hook(dconv_module, input, output):
|
462 |
+
input = input[0]
|
463 |
+
|
464 |
+
batch_size = input.shape[0]
|
465 |
+
output_dims = list(output.shape[2:])
|
466 |
+
|
467 |
+
m_channels, in_channels, kernel_dim1, _, = dconv_module.weight.shape
|
468 |
+
out_channels, _, kernel_dim2, _, = dconv_module.projection.shape
|
469 |
+
# groups = dconv_module.groups
|
470 |
+
|
471 |
+
# filters_per_channel = out_channels // groups
|
472 |
+
conv_per_position_flops1 = kernel_dim1 ** 2 * in_channels * m_channels
|
473 |
+
conv_per_position_flops2 = kernel_dim2 ** 2 * out_channels * m_channels
|
474 |
+
active_elements_count = batch_size * np.prod(output_dims)
|
475 |
+
|
476 |
+
overall_conv_flops = (conv_per_position_flops1 + conv_per_position_flops2) * active_elements_count
|
477 |
+
overall_flops = overall_conv_flops
|
478 |
+
|
479 |
+
dconv_module.__flops__ += int(overall_flops)
|
480 |
+
# dconv_module.__output_dims__ = output_dims
|
481 |
+
|
482 |
+
|
483 |
+
|
484 |
+
|
485 |
+
|
core/data/deg_kair_utils/utils_option.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import OrderedDict
|
3 |
+
from datetime import datetime
|
4 |
+
import json
|
5 |
+
import re
|
6 |
+
import glob
|
7 |
+
|
8 |
+
|
9 |
+
'''
|
10 |
+
# --------------------------------------------
|
11 |
+
# Kai Zhang (github: https://github.com/cszn)
|
12 |
+
# 03/Mar/2019
|
13 |
+
# --------------------------------------------
|
14 |
+
# https://github.com/xinntao/BasicSR
|
15 |
+
# --------------------------------------------
|
16 |
+
'''
|
17 |
+
|
18 |
+
|
19 |
+
def get_timestamp():
|
20 |
+
return datetime.now().strftime('_%y%m%d_%H%M%S')
|
21 |
+
|
22 |
+
|
23 |
+
def parse(opt_path, is_train=True):
|
24 |
+
|
25 |
+
# ----------------------------------------
|
26 |
+
# remove comments starting with '//'
|
27 |
+
# ----------------------------------------
|
28 |
+
json_str = ''
|
29 |
+
with open(opt_path, 'r') as f:
|
30 |
+
for line in f:
|
31 |
+
line = line.split('//')[0] + '\n'
|
32 |
+
json_str += line
|
33 |
+
|
34 |
+
# ----------------------------------------
|
35 |
+
# initialize opt
|
36 |
+
# ----------------------------------------
|
37 |
+
opt = json.loads(json_str, object_pairs_hook=OrderedDict)
|
38 |
+
|
39 |
+
opt['opt_path'] = opt_path
|
40 |
+
opt['is_train'] = is_train
|
41 |
+
|
42 |
+
# ----------------------------------------
|
43 |
+
# set default
|
44 |
+
# ----------------------------------------
|
45 |
+
if 'merge_bn' not in opt:
|
46 |
+
opt['merge_bn'] = False
|
47 |
+
opt['merge_bn_startpoint'] = -1
|
48 |
+
|
49 |
+
if 'scale' not in opt:
|
50 |
+
opt['scale'] = 1
|
51 |
+
|
52 |
+
# ----------------------------------------
|
53 |
+
# datasets
|
54 |
+
# ----------------------------------------
|
55 |
+
for phase, dataset in opt['datasets'].items():
|
56 |
+
phase = phase.split('_')[0]
|
57 |
+
dataset['phase'] = phase
|
58 |
+
dataset['scale'] = opt['scale'] # broadcast
|
59 |
+
dataset['n_channels'] = opt['n_channels'] # broadcast
|
60 |
+
if 'dataroot_H' in dataset and dataset['dataroot_H'] is not None:
|
61 |
+
dataset['dataroot_H'] = os.path.expanduser(dataset['dataroot_H'])
|
62 |
+
if 'dataroot_L' in dataset and dataset['dataroot_L'] is not None:
|
63 |
+
dataset['dataroot_L'] = os.path.expanduser(dataset['dataroot_L'])
|
64 |
+
|
65 |
+
# ----------------------------------------
|
66 |
+
# path
|
67 |
+
# ----------------------------------------
|
68 |
+
for key, path in opt['path'].items():
|
69 |
+
if path and key in opt['path']:
|
70 |
+
opt['path'][key] = os.path.expanduser(path)
|
71 |
+
|
72 |
+
path_task = os.path.join(opt['path']['root'], opt['task'])
|
73 |
+
opt['path']['task'] = path_task
|
74 |
+
opt['path']['log'] = path_task
|
75 |
+
opt['path']['options'] = os.path.join(path_task, 'options')
|
76 |
+
|
77 |
+
if is_train:
|
78 |
+
opt['path']['models'] = os.path.join(path_task, 'models')
|
79 |
+
opt['path']['images'] = os.path.join(path_task, 'images')
|
80 |
+
else: # test
|
81 |
+
opt['path']['images'] = os.path.join(path_task, 'test_images')
|
82 |
+
|
83 |
+
# ----------------------------------------
|
84 |
+
# network
|
85 |
+
# ----------------------------------------
|
86 |
+
opt['netG']['scale'] = opt['scale'] if 'scale' in opt else 1
|
87 |
+
|
88 |
+
# ----------------------------------------
|
89 |
+
# GPU devices
|
90 |
+
# ----------------------------------------
|
91 |
+
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
92 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
93 |
+
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
94 |
+
|
95 |
+
# ----------------------------------------
|
96 |
+
# default setting for distributeddataparallel
|
97 |
+
# ----------------------------------------
|
98 |
+
if 'find_unused_parameters' not in opt:
|
99 |
+
opt['find_unused_parameters'] = True
|
100 |
+
if 'use_static_graph' not in opt:
|
101 |
+
opt['use_static_graph'] = False
|
102 |
+
if 'dist' not in opt:
|
103 |
+
opt['dist'] = False
|
104 |
+
opt['num_gpu'] = len(opt['gpu_ids'])
|
105 |
+
print('number of GPUs is: ' + str(opt['num_gpu']))
|
106 |
+
|
107 |
+
# ----------------------------------------
|
108 |
+
# default setting for perceptual loss
|
109 |
+
# ----------------------------------------
|
110 |
+
if 'F_feature_layer' not in opt['train']:
|
111 |
+
opt['train']['F_feature_layer'] = 34 # 25; [2,7,16,25,34]
|
112 |
+
if 'F_weights' not in opt['train']:
|
113 |
+
opt['train']['F_weights'] = 1.0 # 1.0; [0.1,0.1,1.0,1.0,1.0]
|
114 |
+
if 'F_lossfn_type' not in opt['train']:
|
115 |
+
opt['train']['F_lossfn_type'] = 'l1'
|
116 |
+
if 'F_use_input_norm' not in opt['train']:
|
117 |
+
opt['train']['F_use_input_norm'] = True
|
118 |
+
if 'F_use_range_norm' not in opt['train']:
|
119 |
+
opt['train']['F_use_range_norm'] = False
|
120 |
+
|
121 |
+
# ----------------------------------------
|
122 |
+
# default setting for optimizer
|
123 |
+
# ----------------------------------------
|
124 |
+
if 'G_optimizer_type' not in opt['train']:
|
125 |
+
opt['train']['G_optimizer_type'] = "adam"
|
126 |
+
if 'G_optimizer_betas' not in opt['train']:
|
127 |
+
opt['train']['G_optimizer_betas'] = [0.9,0.999]
|
128 |
+
if 'G_scheduler_restart_weights' not in opt['train']:
|
129 |
+
opt['train']['G_scheduler_restart_weights'] = 1
|
130 |
+
if 'G_optimizer_wd' not in opt['train']:
|
131 |
+
opt['train']['G_optimizer_wd'] = 0
|
132 |
+
if 'G_optimizer_reuse' not in opt['train']:
|
133 |
+
opt['train']['G_optimizer_reuse'] = False
|
134 |
+
if 'netD' in opt and 'D_optimizer_reuse' not in opt['train']:
|
135 |
+
opt['train']['D_optimizer_reuse'] = False
|
136 |
+
|
137 |
+
# ----------------------------------------
|
138 |
+
# default setting of strict for model loading
|
139 |
+
# ----------------------------------------
|
140 |
+
if 'G_param_strict' not in opt['train']:
|
141 |
+
opt['train']['G_param_strict'] = True
|
142 |
+
if 'netD' in opt and 'D_param_strict' not in opt['path']:
|
143 |
+
opt['train']['D_param_strict'] = True
|
144 |
+
if 'E_param_strict' not in opt['path']:
|
145 |
+
opt['train']['E_param_strict'] = True
|
146 |
+
|
147 |
+
# ----------------------------------------
|
148 |
+
# Exponential Moving Average
|
149 |
+
# ----------------------------------------
|
150 |
+
if 'E_decay' not in opt['train']:
|
151 |
+
opt['train']['E_decay'] = 0
|
152 |
+
|
153 |
+
# ----------------------------------------
|
154 |
+
# default setting for discriminator
|
155 |
+
# ----------------------------------------
|
156 |
+
if 'netD' in opt:
|
157 |
+
if 'net_type' not in opt['netD']:
|
158 |
+
opt['netD']['net_type'] = 'discriminator_patchgan' # discriminator_unet
|
159 |
+
if 'in_nc' not in opt['netD']:
|
160 |
+
opt['netD']['in_nc'] = 3
|
161 |
+
if 'base_nc' not in opt['netD']:
|
162 |
+
opt['netD']['base_nc'] = 64
|
163 |
+
if 'n_layers' not in opt['netD']:
|
164 |
+
opt['netD']['n_layers'] = 3
|
165 |
+
if 'norm_type' not in opt['netD']:
|
166 |
+
opt['netD']['norm_type'] = 'spectral'
|
167 |
+
|
168 |
+
|
169 |
+
return opt
|
170 |
+
|
171 |
+
|
172 |
+
def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
|
173 |
+
"""
|
174 |
+
Args:
|
175 |
+
save_dir: model folder
|
176 |
+
net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
|
177 |
+
pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
|
178 |
+
|
179 |
+
Return:
|
180 |
+
init_iter: iteration number
|
181 |
+
init_path: model path
|
182 |
+
"""
|
183 |
+
file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
|
184 |
+
if file_list:
|
185 |
+
iter_exist = []
|
186 |
+
for file_ in file_list:
|
187 |
+
iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
|
188 |
+
iter_exist.append(int(iter_current[0]))
|
189 |
+
init_iter = max(iter_exist)
|
190 |
+
init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
|
191 |
+
else:
|
192 |
+
init_iter = 0
|
193 |
+
init_path = pretrained_path
|
194 |
+
return init_iter, init_path
|
195 |
+
|
196 |
+
|
197 |
+
'''
|
198 |
+
# --------------------------------------------
|
199 |
+
# convert the opt into json file
|
200 |
+
# --------------------------------------------
|
201 |
+
'''
|
202 |
+
|
203 |
+
|
204 |
+
def save(opt):
|
205 |
+
opt_path = opt['opt_path']
|
206 |
+
opt_path_copy = opt['path']['options']
|
207 |
+
dirname, filename_ext = os.path.split(opt_path)
|
208 |
+
filename, ext = os.path.splitext(filename_ext)
|
209 |
+
dump_path = os.path.join(opt_path_copy, filename+get_timestamp()+ext)
|
210 |
+
with open(dump_path, 'w') as dump_file:
|
211 |
+
json.dump(opt, dump_file, indent=2)
|
212 |
+
|
213 |
+
|
214 |
+
'''
|
215 |
+
# --------------------------------------------
|
216 |
+
# dict to string for logger
|
217 |
+
# --------------------------------------------
|
218 |
+
'''
|
219 |
+
|
220 |
+
|
221 |
+
def dict2str(opt, indent_l=1):
|
222 |
+
msg = ''
|
223 |
+
for k, v in opt.items():
|
224 |
+
if isinstance(v, dict):
|
225 |
+
msg += ' ' * (indent_l * 2) + k + ':[\n'
|
226 |
+
msg += dict2str(v, indent_l + 1)
|
227 |
+
msg += ' ' * (indent_l * 2) + ']\n'
|
228 |
+
else:
|
229 |
+
msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
|
230 |
+
return msg
|
231 |
+
|
232 |
+
|
233 |
+
'''
|
234 |
+
# --------------------------------------------
|
235 |
+
# convert OrderedDict to NoneDict,
|
236 |
+
# return None for missing key
|
237 |
+
# --------------------------------------------
|
238 |
+
'''
|
239 |
+
|
240 |
+
|
241 |
+
def dict_to_nonedict(opt):
|
242 |
+
if isinstance(opt, dict):
|
243 |
+
new_opt = dict()
|
244 |
+
for key, sub_opt in opt.items():
|
245 |
+
new_opt[key] = dict_to_nonedict(sub_opt)
|
246 |
+
return NoneDict(**new_opt)
|
247 |
+
elif isinstance(opt, list):
|
248 |
+
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
|
249 |
+
else:
|
250 |
+
return opt
|
251 |
+
|
252 |
+
|
253 |
+
class NoneDict(dict):
|
254 |
+
def __missing__(self, key):
|
255 |
+
return None
|
core/data/deg_kair_utils/utils_params.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import torchvision
|
4 |
+
|
5 |
+
from models import basicblock as B
|
6 |
+
|
7 |
+
def show_kv(net):
|
8 |
+
for k, v in net.items():
|
9 |
+
print(k)
|
10 |
+
|
11 |
+
# should run train debug mode first to get an initial model
|
12 |
+
#crt_net = torch.load('../../experiments/debug_SRResNet_bicx4_in3nf64nb16/models/8_G.pth')
|
13 |
+
#
|
14 |
+
#for k, v in crt_net.items():
|
15 |
+
# print(k)
|
16 |
+
#for k, v in crt_net.items():
|
17 |
+
# if k in pretrained_net:
|
18 |
+
# crt_net[k] = pretrained_net[k]
|
19 |
+
# print('replace ... ', k)
|
20 |
+
|
21 |
+
# x2 -> x4
|
22 |
+
#crt_net['model.5.weight'] = pretrained_net['model.2.weight']
|
23 |
+
#crt_net['model.5.bias'] = pretrained_net['model.2.bias']
|
24 |
+
#crt_net['model.8.weight'] = pretrained_net['model.5.weight']
|
25 |
+
#crt_net['model.8.bias'] = pretrained_net['model.5.bias']
|
26 |
+
#crt_net['model.10.weight'] = pretrained_net['model.7.weight']
|
27 |
+
#crt_net['model.10.bias'] = pretrained_net['model.7.bias']
|
28 |
+
#torch.save(crt_net, '../pretrained_tmp.pth')
|
29 |
+
|
30 |
+
# x2 -> x3
|
31 |
+
'''
|
32 |
+
in_filter = pretrained_net['model.2.weight'] # 256, 64, 3, 3
|
33 |
+
new_filter = torch.Tensor(576, 64, 3, 3)
|
34 |
+
new_filter[0:256, :, :, :] = in_filter
|
35 |
+
new_filter[256:512, :, :, :] = in_filter
|
36 |
+
new_filter[512:, :, :, :] = in_filter[0:576-512, :, :, :]
|
37 |
+
crt_net['model.2.weight'] = new_filter
|
38 |
+
|
39 |
+
in_bias = pretrained_net['model.2.bias'] # 256, 64, 3, 3
|
40 |
+
new_bias = torch.Tensor(576)
|
41 |
+
new_bias[0:256] = in_bias
|
42 |
+
new_bias[256:512] = in_bias
|
43 |
+
new_bias[512:] = in_bias[0:576 - 512]
|
44 |
+
crt_net['model.2.bias'] = new_bias
|
45 |
+
|
46 |
+
torch.save(crt_net, '../pretrained_tmp.pth')
|
47 |
+
'''
|
48 |
+
|
49 |
+
# x2 -> x8
|
50 |
+
'''
|
51 |
+
crt_net['model.5.weight'] = pretrained_net['model.2.weight']
|
52 |
+
crt_net['model.5.bias'] = pretrained_net['model.2.bias']
|
53 |
+
crt_net['model.8.weight'] = pretrained_net['model.2.weight']
|
54 |
+
crt_net['model.8.bias'] = pretrained_net['model.2.bias']
|
55 |
+
crt_net['model.11.weight'] = pretrained_net['model.5.weight']
|
56 |
+
crt_net['model.11.bias'] = pretrained_net['model.5.bias']
|
57 |
+
crt_net['model.13.weight'] = pretrained_net['model.7.weight']
|
58 |
+
crt_net['model.13.bias'] = pretrained_net['model.7.bias']
|
59 |
+
torch.save(crt_net, '../pretrained_tmp.pth')
|
60 |
+
'''
|
61 |
+
|
62 |
+
# x3/4/8 RGB -> Y
|
63 |
+
|
64 |
+
def rgb2gray_net(net, only_input=True):
|
65 |
+
|
66 |
+
if only_input:
|
67 |
+
in_filter = net['0.weight']
|
68 |
+
in_new_filter = in_filter[:,0,:,:]*0.2989 + in_filter[:,1,:,:]*0.587 + in_filter[:,2,:,:]*0.114
|
69 |
+
in_new_filter.unsqueeze_(1)
|
70 |
+
net['0.weight'] = in_new_filter
|
71 |
+
|
72 |
+
# out_filter = pretrained_net['model.13.weight']
|
73 |
+
# out_new_filter = out_filter[0, :, :, :] * 0.2989 + out_filter[1, :, :, :] * 0.587 + \
|
74 |
+
# out_filter[2, :, :, :] * 0.114
|
75 |
+
# out_new_filter.unsqueeze_(0)
|
76 |
+
# crt_net['model.13.weight'] = out_new_filter
|
77 |
+
# out_bias = pretrained_net['model.13.bias']
|
78 |
+
# out_new_bias = out_bias[0] * 0.2989 + out_bias[1] * 0.587 + out_bias[2] * 0.114
|
79 |
+
# out_new_bias = torch.Tensor(1).fill_(out_new_bias)
|
80 |
+
# crt_net['model.13.bias'] = out_new_bias
|
81 |
+
|
82 |
+
# torch.save(crt_net, '../pretrained_tmp.pth')
|
83 |
+
|
84 |
+
return net
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
if __name__ == '__main__':
|
89 |
+
|
90 |
+
net = torchvision.models.vgg19(pretrained=True)
|
91 |
+
for k,v in net.features.named_parameters():
|
92 |
+
if k=='0.weight':
|
93 |
+
in_new_filter = v[:,0,:,:]*0.2989 + v[:,1,:,:]*0.587 + v[:,2,:,:]*0.114
|
94 |
+
in_new_filter.unsqueeze_(1)
|
95 |
+
v = in_new_filter
|
96 |
+
print(v.shape)
|
97 |
+
print(v[0,0,0,0])
|
98 |
+
if k=='0.bias':
|
99 |
+
in_new_bias = v
|
100 |
+
print(v[0])
|
101 |
+
|
102 |
+
print(net.features[0])
|
103 |
+
|
104 |
+
net.features[0] = B.conv(1, 64, mode='C')
|
105 |
+
|
106 |
+
print(net.features[0])
|
107 |
+
net.features[0].weight.data=in_new_filter
|
108 |
+
net.features[0].bias.data=in_new_bias
|
109 |
+
|
110 |
+
for k,v in net.features.named_parameters():
|
111 |
+
if k=='0.weight':
|
112 |
+
print(v[0,0,0,0])
|
113 |
+
if k=='0.bias':
|
114 |
+
print(v[0])
|
115 |
+
|
116 |
+
# transfer parameters of old model to new one
|
117 |
+
model_old = torch.load(model_path)
|
118 |
+
state_dict = model.state_dict()
|
119 |
+
for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()):
|
120 |
+
state_dict[key2] = param
|
121 |
+
print([key, key2])
|
122 |
+
# print([param.size(), param2.size()])
|
123 |
+
torch.save(state_dict, 'model_new.pth')
|
124 |
+
|
125 |
+
|
126 |
+
# rgb2gray_net(net)
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
core/data/deg_kair_utils/utils_receptivefield.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# online calculation: https://fomoro.com/research/article/receptive-field-calculator#
|
4 |
+
|
5 |
+
# [filter size, stride, padding]
|
6 |
+
#Assume the two dimensions are the same
|
7 |
+
#Each kernel requires the following parameters:
|
8 |
+
# - k_i: kernel size
|
9 |
+
# - s_i: stride
|
10 |
+
# - p_i: padding (if padding is uneven, right padding will higher than left padding; "SAME" option in tensorflow)
|
11 |
+
#
|
12 |
+
#Each layer i requires the following parameters to be fully represented:
|
13 |
+
# - n_i: number of feature (data layer has n_1 = imagesize )
|
14 |
+
# - j_i: distance (projected to image pixel distance) between center of two adjacent features
|
15 |
+
# - r_i: receptive field of a feature in layer i
|
16 |
+
# - start_i: position of the first feature's receptive field in layer i (idx start from 0, negative means the center fall into padding)
|
17 |
+
|
18 |
+
import math
|
19 |
+
|
20 |
+
def outFromIn(conv, layerIn):
|
21 |
+
n_in = layerIn[0]
|
22 |
+
j_in = layerIn[1]
|
23 |
+
r_in = layerIn[2]
|
24 |
+
start_in = layerIn[3]
|
25 |
+
k = conv[0]
|
26 |
+
s = conv[1]
|
27 |
+
p = conv[2]
|
28 |
+
|
29 |
+
n_out = math.floor((n_in - k + 2*p)/s) + 1
|
30 |
+
actualP = (n_out-1)*s - n_in + k
|
31 |
+
pR = math.ceil(actualP/2)
|
32 |
+
pL = math.floor(actualP/2)
|
33 |
+
|
34 |
+
j_out = j_in * s
|
35 |
+
r_out = r_in + (k - 1)*j_in
|
36 |
+
start_out = start_in + ((k-1)/2 - pL)*j_in
|
37 |
+
return n_out, j_out, r_out, start_out
|
38 |
+
|
39 |
+
def printLayer(layer, layer_name):
|
40 |
+
print(layer_name + ":")
|
41 |
+
print(" n features: %s jump: %s receptive size: %s start: %s " % (layer[0], layer[1], layer[2], layer[3]))
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
layerInfos = []
|
46 |
+
if __name__ == '__main__':
|
47 |
+
|
48 |
+
convnet = [[3,1,1],[3,1,1],[3,1,1],[4,2,1],[2,2,0],[3,1,1]]
|
49 |
+
layer_names = ['conv1','conv2','conv3','conv4','conv5','conv6','conv7','conv8','conv9','conv10','conv11','conv12']
|
50 |
+
imsize = 128
|
51 |
+
|
52 |
+
print ("-------Net summary------")
|
53 |
+
currentLayer = [imsize, 1, 1, 0.5]
|
54 |
+
printLayer(currentLayer, "input image")
|
55 |
+
for i in range(len(convnet)):
|
56 |
+
currentLayer = outFromIn(convnet[i], currentLayer)
|
57 |
+
layerInfos.append(currentLayer)
|
58 |
+
printLayer(currentLayer, layer_names[i])
|
59 |
+
|
60 |
+
|
61 |
+
# run utils/utils_receptivefield.py
|
62 |
+
|
core/data/deg_kair_utils/utils_regularizers.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
'''
|
6 |
+
# --------------------------------------------
|
7 |
+
# Kai Zhang (github: https://github.com/cszn)
|
8 |
+
# 03/Mar/2019
|
9 |
+
# --------------------------------------------
|
10 |
+
'''
|
11 |
+
|
12 |
+
|
13 |
+
# --------------------------------------------
|
14 |
+
# SVD Orthogonal Regularization
|
15 |
+
# --------------------------------------------
|
16 |
+
def regularizer_orth(m):
|
17 |
+
"""
|
18 |
+
# ----------------------------------------
|
19 |
+
# SVD Orthogonal Regularization
|
20 |
+
# ----------------------------------------
|
21 |
+
# Applies regularization to the training by performing the
|
22 |
+
# orthogonalization technique described in the paper
|
23 |
+
# This function is to be called by the torch.nn.Module.apply() method,
|
24 |
+
# which applies svd_orthogonalization() to every layer of the model.
|
25 |
+
# usage: net.apply(regularizer_orth)
|
26 |
+
# ----------------------------------------
|
27 |
+
"""
|
28 |
+
classname = m.__class__.__name__
|
29 |
+
if classname.find('Conv') != -1:
|
30 |
+
w = m.weight.data.clone()
|
31 |
+
c_out, c_in, f1, f2 = w.size()
|
32 |
+
# dtype = m.weight.data.type()
|
33 |
+
w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
|
34 |
+
# self.netG.apply(svd_orthogonalization)
|
35 |
+
u, s, v = torch.svd(w)
|
36 |
+
s[s > 1.5] = s[s > 1.5] - 1e-4
|
37 |
+
s[s < 0.5] = s[s < 0.5] + 1e-4
|
38 |
+
w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
|
39 |
+
m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)
|
40 |
+
else:
|
41 |
+
pass
|
42 |
+
|
43 |
+
|
44 |
+
# --------------------------------------------
|
45 |
+
# SVD Orthogonal Regularization
|
46 |
+
# --------------------------------------------
|
47 |
+
def regularizer_orth2(m):
|
48 |
+
"""
|
49 |
+
# ----------------------------------------
|
50 |
+
# Applies regularization to the training by performing the
|
51 |
+
# orthogonalization technique described in the paper
|
52 |
+
# This function is to be called by the torch.nn.Module.apply() method,
|
53 |
+
# which applies svd_orthogonalization() to every layer of the model.
|
54 |
+
# usage: net.apply(regularizer_orth2)
|
55 |
+
# ----------------------------------------
|
56 |
+
"""
|
57 |
+
classname = m.__class__.__name__
|
58 |
+
if classname.find('Conv') != -1:
|
59 |
+
w = m.weight.data.clone()
|
60 |
+
c_out, c_in, f1, f2 = w.size()
|
61 |
+
# dtype = m.weight.data.type()
|
62 |
+
w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
|
63 |
+
u, s, v = torch.svd(w)
|
64 |
+
s_mean = s.mean()
|
65 |
+
s[s > 1.5*s_mean] = s[s > 1.5*s_mean] - 1e-4
|
66 |
+
s[s < 0.5*s_mean] = s[s < 0.5*s_mean] + 1e-4
|
67 |
+
w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
|
68 |
+
m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)
|
69 |
+
else:
|
70 |
+
pass
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
def regularizer_clip(m):
|
75 |
+
"""
|
76 |
+
# ----------------------------------------
|
77 |
+
# usage: net.apply(regularizer_clip)
|
78 |
+
# ----------------------------------------
|
79 |
+
"""
|
80 |
+
eps = 1e-4
|
81 |
+
c_min = -1.5
|
82 |
+
c_max = 1.5
|
83 |
+
|
84 |
+
classname = m.__class__.__name__
|
85 |
+
if classname.find('Conv') != -1 or classname.find('Linear') != -1:
|
86 |
+
w = m.weight.data.clone()
|
87 |
+
w[w > c_max] -= eps
|
88 |
+
w[w < c_min] += eps
|
89 |
+
m.weight.data = w
|
90 |
+
|
91 |
+
if m.bias is not None:
|
92 |
+
b = m.bias.data.clone()
|
93 |
+
b[b > c_max] -= eps
|
94 |
+
b[b < c_min] += eps
|
95 |
+
m.bias.data = b
|
96 |
+
|
97 |
+
# elif classname.find('BatchNorm2d') != -1:
|
98 |
+
#
|
99 |
+
# rv = m.running_var.data.clone()
|
100 |
+
# rm = m.running_mean.data.clone()
|
101 |
+
#
|
102 |
+
# if m.affine:
|
103 |
+
# m.weight.data
|
104 |
+
# m.bias.data
|
core/data/deg_kair_utils/utils_sisr.py
ADDED
@@ -0,0 +1,848 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
from utils import utils_image as util
|
3 |
+
import random
|
4 |
+
|
5 |
+
import scipy
|
6 |
+
import scipy.stats as ss
|
7 |
+
import scipy.io as io
|
8 |
+
from scipy import ndimage
|
9 |
+
from scipy.interpolate import interp2d
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
"""
|
16 |
+
# --------------------------------------------
|
17 |
+
# Super-Resolution
|
18 |
+
# --------------------------------------------
|
19 |
+
#
|
20 |
+
# Kai Zhang ([email protected])
|
21 |
+
# https://github.com/cszn
|
22 |
+
# modified by Kai Zhang (github: https://github.com/cszn)
|
23 |
+
# 03/03/2020
|
24 |
+
# --------------------------------------------
|
25 |
+
"""
|
26 |
+
|
27 |
+
|
28 |
+
"""
|
29 |
+
# --------------------------------------------
|
30 |
+
# anisotropic Gaussian kernels
|
31 |
+
# --------------------------------------------
|
32 |
+
"""
|
33 |
+
|
34 |
+
|
35 |
+
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
36 |
+
""" generate an anisotropic Gaussian kernel
|
37 |
+
Args:
|
38 |
+
ksize : e.g., 15, kernel size
|
39 |
+
theta : [0, pi], rotation angle range
|
40 |
+
l1 : [0.1,50], scaling of eigenvalues
|
41 |
+
l2 : [0.1,l1], scaling of eigenvalues
|
42 |
+
If l1 = l2, will get an isotropic Gaussian kernel.
|
43 |
+
Returns:
|
44 |
+
k : kernel
|
45 |
+
"""
|
46 |
+
|
47 |
+
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
|
48 |
+
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
49 |
+
D = np.array([[l1, 0], [0, l2]])
|
50 |
+
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
51 |
+
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
52 |
+
|
53 |
+
return k
|
54 |
+
|
55 |
+
|
56 |
+
def gm_blur_kernel(mean, cov, size=15):
|
57 |
+
center = size / 2.0 + 0.5
|
58 |
+
k = np.zeros([size, size])
|
59 |
+
for y in range(size):
|
60 |
+
for x in range(size):
|
61 |
+
cy = y - center + 1
|
62 |
+
cx = x - center + 1
|
63 |
+
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
64 |
+
|
65 |
+
k = k / np.sum(k)
|
66 |
+
return k
|
67 |
+
|
68 |
+
|
69 |
+
"""
|
70 |
+
# --------------------------------------------
|
71 |
+
# calculate PCA projection matrix
|
72 |
+
# --------------------------------------------
|
73 |
+
"""
|
74 |
+
|
75 |
+
|
76 |
+
def get_pca_matrix(x, dim_pca=15):
|
77 |
+
"""
|
78 |
+
Args:
|
79 |
+
x: 225x10000 matrix
|
80 |
+
dim_pca: 15
|
81 |
+
Returns:
|
82 |
+
pca_matrix: 15x225
|
83 |
+
"""
|
84 |
+
C = np.dot(x, x.T)
|
85 |
+
w, v = scipy.linalg.eigh(C)
|
86 |
+
pca_matrix = v[:, -dim_pca:].T
|
87 |
+
|
88 |
+
return pca_matrix
|
89 |
+
|
90 |
+
|
91 |
+
def show_pca(x):
|
92 |
+
"""
|
93 |
+
x: PCA projection matrix, e.g., 15x225
|
94 |
+
"""
|
95 |
+
for i in range(x.shape[0]):
|
96 |
+
xc = np.reshape(x[i, :], (int(np.sqrt(x.shape[1])), -1), order="F")
|
97 |
+
util.surf(xc)
|
98 |
+
|
99 |
+
|
100 |
+
def cal_pca_matrix(path='PCA_matrix.mat', ksize=15, l_max=12.0, dim_pca=15, num_samples=500):
|
101 |
+
kernels = np.zeros([ksize*ksize, num_samples], dtype=np.float32)
|
102 |
+
for i in range(num_samples):
|
103 |
+
|
104 |
+
theta = np.pi*np.random.rand(1)
|
105 |
+
l1 = 0.1+l_max*np.random.rand(1)
|
106 |
+
l2 = 0.1+(l1-0.1)*np.random.rand(1)
|
107 |
+
|
108 |
+
k = anisotropic_Gaussian(ksize=ksize, theta=theta[0], l1=l1[0], l2=l2[0])
|
109 |
+
|
110 |
+
# util.imshow(k)
|
111 |
+
|
112 |
+
kernels[:, i] = np.reshape(k, (-1), order="F") # k.flatten(order='F')
|
113 |
+
|
114 |
+
# io.savemat('k.mat', {'k': kernels})
|
115 |
+
|
116 |
+
pca_matrix = get_pca_matrix(kernels, dim_pca=dim_pca)
|
117 |
+
|
118 |
+
io.savemat(path, {'p': pca_matrix})
|
119 |
+
|
120 |
+
return pca_matrix
|
121 |
+
|
122 |
+
|
123 |
+
"""
|
124 |
+
# --------------------------------------------
|
125 |
+
# shifted anisotropic Gaussian kernels
|
126 |
+
# --------------------------------------------
|
127 |
+
"""
|
128 |
+
|
129 |
+
|
130 |
+
def shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
|
131 |
+
""""
|
132 |
+
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
133 |
+
# Kai Zhang
|
134 |
+
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
135 |
+
# max_var = 2.5 * sf
|
136 |
+
"""
|
137 |
+
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
138 |
+
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
139 |
+
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
140 |
+
theta = np.random.rand() * np.pi # random theta
|
141 |
+
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
142 |
+
|
143 |
+
# Set COV matrix using Lambdas and Theta
|
144 |
+
LAMBDA = np.diag([lambda_1, lambda_2])
|
145 |
+
Q = np.array([[np.cos(theta), -np.sin(theta)],
|
146 |
+
[np.sin(theta), np.cos(theta)]])
|
147 |
+
SIGMA = Q @ LAMBDA @ Q.T
|
148 |
+
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
149 |
+
|
150 |
+
# Set expectation position (shifting kernel for aligned image)
|
151 |
+
MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
152 |
+
MU = MU[None, None, :, None]
|
153 |
+
|
154 |
+
# Create meshgrid for Gaussian
|
155 |
+
[X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
156 |
+
Z = np.stack([X, Y], 2)[:, :, :, None]
|
157 |
+
|
158 |
+
# Calcualte Gaussian for every pixel of the kernel
|
159 |
+
ZZ = Z-MU
|
160 |
+
ZZ_t = ZZ.transpose(0,1,3,2)
|
161 |
+
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
162 |
+
|
163 |
+
# shift the kernel so it will be centered
|
164 |
+
#raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
165 |
+
|
166 |
+
# Normalize the kernel and return
|
167 |
+
#kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
168 |
+
kernel = raw_kernel / np.sum(raw_kernel)
|
169 |
+
return kernel
|
170 |
+
|
171 |
+
|
172 |
+
def gen_kernel(k_size=np.array([25, 25]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=12., noise_level=0):
|
173 |
+
""""
|
174 |
+
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
175 |
+
# Kai Zhang
|
176 |
+
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
177 |
+
# max_var = 2.5 * sf
|
178 |
+
"""
|
179 |
+
sf = random.choice([1, 2, 3, 4])
|
180 |
+
scale_factor = np.array([sf, sf])
|
181 |
+
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
182 |
+
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
183 |
+
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
184 |
+
theta = np.random.rand() * np.pi # random theta
|
185 |
+
noise = 0#-noise_level + np.random.rand(*k_size) * noise_level * 2
|
186 |
+
|
187 |
+
# Set COV matrix using Lambdas and Theta
|
188 |
+
LAMBDA = np.diag([lambda_1, lambda_2])
|
189 |
+
Q = np.array([[np.cos(theta), -np.sin(theta)],
|
190 |
+
[np.sin(theta), np.cos(theta)]])
|
191 |
+
SIGMA = Q @ LAMBDA @ Q.T
|
192 |
+
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
193 |
+
|
194 |
+
# Set expectation position (shifting kernel for aligned image)
|
195 |
+
MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
196 |
+
MU = MU[None, None, :, None]
|
197 |
+
|
198 |
+
# Create meshgrid for Gaussian
|
199 |
+
[X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
200 |
+
Z = np.stack([X, Y], 2)[:, :, :, None]
|
201 |
+
|
202 |
+
# Calcualte Gaussian for every pixel of the kernel
|
203 |
+
ZZ = Z-MU
|
204 |
+
ZZ_t = ZZ.transpose(0,1,3,2)
|
205 |
+
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
206 |
+
|
207 |
+
# shift the kernel so it will be centered
|
208 |
+
#raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
209 |
+
|
210 |
+
# Normalize the kernel and return
|
211 |
+
#kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
212 |
+
kernel = raw_kernel / np.sum(raw_kernel)
|
213 |
+
return kernel
|
214 |
+
|
215 |
+
|
216 |
+
"""
|
217 |
+
# --------------------------------------------
|
218 |
+
# degradation models
|
219 |
+
# --------------------------------------------
|
220 |
+
"""
|
221 |
+
|
222 |
+
|
223 |
+
def bicubic_degradation(x, sf=3):
|
224 |
+
'''
|
225 |
+
Args:
|
226 |
+
x: HxWxC image, [0, 1]
|
227 |
+
sf: down-scale factor
|
228 |
+
Return:
|
229 |
+
bicubicly downsampled LR image
|
230 |
+
'''
|
231 |
+
x = util.imresize_np(x, scale=1/sf)
|
232 |
+
return x
|
233 |
+
|
234 |
+
|
235 |
+
def srmd_degradation(x, k, sf=3):
|
236 |
+
''' blur + bicubic downsampling
|
237 |
+
Args:
|
238 |
+
x: HxWxC image, [0, 1]
|
239 |
+
k: hxw, double
|
240 |
+
sf: down-scale factor
|
241 |
+
Return:
|
242 |
+
downsampled LR image
|
243 |
+
Reference:
|
244 |
+
@inproceedings{zhang2018learning,
|
245 |
+
title={Learning a single convolutional super-resolution network for multiple degradations},
|
246 |
+
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
247 |
+
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
248 |
+
pages={3262--3271},
|
249 |
+
year={2018}
|
250 |
+
}
|
251 |
+
'''
|
252 |
+
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
|
253 |
+
x = bicubic_degradation(x, sf=sf)
|
254 |
+
return x
|
255 |
+
|
256 |
+
|
257 |
+
def dpsr_degradation(x, k, sf=3):
|
258 |
+
|
259 |
+
''' bicubic downsampling + blur
|
260 |
+
Args:
|
261 |
+
x: HxWxC image, [0, 1]
|
262 |
+
k: hxw, double
|
263 |
+
sf: down-scale factor
|
264 |
+
Return:
|
265 |
+
downsampled LR image
|
266 |
+
Reference:
|
267 |
+
@inproceedings{zhang2019deep,
|
268 |
+
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
269 |
+
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
270 |
+
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
271 |
+
pages={1671--1681},
|
272 |
+
year={2019}
|
273 |
+
}
|
274 |
+
'''
|
275 |
+
x = bicubic_degradation(x, sf=sf)
|
276 |
+
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
277 |
+
return x
|
278 |
+
|
279 |
+
|
280 |
+
def classical_degradation(x, k, sf=3):
|
281 |
+
''' blur + downsampling
|
282 |
+
|
283 |
+
Args:
|
284 |
+
x: HxWxC image, [0, 1]/[0, 255]
|
285 |
+
k: hxw, double
|
286 |
+
sf: down-scale factor
|
287 |
+
|
288 |
+
Return:
|
289 |
+
downsampled LR image
|
290 |
+
'''
|
291 |
+
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
292 |
+
#x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
293 |
+
st = 0
|
294 |
+
return x[st::sf, st::sf, ...]
|
295 |
+
|
296 |
+
|
297 |
+
def modcrop_np(img, sf):
|
298 |
+
'''
|
299 |
+
Args:
|
300 |
+
img: numpy image, WxH or WxHxC
|
301 |
+
sf: scale factor
|
302 |
+
Return:
|
303 |
+
cropped image
|
304 |
+
'''
|
305 |
+
w, h = img.shape[:2]
|
306 |
+
im = np.copy(img)
|
307 |
+
return im[:w - w % sf, :h - h % sf, ...]
|
308 |
+
|
309 |
+
|
310 |
+
'''
|
311 |
+
# =================
|
312 |
+
# Numpy
|
313 |
+
# =================
|
314 |
+
'''
|
315 |
+
|
316 |
+
|
317 |
+
def shift_pixel(x, sf, upper_left=True):
|
318 |
+
"""shift pixel for super-resolution with different scale factors
|
319 |
+
Args:
|
320 |
+
x: WxHxC or WxH, image or kernel
|
321 |
+
sf: scale factor
|
322 |
+
upper_left: shift direction
|
323 |
+
"""
|
324 |
+
h, w = x.shape[:2]
|
325 |
+
shift = (sf-1)*0.5
|
326 |
+
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
327 |
+
if upper_left:
|
328 |
+
x1 = xv + shift
|
329 |
+
y1 = yv + shift
|
330 |
+
else:
|
331 |
+
x1 = xv - shift
|
332 |
+
y1 = yv - shift
|
333 |
+
|
334 |
+
x1 = np.clip(x1, 0, w-1)
|
335 |
+
y1 = np.clip(y1, 0, h-1)
|
336 |
+
|
337 |
+
if x.ndim == 2:
|
338 |
+
x = interp2d(xv, yv, x)(x1, y1)
|
339 |
+
if x.ndim == 3:
|
340 |
+
for i in range(x.shape[-1]):
|
341 |
+
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
342 |
+
|
343 |
+
return x
|
344 |
+
|
345 |
+
|
346 |
+
'''
|
347 |
+
# =================
|
348 |
+
# pytorch
|
349 |
+
# =================
|
350 |
+
'''
|
351 |
+
|
352 |
+
|
353 |
+
def splits(a, sf):
|
354 |
+
'''
|
355 |
+
a: tensor NxCxWxHx2
|
356 |
+
sf: scale factor
|
357 |
+
out: tensor NxCx(W/sf)x(H/sf)x2x(sf^2)
|
358 |
+
'''
|
359 |
+
b = torch.stack(torch.chunk(a, sf, dim=2), dim=5)
|
360 |
+
b = torch.cat(torch.chunk(b, sf, dim=3), dim=5)
|
361 |
+
return b
|
362 |
+
|
363 |
+
|
364 |
+
def c2c(x):
|
365 |
+
return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
|
366 |
+
|
367 |
+
|
368 |
+
def r2c(x):
|
369 |
+
return torch.stack([x, torch.zeros_like(x)], -1)
|
370 |
+
|
371 |
+
|
372 |
+
def cdiv(x, y):
|
373 |
+
a, b = x[..., 0], x[..., 1]
|
374 |
+
c, d = y[..., 0], y[..., 1]
|
375 |
+
cd2 = c**2 + d**2
|
376 |
+
return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
|
377 |
+
|
378 |
+
|
379 |
+
def csum(x, y):
|
380 |
+
return torch.stack([x[..., 0] + y, x[..., 1]], -1)
|
381 |
+
|
382 |
+
|
383 |
+
def cabs(x):
|
384 |
+
return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
|
385 |
+
|
386 |
+
|
387 |
+
def cmul(t1, t2):
|
388 |
+
'''
|
389 |
+
complex multiplication
|
390 |
+
t1: NxCxHxWx2
|
391 |
+
output: NxCxHxWx2
|
392 |
+
'''
|
393 |
+
real1, imag1 = t1[..., 0], t1[..., 1]
|
394 |
+
real2, imag2 = t2[..., 0], t2[..., 1]
|
395 |
+
return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
|
396 |
+
|
397 |
+
|
398 |
+
def cconj(t, inplace=False):
|
399 |
+
'''
|
400 |
+
# complex's conjugation
|
401 |
+
t: NxCxHxWx2
|
402 |
+
output: NxCxHxWx2
|
403 |
+
'''
|
404 |
+
c = t.clone() if not inplace else t
|
405 |
+
c[..., 1] *= -1
|
406 |
+
return c
|
407 |
+
|
408 |
+
|
409 |
+
def rfft(t):
|
410 |
+
return torch.rfft(t, 2, onesided=False)
|
411 |
+
|
412 |
+
|
413 |
+
def irfft(t):
|
414 |
+
return torch.irfft(t, 2, onesided=False)
|
415 |
+
|
416 |
+
|
417 |
+
def fft(t):
|
418 |
+
return torch.fft(t, 2)
|
419 |
+
|
420 |
+
|
421 |
+
def ifft(t):
|
422 |
+
return torch.ifft(t, 2)
|
423 |
+
|
424 |
+
|
425 |
+
def p2o(psf, shape):
|
426 |
+
'''
|
427 |
+
Args:
|
428 |
+
psf: NxCxhxw
|
429 |
+
shape: [H,W]
|
430 |
+
|
431 |
+
Returns:
|
432 |
+
otf: NxCxHxWx2
|
433 |
+
'''
|
434 |
+
otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
|
435 |
+
otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
|
436 |
+
for axis, axis_size in enumerate(psf.shape[2:]):
|
437 |
+
otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
|
438 |
+
otf = torch.rfft(otf, 2, onesided=False)
|
439 |
+
n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
|
440 |
+
otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
|
441 |
+
return otf
|
442 |
+
|
443 |
+
|
444 |
+
'''
|
445 |
+
# =================
|
446 |
+
PyTorch
|
447 |
+
# =================
|
448 |
+
'''
|
449 |
+
|
450 |
+
def INVLS_pytorch(FB, FBC, F2B, FR, tau, sf=2):
|
451 |
+
'''
|
452 |
+
FB: NxCxWxHx2
|
453 |
+
F2B: NxCxWxHx2
|
454 |
+
|
455 |
+
x1 = FB.*FR;
|
456 |
+
FBR = BlockMM(nr,nc,Nb,m,x1);
|
457 |
+
invW = BlockMM(nr,nc,Nb,m,F2B);
|
458 |
+
invWBR = FBR./(invW + tau*Nb);
|
459 |
+
fun = @(block_struct) block_struct.data.*invWBR;
|
460 |
+
FCBinvWBR = blockproc(FBC,[nr,nc],fun);
|
461 |
+
FX = (FR-FCBinvWBR)/tau;
|
462 |
+
Xest = real(ifft2(FX));
|
463 |
+
'''
|
464 |
+
x1 = cmul(FB, FR)
|
465 |
+
FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False)
|
466 |
+
invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False)
|
467 |
+
invWBR = cdiv(FBR, csum(invW, tau))
|
468 |
+
FCBinvWBR = cmul(FBC, invWBR.repeat(1,1,sf,sf,1))
|
469 |
+
FX = (FR-FCBinvWBR)/tau
|
470 |
+
Xest = torch.irfft(FX, 2, onesided=False)
|
471 |
+
return Xest
|
472 |
+
|
473 |
+
|
474 |
+
def real2complex(x):
|
475 |
+
return torch.stack([x, torch.zeros_like(x)], -1)
|
476 |
+
|
477 |
+
|
478 |
+
def modcrop(img, sf):
|
479 |
+
'''
|
480 |
+
img: tensor image, NxCxWxH or CxWxH or WxH
|
481 |
+
sf: scale factor
|
482 |
+
'''
|
483 |
+
w, h = img.shape[-2:]
|
484 |
+
im = img.clone()
|
485 |
+
return im[..., :w - w % sf, :h - h % sf]
|
486 |
+
|
487 |
+
|
488 |
+
def upsample(x, sf=3, center=False):
|
489 |
+
'''
|
490 |
+
x: tensor image, NxCxWxH
|
491 |
+
'''
|
492 |
+
st = (sf-1)//2 if center else 0
|
493 |
+
z = torch.zeros((x.shape[0], x.shape[1], x.shape[2]*sf, x.shape[3]*sf)).type_as(x)
|
494 |
+
z[..., st::sf, st::sf].copy_(x)
|
495 |
+
return z
|
496 |
+
|
497 |
+
|
498 |
+
def downsample(x, sf=3, center=False):
|
499 |
+
st = (sf-1)//2 if center else 0
|
500 |
+
return x[..., st::sf, st::sf]
|
501 |
+
|
502 |
+
|
503 |
+
def circular_pad(x, pad):
|
504 |
+
'''
|
505 |
+
# x[N, 1, W, H] -> x[N, 1, W + 2 pad, H + 2 pad] (pariodic padding)
|
506 |
+
'''
|
507 |
+
x = torch.cat([x, x[:, :, 0:pad, :]], dim=2)
|
508 |
+
x = torch.cat([x, x[:, :, :, 0:pad]], dim=3)
|
509 |
+
x = torch.cat([x[:, :, -2 * pad:-pad, :], x], dim=2)
|
510 |
+
x = torch.cat([x[:, :, :, -2 * pad:-pad], x], dim=3)
|
511 |
+
return x
|
512 |
+
|
513 |
+
|
514 |
+
def pad_circular(input, padding):
|
515 |
+
# type: (Tensor, List[int]) -> Tensor
|
516 |
+
"""
|
517 |
+
Arguments
|
518 |
+
:param input: tensor of shape :math:`(N, C_{\text{in}}, H, [W, D]))`
|
519 |
+
:param padding: (tuple): m-elem tuple where m is the degree of convolution
|
520 |
+
Returns
|
521 |
+
:return: tensor of shape :math:`(N, C_{\text{in}}, [D + 2 * padding[0],
|
522 |
+
H + 2 * padding[1]], W + 2 * padding[2]))`
|
523 |
+
"""
|
524 |
+
offset = 3
|
525 |
+
for dimension in range(input.dim() - offset + 1):
|
526 |
+
input = dim_pad_circular(input, padding[dimension], dimension + offset)
|
527 |
+
return input
|
528 |
+
|
529 |
+
|
530 |
+
def dim_pad_circular(input, padding, dimension):
|
531 |
+
# type: (Tensor, int, int) -> Tensor
|
532 |
+
input = torch.cat([input, input[[slice(None)] * (dimension - 1) +
|
533 |
+
[slice(0, padding)]]], dim=dimension - 1)
|
534 |
+
input = torch.cat([input[[slice(None)] * (dimension - 1) +
|
535 |
+
[slice(-2 * padding, -padding)]], input], dim=dimension - 1)
|
536 |
+
return input
|
537 |
+
|
538 |
+
|
539 |
+
def imfilter(x, k):
|
540 |
+
'''
|
541 |
+
x: image, NxcxHxW
|
542 |
+
k: kernel, cx1xhxw
|
543 |
+
'''
|
544 |
+
x = pad_circular(x, padding=((k.shape[-2]-1)//2, (k.shape[-1]-1)//2))
|
545 |
+
x = torch.nn.functional.conv2d(x, k, groups=x.shape[1])
|
546 |
+
return x
|
547 |
+
|
548 |
+
|
549 |
+
def G(x, k, sf=3, center=False):
|
550 |
+
'''
|
551 |
+
x: image, NxcxHxW
|
552 |
+
k: kernel, cx1xhxw
|
553 |
+
sf: scale factor
|
554 |
+
center: the first one or the moddle one
|
555 |
+
|
556 |
+
Matlab function:
|
557 |
+
tmp = imfilter(x,h,'circular');
|
558 |
+
y = downsample2(tmp,K);
|
559 |
+
'''
|
560 |
+
x = downsample(imfilter(x, k), sf=sf, center=center)
|
561 |
+
return x
|
562 |
+
|
563 |
+
|
564 |
+
def Gt(x, k, sf=3, center=False):
|
565 |
+
'''
|
566 |
+
x: image, NxcxHxW
|
567 |
+
k: kernel, cx1xhxw
|
568 |
+
sf: scale factor
|
569 |
+
center: the first one or the moddle one
|
570 |
+
|
571 |
+
Matlab function:
|
572 |
+
tmp = upsample2(x,K);
|
573 |
+
y = imfilter(tmp,h,'circular');
|
574 |
+
'''
|
575 |
+
x = imfilter(upsample(x, sf=sf, center=center), k)
|
576 |
+
return x
|
577 |
+
|
578 |
+
|
579 |
+
def interpolation_down(x, sf, center=False):
|
580 |
+
mask = torch.zeros_like(x)
|
581 |
+
if center:
|
582 |
+
start = torch.tensor((sf-1)//2)
|
583 |
+
mask[..., start::sf, start::sf] = torch.tensor(1).type_as(x)
|
584 |
+
LR = x[..., start::sf, start::sf]
|
585 |
+
else:
|
586 |
+
mask[..., ::sf, ::sf] = torch.tensor(1).type_as(x)
|
587 |
+
LR = x[..., ::sf, ::sf]
|
588 |
+
y = x.mul(mask)
|
589 |
+
|
590 |
+
return LR, y, mask
|
591 |
+
|
592 |
+
|
593 |
+
'''
|
594 |
+
# =================
|
595 |
+
Numpy
|
596 |
+
# =================
|
597 |
+
'''
|
598 |
+
|
599 |
+
|
600 |
+
def blockproc(im, blocksize, fun):
|
601 |
+
xblocks = np.split(im, range(blocksize[0], im.shape[0], blocksize[0]), axis=0)
|
602 |
+
xblocks_proc = []
|
603 |
+
for xb in xblocks:
|
604 |
+
yblocks = np.split(xb, range(blocksize[1], im.shape[1], blocksize[1]), axis=1)
|
605 |
+
yblocks_proc = []
|
606 |
+
for yb in yblocks:
|
607 |
+
yb_proc = fun(yb)
|
608 |
+
yblocks_proc.append(yb_proc)
|
609 |
+
xblocks_proc.append(np.concatenate(yblocks_proc, axis=1))
|
610 |
+
|
611 |
+
proc = np.concatenate(xblocks_proc, axis=0)
|
612 |
+
|
613 |
+
return proc
|
614 |
+
|
615 |
+
|
616 |
+
def fun_reshape(a):
|
617 |
+
return np.reshape(a, (-1,1,a.shape[-1]), order='F')
|
618 |
+
|
619 |
+
|
620 |
+
def fun_mul(a, b):
|
621 |
+
return a*b
|
622 |
+
|
623 |
+
|
624 |
+
def BlockMM(nr, nc, Nb, m, x1):
|
625 |
+
'''
|
626 |
+
myfun = @(block_struct) reshape(block_struct.data,m,1);
|
627 |
+
x1 = blockproc(x1,[nr nc],myfun);
|
628 |
+
x1 = reshape(x1,m,Nb);
|
629 |
+
x1 = sum(x1,2);
|
630 |
+
x = reshape(x1,nr,nc);
|
631 |
+
'''
|
632 |
+
fun = fun_reshape
|
633 |
+
x1 = blockproc(x1, blocksize=(nr, nc), fun=fun)
|
634 |
+
x1 = np.reshape(x1, (m, Nb, x1.shape[-1]), order='F')
|
635 |
+
x1 = np.sum(x1, 1)
|
636 |
+
x = np.reshape(x1, (nr, nc, x1.shape[-1]), order='F')
|
637 |
+
return x
|
638 |
+
|
639 |
+
|
640 |
+
def INVLS(FB, FBC, F2B, FR, tau, Nb, nr, nc, m):
|
641 |
+
'''
|
642 |
+
x1 = FB.*FR;
|
643 |
+
FBR = BlockMM(nr,nc,Nb,m,x1);
|
644 |
+
invW = BlockMM(nr,nc,Nb,m,F2B);
|
645 |
+
invWBR = FBR./(invW + tau*Nb);
|
646 |
+
fun = @(block_struct) block_struct.data.*invWBR;
|
647 |
+
FCBinvWBR = blockproc(FBC,[nr,nc],fun);
|
648 |
+
FX = (FR-FCBinvWBR)/tau;
|
649 |
+
Xest = real(ifft2(FX));
|
650 |
+
'''
|
651 |
+
x1 = FB*FR
|
652 |
+
FBR = BlockMM(nr, nc, Nb, m, x1)
|
653 |
+
invW = BlockMM(nr, nc, Nb, m, F2B)
|
654 |
+
invWBR = FBR/(invW + tau*Nb)
|
655 |
+
FCBinvWBR = blockproc(FBC, [nr, nc], lambda im: fun_mul(im, invWBR))
|
656 |
+
FX = (FR-FCBinvWBR)/tau
|
657 |
+
Xest = np.real(np.fft.ifft2(FX, axes=(0, 1)))
|
658 |
+
return Xest
|
659 |
+
|
660 |
+
|
661 |
+
def psf2otf(psf, shape=None):
|
662 |
+
"""
|
663 |
+
Convert point-spread function to optical transfer function.
|
664 |
+
Compute the Fast Fourier Transform (FFT) of the point-spread
|
665 |
+
function (PSF) array and creates the optical transfer function (OTF)
|
666 |
+
array that is not influenced by the PSF off-centering.
|
667 |
+
By default, the OTF array is the same size as the PSF array.
|
668 |
+
To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF
|
669 |
+
post-pads the PSF array (down or to the right) with zeros to match
|
670 |
+
dimensions specified in OUTSIZE, then circularly shifts the values of
|
671 |
+
the PSF array up (or to the left) until the central pixel reaches (1,1)
|
672 |
+
position.
|
673 |
+
Parameters
|
674 |
+
----------
|
675 |
+
psf : `numpy.ndarray`
|
676 |
+
PSF array
|
677 |
+
shape : int
|
678 |
+
Output shape of the OTF array
|
679 |
+
Returns
|
680 |
+
-------
|
681 |
+
otf : `numpy.ndarray`
|
682 |
+
OTF array
|
683 |
+
Notes
|
684 |
+
-----
|
685 |
+
Adapted from MATLAB psf2otf function
|
686 |
+
"""
|
687 |
+
if type(shape) == type(None):
|
688 |
+
shape = psf.shape
|
689 |
+
shape = np.array(shape)
|
690 |
+
if np.all(psf == 0):
|
691 |
+
# return np.zeros_like(psf)
|
692 |
+
return np.zeros(shape)
|
693 |
+
if len(psf.shape) == 1:
|
694 |
+
psf = psf.reshape((1, psf.shape[0]))
|
695 |
+
inshape = psf.shape
|
696 |
+
psf = zero_pad(psf, shape, position='corner')
|
697 |
+
for axis, axis_size in enumerate(inshape):
|
698 |
+
psf = np.roll(psf, -int(axis_size / 2), axis=axis)
|
699 |
+
# Compute the OTF
|
700 |
+
otf = np.fft.fft2(psf, axes=(0, 1))
|
701 |
+
# Estimate the rough number of operations involved in the FFT
|
702 |
+
# and discard the PSF imaginary part if within roundoff error
|
703 |
+
# roundoff error = machine epsilon = sys.float_info.epsilon
|
704 |
+
# or np.finfo().eps
|
705 |
+
n_ops = np.sum(psf.size * np.log2(psf.shape))
|
706 |
+
otf = np.real_if_close(otf, tol=n_ops)
|
707 |
+
return otf
|
708 |
+
|
709 |
+
|
710 |
+
def zero_pad(image, shape, position='corner'):
|
711 |
+
"""
|
712 |
+
Extends image to a certain size with zeros
|
713 |
+
Parameters
|
714 |
+
----------
|
715 |
+
image: real 2d `numpy.ndarray`
|
716 |
+
Input image
|
717 |
+
shape: tuple of int
|
718 |
+
Desired output shape of the image
|
719 |
+
position : str, optional
|
720 |
+
The position of the input image in the output one:
|
721 |
+
* 'corner'
|
722 |
+
top-left corner (default)
|
723 |
+
* 'center'
|
724 |
+
centered
|
725 |
+
Returns
|
726 |
+
-------
|
727 |
+
padded_img: real `numpy.ndarray`
|
728 |
+
The zero-padded image
|
729 |
+
"""
|
730 |
+
shape = np.asarray(shape, dtype=int)
|
731 |
+
imshape = np.asarray(image.shape, dtype=int)
|
732 |
+
if np.alltrue(imshape == shape):
|
733 |
+
return image
|
734 |
+
if np.any(shape <= 0):
|
735 |
+
raise ValueError("ZERO_PAD: null or negative shape given")
|
736 |
+
dshape = shape - imshape
|
737 |
+
if np.any(dshape < 0):
|
738 |
+
raise ValueError("ZERO_PAD: target size smaller than source one")
|
739 |
+
pad_img = np.zeros(shape, dtype=image.dtype)
|
740 |
+
idx, idy = np.indices(imshape)
|
741 |
+
if position == 'center':
|
742 |
+
if np.any(dshape % 2 != 0):
|
743 |
+
raise ValueError("ZERO_PAD: source and target shapes "
|
744 |
+
"have different parity.")
|
745 |
+
offx, offy = dshape // 2
|
746 |
+
else:
|
747 |
+
offx, offy = (0, 0)
|
748 |
+
pad_img[idx + offx, idy + offy] = image
|
749 |
+
return pad_img
|
750 |
+
|
751 |
+
|
752 |
+
def upsample_np(x, sf=3, center=False):
|
753 |
+
st = (sf-1)//2 if center else 0
|
754 |
+
z = np.zeros((x.shape[0]*sf, x.shape[1]*sf, x.shape[2]))
|
755 |
+
z[st::sf, st::sf, ...] = x
|
756 |
+
return z
|
757 |
+
|
758 |
+
|
759 |
+
def downsample_np(x, sf=3, center=False):
|
760 |
+
st = (sf-1)//2 if center else 0
|
761 |
+
return x[st::sf, st::sf, ...]
|
762 |
+
|
763 |
+
|
764 |
+
def imfilter_np(x, k):
|
765 |
+
'''
|
766 |
+
x: image, NxcxHxW
|
767 |
+
k: kernel, cx1xhxw
|
768 |
+
'''
|
769 |
+
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
770 |
+
return x
|
771 |
+
|
772 |
+
|
773 |
+
def G_np(x, k, sf=3, center=False):
|
774 |
+
'''
|
775 |
+
x: image, NxcxHxW
|
776 |
+
k: kernel, cx1xhxw
|
777 |
+
|
778 |
+
Matlab function:
|
779 |
+
tmp = imfilter(x,h,'circular');
|
780 |
+
y = downsample2(tmp,K);
|
781 |
+
'''
|
782 |
+
x = downsample_np(imfilter_np(x, k), sf=sf, center=center)
|
783 |
+
return x
|
784 |
+
|
785 |
+
|
786 |
+
def Gt_np(x, k, sf=3, center=False):
|
787 |
+
'''
|
788 |
+
x: image, NxcxHxW
|
789 |
+
k: kernel, cx1xhxw
|
790 |
+
|
791 |
+
Matlab function:
|
792 |
+
tmp = upsample2(x,K);
|
793 |
+
y = imfilter(tmp,h,'circular');
|
794 |
+
'''
|
795 |
+
x = imfilter_np(upsample_np(x, sf=sf, center=center), k)
|
796 |
+
return x
|
797 |
+
|
798 |
+
|
799 |
+
if __name__ == '__main__':
|
800 |
+
img = util.imread_uint('test.bmp', 3)
|
801 |
+
|
802 |
+
img = util.uint2single(img)
|
803 |
+
k = anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6)
|
804 |
+
util.imshow(k*10)
|
805 |
+
|
806 |
+
|
807 |
+
for sf in [2, 3, 4]:
|
808 |
+
|
809 |
+
# modcrop
|
810 |
+
img = modcrop_np(img, sf=sf)
|
811 |
+
|
812 |
+
# 1) bicubic degradation
|
813 |
+
img_b = bicubic_degradation(img, sf=sf)
|
814 |
+
print(img_b.shape)
|
815 |
+
|
816 |
+
# 2) srmd degradation
|
817 |
+
img_s = srmd_degradation(img, k, sf=sf)
|
818 |
+
print(img_s.shape)
|
819 |
+
|
820 |
+
# 3) dpsr degradation
|
821 |
+
img_d = dpsr_degradation(img, k, sf=sf)
|
822 |
+
print(img_d.shape)
|
823 |
+
|
824 |
+
# 4) classical degradation
|
825 |
+
img_d = classical_degradation(img, k, sf=sf)
|
826 |
+
print(img_d.shape)
|
827 |
+
|
828 |
+
k = anisotropic_Gaussian(ksize=7, theta=0.25*np.pi, l1=0.01, l2=0.01)
|
829 |
+
#print(k)
|
830 |
+
# util.imshow(k*10)
|
831 |
+
|
832 |
+
k = shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.8, max_var=10.8, noise_level=0.0)
|
833 |
+
# util.imshow(k*10)
|
834 |
+
|
835 |
+
|
836 |
+
# PCA
|
837 |
+
# pca_matrix = cal_pca_matrix(ksize=15, l_max=10.0, dim_pca=15, num_samples=12500)
|
838 |
+
# print(pca_matrix.shape)
|
839 |
+
# show_pca(pca_matrix)
|
840 |
+
# run utils/utils_sisr.py
|
841 |
+
# run utils_sisr.py
|
842 |
+
|
843 |
+
|
844 |
+
|
845 |
+
|
846 |
+
|
847 |
+
|
848 |
+
|
core/data/deg_kair_utils/utils_video.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
from os import path as osp
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from abc import ABCMeta, abstractmethod
|
9 |
+
|
10 |
+
|
11 |
+
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
12 |
+
"""Scan a directory to find the interested files.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
dir_path (str): Path of the directory.
|
16 |
+
suffix (str | tuple(str), optional): File suffix that we are
|
17 |
+
interested in. Default: None.
|
18 |
+
recursive (bool, optional): If set to True, recursively scan the
|
19 |
+
directory. Default: False.
|
20 |
+
full_path (bool, optional): If set to True, include the dir_path.
|
21 |
+
Default: False.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
A generator for all the interested files with relative paths.
|
25 |
+
"""
|
26 |
+
|
27 |
+
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
28 |
+
raise TypeError('"suffix" must be a string or tuple of strings')
|
29 |
+
|
30 |
+
root = dir_path
|
31 |
+
|
32 |
+
def _scandir(dir_path, suffix, recursive):
|
33 |
+
for entry in os.scandir(dir_path):
|
34 |
+
if not entry.name.startswith('.') and entry.is_file():
|
35 |
+
if full_path:
|
36 |
+
return_path = entry.path
|
37 |
+
else:
|
38 |
+
return_path = osp.relpath(entry.path, root)
|
39 |
+
|
40 |
+
if suffix is None:
|
41 |
+
yield return_path
|
42 |
+
elif return_path.endswith(suffix):
|
43 |
+
yield return_path
|
44 |
+
else:
|
45 |
+
if recursive:
|
46 |
+
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
|
47 |
+
else:
|
48 |
+
continue
|
49 |
+
|
50 |
+
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
51 |
+
|
52 |
+
|
53 |
+
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
|
54 |
+
"""Read a sequence of images from a given folder path.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
path (list[str] | str): List of image paths or image folder path.
|
58 |
+
require_mod_crop (bool): Require mod crop for each image.
|
59 |
+
Default: False.
|
60 |
+
scale (int): Scale factor for mod_crop. Default: 1.
|
61 |
+
return_imgname(bool): Whether return image names. Default False.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Tensor: size (t, c, h, w), RGB, [0, 1].
|
65 |
+
list[str]: Returned image name list.
|
66 |
+
"""
|
67 |
+
if isinstance(path, list):
|
68 |
+
img_paths = path
|
69 |
+
else:
|
70 |
+
img_paths = sorted(list(scandir(path, full_path=True)))
|
71 |
+
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
|
72 |
+
|
73 |
+
if require_mod_crop:
|
74 |
+
imgs = [mod_crop(img, scale) for img in imgs]
|
75 |
+
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
|
76 |
+
imgs = torch.stack(imgs, dim=0)
|
77 |
+
|
78 |
+
if return_imgname:
|
79 |
+
imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
|
80 |
+
return imgs, imgnames
|
81 |
+
else:
|
82 |
+
return imgs
|
83 |
+
|
84 |
+
|
85 |
+
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
86 |
+
"""Numpy array to tensor.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
imgs (list[ndarray] | ndarray): Input images.
|
90 |
+
bgr2rgb (bool): Whether to change bgr to rgb.
|
91 |
+
float32 (bool): Whether to change to float32.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
list[tensor] | tensor: Tensor images. If returned results only have
|
95 |
+
one element, just return tensor.
|
96 |
+
"""
|
97 |
+
|
98 |
+
def _totensor(img, bgr2rgb, float32):
|
99 |
+
if img.shape[2] == 3 and bgr2rgb:
|
100 |
+
if img.dtype == 'float64':
|
101 |
+
img = img.astype('float32')
|
102 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
103 |
+
img = torch.from_numpy(img.transpose(2, 0, 1))
|
104 |
+
if float32:
|
105 |
+
img = img.float()
|
106 |
+
return img
|
107 |
+
|
108 |
+
if isinstance(imgs, list):
|
109 |
+
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
110 |
+
else:
|
111 |
+
return _totensor(imgs, bgr2rgb, float32)
|
112 |
+
|
113 |
+
|
114 |
+
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
|
115 |
+
"""Convert torch Tensors into image numpy arrays.
|
116 |
+
|
117 |
+
After clamping to [min, max], values will be normalized to [0, 1].
|
118 |
+
|
119 |
+
Args:
|
120 |
+
tensor (Tensor or list[Tensor]): Accept shapes:
|
121 |
+
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
|
122 |
+
2) 3D Tensor of shape (3/1 x H x W);
|
123 |
+
3) 2D Tensor of shape (H x W).
|
124 |
+
Tensor channel should be in RGB order.
|
125 |
+
rgb2bgr (bool): Whether to change rgb to bgr.
|
126 |
+
out_type (numpy type): output types. If ``np.uint8``, transform outputs
|
127 |
+
to uint8 type with range [0, 255]; otherwise, float type with
|
128 |
+
range [0, 1]. Default: ``np.uint8``.
|
129 |
+
min_max (tuple[int]): min and max values for clamp.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
|
133 |
+
shape (H x W). The channel order is BGR.
|
134 |
+
"""
|
135 |
+
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
|
136 |
+
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
|
137 |
+
|
138 |
+
if torch.is_tensor(tensor):
|
139 |
+
tensor = [tensor]
|
140 |
+
result = []
|
141 |
+
for _tensor in tensor:
|
142 |
+
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
143 |
+
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
|
144 |
+
|
145 |
+
n_dim = _tensor.dim()
|
146 |
+
if n_dim == 4:
|
147 |
+
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
|
148 |
+
img_np = img_np.transpose(1, 2, 0)
|
149 |
+
if rgb2bgr:
|
150 |
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
151 |
+
elif n_dim == 3:
|
152 |
+
img_np = _tensor.numpy()
|
153 |
+
img_np = img_np.transpose(1, 2, 0)
|
154 |
+
if img_np.shape[2] == 1: # gray image
|
155 |
+
img_np = np.squeeze(img_np, axis=2)
|
156 |
+
else:
|
157 |
+
if rgb2bgr:
|
158 |
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
159 |
+
elif n_dim == 2:
|
160 |
+
img_np = _tensor.numpy()
|
161 |
+
else:
|
162 |
+
raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
|
163 |
+
if out_type == np.uint8:
|
164 |
+
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
|
165 |
+
img_np = (img_np * 255.0).round()
|
166 |
+
img_np = img_np.astype(out_type)
|
167 |
+
result.append(img_np)
|
168 |
+
if len(result) == 1:
|
169 |
+
result = result[0]
|
170 |
+
return result
|
171 |
+
|
172 |
+
|
173 |
+
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
|
174 |
+
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
|
175 |
+
|
176 |
+
We use vertical flip and transpose for rotation implementation.
|
177 |
+
All the images in the list use the same augmentation.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
|
181 |
+
is an ndarray, it will be transformed to a list.
|
182 |
+
hflip (bool): Horizontal flip. Default: True.
|
183 |
+
rotation (bool): Ratotation. Default: True.
|
184 |
+
flows (list[ndarray]: Flows to be augmented. If the input is an
|
185 |
+
ndarray, it will be transformed to a list.
|
186 |
+
Dimension is (h, w, 2). Default: None.
|
187 |
+
return_status (bool): Return the status of flip and rotation.
|
188 |
+
Default: False.
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
list[ndarray] | ndarray: Augmented images and flows. If returned
|
192 |
+
results only have one element, just return ndarray.
|
193 |
+
|
194 |
+
"""
|
195 |
+
hflip = hflip and random.random() < 0.5
|
196 |
+
vflip = rotation and random.random() < 0.5
|
197 |
+
rot90 = rotation and random.random() < 0.5
|
198 |
+
|
199 |
+
def _augment(img):
|
200 |
+
if hflip: # horizontal
|
201 |
+
cv2.flip(img, 1, img)
|
202 |
+
if vflip: # vertical
|
203 |
+
cv2.flip(img, 0, img)
|
204 |
+
if rot90:
|
205 |
+
img = img.transpose(1, 0, 2)
|
206 |
+
return img
|
207 |
+
|
208 |
+
def _augment_flow(flow):
|
209 |
+
if hflip: # horizontal
|
210 |
+
cv2.flip(flow, 1, flow)
|
211 |
+
flow[:, :, 0] *= -1
|
212 |
+
if vflip: # vertical
|
213 |
+
cv2.flip(flow, 0, flow)
|
214 |
+
flow[:, :, 1] *= -1
|
215 |
+
if rot90:
|
216 |
+
flow = flow.transpose(1, 0, 2)
|
217 |
+
flow = flow[:, :, [1, 0]]
|
218 |
+
return flow
|
219 |
+
|
220 |
+
if not isinstance(imgs, list):
|
221 |
+
imgs = [imgs]
|
222 |
+
imgs = [_augment(img) for img in imgs]
|
223 |
+
if len(imgs) == 1:
|
224 |
+
imgs = imgs[0]
|
225 |
+
|
226 |
+
if flows is not None:
|
227 |
+
if not isinstance(flows, list):
|
228 |
+
flows = [flows]
|
229 |
+
flows = [_augment_flow(flow) for flow in flows]
|
230 |
+
if len(flows) == 1:
|
231 |
+
flows = flows[0]
|
232 |
+
return imgs, flows
|
233 |
+
else:
|
234 |
+
if return_status:
|
235 |
+
return imgs, (hflip, vflip, rot90)
|
236 |
+
else:
|
237 |
+
return imgs
|
238 |
+
|
239 |
+
|
240 |
+
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
|
241 |
+
"""Paired random crop. Support Numpy array and Tensor inputs.
|
242 |
+
|
243 |
+
It crops lists of lq and gt images with corresponding locations.
|
244 |
+
|
245 |
+
Args:
|
246 |
+
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
|
247 |
+
should have the same shape. If the input is an ndarray, it will
|
248 |
+
be transformed to a list containing itself.
|
249 |
+
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
|
250 |
+
should have the same shape. If the input is an ndarray, it will
|
251 |
+
be transformed to a list containing itself.
|
252 |
+
gt_patch_size (int): GT patch size.
|
253 |
+
scale (int): Scale factor.
|
254 |
+
gt_path (str): Path to ground-truth. Default: None.
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
list[ndarray] | ndarray: GT images and LQ images. If returned results
|
258 |
+
only have one element, just return ndarray.
|
259 |
+
"""
|
260 |
+
|
261 |
+
if not isinstance(img_gts, list):
|
262 |
+
img_gts = [img_gts]
|
263 |
+
if not isinstance(img_lqs, list):
|
264 |
+
img_lqs = [img_lqs]
|
265 |
+
|
266 |
+
# determine input type: Numpy array or Tensor
|
267 |
+
input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
|
268 |
+
|
269 |
+
if input_type == 'Tensor':
|
270 |
+
h_lq, w_lq = img_lqs[0].size()[-2:]
|
271 |
+
h_gt, w_gt = img_gts[0].size()[-2:]
|
272 |
+
else:
|
273 |
+
h_lq, w_lq = img_lqs[0].shape[0:2]
|
274 |
+
h_gt, w_gt = img_gts[0].shape[0:2]
|
275 |
+
lq_patch_size = gt_patch_size // scale
|
276 |
+
|
277 |
+
if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
278 |
+
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
279 |
+
f'multiplication of LQ ({h_lq}, {w_lq}).')
|
280 |
+
if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
281 |
+
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
282 |
+
f'({lq_patch_size}, {lq_patch_size}). '
|
283 |
+
f'Please remove {gt_path}.')
|
284 |
+
|
285 |
+
# randomly choose top and left coordinates for lq patch
|
286 |
+
top = random.randint(0, h_lq - lq_patch_size)
|
287 |
+
left = random.randint(0, w_lq - lq_patch_size)
|
288 |
+
|
289 |
+
# crop lq patch
|
290 |
+
if input_type == 'Tensor':
|
291 |
+
img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
|
292 |
+
else:
|
293 |
+
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
|
294 |
+
|
295 |
+
# crop corresponding gt patch
|
296 |
+
top_gt, left_gt = int(top * scale), int(left * scale)
|
297 |
+
if input_type == 'Tensor':
|
298 |
+
img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
|
299 |
+
else:
|
300 |
+
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
|
301 |
+
if len(img_gts) == 1:
|
302 |
+
img_gts = img_gts[0]
|
303 |
+
if len(img_lqs) == 1:
|
304 |
+
img_lqs = img_lqs[0]
|
305 |
+
return img_gts, img_lqs
|
306 |
+
|
307 |
+
|
308 |
+
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
|
309 |
+
class BaseStorageBackend(metaclass=ABCMeta):
|
310 |
+
"""Abstract class of storage backends.
|
311 |
+
|
312 |
+
All backends need to implement two apis: ``get()`` and ``get_text()``.
|
313 |
+
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
|
314 |
+
as texts.
|
315 |
+
"""
|
316 |
+
|
317 |
+
@abstractmethod
|
318 |
+
def get(self, filepath):
|
319 |
+
pass
|
320 |
+
|
321 |
+
@abstractmethod
|
322 |
+
def get_text(self, filepath):
|
323 |
+
pass
|
324 |
+
|
325 |
+
|
326 |
+
class MemcachedBackend(BaseStorageBackend):
|
327 |
+
"""Memcached storage backend.
|
328 |
+
|
329 |
+
Attributes:
|
330 |
+
server_list_cfg (str): Config file for memcached server list.
|
331 |
+
client_cfg (str): Config file for memcached client.
|
332 |
+
sys_path (str | None): Additional path to be appended to `sys.path`.
|
333 |
+
Default: None.
|
334 |
+
"""
|
335 |
+
|
336 |
+
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
|
337 |
+
if sys_path is not None:
|
338 |
+
import sys
|
339 |
+
sys.path.append(sys_path)
|
340 |
+
try:
|
341 |
+
import mc
|
342 |
+
except ImportError:
|
343 |
+
raise ImportError('Please install memcached to enable MemcachedBackend.')
|
344 |
+
|
345 |
+
self.server_list_cfg = server_list_cfg
|
346 |
+
self.client_cfg = client_cfg
|
347 |
+
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
|
348 |
+
# mc.pyvector servers as a point which points to a memory cache
|
349 |
+
self._mc_buffer = mc.pyvector()
|
350 |
+
|
351 |
+
def get(self, filepath):
|
352 |
+
filepath = str(filepath)
|
353 |
+
import mc
|
354 |
+
self._client.Get(filepath, self._mc_buffer)
|
355 |
+
value_buf = mc.ConvertBuffer(self._mc_buffer)
|
356 |
+
return value_buf
|
357 |
+
|
358 |
+
def get_text(self, filepath):
|
359 |
+
raise NotImplementedError
|
360 |
+
|
361 |
+
|
362 |
+
class HardDiskBackend(BaseStorageBackend):
|
363 |
+
"""Raw hard disks storage backend."""
|
364 |
+
|
365 |
+
def get(self, filepath):
|
366 |
+
filepath = str(filepath)
|
367 |
+
with open(filepath, 'rb') as f:
|
368 |
+
value_buf = f.read()
|
369 |
+
return value_buf
|
370 |
+
|
371 |
+
def get_text(self, filepath):
|
372 |
+
filepath = str(filepath)
|
373 |
+
with open(filepath, 'r') as f:
|
374 |
+
value_buf = f.read()
|
375 |
+
return value_buf
|
376 |
+
|
377 |
+
|
378 |
+
class LmdbBackend(BaseStorageBackend):
|
379 |
+
"""Lmdb storage backend.
|
380 |
+
|
381 |
+
Args:
|
382 |
+
db_paths (str | list[str]): Lmdb database paths.
|
383 |
+
client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
|
384 |
+
readonly (bool, optional): Lmdb environment parameter. If True,
|
385 |
+
disallow any write operations. Default: True.
|
386 |
+
lock (bool, optional): Lmdb environment parameter. If False, when
|
387 |
+
concurrent access occurs, do not lock the database. Default: False.
|
388 |
+
readahead (bool, optional): Lmdb environment parameter. If False,
|
389 |
+
disable the OS filesystem readahead mechanism, which may improve
|
390 |
+
random read performance when a database is larger than RAM.
|
391 |
+
Default: False.
|
392 |
+
|
393 |
+
Attributes:
|
394 |
+
db_paths (list): Lmdb database path.
|
395 |
+
_client (list): A list of several lmdb envs.
|
396 |
+
"""
|
397 |
+
|
398 |
+
def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
|
399 |
+
try:
|
400 |
+
import lmdb
|
401 |
+
except ImportError:
|
402 |
+
raise ImportError('Please install lmdb to enable LmdbBackend.')
|
403 |
+
|
404 |
+
if isinstance(client_keys, str):
|
405 |
+
client_keys = [client_keys]
|
406 |
+
|
407 |
+
if isinstance(db_paths, list):
|
408 |
+
self.db_paths = [str(v) for v in db_paths]
|
409 |
+
elif isinstance(db_paths, str):
|
410 |
+
self.db_paths = [str(db_paths)]
|
411 |
+
assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
|
412 |
+
f'but received {len(client_keys)} and {len(self.db_paths)}.')
|
413 |
+
|
414 |
+
self._client = {}
|
415 |
+
for client, path in zip(client_keys, self.db_paths):
|
416 |
+
self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
|
417 |
+
|
418 |
+
def get(self, filepath, client_key):
|
419 |
+
"""Get values according to the filepath from one lmdb named client_key.
|
420 |
+
|
421 |
+
Args:
|
422 |
+
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
|
423 |
+
client_key (str): Used for distinguishing different lmdb envs.
|
424 |
+
"""
|
425 |
+
filepath = str(filepath)
|
426 |
+
assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
|
427 |
+
client = self._client[client_key]
|
428 |
+
with client.begin(write=False) as txn:
|
429 |
+
value_buf = txn.get(filepath.encode('ascii'))
|
430 |
+
return value_buf
|
431 |
+
|
432 |
+
def get_text(self, filepath):
|
433 |
+
raise NotImplementedError
|
434 |
+
|
435 |
+
|
436 |
+
class FileClient(object):
|
437 |
+
"""A general file client to access files in different backend.
|
438 |
+
|
439 |
+
The client loads a file or text in a specified backend from its path
|
440 |
+
and return it as a binary file. it can also register other backend
|
441 |
+
accessor with a given name and backend class.
|
442 |
+
|
443 |
+
Attributes:
|
444 |
+
backend (str): The storage backend type. Options are "disk",
|
445 |
+
"memcached" and "lmdb".
|
446 |
+
client (:obj:`BaseStorageBackend`): The backend object.
|
447 |
+
"""
|
448 |
+
|
449 |
+
_backends = {
|
450 |
+
'disk': HardDiskBackend,
|
451 |
+
'memcached': MemcachedBackend,
|
452 |
+
'lmdb': LmdbBackend,
|
453 |
+
}
|
454 |
+
|
455 |
+
def __init__(self, backend='disk', **kwargs):
|
456 |
+
if backend not in self._backends:
|
457 |
+
raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
|
458 |
+
f' are {list(self._backends.keys())}')
|
459 |
+
self.backend = backend
|
460 |
+
self.client = self._backends[backend](**kwargs)
|
461 |
+
|
462 |
+
def get(self, filepath, client_key='default'):
|
463 |
+
# client_key is used only for lmdb, where different fileclients have
|
464 |
+
# different lmdb environments.
|
465 |
+
if self.backend == 'lmdb':
|
466 |
+
return self.client.get(filepath, client_key)
|
467 |
+
else:
|
468 |
+
return self.client.get(filepath)
|
469 |
+
|
470 |
+
def get_text(self, filepath):
|
471 |
+
return self.client.get_text(filepath)
|
472 |
+
|
473 |
+
|
474 |
+
def imfrombytes(content, flag='color', float32=False):
|
475 |
+
"""Read an image from bytes.
|
476 |
+
|
477 |
+
Args:
|
478 |
+
content (bytes): Image bytes got from files or other streams.
|
479 |
+
flag (str): Flags specifying the color type of a loaded image,
|
480 |
+
candidates are `color`, `grayscale` and `unchanged`.
|
481 |
+
float32 (bool): Whether to change to float32., If True, will also norm
|
482 |
+
to [0, 1]. Default: False.
|
483 |
+
|
484 |
+
Returns:
|
485 |
+
ndarray: Loaded image array.
|
486 |
+
"""
|
487 |
+
img_np = np.frombuffer(content, np.uint8)
|
488 |
+
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
|
489 |
+
img = cv2.imdecode(img_np, imread_flags[flag])
|
490 |
+
if float32:
|
491 |
+
img = img.astype(np.float32) / 255.
|
492 |
+
return img
|
493 |
+
|
core/data/deg_kair_utils/utils_videoio.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
from os import path as osp
|
7 |
+
from torchvision.utils import make_grid
|
8 |
+
import sys
|
9 |
+
from pathlib import Path
|
10 |
+
import six
|
11 |
+
from collections import OrderedDict
|
12 |
+
import math
|
13 |
+
import glob
|
14 |
+
import av
|
15 |
+
import io
|
16 |
+
from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT,
|
17 |
+
CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH,
|
18 |
+
CAP_PROP_POS_FRAMES, VideoWriter_fourcc)
|
19 |
+
|
20 |
+
if sys.version_info <= (3, 3):
|
21 |
+
FileNotFoundError = IOError
|
22 |
+
else:
|
23 |
+
FileNotFoundError = FileNotFoundError
|
24 |
+
|
25 |
+
|
26 |
+
def is_str(x):
|
27 |
+
"""Whether the input is an string instance."""
|
28 |
+
return isinstance(x, six.string_types)
|
29 |
+
|
30 |
+
|
31 |
+
def is_filepath(x):
|
32 |
+
return is_str(x) or isinstance(x, Path)
|
33 |
+
|
34 |
+
|
35 |
+
def fopen(filepath, *args, **kwargs):
|
36 |
+
if is_str(filepath):
|
37 |
+
return open(filepath, *args, **kwargs)
|
38 |
+
elif isinstance(filepath, Path):
|
39 |
+
return filepath.open(*args, **kwargs)
|
40 |
+
raise ValueError('`filepath` should be a string or a Path')
|
41 |
+
|
42 |
+
|
43 |
+
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
|
44 |
+
if not osp.isfile(filename):
|
45 |
+
raise FileNotFoundError(msg_tmpl.format(filename))
|
46 |
+
|
47 |
+
|
48 |
+
def mkdir_or_exist(dir_name, mode=0o777):
|
49 |
+
if dir_name == '':
|
50 |
+
return
|
51 |
+
dir_name = osp.expanduser(dir_name)
|
52 |
+
os.makedirs(dir_name, mode=mode, exist_ok=True)
|
53 |
+
|
54 |
+
|
55 |
+
def symlink(src, dst, overwrite=True, **kwargs):
|
56 |
+
if os.path.lexists(dst) and overwrite:
|
57 |
+
os.remove(dst)
|
58 |
+
os.symlink(src, dst, **kwargs)
|
59 |
+
|
60 |
+
|
61 |
+
def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
|
62 |
+
"""Scan a directory to find the interested files.
|
63 |
+
Args:
|
64 |
+
dir_path (str | :obj:`Path`): Path of the directory.
|
65 |
+
suffix (str | tuple(str), optional): File suffix that we are
|
66 |
+
interested in. Default: None.
|
67 |
+
recursive (bool, optional): If set to True, recursively scan the
|
68 |
+
directory. Default: False.
|
69 |
+
case_sensitive (bool, optional) : If set to False, ignore the case of
|
70 |
+
suffix. Default: True.
|
71 |
+
Returns:
|
72 |
+
A generator for all the interested files with relative paths.
|
73 |
+
"""
|
74 |
+
if isinstance(dir_path, (str, Path)):
|
75 |
+
dir_path = str(dir_path)
|
76 |
+
else:
|
77 |
+
raise TypeError('"dir_path" must be a string or Path object')
|
78 |
+
|
79 |
+
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
80 |
+
raise TypeError('"suffix" must be a string or tuple of strings')
|
81 |
+
|
82 |
+
if suffix is not None and not case_sensitive:
|
83 |
+
suffix = suffix.lower() if isinstance(suffix, str) else tuple(
|
84 |
+
item.lower() for item in suffix)
|
85 |
+
|
86 |
+
root = dir_path
|
87 |
+
|
88 |
+
def _scandir(dir_path, suffix, recursive, case_sensitive):
|
89 |
+
for entry in os.scandir(dir_path):
|
90 |
+
if not entry.name.startswith('.') and entry.is_file():
|
91 |
+
rel_path = osp.relpath(entry.path, root)
|
92 |
+
_rel_path = rel_path if case_sensitive else rel_path.lower()
|
93 |
+
if suffix is None or _rel_path.endswith(suffix):
|
94 |
+
yield rel_path
|
95 |
+
elif recursive and os.path.isdir(entry.path):
|
96 |
+
# scan recursively if entry.path is a directory
|
97 |
+
yield from _scandir(entry.path, suffix, recursive,
|
98 |
+
case_sensitive)
|
99 |
+
|
100 |
+
return _scandir(dir_path, suffix, recursive, case_sensitive)
|
101 |
+
|
102 |
+
|
103 |
+
class Cache:
|
104 |
+
|
105 |
+
def __init__(self, capacity):
|
106 |
+
self._cache = OrderedDict()
|
107 |
+
self._capacity = int(capacity)
|
108 |
+
if capacity <= 0:
|
109 |
+
raise ValueError('capacity must be a positive integer')
|
110 |
+
|
111 |
+
@property
|
112 |
+
def capacity(self):
|
113 |
+
return self._capacity
|
114 |
+
|
115 |
+
@property
|
116 |
+
def size(self):
|
117 |
+
return len(self._cache)
|
118 |
+
|
119 |
+
def put(self, key, val):
|
120 |
+
if key in self._cache:
|
121 |
+
return
|
122 |
+
if len(self._cache) >= self.capacity:
|
123 |
+
self._cache.popitem(last=False)
|
124 |
+
self._cache[key] = val
|
125 |
+
|
126 |
+
def get(self, key, default=None):
|
127 |
+
val = self._cache[key] if key in self._cache else default
|
128 |
+
return val
|
129 |
+
|
130 |
+
|
131 |
+
class VideoReader:
|
132 |
+
"""Video class with similar usage to a list object.
|
133 |
+
|
134 |
+
This video warpper class provides convenient apis to access frames.
|
135 |
+
There exists an issue of OpenCV's VideoCapture class that jumping to a
|
136 |
+
certain frame may be inaccurate. It is fixed in this class by checking
|
137 |
+
the position after jumping each time.
|
138 |
+
Cache is used when decoding videos. So if the same frame is visited for
|
139 |
+
the second time, there is no need to decode again if it is stored in the
|
140 |
+
cache.
|
141 |
+
|
142 |
+
"""
|
143 |
+
|
144 |
+
def __init__(self, filename, cache_capacity=10):
|
145 |
+
# Check whether the video path is a url
|
146 |
+
if not filename.startswith(('https://', 'http://')):
|
147 |
+
check_file_exist(filename, 'Video file not found: ' + filename)
|
148 |
+
self._vcap = cv2.VideoCapture(filename)
|
149 |
+
assert cache_capacity > 0
|
150 |
+
self._cache = Cache(cache_capacity)
|
151 |
+
self._position = 0
|
152 |
+
# get basic info
|
153 |
+
self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
|
154 |
+
self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
|
155 |
+
self._fps = self._vcap.get(CAP_PROP_FPS)
|
156 |
+
self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
|
157 |
+
self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
|
158 |
+
|
159 |
+
@property
|
160 |
+
def vcap(self):
|
161 |
+
""":obj:`cv2.VideoCapture`: The raw VideoCapture object."""
|
162 |
+
return self._vcap
|
163 |
+
|
164 |
+
@property
|
165 |
+
def opened(self):
|
166 |
+
"""bool: Indicate whether the video is opened."""
|
167 |
+
return self._vcap.isOpened()
|
168 |
+
|
169 |
+
@property
|
170 |
+
def width(self):
|
171 |
+
"""int: Width of video frames."""
|
172 |
+
return self._width
|
173 |
+
|
174 |
+
@property
|
175 |
+
def height(self):
|
176 |
+
"""int: Height of video frames."""
|
177 |
+
return self._height
|
178 |
+
|
179 |
+
@property
|
180 |
+
def resolution(self):
|
181 |
+
"""tuple: Video resolution (width, height)."""
|
182 |
+
return (self._width, self._height)
|
183 |
+
|
184 |
+
@property
|
185 |
+
def fps(self):
|
186 |
+
"""float: FPS of the video."""
|
187 |
+
return self._fps
|
188 |
+
|
189 |
+
@property
|
190 |
+
def frame_cnt(self):
|
191 |
+
"""int: Total frames of the video."""
|
192 |
+
return self._frame_cnt
|
193 |
+
|
194 |
+
@property
|
195 |
+
def fourcc(self):
|
196 |
+
"""str: "Four character code" of the video."""
|
197 |
+
return self._fourcc
|
198 |
+
|
199 |
+
@property
|
200 |
+
def position(self):
|
201 |
+
"""int: Current cursor position, indicating frame decoded."""
|
202 |
+
return self._position
|
203 |
+
|
204 |
+
def _get_real_position(self):
|
205 |
+
return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
|
206 |
+
|
207 |
+
def _set_real_position(self, frame_id):
|
208 |
+
self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
|
209 |
+
pos = self._get_real_position()
|
210 |
+
for _ in range(frame_id - pos):
|
211 |
+
self._vcap.read()
|
212 |
+
self._position = frame_id
|
213 |
+
|
214 |
+
def read(self):
|
215 |
+
"""Read the next frame.
|
216 |
+
|
217 |
+
If the next frame have been decoded before and in the cache, then
|
218 |
+
return it directly, otherwise decode, cache and return it.
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
ndarray or None: Return the frame if successful, otherwise None.
|
222 |
+
"""
|
223 |
+
# pos = self._position
|
224 |
+
if self._cache:
|
225 |
+
img = self._cache.get(self._position)
|
226 |
+
if img is not None:
|
227 |
+
ret = True
|
228 |
+
else:
|
229 |
+
if self._position != self._get_real_position():
|
230 |
+
self._set_real_position(self._position)
|
231 |
+
ret, img = self._vcap.read()
|
232 |
+
if ret:
|
233 |
+
self._cache.put(self._position, img)
|
234 |
+
else:
|
235 |
+
ret, img = self._vcap.read()
|
236 |
+
if ret:
|
237 |
+
self._position += 1
|
238 |
+
return img
|
239 |
+
|
240 |
+
def get_frame(self, frame_id):
|
241 |
+
"""Get frame by index.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
frame_id (int): Index of the expected frame, 0-based.
|
245 |
+
|
246 |
+
Returns:
|
247 |
+
ndarray or None: Return the frame if successful, otherwise None.
|
248 |
+
"""
|
249 |
+
if frame_id < 0 or frame_id >= self._frame_cnt:
|
250 |
+
raise IndexError(
|
251 |
+
f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
|
252 |
+
if frame_id == self._position:
|
253 |
+
return self.read()
|
254 |
+
if self._cache:
|
255 |
+
img = self._cache.get(frame_id)
|
256 |
+
if img is not None:
|
257 |
+
self._position = frame_id + 1
|
258 |
+
return img
|
259 |
+
self._set_real_position(frame_id)
|
260 |
+
ret, img = self._vcap.read()
|
261 |
+
if ret:
|
262 |
+
if self._cache:
|
263 |
+
self._cache.put(self._position, img)
|
264 |
+
self._position += 1
|
265 |
+
return img
|
266 |
+
|
267 |
+
def current_frame(self):
|
268 |
+
"""Get the current frame (frame that is just visited).
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
ndarray or None: If the video is fresh, return None, otherwise
|
272 |
+
return the frame.
|
273 |
+
"""
|
274 |
+
if self._position == 0:
|
275 |
+
return None
|
276 |
+
return self._cache.get(self._position - 1)
|
277 |
+
|
278 |
+
def cvt2frames(self,
|
279 |
+
frame_dir,
|
280 |
+
file_start=0,
|
281 |
+
filename_tmpl='{:06d}.jpg',
|
282 |
+
start=0,
|
283 |
+
max_num=0,
|
284 |
+
show_progress=False):
|
285 |
+
"""Convert a video to frame images.
|
286 |
+
|
287 |
+
Args:
|
288 |
+
frame_dir (str): Output directory to store all the frame images.
|
289 |
+
file_start (int): Filenames will start from the specified number.
|
290 |
+
filename_tmpl (str): Filename template with the index as the
|
291 |
+
placeholder.
|
292 |
+
start (int): The starting frame index.
|
293 |
+
max_num (int): Maximum number of frames to be written.
|
294 |
+
show_progress (bool): Whether to show a progress bar.
|
295 |
+
"""
|
296 |
+
mkdir_or_exist(frame_dir)
|
297 |
+
if max_num == 0:
|
298 |
+
task_num = self.frame_cnt - start
|
299 |
+
else:
|
300 |
+
task_num = min(self.frame_cnt - start, max_num)
|
301 |
+
if task_num <= 0:
|
302 |
+
raise ValueError('start must be less than total frame number')
|
303 |
+
if start > 0:
|
304 |
+
self._set_real_position(start)
|
305 |
+
|
306 |
+
def write_frame(file_idx):
|
307 |
+
img = self.read()
|
308 |
+
if img is None:
|
309 |
+
return
|
310 |
+
filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
|
311 |
+
cv2.imwrite(filename, img)
|
312 |
+
|
313 |
+
if show_progress:
|
314 |
+
pass
|
315 |
+
#track_progress(write_frame, range(file_start,file_start + task_num))
|
316 |
+
else:
|
317 |
+
for i in range(task_num):
|
318 |
+
write_frame(file_start + i)
|
319 |
+
|
320 |
+
def __len__(self):
|
321 |
+
return self.frame_cnt
|
322 |
+
|
323 |
+
def __getitem__(self, index):
|
324 |
+
if isinstance(index, slice):
|
325 |
+
return [
|
326 |
+
self.get_frame(i)
|
327 |
+
for i in range(*index.indices(self.frame_cnt))
|
328 |
+
]
|
329 |
+
# support negative indexing
|
330 |
+
if index < 0:
|
331 |
+
index += self.frame_cnt
|
332 |
+
if index < 0:
|
333 |
+
raise IndexError('index out of range')
|
334 |
+
return self.get_frame(index)
|
335 |
+
|
336 |
+
def __iter__(self):
|
337 |
+
self._set_real_position(0)
|
338 |
+
return self
|
339 |
+
|
340 |
+
def __next__(self):
|
341 |
+
img = self.read()
|
342 |
+
if img is not None:
|
343 |
+
return img
|
344 |
+
else:
|
345 |
+
raise StopIteration
|
346 |
+
|
347 |
+
next = __next__
|
348 |
+
|
349 |
+
def __enter__(self):
|
350 |
+
return self
|
351 |
+
|
352 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
353 |
+
self._vcap.release()
|
354 |
+
|
355 |
+
|
356 |
+
def frames2video(frame_dir,
|
357 |
+
video_file,
|
358 |
+
fps=30,
|
359 |
+
fourcc='XVID',
|
360 |
+
filename_tmpl='{:06d}.jpg',
|
361 |
+
start=0,
|
362 |
+
end=0,
|
363 |
+
show_progress=False):
|
364 |
+
"""Read the frame images from a directory and join them as a video.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
frame_dir (str): The directory containing video frames.
|
368 |
+
video_file (str): Output filename.
|
369 |
+
fps (float): FPS of the output video.
|
370 |
+
fourcc (str): Fourcc of the output video, this should be compatible
|
371 |
+
with the output file type.
|
372 |
+
filename_tmpl (str): Filename template with the index as the variable.
|
373 |
+
start (int): Starting frame index.
|
374 |
+
end (int): Ending frame index.
|
375 |
+
show_progress (bool): Whether to show a progress bar.
|
376 |
+
"""
|
377 |
+
if end == 0:
|
378 |
+
ext = filename_tmpl.split('.')[-1]
|
379 |
+
end = len([name for name in scandir(frame_dir, ext)])
|
380 |
+
first_file = osp.join(frame_dir, filename_tmpl.format(start))
|
381 |
+
check_file_exist(first_file, 'The start frame not found: ' + first_file)
|
382 |
+
img = cv2.imread(first_file)
|
383 |
+
height, width = img.shape[:2]
|
384 |
+
resolution = (width, height)
|
385 |
+
vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps,
|
386 |
+
resolution)
|
387 |
+
|
388 |
+
def write_frame(file_idx):
|
389 |
+
filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
|
390 |
+
img = cv2.imread(filename)
|
391 |
+
vwriter.write(img)
|
392 |
+
|
393 |
+
if show_progress:
|
394 |
+
pass
|
395 |
+
# track_progress(write_frame, range(start, end))
|
396 |
+
else:
|
397 |
+
for i in range(start, end):
|
398 |
+
write_frame(i)
|
399 |
+
vwriter.release()
|
400 |
+
|
401 |
+
|
402 |
+
def video2images(video_path, output_dir):
|
403 |
+
vidcap = cv2.VideoCapture(video_path)
|
404 |
+
in_fps = vidcap.get(cv2.CAP_PROP_FPS)
|
405 |
+
print('video fps:', in_fps)
|
406 |
+
if not os.path.isdir(output_dir):
|
407 |
+
os.makedirs(output_dir)
|
408 |
+
loaded, frame = vidcap.read()
|
409 |
+
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
410 |
+
print(f'number of total frames is: {total_frames:06}')
|
411 |
+
for i_frame in range(total_frames):
|
412 |
+
if i_frame % 100 == 0:
|
413 |
+
print(f'{i_frame:06} / {total_frames:06}')
|
414 |
+
frame_name = os.path.join(output_dir, f'{i_frame:06}' + '.png')
|
415 |
+
cv2.imwrite(frame_name, frame)
|
416 |
+
loaded, frame = vidcap.read()
|
417 |
+
|
418 |
+
|
419 |
+
def images2video(image_dir, video_path, fps=24, image_ext='png'):
|
420 |
+
'''
|
421 |
+
#codec = cv2.VideoWriter_fourcc(*'XVID')
|
422 |
+
#codec = cv2.VideoWriter_fourcc('A','V','C','1')
|
423 |
+
#codec = cv2.VideoWriter_fourcc('Y','U','V','1')
|
424 |
+
#codec = cv2.VideoWriter_fourcc('P','I','M','1')
|
425 |
+
#codec = cv2.VideoWriter_fourcc('M','J','P','G')
|
426 |
+
codec = cv2.VideoWriter_fourcc('M','P','4','2')
|
427 |
+
#codec = cv2.VideoWriter_fourcc('D','I','V','3')
|
428 |
+
#codec = cv2.VideoWriter_fourcc('D','I','V','X')
|
429 |
+
#codec = cv2.VideoWriter_fourcc('U','2','6','3')
|
430 |
+
#codec = cv2.VideoWriter_fourcc('I','2','6','3')
|
431 |
+
#codec = cv2.VideoWriter_fourcc('F','L','V','1')
|
432 |
+
#codec = cv2.VideoWriter_fourcc('H','2','6','4')
|
433 |
+
#codec = cv2.VideoWriter_fourcc('A','Y','U','V')
|
434 |
+
#codec = cv2.VideoWriter_fourcc('I','U','Y','V')
|
435 |
+
编码器常用的几种:
|
436 |
+
cv2.VideoWriter_fourcc("I", "4", "2", "0")
|
437 |
+
压缩的yuv颜色编码器,4:2:0色彩度子采样 兼容性好,产生很大的视频 avi
|
438 |
+
cv2.VideoWriter_fourcc("P", I", "M", "1")
|
439 |
+
采用mpeg-1编码,文件为avi
|
440 |
+
cv2.VideoWriter_fourcc("X", "V", "T", "D")
|
441 |
+
采用mpeg-4编码,得到视频大小平均 拓展名avi
|
442 |
+
cv2.VideoWriter_fourcc("T", "H", "E", "O")
|
443 |
+
Ogg Vorbis, 拓展名为ogv
|
444 |
+
cv2.VideoWriter_fourcc("F", "L", "V", "1")
|
445 |
+
FLASH视频,拓展名为.flv
|
446 |
+
'''
|
447 |
+
image_files = sorted(glob.glob(os.path.join(image_dir, '*.{}'.format(image_ext))))
|
448 |
+
print(len(image_files))
|
449 |
+
height, width, _ = cv2.imread(image_files[0]).shape
|
450 |
+
out_fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') # cv2.VideoWriter_fourcc(*'MP4V')
|
451 |
+
out_video = cv2.VideoWriter(video_path, out_fourcc, fps, (width, height))
|
452 |
+
|
453 |
+
for image_file in image_files:
|
454 |
+
img = cv2.imread(image_file)
|
455 |
+
img = cv2.resize(img, (width, height), interpolation=3)
|
456 |
+
out_video.write(img)
|
457 |
+
out_video.release()
|
458 |
+
|
459 |
+
|
460 |
+
def add_video_compression(imgs):
|
461 |
+
codec_type = ['libx264', 'h264', 'mpeg4']
|
462 |
+
codec_prob = [1 / 3., 1 / 3., 1 / 3.]
|
463 |
+
codec = random.choices(codec_type, codec_prob)[0]
|
464 |
+
# codec = 'mpeg4'
|
465 |
+
bitrate = [1e4, 1e5]
|
466 |
+
bitrate = np.random.randint(bitrate[0], bitrate[1] + 1)
|
467 |
+
|
468 |
+
buf = io.BytesIO()
|
469 |
+
with av.open(buf, 'w', 'mp4') as container:
|
470 |
+
stream = container.add_stream(codec, rate=1)
|
471 |
+
stream.height = imgs[0].shape[0]
|
472 |
+
stream.width = imgs[0].shape[1]
|
473 |
+
stream.pix_fmt = 'yuv420p'
|
474 |
+
stream.bit_rate = bitrate
|
475 |
+
|
476 |
+
for img in imgs:
|
477 |
+
img = np.uint8((img.clip(0, 1)*255.).round())
|
478 |
+
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
479 |
+
frame.pict_type = 'NONE'
|
480 |
+
# pdb.set_trace()
|
481 |
+
for packet in stream.encode(frame):
|
482 |
+
container.mux(packet)
|
483 |
+
|
484 |
+
# Flush stream
|
485 |
+
for packet in stream.encode():
|
486 |
+
container.mux(packet)
|
487 |
+
|
488 |
+
outputs = []
|
489 |
+
with av.open(buf, 'r', 'mp4') as container:
|
490 |
+
if container.streams.video:
|
491 |
+
for frame in container.decode(**{'video': 0}):
|
492 |
+
outputs.append(
|
493 |
+
frame.to_rgb().to_ndarray().astype(np.float32) / 255.)
|
494 |
+
|
495 |
+
#outputs = np.stack(outputs, axis=0)
|
496 |
+
return outputs
|
497 |
+
|
498 |
+
|
499 |
+
if __name__ == '__main__':
|
500 |
+
|
501 |
+
# -----------------------------------
|
502 |
+
# test VideoReader(filename, cache_capacity=10)
|
503 |
+
# -----------------------------------
|
504 |
+
# video_reader = VideoReader('utils/test.mp4')
|
505 |
+
# from utils import utils_image as util
|
506 |
+
# inputs = []
|
507 |
+
# for frame in video_reader:
|
508 |
+
# print(frame.dtype)
|
509 |
+
# util.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
510 |
+
# #util.imshow(np.flip(frame, axis=2))
|
511 |
+
|
512 |
+
# -----------------------------------
|
513 |
+
# test video2images(video_path, output_dir)
|
514 |
+
# -----------------------------------
|
515 |
+
# video2images('utils/test.mp4', 'frames')
|
516 |
+
|
517 |
+
# -----------------------------------
|
518 |
+
# test images2video(image_dir, video_path, fps=24, image_ext='png')
|
519 |
+
# -----------------------------------
|
520 |
+
# images2video('frames', 'video_02.mp4', fps=30, image_ext='png')
|
521 |
+
|
522 |
+
|
523 |
+
# -----------------------------------
|
524 |
+
# test frames2video(frame_dir, video_file, fps=30, fourcc='XVID', filename_tmpl='{:06d}.png')
|
525 |
+
# -----------------------------------
|
526 |
+
# frames2video('frames', 'video_01.mp4', filename_tmpl='{:06d}.png')
|
527 |
+
|
528 |
+
|
529 |
+
# -----------------------------------
|
530 |
+
# test add_video_compression(imgs)
|
531 |
+
# -----------------------------------
|
532 |
+
# imgs = []
|
533 |
+
# image_ext = 'png'
|
534 |
+
# frames = 'frames'
|
535 |
+
# from utils import utils_image as util
|
536 |
+
# image_files = sorted(glob.glob(os.path.join(frames, '*.{}'.format(image_ext))))
|
537 |
+
# for i, image_file in enumerate(image_files):
|
538 |
+
# if i < 7:
|
539 |
+
# img = util.imread_uint(image_file, 3)
|
540 |
+
# img = util.uint2single(img)
|
541 |
+
# imgs.append(img)
|
542 |
+
#
|
543 |
+
# results = add_video_compression(imgs)
|
544 |
+
# for i, img in enumerate(results):
|
545 |
+
# util.imshow(util.single2uint(img))
|
546 |
+
# util.imsave(util.single2uint(img),f'{i:05}.png')
|
547 |
+
|
548 |
+
# run utils/utils_video.py
|
549 |
+
|
550 |
+
|
551 |
+
|
552 |
+
|
553 |
+
|
554 |
+
|
555 |
+
|
core/scripts/__init__.py
ADDED
File without changes
|
core/scripts/cli.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import argparse
|
3 |
+
from .. import WarpCore
|
4 |
+
from .. import templates
|
5 |
+
|
6 |
+
|
7 |
+
def template_init(args):
|
8 |
+
return ''''
|
9 |
+
|
10 |
+
|
11 |
+
'''.strip()
|
12 |
+
|
13 |
+
|
14 |
+
def init_template(args):
|
15 |
+
parser = argparse.ArgumentParser(description='WarpCore template init tool')
|
16 |
+
parser.add_argument('-t', '--template', type=str, default='WarpCore')
|
17 |
+
args = parser.parse_args(args)
|
18 |
+
|
19 |
+
if args.template == 'WarpCore':
|
20 |
+
template_cls = WarpCore
|
21 |
+
else:
|
22 |
+
try:
|
23 |
+
template_cls = __import__(args.template)
|
24 |
+
except ModuleNotFoundError:
|
25 |
+
template_cls = getattr(templates, args.template)
|
26 |
+
print(template_cls)
|
27 |
+
|
28 |
+
|
29 |
+
def main():
|
30 |
+
if len(sys.argv) < 2:
|
31 |
+
print('Usage: core <command>')
|
32 |
+
sys.exit(1)
|
33 |
+
if sys.argv[1] == 'init':
|
34 |
+
init_template(sys.argv[2:])
|
35 |
+
else:
|
36 |
+
print('Unknown command')
|
37 |
+
sys.exit(1)
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == '__main__':
|
41 |
+
main()
|
core/templates/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .diffusion import DiffusionCore
|
core/templates/diffusion.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .. import WarpCore
|
2 |
+
from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
|
3 |
+
from abc import abstractmethod
|
4 |
+
from dataclasses import dataclass
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from gdf import GDF
|
9 |
+
import numpy as np
|
10 |
+
from tqdm import tqdm
|
11 |
+
import wandb
|
12 |
+
|
13 |
+
import webdataset as wds
|
14 |
+
from webdataset.handlers import warn_and_continue
|
15 |
+
from torch.distributed import barrier
|
16 |
+
from enum import Enum
|
17 |
+
|
18 |
+
class TargetReparametrization(Enum):
|
19 |
+
EPSILON = 'epsilon'
|
20 |
+
X0 = 'x0'
|
21 |
+
|
22 |
+
class DiffusionCore(WarpCore):
|
23 |
+
@dataclass(frozen=True)
|
24 |
+
class Config(WarpCore.Config):
|
25 |
+
# TRAINING PARAMS
|
26 |
+
lr: float = EXPECTED_TRAIN
|
27 |
+
grad_accum_steps: int = EXPECTED_TRAIN
|
28 |
+
batch_size: int = EXPECTED_TRAIN
|
29 |
+
updates: int = EXPECTED_TRAIN
|
30 |
+
warmup_updates: int = EXPECTED_TRAIN
|
31 |
+
save_every: int = 500
|
32 |
+
backup_every: int = 20000
|
33 |
+
use_fsdp: bool = True
|
34 |
+
|
35 |
+
# EMA UPDATE
|
36 |
+
ema_start_iters: int = None
|
37 |
+
ema_iters: int = None
|
38 |
+
ema_beta: float = None
|
39 |
+
|
40 |
+
# GDF setting
|
41 |
+
gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0
|
42 |
+
|
43 |
+
@dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED
|
44 |
+
class Info(WarpCore.Info):
|
45 |
+
ema_loss: float = None
|
46 |
+
|
47 |
+
@dataclass(frozen=True)
|
48 |
+
class Models(WarpCore.Models):
|
49 |
+
generator : nn.Module = EXPECTED
|
50 |
+
generator_ema : nn.Module = None # optional
|
51 |
+
|
52 |
+
@dataclass(frozen=True)
|
53 |
+
class Optimizers(WarpCore.Optimizers):
|
54 |
+
generator : any = EXPECTED
|
55 |
+
|
56 |
+
@dataclass(frozen=True)
|
57 |
+
class Schedulers(WarpCore.Schedulers):
|
58 |
+
generator: any = None
|
59 |
+
|
60 |
+
@dataclass(frozen=True)
|
61 |
+
class Extras(WarpCore.Extras):
|
62 |
+
gdf: GDF = EXPECTED
|
63 |
+
sampling_configs: dict = EXPECTED
|
64 |
+
|
65 |
+
# --------------------------------------------
|
66 |
+
info: Info
|
67 |
+
config: Config
|
68 |
+
|
69 |
+
@abstractmethod
|
70 |
+
def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
|
71 |
+
raise NotImplementedError("This method needs to be overriden")
|
72 |
+
|
73 |
+
@abstractmethod
|
74 |
+
def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
|
75 |
+
raise NotImplementedError("This method needs to be overriden")
|
76 |
+
|
77 |
+
@abstractmethod
|
78 |
+
def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False):
|
79 |
+
raise NotImplementedError("This method needs to be overriden")
|
80 |
+
|
81 |
+
@abstractmethod
|
82 |
+
def webdataset_path(self, extras: Extras):
|
83 |
+
raise NotImplementedError("This method needs to be overriden")
|
84 |
+
|
85 |
+
@abstractmethod
|
86 |
+
def webdataset_filters(self, extras: Extras):
|
87 |
+
raise NotImplementedError("This method needs to be overriden")
|
88 |
+
|
89 |
+
@abstractmethod
|
90 |
+
def webdataset_preprocessors(self, extras: Extras):
|
91 |
+
raise NotImplementedError("This method needs to be overriden")
|
92 |
+
|
93 |
+
@abstractmethod
|
94 |
+
def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
|
95 |
+
raise NotImplementedError("This method needs to be overriden")
|
96 |
+
# -------------
|
97 |
+
|
98 |
+
def setup_data(self, extras: Extras) -> WarpCore.Data:
|
99 |
+
# SETUP DATASET
|
100 |
+
dataset_path = self.webdataset_path(extras)
|
101 |
+
preprocessors = self.webdataset_preprocessors(extras)
|
102 |
+
filters = self.webdataset_filters(extras)
|
103 |
+
|
104 |
+
handler = warn_and_continue # None
|
105 |
+
# handler = None
|
106 |
+
dataset = wds.WebDataset(
|
107 |
+
dataset_path, resampled=True, handler=handler
|
108 |
+
).select(filters).shuffle(690, handler=handler).decode(
|
109 |
+
"pilrgb", handler=handler
|
110 |
+
).to_tuple(
|
111 |
+
*[p[0] for p in preprocessors], handler=handler
|
112 |
+
).map_tuple(
|
113 |
+
*[p[1] for p in preprocessors], handler=handler
|
114 |
+
).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)})
|
115 |
+
|
116 |
+
# SETUP DATALOADER
|
117 |
+
real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps)
|
118 |
+
dataloader = DataLoader(
|
119 |
+
dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True
|
120 |
+
)
|
121 |
+
|
122 |
+
return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader))
|
123 |
+
|
124 |
+
def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
|
125 |
+
batch = next(data.iterator)
|
126 |
+
|
127 |
+
with torch.no_grad():
|
128 |
+
conditions = self.get_conditions(batch, models, extras)
|
129 |
+
latents = self.encode_latents(batch, models, extras)
|
130 |
+
noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
|
131 |
+
|
132 |
+
# FORWARD PASS
|
133 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
134 |
+
pred = models.generator(noised, noise_cond, **conditions)
|
135 |
+
if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON:
|
136 |
+
pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss
|
137 |
+
target = noise
|
138 |
+
elif self.config.gdf_target_reparametrization == TargetReparametrization.X0:
|
139 |
+
pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss
|
140 |
+
target = latents
|
141 |
+
loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
|
142 |
+
loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps
|
143 |
+
|
144 |
+
return loss, loss_adjusted
|
145 |
+
|
146 |
+
def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
|
147 |
+
start_iter = self.info.iter+1
|
148 |
+
max_iters = self.config.updates * self.config.grad_accum_steps
|
149 |
+
if self.is_main_node:
|
150 |
+
print(f"STARTING AT STEP: {start_iter}/{max_iters}")
|
151 |
+
|
152 |
+
pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP
|
153 |
+
models.generator.train()
|
154 |
+
for i in pbar:
|
155 |
+
# FORWARD PASS
|
156 |
+
loss, loss_adjusted = self.forward_pass(data, extras, models)
|
157 |
+
|
158 |
+
# BACKWARD PASS
|
159 |
+
if i % self.config.grad_accum_steps == 0 or i == max_iters:
|
160 |
+
loss_adjusted.backward()
|
161 |
+
grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0)
|
162 |
+
optimizers_dict = optimizers.to_dict()
|
163 |
+
for k in optimizers_dict:
|
164 |
+
optimizers_dict[k].step()
|
165 |
+
schedulers_dict = schedulers.to_dict()
|
166 |
+
for k in schedulers_dict:
|
167 |
+
schedulers_dict[k].step()
|
168 |
+
models.generator.zero_grad(set_to_none=True)
|
169 |
+
self.info.total_steps += 1
|
170 |
+
else:
|
171 |
+
with models.generator.no_sync():
|
172 |
+
loss_adjusted.backward()
|
173 |
+
self.info.iter = i
|
174 |
+
|
175 |
+
# UPDATE EMA
|
176 |
+
if models.generator_ema is not None and i % self.config.ema_iters == 0:
|
177 |
+
update_weights_ema(
|
178 |
+
models.generator_ema, models.generator,
|
179 |
+
beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0)
|
180 |
+
)
|
181 |
+
|
182 |
+
# UPDATE LOSS METRICS
|
183 |
+
self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
|
184 |
+
|
185 |
+
if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
|
186 |
+
wandb.alert(
|
187 |
+
title=f"NaN value encountered in training run {self.info.wandb_run_id}",
|
188 |
+
text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}",
|
189 |
+
wait_duration=60*30
|
190 |
+
)
|
191 |
+
|
192 |
+
if self.is_main_node:
|
193 |
+
logs = {
|
194 |
+
'loss': self.info.ema_loss,
|
195 |
+
'raw_loss': loss.mean().item(),
|
196 |
+
'grad_norm': grad_norm.item(),
|
197 |
+
'lr': optimizers.generator.param_groups[0]['lr'],
|
198 |
+
'total_steps': self.info.total_steps,
|
199 |
+
}
|
200 |
+
|
201 |
+
pbar.set_postfix(logs)
|
202 |
+
if self.config.wandb_project is not None:
|
203 |
+
wandb.log(logs)
|
204 |
+
|
205 |
+
if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters:
|
206 |
+
# SAVE AND CHECKPOINT STUFF
|
207 |
+
if np.isnan(loss.mean().item()):
|
208 |
+
if self.is_main_node and self.config.wandb_project is not None:
|
209 |
+
tqdm.write("Skipping sampling & checkpoint because the loss is NaN")
|
210 |
+
wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN")
|
211 |
+
else:
|
212 |
+
self.save_checkpoints(models, optimizers)
|
213 |
+
if self.is_main_node:
|
214 |
+
create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
|
215 |
+
self.sample(models, data, extras)
|
216 |
+
|
217 |
+
def models_to_save(self):
|
218 |
+
return ['generator', 'generator_ema']
|
219 |
+
|
220 |
+
def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
|
221 |
+
barrier()
|
222 |
+
suffix = '' if suffix is None else suffix
|
223 |
+
self.save_info(self.info, suffix=suffix)
|
224 |
+
models_dict = models.to_dict()
|
225 |
+
optimizers_dict = optimizers.to_dict()
|
226 |
+
for key in self.models_to_save():
|
227 |
+
model = models_dict[key]
|
228 |
+
if model is not None:
|
229 |
+
self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp)
|
230 |
+
for key in optimizers_dict:
|
231 |
+
optimizer = optimizers_dict[key]
|
232 |
+
if optimizer is not None:
|
233 |
+
self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None)
|
234 |
+
if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0:
|
235 |
+
self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k")
|
236 |
+
torch.cuda.empty_cache()
|
core/utils/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_dto import Base, nested_dto, EXPECTED, EXPECTED_TRAIN
|
2 |
+
from .save_and_load import create_folder_if_necessary, safe_save, load_or_fail
|
3 |
+
|
4 |
+
# MOVE IT SOMERWHERE ELSE
|
5 |
+
def update_weights_ema(tgt_model, src_model, beta=0.999):
|
6 |
+
for self_params, src_params in zip(tgt_model.parameters(), src_model.parameters()):
|
7 |
+
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1-beta)
|
8 |
+
for self_buffers, src_buffers in zip(tgt_model.buffers(), src_model.buffers()):
|
9 |
+
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1-beta)
|
core/utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (763 Bytes). View file
|
|
core/utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (804 Bytes). View file
|
|
core/utils/__pycache__/base_dto.cpython-310.pyc
ADDED
Binary file (3.09 kB). View file
|
|
core/utils/__pycache__/base_dto.cpython-39.pyc
ADDED
Binary file (3.11 kB). View file
|
|
core/utils/__pycache__/save_and_load.cpython-310.pyc
ADDED
Binary file (2.19 kB). View file
|
|
core/utils/__pycache__/save_and_load.cpython-39.pyc
ADDED
Binary file (2.2 kB). View file
|
|
core/utils/base_dto.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from dataclasses import dataclass, _MISSING_TYPE
|
3 |
+
from munch import Munch
|
4 |
+
|
5 |
+
EXPECTED = "___REQUIRED___"
|
6 |
+
EXPECTED_TRAIN = "___REQUIRED_TRAIN___"
|
7 |
+
|
8 |
+
# pylint: disable=invalid-field-call
|
9 |
+
def nested_dto(x, raw=False):
|
10 |
+
return dataclasses.field(default_factory=lambda: x if raw else Munch.fromDict(x))
|
11 |
+
|
12 |
+
@dataclass(frozen=True)
|
13 |
+
class Base:
|
14 |
+
training: bool = None
|
15 |
+
def __new__(cls, **kwargs):
|
16 |
+
training = kwargs.get('training', True)
|
17 |
+
setteable_fields = cls.setteable_fields(**kwargs)
|
18 |
+
mandatory_fields = cls.mandatory_fields(**kwargs)
|
19 |
+
invalid_kwargs = [
|
20 |
+
{k: v} for k, v in kwargs.items() if k not in setteable_fields or v == EXPECTED or (v == EXPECTED_TRAIN and training is not False)
|
21 |
+
]
|
22 |
+
print(mandatory_fields)
|
23 |
+
assert (
|
24 |
+
len(invalid_kwargs) == 0
|
25 |
+
), f"Invalid fields detected when initializing this DTO: {invalid_kwargs}.\nDeclare this field and set it to None or EXPECTED in order to make it setteable."
|
26 |
+
missing_kwargs = [f for f in mandatory_fields if f not in kwargs]
|
27 |
+
assert (
|
28 |
+
len(missing_kwargs) == 0
|
29 |
+
), f"Required fields missing initializing this DTO: {missing_kwargs}."
|
30 |
+
return object.__new__(cls)
|
31 |
+
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def setteable_fields(cls, **kwargs):
|
35 |
+
return [f.name for f in dataclasses.fields(cls) if f.default is None or isinstance(f.default, _MISSING_TYPE) or f.default == EXPECTED or f.default == EXPECTED_TRAIN]
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def mandatory_fields(cls, **kwargs):
|
39 |
+
training = kwargs.get('training', True)
|
40 |
+
return [f.name for f in dataclasses.fields(cls) if isinstance(f.default, _MISSING_TYPE) and isinstance(f.default_factory, _MISSING_TYPE) or f.default == EXPECTED or (f.default == EXPECTED_TRAIN and training is not False)]
|
41 |
+
|
42 |
+
@classmethod
|
43 |
+
def from_dict(cls, kwargs):
|
44 |
+
for k in kwargs:
|
45 |
+
if isinstance(kwargs[k], (dict, list, tuple)):
|
46 |
+
kwargs[k] = Munch.fromDict(kwargs[k])
|
47 |
+
return cls(**kwargs)
|
48 |
+
|
49 |
+
def to_dict(self):
|
50 |
+
# selfdict = dataclasses.asdict(self) # needs to pickle stuff, doesn't support some more complex classes
|
51 |
+
selfdict = {}
|
52 |
+
for k in dataclasses.fields(self):
|
53 |
+
selfdict[k.name] = getattr(self, k.name)
|
54 |
+
if isinstance(selfdict[k.name], Munch):
|
55 |
+
selfdict[k.name] = selfdict[k.name].toDict()
|
56 |
+
return selfdict
|
core/utils/save_and_load.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
from pathlib import Path
|
5 |
+
import safetensors
|
6 |
+
import wandb
|
7 |
+
|
8 |
+
|
9 |
+
def create_folder_if_necessary(path):
|
10 |
+
path = "/".join(path.split("/")[:-1])
|
11 |
+
Path(path).mkdir(parents=True, exist_ok=True)
|
12 |
+
|
13 |
+
|
14 |
+
def safe_save(ckpt, path):
|
15 |
+
try:
|
16 |
+
os.remove(f"{path}.bak")
|
17 |
+
except OSError:
|
18 |
+
pass
|
19 |
+
try:
|
20 |
+
os.rename(path, f"{path}.bak")
|
21 |
+
except OSError:
|
22 |
+
pass
|
23 |
+
if path.endswith(".pt") or path.endswith(".ckpt"):
|
24 |
+
torch.save(ckpt, path)
|
25 |
+
elif path.endswith(".json"):
|
26 |
+
with open(path, "w", encoding="utf-8") as f:
|
27 |
+
json.dump(ckpt, f, indent=4)
|
28 |
+
elif path.endswith(".safetensors"):
|
29 |
+
safetensors.torch.save_file(ckpt, path)
|
30 |
+
else:
|
31 |
+
raise ValueError(f"File extension not supported: {path}")
|
32 |
+
|
33 |
+
|
34 |
+
def load_or_fail(path, wandb_run_id=None):
|
35 |
+
accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"]
|
36 |
+
try:
|
37 |
+
assert any(
|
38 |
+
[path.endswith(ext) for ext in accepted_extensions]
|
39 |
+
), f"Automatic loading not supported for this extension: {path}"
|
40 |
+
if not os.path.exists(path):
|
41 |
+
checkpoint = None
|
42 |
+
elif path.endswith(".pt") or path.endswith(".ckpt"):
|
43 |
+
checkpoint = torch.load(path, map_location="cpu")
|
44 |
+
elif path.endswith(".json"):
|
45 |
+
with open(path, "r", encoding="utf-8") as f:
|
46 |
+
checkpoint = json.load(f)
|
47 |
+
elif path.endswith(".safetensors"):
|
48 |
+
checkpoint = {}
|
49 |
+
with safetensors.safe_open(path, framework="pt", device="cpu") as f:
|
50 |
+
for key in f.keys():
|
51 |
+
checkpoint[key] = f.get_tensor(key)
|
52 |
+
return checkpoint
|
53 |
+
except Exception as e:
|
54 |
+
if wandb_run_id is not None:
|
55 |
+
wandb.alert(
|
56 |
+
title=f"Corrupt checkpoint for run {wandb_run_id}",
|
57 |
+
text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed",
|
58 |
+
)
|
59 |
+
raise e
|