Daniel Verdu commited on
Commit
9e08039
·
1 Parent(s): 30cd7b1

first commit in hf_spaces

Browse files
app.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #importing the libraries
2
+ import os, sys, re
3
+ import streamlit as st
4
+ from PIL import Image
5
+ import cv2
6
+ import numpy as np
7
+ import uuid
8
+
9
+ # Import torch libraries
10
+ import fastai
11
+ import torch
12
+
13
+ # Import util functions from app_utils
14
+ from app_utils import download
15
+ from app_utils import generate_random_filename
16
+ from app_utils import clean_me
17
+ from app_utils import clean_all
18
+ from app_utils import create_directory
19
+ from app_utils import get_model_bin
20
+ from app_utils import convertToJPG
21
+
22
+ # Import util functions from deoldify
23
+ # NOTE: This must be the first call in order to work properly!
24
+ from deoldify import device
25
+ from deoldify.device_id import DeviceId
26
+ #choices: CPU, GPU0...GPU7
27
+ device.set(device=DeviceId.CPU)
28
+ from deoldify.visualize import *
29
+
30
+
31
+ ####### INPUT PARAMS ###########
32
+ model_folder = 'models/'
33
+ max_img_size = 800
34
+ ################################
35
+
36
+ @st.cache(allow_output_mutation=True)
37
+ def load_model(model_dir, option):
38
+ if option.lower() == 'artistic':
39
+ model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
40
+ get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth"))
41
+ colorizer = get_image_colorizer(artistic=True)
42
+ elif option.lower() == 'stable':
43
+ model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0"
44
+ get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth"))
45
+ colorizer = get_image_colorizer(artistic=False)
46
+
47
+ return colorizer
48
+
49
+ def resize_img(input_img, max_size):
50
+ img = input_img.copy()
51
+ img_height, img_width = img.shape[0],img.shape[1]
52
+
53
+ if max(img_height, img_width) > max_size:
54
+ if img_height > img_width:
55
+ new_width = img_width*(max_size/img_height)
56
+ new_height = max_size
57
+ resized_img = cv2.resize(img,(int(new_width), int(new_height)))
58
+ return resized_img
59
+
60
+ elif img_height <= img_width:
61
+ new_width = img_height*(max_size/img_width)
62
+ new_height = max_size
63
+ resized_img = cv2.resize(img,(int(new_width), int(new_height)))
64
+ return resized_img
65
+
66
+ return img
67
+
68
+ def get_image_download_link(img,filename,text):
69
+ button_uuid = str(uuid.uuid4()).replace('-', '')
70
+ button_id = re.sub('\d+', '', button_uuid)
71
+
72
+ custom_css = f"""
73
+ <style>
74
+ #{button_id} {{
75
+ background-color: rgb(255, 255, 255);
76
+ color: rgb(38, 39, 48);
77
+ padding: 0.25em 0.38em;
78
+ position: relative;
79
+ text-decoration: none;
80
+ border-radius: 4px;
81
+ border-width: 1px;
82
+ border-style: solid;
83
+ border-color: rgb(230, 234, 241);
84
+ border-image: initial;
85
+
86
+ }}
87
+ #{button_id}:hover {{
88
+ border-color: rgb(246, 51, 102);
89
+ color: rgb(246, 51, 102);
90
+ }}
91
+ #{button_id}:active {{
92
+ box-shadow: none;
93
+ background-color: rgb(246, 51, 102);
94
+ color: white;
95
+ }}
96
+ </style> """
97
+
98
+ buffered = BytesIO()
99
+ img.save(buffered, format="JPEG")
100
+ img_str = base64.b64encode(buffered.getvalue()).decode()
101
+ href = custom_css + f'<a href="data:file/txt;base64,{img_str}" id="{button_id}" download="{filename}">{text}</a>'
102
+ return href
103
+
104
+
105
+ # General configuration
106
+ st.set_page_config(layout="centered")
107
+ st.set_option('deprecation.showfileUploaderEncoding', False)
108
+ st.markdown('''
109
+ <style>
110
+ .uploadedFile {display: none}
111
+ <style>''',
112
+ unsafe_allow_html=True)
113
+
114
+ # Main window configuration
115
+ st.title("Black and white colorizer")
116
+ st.markdown("This app puts color into your black and white pictures")
117
+ title_message = st.empty()
118
+
119
+ title_message.markdown("**Model loading, please wait** ⌛")
120
+
121
+ # # Sidebar
122
+ color_option = st.sidebar.selectbox('Select colorizer mode',
123
+ ('Artistic', 'Stable'))
124
+
125
+ # st.sidebar.title('Model parameters')
126
+ # det_conf_thres = st.sidebar.slider("Detector confidence threshold", 0.1, 0.9, value=0.5, step=0.1)
127
+ # det_nms_thres = st.sidebar.slider("Non-maximum supression IoU", 0.1, 0.9, value=0.4, step=0.1)
128
+
129
+ # Load models
130
+ colorizer = load_model(model_folder, color_option)
131
+
132
+ title_message.markdown("**To begin, please upload an image** 👇")
133
+
134
+ #Choose your own image
135
+ uploaded_file = st.file_uploader("Upload a black and white photo", type=['png', 'jpg', 'jpeg'])
136
+
137
+ # show = st.image(use_column_width='auto')
138
+ input_img_pos = st.empty()
139
+ output_img_pos = st.empty()
140
+
141
+ if uploaded_file is not None:
142
+ img_name = uploaded_file.name
143
+
144
+ pil_img = Image.open(uploaded_file)
145
+ img_rgb = np.array(pil_img)
146
+
147
+ resized_img_rgb = resize_img(img_rgb, max_img_size)
148
+ resized_pil_img = Image.fromarray(resized_img_rgb)
149
+
150
+ title_message.markdown("**Processing your image, please wait** ⌛")
151
+
152
+ output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
153
+
154
+ title_message.markdown("**To begin, please upload an image** 👇")
155
+
156
+ # Plot images
157
+ input_img_pos.image(resized_pil_img, 'Input image', use_column_width=True)
158
+ output_img_pos.image(output_pil_img, 'Output image', use_column_width=True)
159
+
160
+ st.markdown(get_image_download_link(output_pil_img, img_name, 'Download '+img_name), unsafe_allow_html=True)
app_utils.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import random
4
+ import _thread as thread
5
+ from uuid import uuid4
6
+ import urllib
7
+
8
+ import numpy as np
9
+ import skimage
10
+ from skimage.filters import gaussian
11
+ from PIL import Image
12
+
13
+ def compress_image(image, path_original):
14
+ size = 1920, 1080
15
+ width = 1920
16
+ height = 1080
17
+
18
+ name = os.path.basename(path_original).split('.')
19
+ first_name = os.path.join(os.path.dirname(path_original), name[0] + '.jpg')
20
+
21
+ if image.size[0] > width and image.size[1] > height:
22
+ image.thumbnail(size, Image.ANTIALIAS)
23
+ image.save(first_name, quality=85)
24
+ elif image.size[0] > width:
25
+ wpercent = (width/float(image.size[0]))
26
+ height = int((float(image.size[1])*float(wpercent)))
27
+ image = image.resize((width,height), Image.ANTIALIAS)
28
+ image.save(first_name,quality=85)
29
+ elif image.size[1] > height:
30
+ wpercent = (height/float(image.size[1]))
31
+ width = int((float(image.size[0])*float(wpercent)))
32
+ image = image.resize((width,height), Image.ANTIALIAS)
33
+ image.save(first_name, quality=85)
34
+ else:
35
+ image.save(first_name, quality=85)
36
+
37
+
38
+ def convertToJPG(path_original):
39
+ img = Image.open(path_original)
40
+ name = os.path.basename(path_original).split('.')
41
+ first_name = os.path.join(os.path.dirname(path_original), name[0] + '.jpg')
42
+
43
+ if img.format == "JPEG":
44
+ image = img.convert('RGB')
45
+ compress_image(image, path_original)
46
+ img.close()
47
+
48
+ elif img.format == "GIF":
49
+ i = img.convert("RGBA")
50
+ bg = Image.new("RGBA", i.size)
51
+ image = Image.composite(i, bg, i)
52
+ compress_image(image, path_original)
53
+ img.close()
54
+
55
+ elif img.format == "PNG":
56
+ try:
57
+ image = Image.new("RGB", img.size, (255,255,255))
58
+ image.paste(img,img)
59
+ compress_image(image, path_original)
60
+ except ValueError:
61
+ image = img.convert('RGB')
62
+ compress_image(image, path_original)
63
+
64
+ img.close()
65
+
66
+ elif img.format == "BMP":
67
+ image = img.convert('RGB')
68
+ compress_image(image, path_original)
69
+ img.close()
70
+
71
+
72
+
73
+ def blur(image, x0, x1, y0, y1, sigma=1, multichannel=True):
74
+ y0, y1 = min(y0, y1), max(y0, y1)
75
+ x0, x1 = min(x0, x1), max(x0, x1)
76
+ im = image.copy()
77
+ sub_im = im[y0:y1,x0:x1].copy()
78
+ blur_sub_im = gaussian(sub_im, sigma=sigma, multichannel=multichannel)
79
+ blur_sub_im = np.round(255 * blur_sub_im)
80
+ im[y0:y1,x0:x1] = blur_sub_im
81
+ return im
82
+
83
+
84
+
85
+ def download(url, filename):
86
+ data = requests.get(url).content
87
+ with open(filename, 'wb') as handler:
88
+ handler.write(data)
89
+
90
+ return filename
91
+
92
+
93
+ def generate_random_filename(upload_directory, extension):
94
+ filename = str(uuid4())
95
+ filename = os.path.join(upload_directory, filename + "." + extension)
96
+ return filename
97
+
98
+
99
+ def clean_me(filename):
100
+ if os.path.exists(filename):
101
+ os.remove(filename)
102
+
103
+
104
+ def clean_all(files):
105
+ for me in files:
106
+ clean_me(me)
107
+
108
+
109
+ def create_directory(path):
110
+ os.makedirs(os.path.dirname(path), exist_ok=True)
111
+
112
+
113
+ def get_model_bin(url, output_path):
114
+ # print('Getting model dir: ', output_path)
115
+ if not os.path.exists(output_path):
116
+ create_directory(output_path)
117
+
118
+ urllib.request.urlretrieve(url, output_path)
119
+
120
+ # cmd = "wget -O %s %s" % (output_path, url)
121
+ # print(cmd)
122
+ # os.system(cmd)
123
+
124
+ return output_path
125
+
126
+
127
+ #model_list = [(url, output_path), (url, output_path)]
128
+ def get_multi_model_bin(model_list):
129
+ for m in model_list:
130
+ thread.start_new_thread(get_model_bin, m)
131
+
deoldify/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from deoldify._device import _Device
2
+
3
+ device = _Device()
deoldify/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (272 Bytes). View file
 
