Daniel Verdu commited on
Commit
0cb9530
0 Parent(s):

first commit2

Browse files
.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Colorizing_images
3
+ emoji: 📽
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: streamlit
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
10
+
11
+ # Configuration
12
+
13
+ `title`: _string_
14
+ Display title for the Space
15
+
16
+ `emoji`: _string_
17
+ Space emoji (emoji-only character allowed)
18
+
19
+ `colorFrom`: _string_
20
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
21
+
22
+ `colorTo`: _string_
23
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
+
25
+ `sdk`: _string_
26
+ Can be either `gradio` or `streamlit`
27
+
28
+ `sdk_version` : _string_
29
+ Only applicable for `streamlit` SDK.
30
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
31
+
32
+ `app_file`: _string_
33
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
34
+ Path is relative to the root of the repository.
35
+
36
+ `pinned`: _boolean_
37
+ Whether the Space stays on top of your list.
__pycache__/app_utils.cpython-38.pyc ADDED
Binary file (3.46 kB). View file
 
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #importing the libraries
2
+ import os, sys, re
3
+ import streamlit as st
4
+ import PIL
5
+ from PIL import Image
6
+ import cv2
7
+ import numpy as np
8
+ import uuid
9
+
10
+ import ssl
11
+ ssl._create_default_https_context = ssl._create_unverified_context
12
+
13
+ # Import torch libraries
14
+ import fastai
15
+ import torch
16
+
17
+ # Import util functions from app_utils
18
+ from app_utils import download
19
+ from app_utils import generate_random_filename
20
+ from app_utils import clean_me
21
+ from app_utils import clean_all
22
+ from app_utils import get_model_bin
23
+ from app_utils import convertToJPG
24
+
25
+ # Import util functions from deoldify
26
+ # NOTE: This must be the first call in order to work properly!
27
+ from deoldify import device
28
+ from deoldify.device_id import DeviceId
29
+ #choices: CPU, GPU0...GPU7
30
+ device.set(device=DeviceId.CPU)
31
+ from deoldify.visualize import *
32
+
33
+
34
+ ####### INPUT PARAMS ###########
35
+ model_folder = 'models/'
36
+ max_img_size = 800
37
+ ################################
38
+
39
+ @st.cache(allow_output_mutation=True)
40
+ def load_model(model_dir, option):
41
+ if option.lower() == 'artistic':
42
+ model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
43
+ get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth"))
44
+ colorizer = get_image_colorizer(artistic=True)
45
+ elif option.lower() == 'stable':
46
+ model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0"
47
+ get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth"))
48
+ colorizer = get_image_colorizer(artistic=False)
49
+
50
+ return colorizer
51
+
52
+ def resize_img(input_img, max_size):
53
+ img = input_img.copy()
54
+ img_height, img_width = img.shape[0],img.shape[1]
55
+
56
+ if max(img_height, img_width) > max_size:
57
+ if img_height > img_width:
58
+ new_width = img_width*(max_size/img_height)
59
+ new_height = max_size
60
+ resized_img = cv2.resize(img,(int(new_width), int(new_height)))
61
+ return resized_img
62
+
63
+ elif img_height <= img_width:
64
+ new_width = img_height*(max_size/img_width)
65
+ new_height = max_size
66
+ resized_img = cv2.resize(img,(int(new_width), int(new_height)))
67
+ return resized_img
68
+
69
+ return img
70
+
71
+ def get_image_download_link(img,filename,text):
72
+ button_uuid = str(uuid.uuid4()).replace('-', '')
73
+ button_id = re.sub('\d+', '', button_uuid)
74
+
75
+ custom_css = f"""
76
+ <style>
77
+ #{button_id} {{
78
+ background-color: rgb(255, 255, 255);
79
+ color: rgb(38, 39, 48);
80
+ padding: 0.25em 0.38em;
81
+ position: relative;
82
+ text-decoration: none;
83
+ border-radius: 4px;
84
+ border-width: 1px;
85
+ border-style: solid;
86
+ border-color: rgb(230, 234, 241);
87
+ border-image: initial;
88
+
89
+ }}
90
+ #{button_id}:hover {{
91
+ border-color: rgb(246, 51, 102);
92
+ color: rgb(246, 51, 102);
93
+ }}
94
+ #{button_id}:active {{
95
+ box-shadow: none;
96
+ background-color: rgb(246, 51, 102);
97
+ color: white;
98
+ }}
99
+ </style> """
100
+
101
+ buffered = BytesIO()
102
+ img.save(buffered, format="JPEG")
103
+ img_str = base64.b64encode(buffered.getvalue()).decode()
104
+ href = custom_css + f'<a href="data:file/txt;base64,{img_str}" id="{button_id}" download="{filename}">{text}</a>'
105
+ return href
106
+
107
+
108
+ # General configuration
109
+ # st.set_page_config(layout="centered")
110
+ st.set_page_config(layout="wide")
111
+ st.set_option('deprecation.showfileUploaderEncoding', False)
112
+ st.markdown('''
113
+ <style>
114
+ .uploadedFile {display: none}
115
+ <style>''',
116
+ unsafe_allow_html=True)
117
+
118
+ # Main window configuration
119
+ st.title("Black and white colorizer")
120
+ st.markdown("This app puts color into your black and white pictures")
121
+ title_message = st.empty()
122
+
123
+ title_message.markdown("**Model loading, please wait** ⌛")
124
+
125
+ # # Sidebar
126
+ color_option = st.sidebar.selectbox('Select colorizer mode',
127
+ ('Artistic', 'Stable'))
128
+
129
+ # st.sidebar.title('Model parameters')
130
+ # det_conf_thres = st.sidebar.slider("Detector confidence threshold", 0.1, 0.9, value=0.5, step=0.1)
131
+ # det_nms_thres = st.sidebar.slider("Non-maximum supression IoU", 0.1, 0.9, value=0.4, step=0.1)
132
+
133
+ # Load models
134
+ try:
135
+ colorizer = load_model(model_folder, color_option)
136
+ except Exception as e:
137
+ print(e)
138
+ colorizer = None
139
+ print('Error while loading the model. Please refresh the page')
140
+
141
+ if colorizer is not None:
142
+ print('Running colorizer')
143
+ title_message.markdown("**To begin, please upload an image** 👇")
144
+
145
+ #Choose your own image
146
+ uploaded_file = st.file_uploader("Upload a black and white photo", type=['png', 'jpg', 'jpeg'])
147
+
148
+ # show = st.image(use_column_width='auto')
149
+ input_img_pos = st.empty()
150
+ output_img_pos = st.empty()
151
+
152
+ if uploaded_file is not None:
153
+ img_name = uploaded_file.name
154
+
155
+ pil_img = PIL.Image.open(uploaded_file)
156
+ img_rgb = np.array(pil_img)
157
+
158
+ resized_img_rgb = resize_img(img_rgb, max_img_size)
159
+ resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
160
+
161
+ title_message.markdown("**Processing your image, please wait** ⌛")
162
+
163
+ output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
164
+
165
+ title_message.markdown("**To begin, please upload an image** 👇")
166
+
167
+ # Plot images
168
+ input_img_pos.image(resized_pil_img, 'Input image', use_column_width=True)
169
+ output_img_pos.image(output_pil_img, 'Output image', use_column_width=True)
170
+
171
+ st.markdown(get_image_download_link(output_pil_img, img_name, 'Download '+img_name), unsafe_allow_html=True)
app_utils.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import random
4
+ import _thread as thread
5
+ from uuid import uuid4
6
+ import urllib
7
+
8
+ import numpy as np
9
+ import skimage
10
+ from skimage.filters import gaussian
11
+ from PIL import Image
12
+
13
+ def compress_image(image, path_original):
14
+ size = 1920, 1080
15
+ width = 1920
16
+ height = 1080
17
+
18
+ name = os.path.basename(path_original).split('.')
19
+ first_name = os.path.join(os.path.dirname(path_original), name[0] + '.jpg')
20
+
21
+ if image.size[0] > width and image.size[1] > height:
22
+ image.thumbnail(size, Image.ANTIALIAS)
23
+ image.save(first_name, quality=85)
24
+ elif image.size[0] > width:
25
+ wpercent = (width/float(image.size[0]))
26
+ height = int((float(image.size[1])*float(wpercent)))
27
+ image = image.resize((width,height), Image.ANTIALIAS)
28
+ image.save(first_name,quality=85)
29
+ elif image.size[1] > height:
30
+ wpercent = (height/float(image.size[1]))
31
+ width = int((float(image.size[0])*float(wpercent)))
32
+ image = image.resize((width,height), Image.ANTIALIAS)
33
+ image.save(first_name, quality=85)
34
+ else:
35
+ image.save(first_name, quality=85)
36
+
37
+
38
+ def convertToJPG(path_original):
39
+ img = Image.open(path_original)
40
+ name = os.path.basename(path_original).split('.')
41
+ first_name = os.path.join(os.path.dirname(path_original), name[0] + '.jpg')
42
+
43
+ if img.format == "JPEG":
44
+ image = img.convert('RGB')
45
+ compress_image(image, path_original)
46
+ img.close()
47
+
48
+ elif img.format == "GIF":
49
+ i = img.convert("RGBA")
50
+ bg = Image.new("RGBA", i.size)
51
+ image = Image.composite(i, bg, i)
52
+ compress_image(image, path_original)
53
+ img.close()
54
+
55
+ elif img.format == "PNG":
56
+ try:
57
+ image = Image.new("RGB", img.size, (255,255,255))
58
+ image.paste(img,img)
59
+ compress_image(image, path_original)
60
+ except ValueError:
61
+ image = img.convert('RGB')
62
+ compress_image(image, path_original)
63
+
64
+ img.close()
65
+
66
+ elif img.format == "BMP":
67
+ image = img.convert('RGB')
68
+ compress_image(image, path_original)
69
+ img.close()
70
+
71
+
72
+
73
+ def blur(image, x0, x1, y0, y1, sigma=1, multichannel=True):
74
+ y0, y1 = min(y0, y1), max(y0, y1)
75
+ x0, x1 = min(x0, x1), max(x0, x1)
76
+ im = image.copy()
77
+ sub_im = im[y0:y1,x0:x1].copy()
78
+ blur_sub_im = gaussian(sub_im, sigma=sigma, multichannel=multichannel)
79
+ blur_sub_im = np.round(255 * blur_sub_im)
80
+ im[y0:y1,x0:x1] = blur_sub_im
81
+ return im
82
+
83
+
84
+
85
+ def download(url, filename):
86
+ data = requests.get(url).content
87
+ with open(filename, 'wb') as handler:
88
+ handler.write(data)
89
+
90
+ return filename
91
+
92
+
93
+ def generate_random_filename(upload_directory, extension):
94
+ filename = str(uuid4())
95
+ filename = os.path.join(upload_directory, filename + "." + extension)
96
+ return filename
97
+
98
+
99
+ def clean_me(filename):
100
+ if os.path.exists(filename):
101
+ os.remove(filename)
102
+
103
+
104
+ def clean_all(files):
105
+ for me in files:
106
+ clean_me(me)
107
+
108
+
109
+ def get_model_bin(url, output_path):
110
+ # print('Getting model dir: ', output_path)
111
+ if not os.path.exists(output_path):
112
+ print('Downloading model')
113
+
114
+ # Check if folder exists
115
+ output_folder = output_path.replace('\\','/').split('/')[0]
116
+ if not os.path.exists(output_folder):
117
+ os.makedirs(output_folder, exist_ok=True)
118
+
119
+ urllib.request.urlretrieve(url, output_path)
120
+
121
+ # cmd = "wget -O %s %s" % (output_path, url)
122
+ # print(cmd)
123
+ # os.system(cmd)
124
+ else:
125
+ print('Model exists')
126
+
127
+ return output_path
128
+
129
+
130
+ #model_list = [(url, output_path), (url, output_path)]
131
+ def get_multi_model_bin(model_list):
132
+ for m in model_list:
133
+ thread.start_new_thread(get_model_bin, m)
134
+
deoldify/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from deoldify._device import _Device
2
+
3
+ device = _Device()
deoldify/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (272 Bytes). View file
 