deoldify/__pycache__/_device.cpython-38.pyc ADDED
Binary file (1.4 kB). View file
 
deoldify/__pycache__/augs.cpython-38.pyc ADDED
Binary file (937 Bytes). View file
 
deoldify/__pycache__/critics.cpython-38.pyc ADDED
Binary file (1.61 kB). View file
 
deoldify/__pycache__/dataset.cpython-38.pyc ADDED
Binary file (1.65 kB). View file
 
deoldify/__pycache__/device_id.cpython-38.pyc ADDED
Binary file (568 Bytes). View file
 
deoldify/__pycache__/filters.cpython-38.pyc ADDED
Binary file (4.99 kB). View file
 
deoldify/__pycache__/generators.cpython-38.pyc ADDED
Binary file (3.12 kB). View file
 
deoldify/__pycache__/layers.cpython-38.pyc ADDED
Binary file (1.53 kB). View file
 
deoldify/__pycache__/loss.cpython-38.pyc ADDED
Binary file (6.52 kB). View file
 
deoldify/__pycache__/unet.cpython-38.pyc ADDED
Binary file (8.14 kB). View file
 
deoldify/__pycache__/visualize.cpython-38.pyc ADDED
Binary file (6.77 kB). View file
 
deoldify/_device.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from enum import Enum
3
+ from .device_id import DeviceId
4
+
5
+ #NOTE: This must be called first before any torch imports in order to work properly!
6
+
7
+ class DeviceException(Exception):
8
+ pass
9
+
10
+ class _Device:
11
+ def __init__(self):
12
+ self.set(DeviceId.CPU)
13
+
14
+ def is_gpu(self):
15
+ ''' Returns `True` if the current device is GPU, `False` otherwise. '''
16
+ return self.current() is not DeviceId.CPU
17
+
18
+ def current(self):
19
+ return self._current_device
20
+
21
+ def set(self, device:DeviceId):
22
+ if device == DeviceId.CPU:
23
+ os.environ['CUDA_VISIBLE_DEVICES']=''
24
+ else:
25
+ os.environ['CUDA_VISIBLE_DEVICES']=str(device.value)
26
+ import torch
27
+ torch.backends.cudnn.benchmark=False
28
+
29
+ self._current_device = device
30
+ return device
deoldify/augs.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from fastai.vision.image import TfmPixel
4
+
5
+ # Contributed by Rani Horev. Thank you!
6
+ def _noisify(
7
+ x, pct_pixels_min: float = 0.001, pct_pixels_max: float = 0.4, noise_range: int = 30
8
+ ):
9
+ if noise_range > 255 or noise_range < 0:
10
+ raise Exception("noise_range must be between 0 and 255, inclusively.")
11
+
12
+ h, w = x.shape[1:]
13
+ img_size = h * w
14
+ mult = 10000.0
15
+ pct_pixels = (
16
+ random.randrange(int(pct_pixels_min * mult), int(pct_pixels_max * mult)) / mult
17
+ )
18
+ noise_count = int(img_size * pct_pixels)
19
+
20
+ for ii in range(noise_count):
21
+ yy = random.randrange(h)
22
+ xx = random.randrange(w)
23
+ noise = random.randrange(-noise_range, noise_range) / 255.0
24
+ x[:, yy, xx].add_(noise)
25
+
26
+ return x
27
+
28
+
29
+ noisify = TfmPixel(_noisify)
deoldify/critics.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.core import *
2
+ from fastai.torch_core import *
3
+ from fastai.vision import *
4
+ from fastai.vision.gan import AdaptiveLoss, accuracy_thresh_expand
5
+
6
+ _conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)
7
+
8
+
9
+ def _conv(ni: int, nf: int, ks: int = 3, stride: int = 1, **kwargs):
10
+ return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
11
+
12
+
13
+ def custom_gan_critic(
14
+ n_channels: int = 3, nf: int = 256, n_blocks: int = 3, p: int = 0.15
15
+ ):
16
+ "Critic to train a `GAN`."
17
+ layers = [_conv(n_channels, nf, ks=4, stride=2), nn.Dropout2d(p / 2)]
18
+ for i in range(n_blocks):
19
+ layers += [
20
+ _conv(nf, nf, ks=3, stride=1),
21
+ nn.Dropout2d(p),
22
+ _conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)),
23
+ ]
24
+ nf *= 2
25
+ layers += [
26
+ _conv(nf, nf, ks=3, stride=1),
27
+ _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
28
+ Flatten(),
29
+ ]
30
+ return nn.Sequential(*layers)
31
+
32
+
33
+ def colorize_crit_learner(
34
+ data: ImageDataBunch,
35
+ loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()),
36
+ nf: int = 256,
37
+ ) -> Learner:
38
+ return Learner(
39
+ data,
40
+ custom_gan_critic(nf=nf),
41
+ metrics=accuracy_thresh_expand,
42
+ loss_func=loss_critic,
43
+ wd=1e-3,
44
+ )
deoldify/dataset.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fastai
2
+ from fastai import *
3
+ from fastai.core import *
4
+ from fastai.vision.transform import get_transforms
5
+ from fastai.vision.data import ImageImageList, ImageDataBunch, imagenet_stats
6
+ from .augs import noisify
7
+
8
+
9
+ def get_colorize_data(
10
+ sz: int,
11
+ bs: int,
12
+ crappy_path: Path,
13
+ good_path: Path,
14
+ random_seed: int = None,
15
+ keep_pct: float = 1.0,
16
+ num_workers: int = 8,
17
+ stats: tuple = imagenet_stats,
18
+ xtra_tfms=[],
19
+ ) -> ImageDataBunch:
20
+
21
+ src = (
22
+ ImageImageList.from_folder(crappy_path, convert_mode='RGB')
23
+ .use_partial_data(sample_pct=keep_pct, seed=random_seed)
24
+ .split_by_rand_pct(0.1, seed=random_seed)
25
+ )
26
+
27
+ data = (
28
+ src.label_from_func(lambda x: good_path / x.relative_to(crappy_path))
29
+ .transform(
30
+ get_transforms(
31
+ max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms
32
+ ),
33
+ size=sz,
34
+ tfm_y=True,
35
+ )
36
+ .databunch(bs=bs, num_workers=num_workers, no_check=True)
37
+ .normalize(stats, do_y=True)
38
+ )
39
+
40
+ data.c = 3
41
+ return data
42
+
43
+
44
+ def get_dummy_databunch() -> ImageDataBunch:
45
+ path = Path('./dummy/')
46
+ return get_colorize_data(
47
+ sz=1, bs=1, crappy_path=path, good_path=path, keep_pct=0.001
48
+ )
deoldify/device_id.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import IntEnum
2
+
3
+ class DeviceId(IntEnum):
4
+ GPU0 = 0,
5
+ GPU1 = 1,
6
+ GPU2 = 2,
7
+ GPU3 = 3,
8
+ GPU4 = 4,
9
+ GPU5 = 5,
10
+ GPU6 = 6,
11
+ GPU7 = 7,
12
+ CPU = 99
deoldify/filters.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numpy import ndarray
2
+ from abc import ABC, abstractmethod
3
+ from .critics import colorize_crit_learner
4
+ from fastai.core import *
5
+ from fastai.vision import *
6
+ from fastai.vision.image import *
7
+ from fastai.vision.data import *
8
+ from fastai import *
9
+ import math
10
+ from scipy import misc
11
+ import cv2
12
+ from PIL import Image as PilImage
13
+
14
+
15
+ class IFilter(ABC):
16
+ @abstractmethod
17
+ def filter(
18
+ self, orig_image: PilImage, filtered_image: PilImage, render_factor: int
19
+ ) -> PilImage:
20
+ pass
21
+
22
+
23
+ class BaseFilter(IFilter):
24
+ def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
25
+ super().__init__()
26
+ self.learn = learn
27
+ self.device = next(self.learn.model.parameters()).device
28
+ self.norm, self.denorm = normalize_funcs(*stats)
29
+
30
+ def _transform(self, image: PilImage) -> PilImage:
31
+ return image
32
+
33
+ def _scale_to_square(self, orig: PilImage, targ: int) -> PilImage:
34
+ # a simple stretch to fit a square really makes a big difference in rendering quality/consistency.
35
+ # I've tried padding to the square as well (reflect, symetric, constant, etc). Not as good!
36
+ targ_sz = (targ, targ)
37
+ return orig.resize(targ_sz, resample=PIL.Image.BILINEAR)
38
+
39
+ def _get_model_ready_image(self, orig: PilImage, sz: int) -> PilImage:
40
+ result = self._scale_to_square(orig, sz)
41
+ result = self._transform(result)
42
+ return result
43
+
44
+ def _model_process(self, orig: PilImage, sz: int) -> PilImage:
45
+ model_image = self._get_model_ready_image(orig, sz)
46
+ x = pil2tensor(model_image, np.float32)
47
+ x = x.to(self.device)
48
+ x.div_(255)
49
+ x, y = self.norm((x, x), do_x=True)
50
+
51
+ try:
52
+ result = self.learn.pred_batch(
53
+ ds_type=DatasetType.Valid, batch=(x[None], y[None]), reconstruct=True
54
+ )
55
+ except RuntimeError as rerr:
56
+ if 'memory' not in str(rerr):
57
+ raise rerr
58
+ print('Warning: render_factor was set too high, and out of memory error resulted. Returning original image.')
59
+ return model_image
60
+
61
+ out = result[0]
62
+ out = self.denorm(out.px, do_x=False)
63
+ out = image2np(out * 255).astype(np.uint8)
64
+ return PilImage.fromarray(out)
65
+
66
+ def _unsquare(self, image: PilImage, orig: PilImage) -> PilImage:
67
+ targ_sz = orig.size
68
+ image = image.resize(targ_sz, resample=PIL.Image.BILINEAR)
69
+ return image
70
+
71
+
72
+ class ColorizerFilter(BaseFilter):
73
+ def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
74
+ super().__init__(learn=learn, stats=stats)
75
+ self.render_base = 16
76
+
77
+ def filter(
78
+ self, orig_image: PilImage, filtered_image: PilImage, render_factor: int, post_process: bool = True) -> PilImage:
79
+ render_sz = render_factor * self.render_base
80
+ model_image = self._model_process(orig=filtered_image, sz=render_sz)
81
+ raw_color = self._unsquare(model_image, orig_image)
82
+
83
+ if post_process:
84
+ return self._post_process(raw_color, orig_image)
85
+ else:
86
+ return raw_color
87
+
88
+ def _transform(self, image: PilImage) -> PilImage:
89
+ return image.convert('LA').convert('RGB')
90
+
91
+ # This takes advantage of the fact that human eyes are much less sensitive to
92
+ # imperfections in chrominance compared to luminance. This means we can
93
+ # save a lot on memory and processing in the model, yet get a great high
94
+ # resolution result at the end. This is primarily intended just for
95
+ # inference
96
+ def _post_process(self, raw_color: PilImage, orig: PilImage) -> PilImage:
97
+ color_np = np.asarray(raw_color)
98
+ orig_np = np.asarray(orig)
99
+ color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV)
100
+ # do a black and white transform first to get better luminance values
101
+ orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_BGR2YUV)
102
+ hires = np.copy(orig_yuv)
103
+ hires[:, :, 1:3] = color_yuv[:, :, 1:3]
104
+ final = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)
105
+ final = PilImage.fromarray(final)
106
+ return final
107
+
108
+
109
+ class MasterFilter(BaseFilter):
110
+ def __init__(self, filters: [IFilter], render_factor: int):
111
+ self.filters = filters
112
+ self.render_factor = render_factor
113
+
114
+ def filter(
115
+ self, orig_image: PilImage, filtered_image: PilImage, render_factor: int = None, post_process: bool = True) -> PilImage:
116
+ render_factor = self.render_factor if render_factor is None else render_factor
117
+ for filter in self.filters:
118
+ filtered_image = filter.filter(orig_image, filtered_image, render_factor, post_process)
119
+
120
+ return filtered_image
deoldify/generators.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.vision import *
2
+ from fastai.vision.learner import cnn_config
3
+ from .unet import DynamicUnetWide, DynamicUnetDeep
4
+ from .loss import FeatureLoss
5
+ from .dataset import *
6
+
7
+ # Weights are implicitly read from ./models/ folder
8
+ def gen_inference_wide(
9
+ root_folder: Path, weights_name: str, nf_factor: int = 2, arch=models.resnet101) -> Learner:
10
+ data = get_dummy_databunch()
11
+ learn = gen_learner_wide(
12
+ data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch
13
+ )
14
+ learn.path = root_folder
15
+ learn.load(weights_name)
16
+ learn.model.eval()
17
+ return learn
18
+
19
+
20
+ def gen_learner_wide(
21
+ data: ImageDataBunch, gen_loss, arch=models.resnet101, nf_factor: int = 2
22
+ ) -> Learner:
23
+ return unet_learner_wide(
24
+ data,
25
+ arch=arch,
26
+ wd=1e-3,
27
+ blur=True,
28
+ norm_type=NormType.Spectral,
29
+ self_attention=True,
30
+ y_range=(-3.0, 3.0),
31
+ loss_func=gen_loss,
32
+ nf_factor=nf_factor,
33
+ )
34
+
35
+
36
+ # The code below is meant to be merged into fastaiv1 ideally
37
+ def unet_learner_wide(
38
+ data: DataBunch,
39
+ arch: Callable,
40
+ pretrained: bool = True,
41
+ blur_final: bool = True,
42
+ norm_type: Optional[NormType] = NormType,
43
+ split_on: Optional[SplitFuncOrIdxList] = None,
44
+ blur: bool = False,
45
+ self_attention: bool = False,
46
+ y_range: Optional[Tuple[float, float]] = None,
47
+ last_cross: bool = True,
48
+ bottle: bool = False,
49
+ nf_factor: int = 1,
50
+ **kwargs: Any
51
+ ) -> Learner:
52
+ "Build Unet learner from `data` and `arch`."
53
+ meta = cnn_config(arch)
54
+ body = create_body(arch, pretrained)
55
+ model = to_device(
56
+ DynamicUnetWide(
57
+ body,
58
+ n_classes=data.c,
59
+ blur=blur,
60
+ blur_final=blur_final,
61
+ self_attention=self_attention,
62
+ y_range=y_range,
63
+ norm_type=norm_type,
64
+ last_cross=last_cross,
65
+ bottle=bottle,
66
+ nf_factor=nf_factor,
67
+ ),
68
+ data.device,
69
+ )
70
+ learn = Learner(data, model, **kwargs)
71
+ learn.split(ifnone(split_on, meta['split']))
72
+ if pretrained:
73
+ learn.freeze()
74
+ apply_init(model[2], nn.init.kaiming_normal_)
75
+ return learn
76
+
77
+
78
+ # ----------------------------------------------------------------------
79
+
80
+ # Weights are implicitly read from ./models/ folder
81
+ def gen_inference_deep(
82
+ root_folder: Path, weights_name: str, arch=models.resnet34, nf_factor: float = 1.5) -> Learner:
83
+ data = get_dummy_databunch()
84
+ learn = gen_learner_deep(
85
+ data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor
86
+ )
87
+ learn.path = root_folder
88
+ learn.load(weights_name)
89
+ learn.model.eval()
90
+ return learn
91
+
92
+
93
+ def gen_learner_deep(
94
+ data: ImageDataBunch, gen_loss, arch=models.resnet34, nf_factor: float = 1.5
95
+ ) -> Learner:
96
+ return unet_learner_deep(
97
+ data,
98
+ arch,
99
+ wd=1e-3,
100
+ blur=True,
101
+ norm_type=NormType.Spectral,
102
+ self_attention=True,
103
+ y_range=(-3.0, 3.0),
104
+ loss_func=gen_loss,
105
+ nf_factor=nf_factor,
106
+ )
107
+
108
+
109
+ # The code below is meant to be merged into fastaiv1 ideally
110
+ def unet_learner_deep(
111
+ data: DataBunch,
112
+ arch: Callable,
113
+ pretrained: bool = True,
114
+ blur_final: bool = True,
115
+ norm_type: Optional[NormType] = NormType,
116
+ split_on: Optional[SplitFuncOrIdxList] = None,
117
+ blur: bool = False,
118
+ self_attention: bool = False,
119
+ y_range: Optional[Tuple[float, float]] = None,
120
+ last_cross: bool = True,
121
+ bottle: bool = False,
122
+ nf_factor: float = 1.5,
123
+ **kwargs: Any
124
+ ) -> Learner:
125
+ "Build Unet learner from `data` and `arch`."
126
+ meta = cnn_config(arch)
127
+ body = create_body(arch, pretrained)
128
+ model = to_device(
129
+ DynamicUnetDeep(
130
+ body,
131
+ n_classes=data.c,
132
+ blur=blur,
133
+ blur_final=blur_final,
134
+ self_attention=self_attention,
135
+ y_range=y_range,
136
+ norm_type=norm_type,
137
+ last_cross=last_cross,
138
+ bottle=bottle,
139
+ nf_factor=nf_factor,
140
+ ),
141
+ data.device,
142
+ )
143
+ learn = Learner(data, model, **kwargs)
144
+ learn.split(ifnone(split_on, meta['split']))
145
+ if pretrained:
146
+ learn.freeze()
147
+ apply_init(model[2], nn.init.kaiming_normal_)
148
+ return learn
149
+
150
+
151
+ # -----------------------------
deoldify/layers.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.layers import *
2
+ from fastai.torch_core import *
3
+ from torch.nn.parameter import Parameter
4
+ from torch.autograd import Variable
5
+
6
+
7
+ # The code below is meant to be merged into fastaiv1 ideally
8
+
9
+
10
+ def custom_conv_layer(
11
+ ni: int,
12
+ nf: int,
13
+ ks: int = 3,
14
+ stride: int = 1,
15
+ padding: int = None,
16
+ bias: bool = None,
17
+ is_1d: bool = False,
18
+ norm_type: Optional[NormType] = NormType.Batch,
19
+ use_activ: bool = True,
20
+ leaky: float = None,
21
+ transpose: bool = False,
22
+ init: Callable = nn.init.kaiming_normal_,
23
+ self_attention: bool = False,
24
+ extra_bn: bool = False,
25
+ ):
26
+ "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
27
+ if padding is None:
28
+ padding = (ks - 1) // 2 if not transpose else 0
29
+ bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn == True
30
+ if bias is None:
31
+ bias = not bn
32
+ conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
33
+ conv = init_default(
34
+ conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding),
35
+ init,
36
+ )
37
+ if norm_type == NormType.Weight:
38
+ conv = weight_norm(conv)
39
+ elif norm_type == NormType.Spectral:
40
+ conv = spectral_norm(conv)
41
+ layers = [conv]
42
+ if use_activ:
43
+ layers.append(relu(True, leaky=leaky))
44
+ if bn:
45
+ layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
46
+ if self_attention:
47
+ layers.append(SelfAttention(nf))
48
+ return nn.Sequential(*layers)
deoldify/loss.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai import *
2
+ from fastai.core import *
3
+ from fastai.torch_core import *
4
+ from fastai.callbacks import hook_outputs
5
+ import torchvision.models as models
6
+
7
+
8
+ class FeatureLoss(nn.Module):
9
+ def __init__(self, layer_wgts=[20, 70, 10]):
10
+ super().__init__()
11
+
12
+ self.m_feat = models.vgg16_bn(True).features.cuda().eval()
13
+ requires_grad(self.m_feat, False)
14
+ blocks = [
15
+ i - 1
16
+ for i, o in enumerate(children(self.m_feat))
17
+ if isinstance(o, nn.MaxPool2d)
18
+ ]
19
+ layer_ids = blocks[2:5]
20
+ self.loss_features = [self.m_feat[i] for i in layer_ids]
21
+ self.hooks = hook_outputs(self.loss_features, detach=False)
22
+ self.wgts = layer_wgts
23
+ self.metric_names = ['pixel'] + [f'feat_{i}' for i in range(len(layer_ids))]
24
+ self.base_loss = F.l1_loss
25
+
26
+ def _make_features(self, x, clone=False):
27
+ self.m_feat(x)
28
+ return [(o.clone() if clone else o) for o in self.hooks.stored]
29
+
30
+ def forward(self, input, target):
31
+ out_feat = self._make_features(target, clone=True)
32
+ in_feat = self._make_features(input)
33
+ self.feat_losses = [self.base_loss(input, target)]
34
+ self.feat_losses += [
35
+ self.base_loss(f_in, f_out) * w
36
+ for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
37
+ ]
38
+
39
+ self.metrics = dict(zip(self.metric_names, self.feat_losses))
40
+ return sum(self.feat_losses)
41
+
42
+ def __del__(self):
43
+ self.hooks.remove()
44
+
45
+
46
+ # Refactored code, originally from https://github.com/VinceMarron/style_transfer
47
+ class WassFeatureLoss(nn.Module):
48
+ def __init__(self, layer_wgts=[5, 15, 2], wass_wgts=[3.0, 0.7, 0.01]):
49
+ super().__init__()
50
+ self.m_feat = models.vgg16_bn(True).features.cuda().eval()
51
+ requires_grad(self.m_feat, False)
52
+ blocks = [
53
+ i - 1
54
+ for i, o in enumerate(children(self.m_feat))
55
+ if isinstance(o, nn.MaxPool2d)
56
+ ]
57
+ layer_ids = blocks[2:5]
58
+ self.loss_features = [self.m_feat[i] for i in layer_ids]
59
+ self.hooks = hook_outputs(self.loss_features, detach=False)
60
+ self.wgts = layer_wgts
61
+ self.wass_wgts = wass_wgts
62
+ self.metric_names = (
63
+ ['pixel']
64
+ + [f'feat_{i}' for i in range(len(layer_ids))]
65
+ + [f'wass_{i}' for i in range(len(layer_ids))]
66
+ )
67
+ self.base_loss = F.l1_loss
68
+
69
+ def _make_features(self, x, clone=False):
70
+ self.m_feat(x)
71
+ return [(o.clone() if clone else o) for o in self.hooks.stored]
72
+
73
+ def _calc_2_moments(self, tensor):
74
+ chans = tensor.shape[1]
75
+ tensor = tensor.view(1, chans, -1)
76
+ n = tensor.shape[2]
77
+ mu = tensor.mean(2)
78
+ tensor = (tensor - mu[:, :, None]).squeeze(0)
79
+ # Prevents nasty bug that happens very occassionally- divide by zero. Why such things happen?
80
+ if n == 0:
81
+ return None, None
82
+ cov = torch.mm(tensor, tensor.t()) / float(n)
83
+ return mu, cov
84
+
85
+ def _get_style_vals(self, tensor):
86
+ mean, cov = self._calc_2_moments(tensor)
87
+ if mean is None:
88
+ return None, None, None
89
+ eigvals, eigvects = torch.symeig(cov, eigenvectors=True)
90
+ eigroot_mat = torch.diag(torch.sqrt(eigvals.clamp(min=0)))
91
+ root_cov = torch.mm(torch.mm(eigvects, eigroot_mat), eigvects.t())
92
+ tr_cov = eigvals.clamp(min=0).sum()
93
+ return mean, tr_cov, root_cov
94
+
95
+ def _calc_l2wass_dist(
96
+ self, mean_stl, tr_cov_stl, root_cov_stl, mean_synth, cov_synth
97
+ ):
98
+ tr_cov_synth = torch.symeig(cov_synth, eigenvectors=True)[0].clamp(min=0).sum()
99
+ mean_diff_squared = (mean_stl - mean_synth).pow(2).sum()
100
+ cov_prod = torch.mm(torch.mm(root_cov_stl, cov_synth), root_cov_stl)
101
+ var_overlap = torch.sqrt(
102
+ torch.symeig(cov_prod, eigenvectors=True)[0].clamp(min=0) + 1e-8
103
+ ).sum()
104
+ dist = mean_diff_squared + tr_cov_stl + tr_cov_synth - 2 * var_overlap
105
+ return dist
106
+
107
+ def _single_wass_loss(self, pred, targ):
108
+ mean_test, tr_cov_test, root_cov_test = targ
109
+ mean_synth, cov_synth = self._calc_2_moments(pred)
110
+ loss = self._calc_l2wass_dist(
111
+ mean_test, tr_cov_test, root_cov_test, mean_synth, cov_synth
112
+ )
113
+ return loss
114
+
115
+ def forward(self, input, target):
116
+ out_feat = self._make_features(target, clone=True)
117
+ in_feat = self._make_features(input)
118
+ self.feat_losses = [self.base_loss(input, target)]
119
+ self.feat_losses += [
120
+ self.base_loss(f_in, f_out) * w
121
+ for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
122
+ ]
123
+
124
+ styles = [self._get_style_vals(i) for i in out_feat]
125
+
126
+ if styles[0][0] is not None:
127
+ self.feat_losses += [
128
+ self._single_wass_loss(f_pred, f_targ) * w
129
+ for f_pred, f_targ, w in zip(in_feat, styles, self.wass_wgts)
130
+ ]
131
+
132
+ self.metrics = dict(zip(self.metric_names, self.feat_losses))
133
+ return sum(self.feat_losses)
134
+
135
+ def __del__(self):
136
+ self.hooks.remove()
deoldify/save.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.basic_train import Learner, LearnerCallback
2
+ from fastai.vision.gan import GANLearner
3
+
4
+
5
+ class GANSaveCallback(LearnerCallback):
6
+ """A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."""
7
+
8
+ def __init__(
9
+ self,
10
+ learn: GANLearner,
11
+ learn_gen: Learner,
12
+ filename: str,
13
+ save_iters: int = 1000,
14
+ ):
15
+ super().__init__(learn)
16
+ self.learn_gen = learn_gen
17
+ self.filename = filename
18
+ self.save_iters = save_iters
19
+
20
+ def on_batch_end(self, iteration: int, epoch: int, **kwargs) -> None:
21
+ if iteration == 0:
22
+ return
23
+
24
+ if iteration % self.save_iters == 0:
25
+ self._save_gen_learner(iteration=iteration, epoch=epoch)
26
+
27
+ def _save_gen_learner(self, iteration: int, epoch: int):
28
+ filename = '{}_{}_{}'.format(self.filename, epoch, iteration)
29
+ self.learn_gen.save(filename)
deoldify/unet.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.layers import *
2
+ from .layers import *
3
+ from fastai.torch_core import *
4
+ from fastai.callbacks.hooks import *
5
+ from fastai.vision import *
6
+
7
+
8
+ # The code below is meant to be merged into fastaiv1 ideally
9
+
10
+ __all__ = ['DynamicUnetDeep', 'DynamicUnetWide']
11
+
12
+
13
+ def _get_sfs_idxs(sizes: Sizes) -> List[int]:
14
+ "Get the indexes of the layers where the size of the activation changes."
15
+ feature_szs = [size[-1] for size in sizes]
16
+ sfs_idxs = list(
17
+ np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0]
18
+ )
19
+ if feature_szs[0] != feature_szs[1]:
20
+ sfs_idxs = [0] + sfs_idxs
21
+ return sfs_idxs
22
+
23
+
24
+ class CustomPixelShuffle_ICNR(nn.Module):
25
+ "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
26
+
27
+ def __init__(
28
+ self,
29
+ ni: int,
30
+ nf: int = None,
31
+ scale: int = 2,
32
+ blur: bool = False,
33
+ leaky: float = None,
34
+ **kwargs
35
+ ):
36
+ super().__init__()
37
+ nf = ifnone(nf, ni)
38
+ self.conv = custom_conv_layer(
39
+ ni, nf * (scale ** 2), ks=1, use_activ=False, **kwargs
40
+ )
41
+ icnr(self.conv[0].weight)
42
+ self.shuf = nn.PixelShuffle(scale)
43
+ # Blurring over (h*w) kernel
44
+ # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
45
+ # - https://arxiv.org/abs/1806.02658
46
+ self.pad = nn.ReplicationPad2d((1, 0, 1, 0))
47
+ self.blur = nn.AvgPool2d(2, stride=1)
48
+ self.relu = relu(True, leaky=leaky)
49
+
50
+ def forward(self, x):
51
+ x = self.shuf(self.relu(self.conv(x)))
52
+ return self.blur(self.pad(x)) if self.blur else x
53
+
54
+
55
+ class UnetBlockDeep(nn.Module):
56
+ "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
57
+
58
+ def __init__(
59
+ self,
60
+ up_in_c: int,
61
+ x_in_c: int,
62
+ hook: Hook,
63
+ final_div: bool = True,
64
+ blur: bool = False,
65
+ leaky: float = None,
66
+ self_attention: bool = False,
67
+ nf_factor: float = 1.0,
68
+ **kwargs
69
+ ):
70
+ super().__init__()
71
+ self.hook = hook
72
+ self.shuf = CustomPixelShuffle_ICNR(
73
+ up_in_c, up_in_c // 2, blur=blur, leaky=leaky, **kwargs
74
+ )
75
+ self.bn = batchnorm_2d(x_in_c)
76
+ ni = up_in_c // 2 + x_in_c
77
+ nf = int((ni if final_div else ni // 2) * nf_factor)
78
+ self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs)
79
+ self.conv2 = custom_conv_layer(
80
+ nf, nf, leaky=leaky, self_attention=self_attention, **kwargs
81
+ )
82
+ self.relu = relu(leaky=leaky)
83
+
84
+ def forward(self, up_in: Tensor) -> Tensor:
85
+ s = self.hook.stored
86
+ up_out = self.shuf(up_in)
87
+ ssh = s.shape[-2:]
88
+ if ssh != up_out.shape[-2:]:
89
+ up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
90
+ cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
91
+ return self.conv2(self.conv1(cat_x))
92
+
93
+
94
+ class DynamicUnetDeep(SequentialEx):
95
+ "Create a U-Net from a given architecture."
96
+
97
+ def __init__(
98
+ self,
99
+ encoder: nn.Module,
100
+ n_classes: int,
101
+ blur: bool = False,
102
+ blur_final=True,
103
+ self_attention: bool = False,
104
+ y_range: Optional[Tuple[float, float]] = None,
105
+ last_cross: bool = True,
106
+ bottle: bool = False,
107
+ norm_type: Optional[NormType] = NormType.Batch,
108
+ nf_factor: float = 1.0,
109
+ **kwargs
110
+ ):
111
+ extra_bn = norm_type == NormType.Spectral
112
+ imsize = (256, 256)
113
+ sfs_szs = model_sizes(encoder, size=imsize)
114
+ sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
115
+ self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
116
+ x = dummy_eval(encoder, imsize).detach()
117
+
118
+ ni = sfs_szs[-1][1]
119
+ middle_conv = nn.Sequential(
120
+ custom_conv_layer(
121
+ ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
122
+ ),
123
+ custom_conv_layer(
124
+ ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
125
+ ),
126
+ ).eval()
127
+ x = middle_conv(x)
128
+ layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
129
+
130
+ for i, idx in enumerate(sfs_idxs):
131
+ not_final = i != len(sfs_idxs) - 1
132
+ up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
133
+ do_blur = blur and (not_final or blur_final)
134
+ sa = self_attention and (i == len(sfs_idxs) - 3)
135
+ unet_block = UnetBlockDeep(
136
+ up_in_c,
137
+ x_in_c,
138
+ self.sfs[i],
139
+ final_div=not_final,
140
+ blur=blur,
141
+ self_attention=sa,
142
+ norm_type=norm_type,
143
+ extra_bn=extra_bn,
144
+ nf_factor=nf_factor,
145
+ **kwargs
146
+ ).eval()
147
+ layers.append(unet_block)
148
+ x = unet_block(x)
149
+
150
+ ni = x.shape[1]
151
+ if imsize != sfs_szs[0][-2:]:
152
+ layers.append(PixelShuffle_ICNR(ni, **kwargs))
153
+ if last_cross:
154
+ layers.append(MergeLayer(dense=True))
155
+ ni += in_channels(encoder)
156
+ layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
157
+ layers += [
158
+ custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
159
+ ]
160
+ if y_range is not None:
161
+ layers.append(SigmoidRange(*y_range))
162
+ super().__init__(*layers)
163
+
164
+ def __del__(self):
165
+ if hasattr(self, "sfs"):
166
+ self.sfs.remove()
167
+
168
+
169
+ # ------------------------------------------------------
170
+ class UnetBlockWide(nn.Module):
171
+ "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
172
+
173
+ def __init__(
174
+ self,
175
+ up_in_c: int,
176
+ x_in_c: int,
177
+ n_out: int,
178
+ hook: Hook,
179
+ final_div: bool = True,
180
+ blur: bool = False,
181
+ leaky: float = None,
182
+ self_attention: bool = False,
183
+ **kwargs
184
+ ):
185
+ super().__init__()
186
+ self.hook = hook
187
+ up_out = x_out = n_out // 2
188
+ self.shuf = CustomPixelShuffle_ICNR(
189
+ up_in_c, up_out, blur=blur, leaky=leaky, **kwargs
190
+ )
191
+ self.bn = batchnorm_2d(x_in_c)
192
+ ni = up_out + x_in_c
193
+ self.conv = custom_conv_layer(
194
+ ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs
195
+ )
196
+ self.relu = relu(leaky=leaky)
197
+
198
+ def forward(self, up_in: Tensor) -> Tensor:
199
+ s = self.hook.stored
200
+ up_out = self.shuf(up_in)
201
+ ssh = s.shape[-2:]
202
+ if ssh != up_out.shape[-2:]:
203
+ up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
204
+ cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
205
+ return self.conv(cat_x)
206
+
207
+
208
+ class DynamicUnetWide(SequentialEx):
209
+ "Create a U-Net from a given architecture."
210
+
211
+ def __init__(
212
+ self,
213
+ encoder: nn.Module,
214
+ n_classes: int,
215
+ blur: bool = False,
216
+ blur_final=True,
217
+ self_attention: bool = False,
218
+ y_range: Optional[Tuple[float, float]] = None,
219
+ last_cross: bool = True,
220
+ bottle: bool = False,
221
+ norm_type: Optional[NormType] = NormType.Batch,
222
+ nf_factor: int = 1,
223
+ **kwargs
224
+ ):
225
+
226
+ nf = 512 * nf_factor
227
+ extra_bn = norm_type == NormType.Spectral
228
+ imsize = (256, 256)
229
+ sfs_szs = model_sizes(encoder, size=imsize)
230
+ sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
231
+ self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
232
+ x = dummy_eval(encoder, imsize).detach()
233
+
234
+ ni = sfs_szs[-1][1]
235
+ middle_conv = nn.Sequential(
236
+ custom_conv_layer(
237
+ ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
238
+ ),
239
+ custom_conv_layer(
240
+ ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
241
+ ),
242
+ ).eval()
243
+ x = middle_conv(x)
244
+ layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
245
+
246
+ for i, idx in enumerate(sfs_idxs):
247
+ not_final = i != len(sfs_idxs) - 1
248
+ up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
249
+ do_blur = blur and (not_final or blur_final)
250
+ sa = self_attention and (i == len(sfs_idxs) - 3)
251
+
252
+ n_out = nf if not_final else nf // 2
253
+
254
+ unet_block = UnetBlockWide(
255
+ up_in_c,
256
+ x_in_c,
257
+ n_out,
258
+ self.sfs[i],
259
+ final_div=not_final,
260
+ blur=blur,
261
+ self_attention=sa,
262
+ norm_type=norm_type,
263
+ extra_bn=extra_bn,
264
+ **kwargs
265
+ ).eval()
266
+ layers.append(unet_block)
267
+ x = unet_block(x)
268
+
269
+ ni = x.shape[1]
270
+ if imsize != sfs_szs[0][-2:]:
271
+ layers.append(PixelShuffle_ICNR(ni, **kwargs))
272
+ if last_cross:
273
+ layers.append(MergeLayer(dense=True))
274
+ ni += in_channels(encoder)
275
+ layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
276
+ layers += [
277
+ custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
278
+ ]
279
+ if y_range is not None:
280
+ layers.append(SigmoidRange(*y_range))
281
+ super().__init__(*layers)
282
+
283
+ def __del__(self):
284
+ if hasattr(self, "sfs"):
285
+ self.sfs.remove()
deoldify/visualize.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.core import *
2
+ from fastai.vision import *
3
+ from matplotlib.axes import Axes
4
+ from matplotlib.figure import Figure
5
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
6
+ from .filters import IFilter, MasterFilter, ColorizerFilter
7
+ from .generators import gen_inference_deep, gen_inference_wide
8
+ # from tensorboardX import SummaryWriter
9
+ from scipy import misc
10
+ from PIL import Image
11
+ # import ffmpeg
12
+ # import youtube_dl
13
+ import gc
14
+ import requests
15
+ from io import BytesIO
16
+ import base64
17
+ # from IPython import display as ipythondisplay
18
+ # from IPython.display import HTML
19
+ # from IPython.display import Image as ipythonimage
20
+ import cv2
21
+
22
+
23
+ # # adapted from https://www.pyimagesearch.com/2016/04/25/watermarking-images-with-opencv-and-python/
24
+ # def get_watermarked(pil_image: Image) -> Image:
25
+ # try:
26
+ # image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
27
+ # (h, w) = image.shape[:2]
28
+ # image = np.dstack([image, np.ones((h, w), dtype="uint8") * 255])
29
+ # pct = 0.05
30
+ # full_watermark = cv2.imread(
31
+ # './resource_images/watermark.png', cv2.IMREAD_UNCHANGED
32
+ # )
33
+ # (fwH, fwW) = full_watermark.shape[:2]
34
+ # wH = int(pct * h)
35
+ # wW = int((pct * h / fwH) * fwW)
36
+ # watermark = cv2.resize(full_watermark, (wH, wW), interpolation=cv2.INTER_AREA)
37
+ # overlay = np.zeros((h, w, 4), dtype="uint8")
38
+ # (wH, wW) = watermark.shape[:2]
39
+ # overlay[h - wH - 10 : h - 10, 10 : 10 + wW] = watermark
40
+ # # blend the two images together using transparent overlays
41
+ # output = image.copy()
42
+ # cv2.addWeighted(overlay, 0.5, output, 1.0, 0, output)
43
+ # rgb_image = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
44
+ # final_image = Image.fromarray(rgb_image)
45
+ # return final_image
46
+ # except:
47
+ # # Don't want this to crash everything, so let's just not watermark the image for now.
48
+ # return pil_image
49
+
50
+
51
+ class ModelImageVisualizer:
52
+ def __init__(self, filter: IFilter, results_dir: str = None):
53
+ self.filter = filter
54
+ self.results_dir = None if results_dir is None else Path(results_dir)
55
+ self.results_dir.mkdir(parents=True, exist_ok=True)
56
+
57
+ def _clean_mem(self):
58
+ torch.cuda.empty_cache()
59
+ # gc.collect()
60
+
61
+ def _open_pil_image(self, path: Path) -> Image:
62
+ return PIL.Image.open(path).convert('RGB')
63
+
64
+ def _get_image_from_url(self, url: str) -> Image:
65
+ response = requests.get(url, timeout=30, headers={'Accept': '*/*;q=0.8'})
66
+ img = PIL.Image.open(BytesIO(response.content)).convert('RGB')
67
+ return img
68
+
69
+ def plot_transformed_image_from_url(
70
+ self,
71
+ url: str,
72
+ path: str = 'test_images/image.png',
73
+ results_dir:Path = None,
74
+ figsize: (int, int) = (20, 20),
75
+ render_factor: int = None,
76
+
77
+ display_render_factor: bool = False,
78
+ compare: bool = False,
79
+ post_process: bool = True,
80
+ watermarked: bool = True,
81
+ ) -> Path:
82
+ img = self._get_image_from_url(url)
83
+ img.save(path)
84
+ return self.plot_transformed_image(
85
+ path=path,
86
+ results_dir=results_dir,
87
+ figsize=figsize,
88
+ render_factor=render_factor,
89
+ display_render_factor=display_render_factor,
90
+ compare=compare,
91
+ post_process = post_process,
92
+ watermarked=watermarked,
93
+ )
94
+
95
+ def plot_transformed_image(
96
+ self,
97
+ path: str,
98
+ results_dir:Path = None,
99
+ figsize: (int, int) = (20, 20),
100
+ render_factor: int = None,
101
+ display_render_factor: bool = False,
102
+ compare: bool = False,
103
+ post_process: bool = True,
104
+ watermarked: bool = True,
105
+ ) -> Path:
106
+ path = Path(path)
107
+ if results_dir is None:
108
+ results_dir = Path(self.results_dir)
109
+ result = self.get_transformed_image(
110
+ path, render_factor, post_process=post_process,watermarked=watermarked
111
+ )
112
+ orig = self._open_pil_image(path)
113
+ if compare:
114
+ self._plot_comparison(
115
+ figsize, render_factor, display_render_factor, orig, result
116
+ )
117
+ else:
118
+ self._plot_solo(figsize, render_factor, display_render_factor, result)
119
+
120
+ orig.close()
121
+ result_path = self._save_result_image(path, result, results_dir=results_dir)
122
+ result.close()
123
+ return result_path
124
+
125
+ def plot_transformed_pil_image(
126
+ self,
127
+ input_image: Image,
128
+ figsize: (int, int) = (20, 20),
129
+ render_factor: int = None,
130
+ display_render_factor: bool = False,
131
+ compare: bool = False,
132
+ post_process: bool = True,
133
+ ) -> Image:
134
+
135
+ result = self.get_transformed_pil_image(
136
+ input_image, render_factor, post_process=post_process
137
+ )
138
+
139
+ if compare:
140
+ self._plot_comparison(
141
+ figsize, render_factor, display_render_factor, input_image, result
142
+ )
143
+ else:
144
+ self._plot_solo(figsize, render_factor, display_render_factor, result)
145
+
146
+ return result
147
+
148
+ def _plot_comparison(
149
+ self,
150
+ figsize: (int, int),
151
+ render_factor: int,
152
+ display_render_factor: bool,
153
+ orig: Image,
154
+ result: Image,
155
+ ):
156
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
157
+ self._plot_image(
158
+ orig,
159
+ axes=axes[0],
160
+ figsize=figsize,
161
+ render_factor=render_factor,
162
+ display_render_factor=False,
163
+ )
164
+ self._plot_image(
165
+ result,
166
+ axes=axes[1],
167
+ figsize=figsize,
168
+ render_factor=render_factor,
169
+ display_render_factor=display_render_factor,
170
+ )
171
+
172
+ def _plot_solo(
173
+ self,
174
+ figsize: (int, int),
175
+ render_factor: int,
176
+ display_render_factor: bool,
177
+ result: Image,
178
+ ):
179
+ fig, axes = plt.subplots(1, 1, figsize=figsize)
180
+ self._plot_image(
181
+ result,
182
+ axes=axes,
183
+ figsize=figsize,
184
+ render_factor=render_factor,
185
+ display_render_factor=display_render_factor,
186
+ )
187
+
188
+ def _save_result_image(self, source_path: Path, image: Image, results_dir = None) -> Path:
189
+ if results_dir is None:
190
+ results_dir = Path(self.results_dir)
191
+ result_path = results_dir / source_path.name
192
+ image.save(result_path)
193
+ return result_path
194
+
195
+ def get_transformed_image(
196
+ self, path: Path, render_factor: int = None, post_process: bool = True,
197
+ watermarked: bool = True,
198
+ ) -> Image:
199
+ self._clean_mem()
200
+ orig_image = self._open_pil_image(path)
201
+ filtered_image = self.filter.filter(
202
+ orig_image, orig_image, render_factor=render_factor,post_process=post_process
203
+ )
204
+
205
+ # if watermarked:
206
+ # return get_watermarked(filtered_image)
207
+
208
+ return filtered_image
209
+
210
+ def get_transformed_pil_image(
211
+ self, input_image: Image, render_factor: int = None, post_process: bool = True,
212
+ ) -> Image:
213
+ self._clean_mem()
214
+ filtered_image = self.filter.filter(
215
+ input_image, input_image, render_factor=render_factor,post_process=post_process
216
+ )
217
+
218
+ return filtered_image
219
+
220
+ def _plot_image(
221
+ self,
222
+ image: Image,
223
+ render_factor: int,
224
+ axes: Axes = None,
225
+ figsize=(20, 20),
226
+ display_render_factor = False,
227
+ ):
228
+ if axes is None:
229
+ _, axes = plt.subplots(figsize=figsize)
230
+ axes.imshow(np.asarray(image) / 255)
231
+ axes.axis('off')
232
+ if render_factor is not None and display_render_factor:
233
+ plt.text(
234
+ 10,
235
+ 10,
236
+ 'render_factor: ' + str(render_factor),
237
+ color='white',
238
+ backgroundcolor='black',
239
+ )
240
+
241
+ def _get_num_rows_columns(self, num_images: int, max_columns: int) -> (int, int):
242
+ columns = min(num_images, max_columns)
243
+ rows = num_images // columns
244
+ rows = rows if rows * columns == num_images else rows + 1
245
+ return rows, columns
246
+
247
+
248
+ # class VideoColorizer:
249
+ # def __init__(self, vis: ModelImageVisualizer):
250
+ # self.vis = vis
251
+ # workfolder = Path('./video')
252
+ # self.source_folder = workfolder / "source"
253
+ # self.bwframes_root = workfolder / "bwframes"
254
+ # self.audio_root = workfolder / "audio"
255
+ # self.colorframes_root = workfolder / "colorframes"
256
+ # self.result_folder = workfolder / "result"
257
+
258
+ # def _purge_images(self, dir):
259
+ # for f in os.listdir(dir):
260
+ # if re.search('.*?\.jpg', f):
261
+ # os.remove(os.path.join(dir, f))
262
+
263
+ # def _get_fps(self, source_path: Path) -> str:
264
+ # probe = ffmpeg.probe(str(source_path))
265
+ # stream_data = next(
266
+ # (stream for stream in probe['streams'] if stream['codec_type'] == 'video'),
267
+ # None,
268
+ # )
269
+ # return stream_data['avg_frame_rate']
270
+
271
+ # def _download_video_from_url(self, source_url, source_path: Path):
272
+ # if source_path.exists():
273
+ # source_path.unlink()
274
+
275
+ # ydl_opts = {
276
+ # 'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/mp4',
277
+ # 'outtmpl': str(source_path),
278
+ # 'retries': 30,
279
+ # 'fragment-retries': 30
280
+ # }
281
+ # with youtube_dl.YoutubeDL(ydl_opts) as ydl:
282
+ # ydl.download([source_url])
283
+
284
+ # def _extract_raw_frames(self, source_path: Path):
285
+ # bwframes_folder = self.bwframes_root / (source_path.stem)
286
+ # bwframe_path_template = str(bwframes_folder / '%5d.jpg')
287
+ # bwframes_folder.mkdir(parents=True, exist_ok=True)
288
+ # self._purge_images(bwframes_folder)
289
+ # ffmpeg.input(str(source_path)).output(
290
+ # str(bwframe_path_template), format='image2', vcodec='mjpeg', qscale=0
291
+ # ).run(capture_stdout=True)
292
+
293
+ # def _colorize_raw_frames(
294
+ # self, source_path: Path, render_factor: int = None, post_process: bool = True,
295
+ # watermarked: bool = True,
296
+ # ):
297
+ # colorframes_folder = self.colorframes_root / (source_path.stem)
298
+ # colorframes_folder.mkdir(parents=True, exist_ok=True)
299
+ # self._purge_images(colorframes_folder)
300
+ # bwframes_folder = self.bwframes_root / (source_path.stem)
301
+
302
+ # for img in progress_bar(os.listdir(str(bwframes_folder))):
303
+ # img_path = bwframes_folder / img
304
+
305
+ # if os.path.isfile(str(img_path)):
306
+ # color_image = self.vis.get_transformed_image(
307
+ # str(img_path), render_factor=render_factor, post_process=post_process,watermarked=watermarked
308
+ # )
309
+ # color_image.save(str(colorframes_folder / img))
310
+
311
+ # def _build_video(self, source_path: Path) -> Path:
312
+ # colorized_path = self.result_folder / (
313
+ # source_path.name.replace('.mp4', '_no_audio.mp4')
314
+ # )
315
+ # colorframes_folder = self.colorframes_root / (source_path.stem)
316
+ # colorframes_path_template = str(colorframes_folder / '%5d.jpg')
317
+ # colorized_path.parent.mkdir(parents=True, exist_ok=True)
318
+ # if colorized_path.exists():
319
+ # colorized_path.unlink()
320
+ # fps = self._get_fps(source_path)
321
+
322
+ # ffmpeg.input(
323
+ # str(colorframes_path_template),
324
+ # format='image2',
325
+ # vcodec='mjpeg',
326
+ # framerate=fps,
327
+ # ).output(str(colorized_path), crf=17, vcodec='libx264').run(capture_stdout=True)
328
+
329
+ # result_path = self.result_folder / source_path.name
330
+ # if result_path.exists():
331
+ # result_path.unlink()
332
+ # # making copy of non-audio version in case adding back audio doesn't apply or fails.
333
+ # shutil.copyfile(str(colorized_path), str(result_path))
334
+
335
+ # # adding back sound here
336
+ # audio_file = Path(str(source_path).replace('.mp4', '.aac'))
337
+ # if audio_file.exists():
338
+ # audio_file.unlink()
339
+
340
+ # os.system(
341
+ # 'ffmpeg -y -i "'
342
+ # + str(source_path)
343
+ # + '" -vn -acodec copy "'
344
+ # + str(audio_file)
345
+ # + '"'
346
+ # )
347
+
348
+ # if audio_file.exists:
349
+ # os.system(
350
+ # 'ffmpeg -y -i "'
351
+ # + str(colorized_path)
352
+ # + '" -i "'
353
+ # + str(audio_file)
354
+ # + '" -shortest -c:v copy -c:a aac -b:a 256k "'
355
+ # + str(result_path)
356
+ # + '"'
357
+ # )
358
+ # print('Video created here: ' + str(result_path))
359
+ # return result_path
360
+
361
+ # def colorize_from_url(
362
+ # self,
363
+ # source_url,
364
+ # file_name: str,
365
+ # render_factor: int = None,
366
+ # post_process: bool = True,
367
+ # watermarked: bool = True,
368
+
369
+ # ) -> Path:
370
+ # source_path = self.source_folder / file_name
371
+ # self._download_video_from_url(source_url, source_path)
372
+ # return self._colorize_from_path(
373
+ # source_path, render_factor=render_factor, post_process=post_process,watermarked=watermarked
374
+ # )
375
+
376
+ # def colorize_from_file_name(
377
+ # self, file_name: str, render_factor: int = None, watermarked: bool = True, post_process: bool = True,
378
+ # ) -> Path:
379
+ # source_path = self.source_folder / file_name
380
+ # return self._colorize_from_path(
381
+ # source_path, render_factor=render_factor, post_process=post_process,watermarked=watermarked
382
+ # )
383
+
384
+ # def _colorize_from_path(
385
+ # self, source_path: Path, render_factor: int = None, watermarked: bool = True, post_process: bool = True
386
+ # ) -> Path:
387
+ # if not source_path.exists():
388
+ # raise Exception(
389
+ # 'Video at path specfied, ' + str(source_path) + ' could not be found.'
390
+ # )
391
+ # self._extract_raw_frames(source_path)
392
+ # self._colorize_raw_frames(
393
+ # source_path, render_factor=render_factor,post_process=post_process,watermarked=watermarked
394
+ # )
395
+ # return self._build_video(source_path)
396
+
397
+
398
+ # def get_video_colorizer(render_factor: int = 21) -> VideoColorizer:
399
+ # return get_stable_video_colorizer(render_factor=render_factor)
400
+
401
+
402
+ # def get_artistic_video_colorizer(
403
+ # root_folder: Path = Path('./'),
404
+ # weights_name: str = 'ColorizeArtistic_gen',
405
+ # results_dir='result_images',
406
+ # render_factor: int = 35
407
+ # ) -> VideoColorizer:
408
+ # learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
409
+ # filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
410
+ # vis = ModelImageVisualizer(filtr, results_dir=results_dir)
411
+ # return VideoColorizer(vis)
412
+
413
+
414
+ # def get_stable_video_colorizer(
415
+ # root_folder: Path = Path('./'),
416
+ # weights_name: str = 'ColorizeVideo_gen',
417
+ # results_dir='result_images',
418
+ # render_factor: int = 21
419
+ # ) -> VideoColorizer:
420
+ # learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
421
+ # filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
422
+ # vis = ModelImageVisualizer(filtr, results_dir=results_dir)
423
+ # return VideoColorizer(vis)
424
+
425
+
426
+ def get_image_colorizer(
427
+ root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True
428
+ ) -> ModelImageVisualizer:
429
+ if artistic:
430
+ return get_artistic_image_colorizer(root_folder=root_folder, render_factor=render_factor)
431
+ else:
432
+ return get_stable_image_colorizer(root_folder=root_folder, render_factor=render_factor)
433
+
434
+
435
+ def get_stable_image_colorizer(
436
+ root_folder: Path = Path('./'),
437
+ weights_name: str = 'ColorizeStable_gen',
438
+ results_dir='result_images',
439
+ render_factor: int = 35
440
+ ) -> ModelImageVisualizer:
441
+ learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
442
+ filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
443
+ vis = ModelImageVisualizer(filtr, results_dir=results_dir)
444
+ return vis
445
+
446
+
447
+ def get_artistic_image_colorizer(
448
+ root_folder: Path = Path('./'),
449
+ weights_name: str = 'ColorizeArtistic_gen',
450
+ results_dir='result_images',
451
+ render_factor: int = 35
452
+ ) -> ModelImageVisualizer:
453
+ learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
454
+ filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
455
+ vis = ModelImageVisualizer(filtr, results_dir=results_dir)
456
+ return vis
packages.txt ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scipy==1.7.1
2
+ scikit_image==0.18.3
3
+ streamlit==0.88.0
4
+ requests==2.26.0
5
+ torch==1.9.0
6
+ torchvision==0.10.0
7
+ matplotlib==3.4.3
8
+ numpy==1.21.2
9
+ opencv_python==4.5.3.56
10
+ fastai==1.0.51
11
+ Pillow==8.3.2