deoldify/__pycache__/_device.cpython-38.pyc ADDED
Binary file (1.4 kB). View file
 
deoldify/__pycache__/augs.cpython-38.pyc ADDED
Binary file (937 Bytes). View file
 
deoldify/__pycache__/critics.cpython-38.pyc ADDED
Binary file (1.61 kB). View file
 
deoldify/__pycache__/dataset.cpython-38.pyc ADDED
Binary file (1.65 kB). View file
 
deoldify/__pycache__/device_id.cpython-38.pyc ADDED
Binary file (568 Bytes). View file
 
deoldify/__pycache__/filters.cpython-38.pyc ADDED
Binary file (4.99 kB). View file
 
deoldify/__pycache__/generators.cpython-38.pyc ADDED
Binary file (3.64 kB). View file
 
deoldify/__pycache__/layers.cpython-38.pyc ADDED
Binary file (1.53 kB). View file
 
deoldify/__pycache__/loss.cpython-38.pyc ADDED
Binary file (6.52 kB). View file
 
deoldify/__pycache__/unet.cpython-38.pyc ADDED
Binary file (8.14 kB). View file
 
deoldify/__pycache__/visualize.cpython-38.pyc ADDED
Binary file (6.83 kB). View file
 
deoldify/_device.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from enum import Enum
3
+ from .device_id import DeviceId
4
+
5
+ #NOTE: This must be called first before any torch imports in order to work properly!
6
+
7
+ class DeviceException(Exception):
8
+ pass
9
+
10
+ class _Device:
11
+ def __init__(self):
12
+ self.set(DeviceId.CPU)
13
+
14
+ def is_gpu(self):
15
+ ''' Returns `True` if the current device is GPU, `False` otherwise. '''
16
+ return self.current() is not DeviceId.CPU
17
+
18
+ def current(self):
19
+ return self._current_device
20
+
21
+ def set(self, device:DeviceId):
22
+ if device == DeviceId.CPU:
23
+ os.environ['CUDA_VISIBLE_DEVICES']=''
24
+ else:
25
+ os.environ['CUDA_VISIBLE_DEVICES']=str(device.value)
26
+ import torch
27
+ torch.backends.cudnn.benchmark=False
28
+
29
+ self._current_device = device
30
+ return device
deoldify/augs.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from fastai.vision.image import TfmPixel
4
+
5
+ # Contributed by Rani Horev. Thank you!
6
+ def _noisify(
7
+ x, pct_pixels_min: float = 0.001, pct_pixels_max: float = 0.4, noise_range: int = 30
8
+ ):
9
+ if noise_range > 255 or noise_range < 0:
10
+ raise Exception("noise_range must be between 0 and 255, inclusively.")
11
+
12
+ h, w = x.shape[1:]
13
+ img_size = h * w
14
+ mult = 10000.0
15
+ pct_pixels = (
16
+ random.randrange(int(pct_pixels_min * mult), int(pct_pixels_max * mult)) / mult
17
+ )
18
+ noise_count = int(img_size * pct_pixels)
19
+
20
+ for ii in range(noise_count):
21
+ yy = random.randrange(h)
22
+ xx = random.randrange(w)
23
+ noise = random.randrange(-noise_range, noise_range) / 255.0
24
+ x[:, yy, xx].add_(noise)
25
+
26
+ return x
27
+
28
+
29
+ noisify = TfmPixel(_noisify)
deoldify/critics.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.core import *
2
+ from fastai.torch_core import *
3
+ from fastai.vision import *
4
+ from fastai.vision.gan import AdaptiveLoss, accuracy_thresh_expand
5
+
6
+ _conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)
7
+
8
+
9
+ def _conv(ni: int, nf: int, ks: int = 3, stride: int = 1, **kwargs):
10
+ return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
11
+
12
+
13
+ def custom_gan_critic(
14
+ n_channels: int = 3, nf: int = 256, n_blocks: int = 3, p: int = 0.15
15
+ ):
16
+ "Critic to train a `GAN`."
17
+ layers = [_conv(n_channels, nf, ks=4, stride=2), nn.Dropout2d(p / 2)]
18
+ for i in range(n_blocks):
19
+ layers += [
20
+ _conv(nf, nf, ks=3, stride=1),
21
+ nn.Dropout2d(p),
22
+ _conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)),
23
+ ]
24
+ nf *= 2
25
+ layers += [
26
+ _conv(nf, nf, ks=3, stride=1),
27
+ _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
28
+ Flatten(),
29
+ ]
30
+ return nn.Sequential(*layers)
31
+
32
+
33
+ def colorize_crit_learner(
34
+ data: ImageDataBunch,
35
+ loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()),
36
+ nf: int = 256,
37
+ ) -> Learner:
38
+ return Learner(
39
+ data,
40
+ custom_gan_critic(nf=nf),
41
+ metrics=accuracy_thresh_expand,
42
+ loss_func=loss_critic,
43
+ wd=1e-3,
44
+ )
deoldify/dataset.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fastai
2
+ from fastai import *
3
+ from fastai.core import *
4
+ from fastai.vision.transform import get_transforms
5
+ from fastai.vision.data import ImageImageList, ImageDataBunch, imagenet_stats
6
+ from .augs import noisify
7
+
8
+
9
+ def get_colorize_data(
10
+ sz: int,
11
+ bs: int,
12
+ crappy_path: Path,
13
+ good_path: Path,
14
+ random_seed: int = None,
15
+ keep_pct: float = 1.0,
16
+ num_workers: int = 8,
17
+ stats: tuple = imagenet_stats,
18
+ xtra_tfms=[],
19
+ ) -> ImageDataBunch:
20
+
21
+ src = (
22
+ ImageImageList.from_folder(crappy_path, convert_mode='RGB')
23
+ .use_partial_data(sample_pct=keep_pct, seed=random_seed)
24
+ .split_by_rand_pct(0.1, seed=random_seed)
25
+ )
26
+
27
+ data = (
28
+ src.label_from_func(lambda x: good_path / x.relative_to(crappy_path))
29
+ .transform(
30
+ get_transforms(
31
+ max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms
32
+ ),
33
+ size=sz,
34
+ tfm_y=True,
35
+ )
36
+ .databunch(bs=bs, num_workers=num_workers, no_check=True)
37
+ .normalize(stats, do_y=True)
38
+ )
39
+
40
+ data.c = 3
41
+ return data
42
+
43
+
44
+ def get_dummy_databunch() -> ImageDataBunch:
45
+ path = Path('./dummy/')
46
+ return get_colorize_data(
47
+ sz=1, bs=1, crappy_path=path, good_path=path, keep_pct=0.001
48
+ )
deoldify/device_id.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import IntEnum
2
+
3
+ class DeviceId(IntEnum):
4
+ GPU0 = 0,
5
+ GPU1 = 1,
6
+ GPU2 = 2,
7
+ GPU3 = 3,
8
+ GPU4 = 4,
9
+ GPU5 = 5,
10
+ GPU6 = 6,
11
+ GPU7 = 7,
12
+ CPU = 99
deoldify/filters.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numpy import ndarray
2
+ from abc import ABC, abstractmethod
3
+ from .critics import colorize_crit_learner
4
+ from fastai.core import *
5
+ from fastai.vision import *
6
+ from fastai.vision.image import *
7
+ from fastai.vision.data import *
8
+ from fastai import *
9
+ import math
10
+ from scipy import misc
11
+ import cv2
12
+ from PIL import Image as PilImage
13
+
14
+
15
+ class IFilter(ABC):
16
+ @abstractmethod
17
+ def filter(
18
+ self, orig_image: PilImage, filtered_image: PilImage, render_factor: int
19
+ ) -> PilImage:
20
+ pass
21
+
22
+
23
+ class BaseFilter(IFilter):
24
+ def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
25
+ super().__init__()
26
+ self.learn = learn
27
+ self.device = next(self.learn.model.parameters()).device
28
+ self.norm, self.denorm = normalize_funcs(*stats)
29
+
30
+ def _transform(self, image: PilImage) -> PilImage:
31
+ return image
32
+
33
+ def _scale_to_square(self, orig: PilImage, targ: int) -> PilImage:
34
+ # a simple stretch to fit a square really makes a big difference in rendering quality/consistency.
35
+ # I've tried padding to the square as well (reflect, symetric, constant, etc). Not as good!
36
+ targ_sz = (targ, targ)
37
+ return orig.resize(targ_sz, resample=PIL.Image.BILINEAR)
38
+
39
+ def _get_model_ready_image(self, orig: PilImage, sz: int) -> PilImage:
40
+ result = self._scale_to_square(orig, sz)
41
+ result = self._transform(result)
42
+ return result
43
+
44
+ def _model_process(self, orig: PilImage, sz: int) -> PilImage:
45
+ model_image = self._get_model_ready_image(orig, sz)
46
+ x = pil2tensor(model_image, np.float32)
47
+ x = x.to(self.device)
48
+ x.div_(255)
49
+ x, y = self.norm((x, x), do_x=True)
50
+
51
+ try:
52
+ result = self.learn.pred_batch(
53
+ ds_type=DatasetType.Valid, batch=(x[None], y[None]), reconstruct=True
54
+ )
55
+ except RuntimeError as rerr:
56
+ if 'memory' not in str(rerr):
57
+ raise rerr
58
+ print('Warning: render_factor was set too high, and out of memory error resulted. Returning original image.')
59
+ return model_image
60
+
61
+ out = result[0]
62
+ out = self.denorm(out.px, do_x=False)
63
+ out = image2np(out * 255).astype(np.uint8)
64
+ return PilImage.fromarray(out)
65
+
66
+ def _unsquare(self, image: PilImage, orig: PilImage) -> PilImage:
67
+ targ_sz = orig.size
68
+ image = image.resize(targ_sz, resample=PIL.Image.BILINEAR)
69
+ return image
70
+
71
+
72
+ class ColorizerFilter(BaseFilter):
73
+ def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
74
+ super().__init__(learn=learn, stats=stats)
75
+ self.render_base = 16
76
+
77
+ def filter(
78
+ self, orig_image: PilImage, filtered_image: PilImage, render_factor: int, post_process: bool = True) -> PilImage:
79
+ render_sz = render_factor * self.render_base
80
+ model_image = self._model_process(orig=filtered_image, sz=render_sz)
81
+ raw_color = self._unsquare(model_image, orig_image)
82
+
83
+ if post_process:
84
+ return self._post_process(raw_color, orig_image)
85
+ else:
86
+ return raw_color
87
+
88
+ def _transform(self, image: PilImage) -> PilImage:
89
+ return image.convert('LA').convert('RGB')
90
+
91
+ # This takes advantage of the fact that human eyes are much less sensitive to
92
+ # imperfections in chrominance compared to luminance. This means we can
93
+ # save a lot on memory and processing in the model, yet get a great high
94
+ # resolution result at the end. This is primarily intended just for
95
+ # inference
96
+ def _post_process(self, raw_color: PilImage, orig: PilImage) -> PilImage:
97
+ color_np = np.asarray(raw_color)
98
+ orig_np = np.asarray(orig)
99
+ color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV)
100
+ # do a black and white transform first to get better luminance values
101
+ orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_BGR2YUV)
102
+ hires = np.copy(orig_yuv)
103
+ hires[:, :, 1:3] = color_yuv[:, :, 1:3]
104
+ final = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)
105
+ final = PilImage.fromarray(final)
106
+ return final
107
+
108
+
109
+ class MasterFilter(BaseFilter):
110
+ def __init__(self, filters: [IFilter], render_factor: int):
111
+ self.filters = filters
112
+ self.render_factor = render_factor
113
+
114
+ def filter(
115
+ self, orig_image: PilImage, filtered_image: PilImage, render_factor: int = None, post_process: bool = True) -> PilImage:
116
+ render_factor = self.render_factor if render_factor is None else render_factor
117
+ for filter in self.filters:
118
+ filtered_image = filter.filter(orig_image, filtered_image, render_factor, post_process)
119
+
120
+ return filtered_image
deoldify/generators.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.basics import F, nn
2
+ from fastai.basic_data import DataBunch
3
+ from fastai.basic_train import Learner
4
+ from fastai.layers import NormType
5
+ from fastai.torch_core import SplitFuncOrIdxList, to_device, apply_init
6
+ from fastai.vision import *
7
+ from fastai.vision.learner import cnn_config, create_body
8
+ from .unet import DynamicUnetWide, DynamicUnetDeep
9
+ from .loss import FeatureLoss
10
+ from .dataset import *
11
+
12
+ # Weights are implicitly read from ./models/ folder
13
+ def gen_inference_wide(
14
+ root_folder: Path, weights_name: str, nf_factor: int = 2,
15
+ arch=models.resnet101
16
+ ) -> Learner:
17
+
18
+ data = get_dummy_databunch()
19
+ learn = gen_learner_wide(data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch)
20
+ learn = get_inference(learn, root_folder, weights_name)
21
+ return learn
22
+
23
+ def gen_inference_deep(root_folder: Path, weights_name: str,
24
+ arch=models.resnet34, nf_factor: float = 1.5
25
+ ) -> Learner:
26
+
27
+ data = get_dummy_databunch()
28
+ learn = gen_learner_deep(data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor)
29
+ learn = get_inference(learn, root_folder, weights_name)
30
+ return learn
31
+
32
+ # Weights are implicitly read from ./models/ folder
33
+ # Load loads weights from os.path.join(learner.path, learner.model_dir, weights_name)
34
+ def get_inference(learn, root_folder, weights_name) -> Learner:
35
+ learn.path = root_folder
36
+ try:
37
+ learn.load(weights_name)
38
+ print('Model loaded successfully')
39
+ except Exception as e:
40
+ print(e)
41
+ print('Error while reading the model')
42
+ learn.model.eval()
43
+
44
+ return learn
45
+
46
+
47
+ def gen_learner_wide(
48
+ data: ImageDataBunch, gen_loss, arch=models.resnet101, nf_factor: int = 2
49
+ ) -> Learner:
50
+ return unet_learner_wide(
51
+ data,
52
+ arch=arch,
53
+ wd=1e-3,
54
+ blur=True,
55
+ norm_type=NormType.Spectral,
56
+ self_attention=True,
57
+ y_range=(-3.0, 3.0),
58
+ loss_func=gen_loss,
59
+ nf_factor=nf_factor,
60
+ )
61
+
62
+
63
+ # The code below is meant to be merged into fastaiv1 ideally
64
+ def unet_learner_wide(
65
+ data: DataBunch,
66
+ arch: Callable,
67
+ pretrained: bool = True,
68
+ blur_final: bool = True,
69
+ norm_type: Optional[NormType] = NormType,
70
+ split_on: Optional[SplitFuncOrIdxList] = None,
71
+ blur: bool = False,
72
+ self_attention: bool = False,
73
+ y_range: Optional[Tuple[float, float]] = None,
74
+ last_cross: bool = True,
75
+ bottle: bool = False,
76
+ nf_factor: int = 1,
77
+ **kwargs: Any
78
+ ) -> Learner:
79
+ "Build Unet learner from `data` and `arch`."
80
+ meta = cnn_config(arch)
81
+ body = create_body(arch, pretrained)
82
+ model = to_device(
83
+ DynamicUnetWide(
84
+ body,
85
+ n_classes=data.c,
86
+ blur=blur,
87
+ blur_final=blur_final,
88
+ self_attention=self_attention,
89
+ y_range=y_range,
90
+ norm_type=norm_type,
91
+ last_cross=last_cross,
92
+ bottle=bottle,
93
+ nf_factor=nf_factor,
94
+ ),
95
+ data.device,
96
+ )
97
+ learn = Learner(data, model, **kwargs)
98
+ learn.split(ifnone(split_on, meta['split']))
99
+ if pretrained:
100
+ learn.freeze()
101
+ apply_init(model[2], nn.init.kaiming_normal_)
102
+ return learn
103
+
104
+
105
+ # ----------------------------------------------------------------------
106
+
107
+ def gen_learner_deep(data: ImageDataBunch, gen_loss, arch=models.resnet34,
108
+ nf_factor: float = 1.5
109
+ ) -> Learner:
110
+
111
+ return unet_learner_deep(
112
+ data,
113
+ arch,
114
+ wd=1e-3,
115
+ blur=True,
116
+ norm_type=NormType.Spectral,
117
+ self_attention=True,
118
+ y_range=(-3.0, 3.0),
119
+ loss_func=gen_loss,
120
+ nf_factor=nf_factor,
121
+ )
122
+
123
+
124
+ # The code below is meant to be merged into fastaiv1 ideally
125
+ def unet_learner_deep(
126
+ data: DataBunch,
127
+ arch: Callable,
128
+ pretrained: bool = True,
129
+ blur_final: bool = True,
130
+ norm_type: Optional[NormType] = NormType,
131
+ split_on: Optional[SplitFuncOrIdxList] = None,
132
+ blur: bool = False,
133
+ self_attention: bool = False,
134
+ y_range: Optional[Tuple[float, float]] = None,
135
+ last_cross: bool = True,
136
+ bottle: bool = False,
137
+ nf_factor: float = 1.5,
138
+ **kwargs: Any
139
+ ) -> Learner:
140
+
141
+ "Build Unet learner from `data` and `arch`."
142
+ meta = cnn_config(arch)
143
+ body = create_body(arch, pretrained)
144
+ model = to_device(
145
+ DynamicUnetDeep(
146
+ body,
147
+ n_classes=data.c,
148
+ blur=blur,
149
+ blur_final=blur_final,
150
+ self_attention=self_attention,
151
+ y_range=y_range,
152
+ norm_type=norm_type,
153
+ last_cross=last_cross,
154
+ bottle=bottle,
155
+ nf_factor=nf_factor,
156
+ ),
157
+ data.device,
158
+ )
159
+ learn = Learner(data, model, **kwargs)
160
+ learn.split(ifnone(split_on, meta['split']))
161
+ if pretrained:
162
+ learn.freeze()
163
+ apply_init(model[2], nn.init.kaiming_normal_)
164
+ return learn
165
+
166
+
167
+ # -----------------------------
deoldify/layers.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.layers import *
2
+ from fastai.torch_core import *
3
+ from torch.nn.parameter import Parameter
4
+ from torch.autograd import Variable
5
+
6
+
7
+ # The code below is meant to be merged into fastaiv1 ideally
8
+
9
+
10
+ def custom_conv_layer(
11
+ ni: int,
12
+ nf: int,
13
+ ks: int = 3,
14
+ stride: int = 1,
15
+ padding: int = None,
16
+ bias: bool = None,
17
+ is_1d: bool = False,
18
+ norm_type: Optional[NormType] = NormType.Batch,
19
+ use_activ: bool = True,
20
+ leaky: float = None,
21
+ transpose: bool = False,
22
+ init: Callable = nn.init.kaiming_normal_,
23
+ self_attention: bool = False,
24
+ extra_bn: bool = False,
25
+ ):
26
+ "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
27
+ if padding is None:
28
+ padding = (ks - 1) // 2 if not transpose else 0
29
+ bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn == True
30
+ if bias is None:
31
+ bias = not bn
32
+ conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
33
+ conv = init_default(
34
+ conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding),
35
+ init,
36
+ )
37
+ if norm_type == NormType.Weight:
38
+ conv = weight_norm(conv)
39
+ elif norm_type == NormType.Spectral:
40
+ conv = spectral_norm(conv)
41
+ layers = [conv]
42
+ if use_activ:
43
+ layers.append(relu(True, leaky=leaky))
44
+ if bn:
45
+ layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
46
+ if self_attention:
47
+ layers.append(SelfAttention(nf))
48
+ return nn.Sequential(*layers)
deoldify/loss.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai import *
2
+ from fastai.core import *
3
+ from fastai.torch_core import *
4
+ from fastai.callbacks import hook_outputs
5
+ import torchvision.models as models
6
+
7
+
8
+ class FeatureLoss(nn.Module):
9
+ def __init__(self, layer_wgts=[20, 70, 10]):
10
+ super().__init__()
11
+
12
+ self.m_feat = models.vgg16_bn(True).features.cuda().eval()
13
+ requires_grad(self.m_feat, False)
14
+ blocks = [
15
+ i - 1
16
+ for i, o in enumerate(children(self.m_feat))
17
+ if isinstance(o, nn.MaxPool2d)
18
+ ]
19
+ layer_ids = blocks[2:5]
20
+ self.loss_features = [self.m_feat[i] for i in layer_ids]
21
+ self.hooks = hook_outputs(self.loss_features, detach=False)
22
+ self.wgts = layer_wgts
23
+ self.metric_names = ['pixel'] + [f'feat_{i}' for i in range(len(layer_ids))]
24
+ self.base_loss = F.l1_loss
25
+
26
+ def _make_features(self, x, clone=False):
27
+ self.m_feat(x)
28
+ return [(o.clone() if clone else o) for o in self.hooks.stored]
29
+
30
+ def forward(self, input, target):
31
+ out_feat = self._make_features(target, clone=True)
32
+ in_feat = self._make_features(input)
33
+ self.feat_losses = [self.base_loss(input, target)]
34
+ self.feat_losses += [
35
+ self.base_loss(f_in, f_out) * w
36
+ for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
37
+ ]
38
+
39
+ self.metrics = dict(zip(self.metric_names, self.feat_losses))
40
+ return sum(self.feat_losses)
41
+
42
+ def __del__(self):
43
+ self.hooks.remove()
44
+
45
+
46
+ # Refactored code, originally from https://github.com/VinceMarron/style_transfer
47
+ class WassFeatureLoss(nn.Module):
48
+ def __init__(self, layer_wgts=[5, 15, 2], wass_wgts=[3.0, 0.7, 0.01]):
49
+ super().__init__()
50
+ self.m_feat = models.vgg16_bn(True).features.cuda().eval()
51
+ requires_grad(self.m_feat, False)
52
+ blocks = [
53
+ i - 1
54
+ for i, o in enumerate(children(self.m_feat))
55
+ if isinstance(o, nn.MaxPool2d)
56
+ ]
57
+ layer_ids = blocks[2:5]
58
+ self.loss_features = [self.m_feat[i] for i in layer_ids]
59
+ self.hooks = hook_outputs(self.loss_features, detach=False)
60
+ self.wgts = layer_wgts
61
+ self.wass_wgts = wass_wgts
62
+ self.metric_names = (
63
+ ['pixel']
64
+ + [f'feat_{i}' for i in range(len(layer_ids))]
65
+ + [f'wass_{i}' for i in range(len(layer_ids))]
66
+ )
67
+ self.base_loss = F.l1_loss
68
+
69
+ def _make_features(self, x, clone=False):
70
+ self.m_feat(x)
71
+ return [(o.clone() if clone else o) for o in self.hooks.stored]
72
+
73
+ def _calc_2_moments(self, tensor):
74
+ chans = tensor.shape[1]
75
+ tensor = tensor.view(1, chans, -1)
76
+ n = tensor.shape[2]
77
+ mu = tensor.mean(2)
78
+ tensor = (tensor - mu[:, :, None]).squeeze(0)
79
+ # Prevents nasty bug that happens very occassionally- divide by zero. Why such things happen?
80
+ if n == 0:
81
+ return None, None
82
+ cov = torch.mm(tensor, tensor.t()) / float(n)
83
+ return mu, cov
84
+
85
+ def _get_style_vals(self, tensor):
86
+ mean, cov = self._calc_2_moments(tensor)
87
+ if mean is None:
88
+ return None, None, None
89
+ eigvals, eigvects = torch.symeig(cov, eigenvectors=True)
90
+ eigroot_mat = torch.diag(torch.sqrt(eigvals.clamp(min=0)))
91
+ root_cov = torch.mm(torch.mm(eigvects, eigroot_mat), eigvects.t())
92
+ tr_cov = eigvals.clamp(min=0).sum()
93
+ return mean, tr_cov, root_cov
94
+
95
+ def _calc_l2wass_dist(
96
+ self, mean_stl, tr_cov_stl, root_cov_stl, mean_synth, cov_synth
97
+ ):
98
+ tr_cov_synth = torch.symeig(cov_synth, eigenvectors=True)[0].clamp(min=0).sum()
99
+ mean_diff_squared = (mean_stl - mean_synth).pow(2).sum()
100
+ cov_prod = torch.mm(torch.mm(root_cov_stl, cov_synth), root_cov_stl)
101
+ var_overlap = torch.sqrt(
102
+ torch.symeig(cov_prod, eigenvectors=True)[0].clamp(min=0) + 1e-8
103
+ ).sum()
104
+ dist = mean_diff_squared + tr_cov_stl + tr_cov_synth - 2 * var_overlap
105
+ return dist
106
+
107
+ def _single_wass_loss(self, pred, targ):
108
+ mean_test, tr_cov_test, root_cov_test = targ
109
+ mean_synth, cov_synth = self._calc_2_moments(pred)
110
+ loss = self._calc_l2wass_dist(
111
+ mean_test, tr_cov_test, root_cov_test, mean_synth, cov_synth
112
+ )
113
+ return loss
114
+
115
+ def forward(self, input, target):
116
+ out_feat = self._make_features(target, clone=True)
117
+ in_feat = self._make_features(input)
118
+ self.feat_losses = [self.base_loss(input, target)]
119
+ self.feat_losses += [
120
+ self.base_loss(f_in, f_out) * w
121
+ for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
122
+ ]
123
+
124
+ styles = [self._get_style_vals(i) for i in out_feat]
125
+
126
+ if styles[0][0] is not None:
127
+ self.feat_losses += [
128
+ self._single_wass_loss(f_pred, f_targ) * w
129
+ for f_pred, f_targ, w in zip(in_feat, styles, self.wass_wgts)
130
+ ]
131
+
132
+ self.metrics = dict(zip(self.metric_names, self.feat_losses))
133
+ return sum(self.feat_losses)
134
+
135
+ def __del__(self):
136
+ self.hooks.remove()
deoldify/save.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.basic_train import Learner, LearnerCallback
2
+ from fastai.vision.gan import GANLearner
3
+
4
+
5
+ class GANSaveCallback(LearnerCallback):
6
+ """A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."""
7
+
8
+ def __init__(
9
+ self,
10
+ learn: GANLearner,
11
+ learn_gen: Learner,
12
+ filename: str,
13
+ save_iters: int = 1000,
14
+ ):
15
+ super().__init__(learn)
16
+ self.learn_gen = learn_gen
17
+ self.filename = filename
18
+ self.save_iters = save_iters
19
+
20
+ def on_batch_end(self, iteration: int, epoch: int, **kwargs) -> None:
21
+ if iteration == 0:
22
+ return
23
+
24
+ if iteration % self.save_iters == 0:
25
+ self._save_gen_learner(iteration=iteration, epoch=epoch)
26
+
27
+ def _save_gen_learner(self, iteration: int, epoch: int):
28
+ filename = '{}_{}_{}'.format(self.filename, epoch, iteration)
29
+ self.learn_gen.save(filename)
deoldify/unet.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.layers import *
2
+ from .layers import *
3
+ from fastai.torch_core import *
4
+ from fastai.callbacks.hooks import *
5
+ from fastai.vision import *
6
+
7
+
8
+ # The code below is meant to be merged into fastaiv1 ideally
9
+
10
+ __all__ = ['DynamicUnetDeep', 'DynamicUnetWide']
11
+
12
+
13
+ def _get_sfs_idxs(sizes: Sizes) -> List[int]:
14
+ "Get the indexes of the layers where the size of the activation changes."
15
+ feature_szs = [size[-1] for size in sizes]
16
+ sfs_idxs = list(
17
+ np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0]
18
+ )
19
+ if feature_szs[0] != feature_szs[1]:
20
+ sfs_idxs = [0] + sfs_idxs
21
+ return sfs_idxs
22
+
23
+
24
+ class CustomPixelShuffle_ICNR(nn.Module):
25
+ "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
26
+
27
+ def __init__(
28
+ self,
29
+ ni: int,
30
+ nf: int = None,
31
+ scale: int = 2,
32
+ blur: bool = False,
33
+ leaky: float = None,
34
+ **kwargs
35
+ ):
36
+ super().__init__()
37
+ nf = ifnone(nf, ni)
38
+ self.conv = custom_conv_layer(
39
+ ni, nf * (scale ** 2), ks=1, use_activ=False, **kwargs
40
+ )
41
+ icnr(self.conv[0].weight)
42
+ self.shuf = nn.PixelShuffle(scale)
43
+ # Blurring over (h*w) kernel
44
+ # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
45
+ # - https://arxiv.org/abs/1806.02658
46
+ self.pad = nn.ReplicationPad2d((1, 0, 1, 0))
47
+ self.blur = nn.AvgPool2d(2, stride=1)
48
+ self.relu = relu(True, leaky=leaky)
49
+
50
+ def forward(self, x):
51
+ x = self.shuf(self.relu(self.conv(x)))
52
+ return self.blur(self.pad(x)) if self.blur else x
53
+
54
+
55
+ class UnetBlockDeep(nn.Module):
56
+ "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
57
+
58
+ def __init__(
59
+ self,
60
+ up_in_c: int,
61
+ x_in_c: int,
62
+ hook: Hook,
63
+ final_div: bool = True,
64
+ blur: bool = False,
65
+ leaky: float = None,
66
+ self_attention: bool = False,
67
+ nf_factor: float = 1.0,
68
+ **kwargs
69
+ ):
70
+ super().__init__()
71
+ self.hook = hook
72
+ self.shuf = CustomPixelShuffle_ICNR(
73
+ up_in_c, up_in_c // 2, blur=blur, leaky=leaky, **kwargs
74
+ )
75
+ self.bn = batchnorm_2d(x_in_c)
76
+ ni = up_in_c // 2 + x_in_c
77
+ nf = int((ni if final_div else ni // 2) * nf_factor)
78
+ self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs)
79
+ self.conv2 = custom_conv_layer(
80
+ nf, nf, leaky=leaky, self_attention=self_attention, **kwargs
81
+ )
82
+ self.relu = relu(leaky=leaky)
83
+
84
+ def forward(self, up_in: Tensor) -> Tensor:
85
+ s = self.hook.stored
86
+ up_out = self.shuf(up_in)
87
+ ssh = s.shape[-2:]
88
+ if ssh != up_out.shape[-2:]:
89
+ up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
90
+ cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
91
+ return self.conv2(self.conv1(cat_x))
92
+
93
+
94
+ class DynamicUnetDeep(SequentialEx):
95
+ "Create a U-Net from a given architecture."
96
+
97
+ def __init__(
98
+ self,
99
+ encoder: nn.Module,
100
+ n_classes: int,
101
+ blur: bool = False,
102
+ blur_final=True,
103
+ self_attention: bool = False,
104
+ y_range: Optional[Tuple[float, float]] = None,
105
+ last_cross: bool = True,
106
+ bottle: bool = False,
107
+ norm_type: Optional[NormType] = NormType.Batch,
108
+ nf_factor: float = 1.0,
109
+ **kwargs
110
+ ):
111
+ extra_bn = norm_type == NormType.Spectral
112
+ imsize = (256, 256)
113
+ sfs_szs = model_sizes(encoder, size=imsize)
114
+ sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
115
+ self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
116
+ x = dummy_eval(encoder, imsize).detach()
117
+
118
+ ni = sfs_szs[-1][1]
119
+ middle_conv = nn.Sequential(
120
+ custom_conv_layer(
121
+ ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
122
+ ),
123
+ custom_conv_layer(
124
+ ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
125
+ ),
126
+ ).eval()
127
+ x = middle_conv(x)
128
+ layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
129
+
130
+ for i, idx in enumerate(sfs_idxs):
131
+ not_final = i != len(sfs_idxs) - 1
132
+ up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
133
+ do_blur = blur and (not_final or blur_final)
134
+ sa = self_attention and (i == len(sfs_idxs) - 3)
135
+ unet_block = UnetBlockDeep(
136
+ up_in_c,
137
+ x_in_c,
138
+ self.sfs[i],
139
+ final_div=not_final,
140
+ blur=blur,
141
+ self_attention=sa,
142
+ norm_type=norm_type,
143
+ extra_bn=extra_bn,
144
+ nf_factor=nf_factor,
145
+ **kwargs
146
+ ).eval()
147
+ layers.append(unet_block)
148
+ x = unet_block(x)
149
+
150
+ ni = x.shape[1]
151
+ if imsize != sfs_szs[0][-2:]:
152
+ layers.append(PixelShuffle_ICNR(ni, **kwargs))
153
+ if last_cross:
154
+ layers.append(MergeLayer(dense=True))
155
+ ni += in_channels(encoder)
156
+ layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
157
+ layers += [
158
+ custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
159
+ ]
160
+ if y_range is not None:
161
+ layers.append(SigmoidRange(*y_range))
162
+ super().__init__(*layers)
163
+
164
+ def __del__(self):
165
+ if hasattr(self, "sfs"):
166
+ self.sfs.remove()
167
+
168
+
169
+ # ------------------------------------------------------
170
+ class UnetBlockWide(nn.Module):
171
+ "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
172
+
173
+ def __init__(
174
+ self,
175
+ up_in_c: int,
176
+ x_in_c: int,
177
+ n_out: int,
178
+ hook: Hook,
179
+ final_div: bool = True,
180
+ blur: bool = False,
181
+ leaky: float = None,
182
+ self_attention: bool = False,
183
+ **kwargs
184
+ ):
185
+ super().__init__()
186
+ self.hook = hook
187
+ up_out = x_out = n_out // 2
188
+ self.shuf = CustomPixelShuffle_ICNR(
189
+ up_in_c, up_out, blur=blur, leaky=leaky, **kwargs
190
+ )
191
+ self.bn = batchnorm_2d(x_in_c)
192
+ ni = up_out + x_in_c
193
+ self.conv = custom_conv_layer(
194
+ ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs
195
+ )
196
+ self.relu = relu(leaky=leaky)
197
+
198
+ def forward(self, up_in: Tensor) -> Tensor:
199
+ s = self.hook.stored
200
+ up_out = self.shuf(up_in)
201
+ ssh = s.shape[-2:]
202
+ if ssh != up_out.shape[-2:]:
203
+ up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
204
+ cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
205
+ return self.conv(cat_x)
206
+
207
+
208
+ class DynamicUnetWide(SequentialEx):
209
+ "Create a U-Net from a given architecture."
210
+
211
+ def __init__(
212
+ self,
213
+ encoder: nn.Module,
214
+ n_classes: int,
215
+ blur: bool = False,
216
+ blur_final=True,
217
+ self_attention: bool = False,
218
+ y_range: Optional[Tuple[float, float]] = None,
219
+ last_cross: bool = True,
220
+ bottle: bool = False,
221
+ norm_type: Optional[NormType] = NormType.Batch,
222
+ nf_factor: int = 1,
223
+ **kwargs
224
+ ):
225
+
226
+ nf = 512 * nf_factor
227
+ extra_bn = norm_type == NormType.Spectral
228
+ imsize = (256, 256)
229
+ sfs_szs = model_sizes(encoder, size=imsize)
230
+ sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
231
+ self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
232
+ x = dummy_eval(encoder, imsize).detach()
233
+
234
+ ni = sfs_szs[-1][1]
235
+ middle_conv = nn.Sequential(
236
+ custom_conv_layer(
237
+ ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
238
+ ),
239
+ custom_conv_layer(
240
+ ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
241
+ ),
242
+ ).eval()
243
+ x = middle_conv(x)
244
+ layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
245
+
246
+ for i, idx in enumerate(sfs_idxs):
247
+ not_final = i != len(sfs_idxs) - 1
248
+ up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
249
+ do_blur = blur and (not_final or blur_final)
250
+ sa = self_attention and (i == len(sfs_idxs) - 3)
251
+
252
+ n_out = nf if not_final else nf // 2
253
+
254
+ unet_block = UnetBlockWide(
255
+ up_in_c,
256
+ x_in_c,
257
+ n_out,
258
+ self.sfs[i],
259
+ final_div=not_final,
260
+ blur=blur,
261
+ self_attention=sa,
262
+ norm_type=norm_type,
263
+ extra_bn=extra_bn,
264
+ **kwargs
265
+ ).eval()
266
+ layers.append(unet_block)
267
+ x = unet_block(x)
268
+
269
+ ni = x.shape[1]
270
+ if imsize != sfs_szs[0][-2:]:
271
+ layers.append(PixelShuffle_ICNR(ni, **kwargs))
272
+ if last_cross:
273
+ layers.append(MergeLayer(dense=True))
274
+ ni += in_channels(encoder)
275
+ layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
276
+ layers += [
277
+ custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
278
+ ]
279
+ if y_range is not None:
280
+ layers.append(SigmoidRange(*y_range))
281
+ super().__init__(*layers)
282
+
283
+ def __del__(self):
284
+ if hasattr(self, "sfs"):
285
+ self.sfs.remove()
deoldify/visualize.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import gc
3
+ import requests
4
+ from io import BytesIO
5
+ import base64
6
+ from scipy import misc
7
+ from PIL import Image
8
+ from matplotlib.axes import Axes
9
+ from matplotlib.figure import Figure
10
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
11
+ from typing import Tuple
12
+
13
+ import torch
14
+ from fastai.core import *
15
+ from fastai.vision import *
16
+
17
+ from .filters import IFilter, MasterFilter, ColorizerFilter
18
+ from .generators import gen_inference_deep, gen_inference_wide
19
+
20
+
21
+
22
+ class ModelImageVisualizer:
23
+ def __init__(self, filter: IFilter, results_dir: str = None):
24
+ self.filter = filter
25
+ self.results_dir = None if results_dir is None else Path(results_dir)
26
+
27
+ if self.results_dir is not None:
28
+ self.results_dir.mkdir(parents=True, exist_ok=True)
29
+
30
+ def _clean_mem(self):
31
+ torch.cuda.empty_cache()
32
+ # gc.collect()
33
+
34
+ def _open_pil_image(self, path: Path) -> Image:
35
+ return Image.open(path).convert('RGB')
36
+
37
+ def _get_image_from_url(self, url: str) -> Image:
38
+ response = requests.get(url, timeout=30, headers={'Accept': '*/*;q=0.8'})
39
+ img = Image.open(BytesIO(response.content)).convert('RGB')
40
+ return img
41
+
42
+ def plot_transformed_image_from_url(
43
+ self,
44
+ url: str,
45
+ path: str = 'test_images/image.png',
46
+ results_dir:Path = None,
47
+ figsize: Tuple[int, int] = (20, 20),
48
+ render_factor: int = None,
49
+
50
+ display_render_factor: bool = False,
51
+ compare: bool = False,
52
+ post_process: bool = True,
53
+ watermarked: bool = True,
54
+ ) -> Path:
55
+ img = self._get_image_from_url(url)
56
+ img.save(path)
57
+ return self.plot_transformed_image(
58
+ path=path,
59
+ results_dir=results_dir,
60
+ figsize=figsize,
61
+ render_factor=render_factor,
62
+ display_render_factor=display_render_factor,
63
+ compare=compare,
64
+ post_process = post_process,
65
+ watermarked=watermarked,
66
+ )
67
+
68
+ def plot_transformed_image(
69
+ self,
70
+ path: str,
71
+ results_dir:Path = None,
72
+ figsize: Tuple[int, int] = (20, 20),
73
+ render_factor: int = None,
74
+ display_render_factor: bool = False,
75
+ compare: bool = False,
76
+ post_process: bool = True,
77
+ watermarked: bool = True,
78
+ ) -> Path:
79
+ path = Path(path)
80
+ if results_dir is None:
81
+ results_dir = Path(self.results_dir)
82
+ result = self.get_transformed_image(
83
+ path, render_factor, post_process=post_process,watermarked=watermarked
84
+ )
85
+ orig = self._open_pil_image(path)
86
+ if compare:
87
+ self._plot_comparison(
88
+ figsize, render_factor, display_render_factor, orig, result
89
+ )
90
+ else:
91
+ self._plot_solo(figsize, render_factor, display_render_factor, result)
92
+
93
+ orig.close()
94
+ result_path = self._save_result_image(path, result, results_dir=results_dir)
95
+ result.close()
96
+ return result_path
97
+
98
+ def plot_transformed_pil_image(
99
+ self,
100
+ input_image: Image,
101
+ figsize: Tuple[int, int] = (20, 20),
102
+ render_factor: int = None,
103
+ display_render_factor: bool = False,
104
+ compare: bool = False,
105
+ post_process: bool = True,
106
+ ) -> Image:
107
+
108
+ result = self.get_transformed_pil_image(
109
+ input_image, render_factor, post_process=post_process
110
+ )
111
+
112
+ if compare:
113
+ self._plot_comparison(
114
+ figsize, render_factor, display_render_factor, input_image, result
115
+ )
116
+ else:
117
+ self._plot_solo(figsize, render_factor, display_render_factor, result)
118
+
119
+ return result
120
+
121
+ def _plot_comparison(
122
+ self,
123
+ figsize: Tuple[int, int],
124
+ render_factor: int,
125
+ display_render_factor: bool,
126
+ orig: Image,
127
+ result: Image,
128
+ ):
129
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
130
+ self._plot_image(
131
+ orig,
132
+ axes=axes[0],
133
+ figsize=figsize,
134
+ render_factor=render_factor,
135
+ display_render_factor=False,
136
+ )
137
+ self._plot_image(
138
+ result,
139
+ axes=axes[1],
140
+ figsize=figsize,
141
+ render_factor=render_factor,
142
+ display_render_factor=display_render_factor,
143
+ )
144
+
145
+ def _plot_solo(
146
+ self,
147
+ figsize: Tuple[int, int],
148
+ render_factor: int,
149
+ display_render_factor: bool,
150
+ result: Image,
151
+ ):
152
+ fig, axes = plt.subplots(1, 1, figsize=figsize)
153
+ self._plot_image(
154
+ result,
155
+ axes=axes,
156
+ figsize=figsize,
157
+ render_factor=render_factor,
158
+ display_render_factor=display_render_factor,
159
+ )
160
+
161
+ def _save_result_image(self, source_path: Path, image: Image, results_dir = None) -> Path:
162
+ if results_dir is None:
163
+ results_dir = Path(self.results_dir)
164
+ result_path = results_dir / source_path.name
165
+ image.save(result_path)
166
+ return result_path
167
+
168
+ def get_transformed_image(
169
+ self, path: Path, render_factor: int = None, post_process: bool = True,
170
+ watermarked: bool = True,
171
+ ) -> Image:
172
+ self._clean_mem()
173
+ orig_image = self._open_pil_image(path)
174
+ filtered_image = self.filter.filter(
175
+ orig_image, orig_image, render_factor=render_factor,post_process=post_process
176
+ )
177
+
178
+ return filtered_image
179
+
180
+ def get_transformed_pil_image(
181
+ self, input_image: Image, render_factor: int = None, post_process: bool = True,
182
+ ) -> Image:
183
+ self._clean_mem()
184
+ filtered_image = self.filter.filter(
185
+ input_image, input_image, render_factor=render_factor,post_process=post_process
186
+ )
187
+
188
+ return filtered_image
189
+
190
+ def _plot_image(
191
+ self,
192
+ image: Image,
193
+ render_factor: int,
194
+ axes: Axes = None,
195
+ figsize=(20, 20),
196
+ display_render_factor = False,
197
+ ):
198
+ if axes is None:
199
+ _, axes = plt.subplots(figsize=figsize)
200
+ axes.imshow(np.asarray(image) / 255)
201
+ axes.axis('off')
202
+ if render_factor is not None and display_render_factor:
203
+ plt.text(
204
+ 10,
205
+ 10,
206
+ 'render_factor: ' + str(render_factor),
207
+ color='white',
208
+ backgroundcolor='black',
209
+ )
210
+
211
+ def _get_num_rows_columns(self, num_images: int, max_columns: int) -> Tuple[int, int]:
212
+ columns = min(num_images, max_columns)
213
+ rows = num_images // columns
214
+ rows = rows if rows * columns == num_images else rows + 1
215
+ return rows, columns
216
+
217
+
218
+ def get_image_colorizer(root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True
219
+ ) -> ModelImageVisualizer:
220
+
221
+ if artistic:
222
+ return get_artistic_image_colorizer(root_folder=root_folder, render_factor=render_factor)
223
+ else:
224
+ return get_stable_image_colorizer(root_folder=root_folder, render_factor=render_factor)
225
+
226
+
227
+ def get_stable_image_colorizer(root_folder: Path = Path('./'), weights_name: str = 'ColorizeStable_gen',
228
+ results_dir='result_images', render_factor: int = 35
229
+ ) -> ModelImageVisualizer:
230
+
231
+ learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
232
+ filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
233
+ # vis = ModelImageVisualizer(filtr, results_dir=results_dir)
234
+ vis = ModelImageVisualizer(filtr)
235
+ return vis
236
+
237
+
238
+ def get_artistic_image_colorizer(root_folder: Path = Path('./'), weights_name: str = 'ColorizeArtistic_gen',
239
+ results_dir='result_images', render_factor: int = 35
240
+ ) -> ModelImageVisualizer:
241
+
242
+ learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
243
+ filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
244
+ # vis = ModelImageVisualizer(filtr, results_dir=results_dir)
245
+ vis = ModelImageVisualizer(filtr)
246
+ return vis
models/ColorizeArtistic_gen.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f750246fa220529323b85a8905f9b49c0e5d427099185334d048fb5b5e22477
3
+ size 255144681
packages.txt ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scipy==1.7.1
2
+ scikit_image==0.18.3
3
+ streamlit==0.88.0
4
+ requests==2.26.0
5
+ torch==1.9.0
6
+ torchvision==0.10.0
7
+ matplotlib==3.4.3
8
+ numpy==1.21.2
9
+ fastai==1.0.51
10
+ Pillow==8.3.2
11
+ opencv-python-headless
12
+