Spaces:
Runtime error
Runtime error
Daniel Verdu
commited on
Commit
•
0cb9530
0
Parent(s):
first commit2
Browse files- .gitattributes +27 -0
- README.md +37 -0
- __pycache__/app_utils.cpython-38.pyc +0 -0
- app.py +171 -0
- app_utils.py +134 -0
- deoldify/__init__.py +3 -0
- deoldify/__pycache__/__init__.cpython-38.pyc +0 -0
- deoldify/__pycache__/_device.cpython-38.pyc +0 -0
- deoldify/__pycache__/augs.cpython-38.pyc +0 -0
- deoldify/__pycache__/critics.cpython-38.pyc +0 -0
- deoldify/__pycache__/dataset.cpython-38.pyc +0 -0
- deoldify/__pycache__/device_id.cpython-38.pyc +0 -0
- deoldify/__pycache__/filters.cpython-38.pyc +0 -0
- deoldify/__pycache__/generators.cpython-38.pyc +0 -0
- deoldify/__pycache__/layers.cpython-38.pyc +0 -0
- deoldify/__pycache__/loss.cpython-38.pyc +0 -0
- deoldify/__pycache__/unet.cpython-38.pyc +0 -0
- deoldify/__pycache__/visualize.cpython-38.pyc +0 -0
- deoldify/_device.py +30 -0
- deoldify/augs.py +29 -0
- deoldify/critics.py +44 -0
- deoldify/dataset.py +48 -0
- deoldify/device_id.py +12 -0
- deoldify/filters.py +120 -0
- deoldify/generators.py +167 -0
- deoldify/layers.py +48 -0
- deoldify/loss.py +136 -0
- deoldify/save.py +29 -0
- deoldify/unet.py +285 -0
- deoldify/visualize.py +246 -0
- models/ColorizeArtistic_gen.pth +3 -0
- packages.txt +0 -0
- requirements.txt +12 -0
.gitattributes
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
20 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Colorizing_images
|
3 |
+
emoji: 📽
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: red
|
6 |
+
sdk: streamlit
|
7 |
+
app_file: app.py
|
8 |
+
pinned: false
|
9 |
+
---
|
10 |
+
|
11 |
+
# Configuration
|
12 |
+
|
13 |
+
`title`: _string_
|
14 |
+
Display title for the Space
|
15 |
+
|
16 |
+
`emoji`: _string_
|
17 |
+
Space emoji (emoji-only character allowed)
|
18 |
+
|
19 |
+
`colorFrom`: _string_
|
20 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
21 |
+
|
22 |
+
`colorTo`: _string_
|
23 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
24 |
+
|
25 |
+
`sdk`: _string_
|
26 |
+
Can be either `gradio` or `streamlit`
|
27 |
+
|
28 |
+
`sdk_version` : _string_
|
29 |
+
Only applicable for `streamlit` SDK.
|
30 |
+
See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
|
31 |
+
|
32 |
+
`app_file`: _string_
|
33 |
+
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
34 |
+
Path is relative to the root of the repository.
|
35 |
+
|
36 |
+
`pinned`: _boolean_
|
37 |
+
Whether the Space stays on top of your list.
|
__pycache__/app_utils.cpython-38.pyc
ADDED
Binary file (3.46 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#importing the libraries
|
2 |
+
import os, sys, re
|
3 |
+
import streamlit as st
|
4 |
+
import PIL
|
5 |
+
from PIL import Image
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import uuid
|
9 |
+
|
10 |
+
import ssl
|
11 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
12 |
+
|
13 |
+
# Import torch libraries
|
14 |
+
import fastai
|
15 |
+
import torch
|
16 |
+
|
17 |
+
# Import util functions from app_utils
|
18 |
+
from app_utils import download
|
19 |
+
from app_utils import generate_random_filename
|
20 |
+
from app_utils import clean_me
|
21 |
+
from app_utils import clean_all
|
22 |
+
from app_utils import get_model_bin
|
23 |
+
from app_utils import convertToJPG
|
24 |
+
|
25 |
+
# Import util functions from deoldify
|
26 |
+
# NOTE: This must be the first call in order to work properly!
|
27 |
+
from deoldify import device
|
28 |
+
from deoldify.device_id import DeviceId
|
29 |
+
#choices: CPU, GPU0...GPU7
|
30 |
+
device.set(device=DeviceId.CPU)
|
31 |
+
from deoldify.visualize import *
|
32 |
+
|
33 |
+
|
34 |
+
####### INPUT PARAMS ###########
|
35 |
+
model_folder = 'models/'
|
36 |
+
max_img_size = 800
|
37 |
+
################################
|
38 |
+
|
39 |
+
@st.cache(allow_output_mutation=True)
|
40 |
+
def load_model(model_dir, option):
|
41 |
+
if option.lower() == 'artistic':
|
42 |
+
model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
|
43 |
+
get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth"))
|
44 |
+
colorizer = get_image_colorizer(artistic=True)
|
45 |
+
elif option.lower() == 'stable':
|
46 |
+
model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0"
|
47 |
+
get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth"))
|
48 |
+
colorizer = get_image_colorizer(artistic=False)
|
49 |
+
|
50 |
+
return colorizer
|
51 |
+
|
52 |
+
def resize_img(input_img, max_size):
|
53 |
+
img = input_img.copy()
|
54 |
+
img_height, img_width = img.shape[0],img.shape[1]
|
55 |
+
|
56 |
+
if max(img_height, img_width) > max_size:
|
57 |
+
if img_height > img_width:
|
58 |
+
new_width = img_width*(max_size/img_height)
|
59 |
+
new_height = max_size
|
60 |
+
resized_img = cv2.resize(img,(int(new_width), int(new_height)))
|
61 |
+
return resized_img
|
62 |
+
|
63 |
+
elif img_height <= img_width:
|
64 |
+
new_width = img_height*(max_size/img_width)
|
65 |
+
new_height = max_size
|
66 |
+
resized_img = cv2.resize(img,(int(new_width), int(new_height)))
|
67 |
+
return resized_img
|
68 |
+
|
69 |
+
return img
|
70 |
+
|
71 |
+
def get_image_download_link(img,filename,text):
|
72 |
+
button_uuid = str(uuid.uuid4()).replace('-', '')
|
73 |
+
button_id = re.sub('\d+', '', button_uuid)
|
74 |
+
|
75 |
+
custom_css = f"""
|
76 |
+
<style>
|
77 |
+
#{button_id} {{
|
78 |
+
background-color: rgb(255, 255, 255);
|
79 |
+
color: rgb(38, 39, 48);
|
80 |
+
padding: 0.25em 0.38em;
|
81 |
+
position: relative;
|
82 |
+
text-decoration: none;
|
83 |
+
border-radius: 4px;
|
84 |
+
border-width: 1px;
|
85 |
+
border-style: solid;
|
86 |
+
border-color: rgb(230, 234, 241);
|
87 |
+
border-image: initial;
|
88 |
+
|
89 |
+
}}
|
90 |
+
#{button_id}:hover {{
|
91 |
+
border-color: rgb(246, 51, 102);
|
92 |
+
color: rgb(246, 51, 102);
|
93 |
+
}}
|
94 |
+
#{button_id}:active {{
|
95 |
+
box-shadow: none;
|
96 |
+
background-color: rgb(246, 51, 102);
|
97 |
+
color: white;
|
98 |
+
}}
|
99 |
+
</style> """
|
100 |
+
|
101 |
+
buffered = BytesIO()
|
102 |
+
img.save(buffered, format="JPEG")
|
103 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
104 |
+
href = custom_css + f'<a href="data:file/txt;base64,{img_str}" id="{button_id}" download="{filename}">{text}</a>'
|
105 |
+
return href
|
106 |
+
|
107 |
+
|
108 |
+
# General configuration
|
109 |
+
# st.set_page_config(layout="centered")
|
110 |
+
st.set_page_config(layout="wide")
|
111 |
+
st.set_option('deprecation.showfileUploaderEncoding', False)
|
112 |
+
st.markdown('''
|
113 |
+
<style>
|
114 |
+
.uploadedFile {display: none}
|
115 |
+
<style>''',
|
116 |
+
unsafe_allow_html=True)
|
117 |
+
|
118 |
+
# Main window configuration
|
119 |
+
st.title("Black and white colorizer")
|
120 |
+
st.markdown("This app puts color into your black and white pictures")
|
121 |
+
title_message = st.empty()
|
122 |
+
|
123 |
+
title_message.markdown("**Model loading, please wait** ⌛")
|
124 |
+
|
125 |
+
# # Sidebar
|
126 |
+
color_option = st.sidebar.selectbox('Select colorizer mode',
|
127 |
+
('Artistic', 'Stable'))
|
128 |
+
|
129 |
+
# st.sidebar.title('Model parameters')
|
130 |
+
# det_conf_thres = st.sidebar.slider("Detector confidence threshold", 0.1, 0.9, value=0.5, step=0.1)
|
131 |
+
# det_nms_thres = st.sidebar.slider("Non-maximum supression IoU", 0.1, 0.9, value=0.4, step=0.1)
|
132 |
+
|
133 |
+
# Load models
|
134 |
+
try:
|
135 |
+
colorizer = load_model(model_folder, color_option)
|
136 |
+
except Exception as e:
|
137 |
+
print(e)
|
138 |
+
colorizer = None
|
139 |
+
print('Error while loading the model. Please refresh the page')
|
140 |
+
|
141 |
+
if colorizer is not None:
|
142 |
+
print('Running colorizer')
|
143 |
+
title_message.markdown("**To begin, please upload an image** 👇")
|
144 |
+
|
145 |
+
#Choose your own image
|
146 |
+
uploaded_file = st.file_uploader("Upload a black and white photo", type=['png', 'jpg', 'jpeg'])
|
147 |
+
|
148 |
+
# show = st.image(use_column_width='auto')
|
149 |
+
input_img_pos = st.empty()
|
150 |
+
output_img_pos = st.empty()
|
151 |
+
|
152 |
+
if uploaded_file is not None:
|
153 |
+
img_name = uploaded_file.name
|
154 |
+
|
155 |
+
pil_img = PIL.Image.open(uploaded_file)
|
156 |
+
img_rgb = np.array(pil_img)
|
157 |
+
|
158 |
+
resized_img_rgb = resize_img(img_rgb, max_img_size)
|
159 |
+
resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
|
160 |
+
|
161 |
+
title_message.markdown("**Processing your image, please wait** ⌛")
|
162 |
+
|
163 |
+
output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
|
164 |
+
|
165 |
+
title_message.markdown("**To begin, please upload an image** 👇")
|
166 |
+
|
167 |
+
# Plot images
|
168 |
+
input_img_pos.image(resized_pil_img, 'Input image', use_column_width=True)
|
169 |
+
output_img_pos.image(output_pil_img, 'Output image', use_column_width=True)
|
170 |
+
|
171 |
+
st.markdown(get_image_download_link(output_pil_img, img_name, 'Download '+img_name), unsafe_allow_html=True)
|
app_utils.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
import random
|
4 |
+
import _thread as thread
|
5 |
+
from uuid import uuid4
|
6 |
+
import urllib
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import skimage
|
10 |
+
from skimage.filters import gaussian
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
def compress_image(image, path_original):
|
14 |
+
size = 1920, 1080
|
15 |
+
width = 1920
|
16 |
+
height = 1080
|
17 |
+
|
18 |
+
name = os.path.basename(path_original).split('.')
|
19 |
+
first_name = os.path.join(os.path.dirname(path_original), name[0] + '.jpg')
|
20 |
+
|
21 |
+
if image.size[0] > width and image.size[1] > height:
|
22 |
+
image.thumbnail(size, Image.ANTIALIAS)
|
23 |
+
image.save(first_name, quality=85)
|
24 |
+
elif image.size[0] > width:
|
25 |
+
wpercent = (width/float(image.size[0]))
|
26 |
+
height = int((float(image.size[1])*float(wpercent)))
|
27 |
+
image = image.resize((width,height), Image.ANTIALIAS)
|
28 |
+
image.save(first_name,quality=85)
|
29 |
+
elif image.size[1] > height:
|
30 |
+
wpercent = (height/float(image.size[1]))
|
31 |
+
width = int((float(image.size[0])*float(wpercent)))
|
32 |
+
image = image.resize((width,height), Image.ANTIALIAS)
|
33 |
+
image.save(first_name, quality=85)
|
34 |
+
else:
|
35 |
+
image.save(first_name, quality=85)
|
36 |
+
|
37 |
+
|
38 |
+
def convertToJPG(path_original):
|
39 |
+
img = Image.open(path_original)
|
40 |
+
name = os.path.basename(path_original).split('.')
|
41 |
+
first_name = os.path.join(os.path.dirname(path_original), name[0] + '.jpg')
|
42 |
+
|
43 |
+
if img.format == "JPEG":
|
44 |
+
image = img.convert('RGB')
|
45 |
+
compress_image(image, path_original)
|
46 |
+
img.close()
|
47 |
+
|
48 |
+
elif img.format == "GIF":
|
49 |
+
i = img.convert("RGBA")
|
50 |
+
bg = Image.new("RGBA", i.size)
|
51 |
+
image = Image.composite(i, bg, i)
|
52 |
+
compress_image(image, path_original)
|
53 |
+
img.close()
|
54 |
+
|
55 |
+
elif img.format == "PNG":
|
56 |
+
try:
|
57 |
+
image = Image.new("RGB", img.size, (255,255,255))
|
58 |
+
image.paste(img,img)
|
59 |
+
compress_image(image, path_original)
|
60 |
+
except ValueError:
|
61 |
+
image = img.convert('RGB')
|
62 |
+
compress_image(image, path_original)
|
63 |
+
|
64 |
+
img.close()
|
65 |
+
|
66 |
+
elif img.format == "BMP":
|
67 |
+
image = img.convert('RGB')
|
68 |
+
compress_image(image, path_original)
|
69 |
+
img.close()
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
def blur(image, x0, x1, y0, y1, sigma=1, multichannel=True):
|
74 |
+
y0, y1 = min(y0, y1), max(y0, y1)
|
75 |
+
x0, x1 = min(x0, x1), max(x0, x1)
|
76 |
+
im = image.copy()
|
77 |
+
sub_im = im[y0:y1,x0:x1].copy()
|
78 |
+
blur_sub_im = gaussian(sub_im, sigma=sigma, multichannel=multichannel)
|
79 |
+
blur_sub_im = np.round(255 * blur_sub_im)
|
80 |
+
im[y0:y1,x0:x1] = blur_sub_im
|
81 |
+
return im
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
def download(url, filename):
|
86 |
+
data = requests.get(url).content
|
87 |
+
with open(filename, 'wb') as handler:
|
88 |
+
handler.write(data)
|
89 |
+
|
90 |
+
return filename
|
91 |
+
|
92 |
+
|
93 |
+
def generate_random_filename(upload_directory, extension):
|
94 |
+
filename = str(uuid4())
|
95 |
+
filename = os.path.join(upload_directory, filename + "." + extension)
|
96 |
+
return filename
|
97 |
+
|
98 |
+
|
99 |
+
def clean_me(filename):
|
100 |
+
if os.path.exists(filename):
|
101 |
+
os.remove(filename)
|
102 |
+
|
103 |
+
|
104 |
+
def clean_all(files):
|
105 |
+
for me in files:
|
106 |
+
clean_me(me)
|
107 |
+
|
108 |
+
|
109 |
+
def get_model_bin(url, output_path):
|
110 |
+
# print('Getting model dir: ', output_path)
|
111 |
+
if not os.path.exists(output_path):
|
112 |
+
print('Downloading model')
|
113 |
+
|
114 |
+
# Check if folder exists
|
115 |
+
output_folder = output_path.replace('\\','/').split('/')[0]
|
116 |
+
if not os.path.exists(output_folder):
|
117 |
+
os.makedirs(output_folder, exist_ok=True)
|
118 |
+
|
119 |
+
urllib.request.urlretrieve(url, output_path)
|
120 |
+
|
121 |
+
# cmd = "wget -O %s %s" % (output_path, url)
|
122 |
+
# print(cmd)
|
123 |
+
# os.system(cmd)
|
124 |
+
else:
|
125 |
+
print('Model exists')
|
126 |
+
|
127 |
+
return output_path
|
128 |
+
|
129 |
+
|
130 |
+
#model_list = [(url, output_path), (url, output_path)]
|
131 |
+
def get_multi_model_bin(model_list):
|
132 |
+
for m in model_list:
|
133 |
+
thread.start_new_thread(get_model_bin, m)
|
134 |
+
|
deoldify/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from deoldify._device import _Device
|
2 |
+
|
3 |
+
device = _Device()
|
deoldify/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (272 Bytes). View file
|
|
deoldify/__pycache__/_device.cpython-38.pyc
ADDED
Binary file (1.4 kB). View file
|
|
deoldify/__pycache__/augs.cpython-38.pyc
ADDED
Binary file (937 Bytes). View file
|
|
deoldify/__pycache__/critics.cpython-38.pyc
ADDED
Binary file (1.61 kB). View file
|
|
deoldify/__pycache__/dataset.cpython-38.pyc
ADDED
Binary file (1.65 kB). View file
|
|
deoldify/__pycache__/device_id.cpython-38.pyc
ADDED
Binary file (568 Bytes). View file
|
|
deoldify/__pycache__/filters.cpython-38.pyc
ADDED
Binary file (4.99 kB). View file
|
|
deoldify/__pycache__/generators.cpython-38.pyc
ADDED
Binary file (3.64 kB). View file
|
|
deoldify/__pycache__/layers.cpython-38.pyc
ADDED
Binary file (1.53 kB). View file
|
|
deoldify/__pycache__/loss.cpython-38.pyc
ADDED
Binary file (6.52 kB). View file
|
|
deoldify/__pycache__/unet.cpython-38.pyc
ADDED
Binary file (8.14 kB). View file
|
|
deoldify/__pycache__/visualize.cpython-38.pyc
ADDED
Binary file (6.83 kB). View file
|
|
deoldify/_device.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from enum import Enum
|
3 |
+
from .device_id import DeviceId
|
4 |
+
|
5 |
+
#NOTE: This must be called first before any torch imports in order to work properly!
|
6 |
+
|
7 |
+
class DeviceException(Exception):
|
8 |
+
pass
|
9 |
+
|
10 |
+
class _Device:
|
11 |
+
def __init__(self):
|
12 |
+
self.set(DeviceId.CPU)
|
13 |
+
|
14 |
+
def is_gpu(self):
|
15 |
+
''' Returns `True` if the current device is GPU, `False` otherwise. '''
|
16 |
+
return self.current() is not DeviceId.CPU
|
17 |
+
|
18 |
+
def current(self):
|
19 |
+
return self._current_device
|
20 |
+
|
21 |
+
def set(self, device:DeviceId):
|
22 |
+
if device == DeviceId.CPU:
|
23 |
+
os.environ['CUDA_VISIBLE_DEVICES']=''
|
24 |
+
else:
|
25 |
+
os.environ['CUDA_VISIBLE_DEVICES']=str(device.value)
|
26 |
+
import torch
|
27 |
+
torch.backends.cudnn.benchmark=False
|
28 |
+
|
29 |
+
self._current_device = device
|
30 |
+
return device
|
deoldify/augs.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
from fastai.vision.image import TfmPixel
|
4 |
+
|
5 |
+
# Contributed by Rani Horev. Thank you!
|
6 |
+
def _noisify(
|
7 |
+
x, pct_pixels_min: float = 0.001, pct_pixels_max: float = 0.4, noise_range: int = 30
|
8 |
+
):
|
9 |
+
if noise_range > 255 or noise_range < 0:
|
10 |
+
raise Exception("noise_range must be between 0 and 255, inclusively.")
|
11 |
+
|
12 |
+
h, w = x.shape[1:]
|
13 |
+
img_size = h * w
|
14 |
+
mult = 10000.0
|
15 |
+
pct_pixels = (
|
16 |
+
random.randrange(int(pct_pixels_min * mult), int(pct_pixels_max * mult)) / mult
|
17 |
+
)
|
18 |
+
noise_count = int(img_size * pct_pixels)
|
19 |
+
|
20 |
+
for ii in range(noise_count):
|
21 |
+
yy = random.randrange(h)
|
22 |
+
xx = random.randrange(w)
|
23 |
+
noise = random.randrange(-noise_range, noise_range) / 255.0
|
24 |
+
x[:, yy, xx].add_(noise)
|
25 |
+
|
26 |
+
return x
|
27 |
+
|
28 |
+
|
29 |
+
noisify = TfmPixel(_noisify)
|
deoldify/critics.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.core import *
|
2 |
+
from fastai.torch_core import *
|
3 |
+
from fastai.vision import *
|
4 |
+
from fastai.vision.gan import AdaptiveLoss, accuracy_thresh_expand
|
5 |
+
|
6 |
+
_conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)
|
7 |
+
|
8 |
+
|
9 |
+
def _conv(ni: int, nf: int, ks: int = 3, stride: int = 1, **kwargs):
|
10 |
+
return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
|
11 |
+
|
12 |
+
|
13 |
+
def custom_gan_critic(
|
14 |
+
n_channels: int = 3, nf: int = 256, n_blocks: int = 3, p: int = 0.15
|
15 |
+
):
|
16 |
+
"Critic to train a `GAN`."
|
17 |
+
layers = [_conv(n_channels, nf, ks=4, stride=2), nn.Dropout2d(p / 2)]
|
18 |
+
for i in range(n_blocks):
|
19 |
+
layers += [
|
20 |
+
_conv(nf, nf, ks=3, stride=1),
|
21 |
+
nn.Dropout2d(p),
|
22 |
+
_conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)),
|
23 |
+
]
|
24 |
+
nf *= 2
|
25 |
+
layers += [
|
26 |
+
_conv(nf, nf, ks=3, stride=1),
|
27 |
+
_conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
|
28 |
+
Flatten(),
|
29 |
+
]
|
30 |
+
return nn.Sequential(*layers)
|
31 |
+
|
32 |
+
|
33 |
+
def colorize_crit_learner(
|
34 |
+
data: ImageDataBunch,
|
35 |
+
loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()),
|
36 |
+
nf: int = 256,
|
37 |
+
) -> Learner:
|
38 |
+
return Learner(
|
39 |
+
data,
|
40 |
+
custom_gan_critic(nf=nf),
|
41 |
+
metrics=accuracy_thresh_expand,
|
42 |
+
loss_func=loss_critic,
|
43 |
+
wd=1e-3,
|
44 |
+
)
|
deoldify/dataset.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fastai
|
2 |
+
from fastai import *
|
3 |
+
from fastai.core import *
|
4 |
+
from fastai.vision.transform import get_transforms
|
5 |
+
from fastai.vision.data import ImageImageList, ImageDataBunch, imagenet_stats
|
6 |
+
from .augs import noisify
|
7 |
+
|
8 |
+
|
9 |
+
def get_colorize_data(
|
10 |
+
sz: int,
|
11 |
+
bs: int,
|
12 |
+
crappy_path: Path,
|
13 |
+
good_path: Path,
|
14 |
+
random_seed: int = None,
|
15 |
+
keep_pct: float = 1.0,
|
16 |
+
num_workers: int = 8,
|
17 |
+
stats: tuple = imagenet_stats,
|
18 |
+
xtra_tfms=[],
|
19 |
+
) -> ImageDataBunch:
|
20 |
+
|
21 |
+
src = (
|
22 |
+
ImageImageList.from_folder(crappy_path, convert_mode='RGB')
|
23 |
+
.use_partial_data(sample_pct=keep_pct, seed=random_seed)
|
24 |
+
.split_by_rand_pct(0.1, seed=random_seed)
|
25 |
+
)
|
26 |
+
|
27 |
+
data = (
|
28 |
+
src.label_from_func(lambda x: good_path / x.relative_to(crappy_path))
|
29 |
+
.transform(
|
30 |
+
get_transforms(
|
31 |
+
max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms
|
32 |
+
),
|
33 |
+
size=sz,
|
34 |
+
tfm_y=True,
|
35 |
+
)
|
36 |
+
.databunch(bs=bs, num_workers=num_workers, no_check=True)
|
37 |
+
.normalize(stats, do_y=True)
|
38 |
+
)
|
39 |
+
|
40 |
+
data.c = 3
|
41 |
+
return data
|
42 |
+
|
43 |
+
|
44 |
+
def get_dummy_databunch() -> ImageDataBunch:
|
45 |
+
path = Path('./dummy/')
|
46 |
+
return get_colorize_data(
|
47 |
+
sz=1, bs=1, crappy_path=path, good_path=path, keep_pct=0.001
|
48 |
+
)
|
deoldify/device_id.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import IntEnum
|
2 |
+
|
3 |
+
class DeviceId(IntEnum):
|
4 |
+
GPU0 = 0,
|
5 |
+
GPU1 = 1,
|
6 |
+
GPU2 = 2,
|
7 |
+
GPU3 = 3,
|
8 |
+
GPU4 = 4,
|
9 |
+
GPU5 = 5,
|
10 |
+
GPU6 = 6,
|
11 |
+
GPU7 = 7,
|
12 |
+
CPU = 99
|
deoldify/filters.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from numpy import ndarray
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from .critics import colorize_crit_learner
|
4 |
+
from fastai.core import *
|
5 |
+
from fastai.vision import *
|
6 |
+
from fastai.vision.image import *
|
7 |
+
from fastai.vision.data import *
|
8 |
+
from fastai import *
|
9 |
+
import math
|
10 |
+
from scipy import misc
|
11 |
+
import cv2
|
12 |
+
from PIL import Image as PilImage
|
13 |
+
|
14 |
+
|
15 |
+
class IFilter(ABC):
|
16 |
+
@abstractmethod
|
17 |
+
def filter(
|
18 |
+
self, orig_image: PilImage, filtered_image: PilImage, render_factor: int
|
19 |
+
) -> PilImage:
|
20 |
+
pass
|
21 |
+
|
22 |
+
|
23 |
+
class BaseFilter(IFilter):
|
24 |
+
def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
|
25 |
+
super().__init__()
|
26 |
+
self.learn = learn
|
27 |
+
self.device = next(self.learn.model.parameters()).device
|
28 |
+
self.norm, self.denorm = normalize_funcs(*stats)
|
29 |
+
|
30 |
+
def _transform(self, image: PilImage) -> PilImage:
|
31 |
+
return image
|
32 |
+
|
33 |
+
def _scale_to_square(self, orig: PilImage, targ: int) -> PilImage:
|
34 |
+
# a simple stretch to fit a square really makes a big difference in rendering quality/consistency.
|
35 |
+
# I've tried padding to the square as well (reflect, symetric, constant, etc). Not as good!
|
36 |
+
targ_sz = (targ, targ)
|
37 |
+
return orig.resize(targ_sz, resample=PIL.Image.BILINEAR)
|
38 |
+
|
39 |
+
def _get_model_ready_image(self, orig: PilImage, sz: int) -> PilImage:
|
40 |
+
result = self._scale_to_square(orig, sz)
|
41 |
+
result = self._transform(result)
|
42 |
+
return result
|
43 |
+
|
44 |
+
def _model_process(self, orig: PilImage, sz: int) -> PilImage:
|
45 |
+
model_image = self._get_model_ready_image(orig, sz)
|
46 |
+
x = pil2tensor(model_image, np.float32)
|
47 |
+
x = x.to(self.device)
|
48 |
+
x.div_(255)
|
49 |
+
x, y = self.norm((x, x), do_x=True)
|
50 |
+
|
51 |
+
try:
|
52 |
+
result = self.learn.pred_batch(
|
53 |
+
ds_type=DatasetType.Valid, batch=(x[None], y[None]), reconstruct=True
|
54 |
+
)
|
55 |
+
except RuntimeError as rerr:
|
56 |
+
if 'memory' not in str(rerr):
|
57 |
+
raise rerr
|
58 |
+
print('Warning: render_factor was set too high, and out of memory error resulted. Returning original image.')
|
59 |
+
return model_image
|
60 |
+
|
61 |
+
out = result[0]
|
62 |
+
out = self.denorm(out.px, do_x=False)
|
63 |
+
out = image2np(out * 255).astype(np.uint8)
|
64 |
+
return PilImage.fromarray(out)
|
65 |
+
|
66 |
+
def _unsquare(self, image: PilImage, orig: PilImage) -> PilImage:
|
67 |
+
targ_sz = orig.size
|
68 |
+
image = image.resize(targ_sz, resample=PIL.Image.BILINEAR)
|
69 |
+
return image
|
70 |
+
|
71 |
+
|
72 |
+
class ColorizerFilter(BaseFilter):
|
73 |
+
def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
|
74 |
+
super().__init__(learn=learn, stats=stats)
|
75 |
+
self.render_base = 16
|
76 |
+
|
77 |
+
def filter(
|
78 |
+
self, orig_image: PilImage, filtered_image: PilImage, render_factor: int, post_process: bool = True) -> PilImage:
|
79 |
+
render_sz = render_factor * self.render_base
|
80 |
+
model_image = self._model_process(orig=filtered_image, sz=render_sz)
|
81 |
+
raw_color = self._unsquare(model_image, orig_image)
|
82 |
+
|
83 |
+
if post_process:
|
84 |
+
return self._post_process(raw_color, orig_image)
|
85 |
+
else:
|
86 |
+
return raw_color
|
87 |
+
|
88 |
+
def _transform(self, image: PilImage) -> PilImage:
|
89 |
+
return image.convert('LA').convert('RGB')
|
90 |
+
|
91 |
+
# This takes advantage of the fact that human eyes are much less sensitive to
|
92 |
+
# imperfections in chrominance compared to luminance. This means we can
|
93 |
+
# save a lot on memory and processing in the model, yet get a great high
|
94 |
+
# resolution result at the end. This is primarily intended just for
|
95 |
+
# inference
|
96 |
+
def _post_process(self, raw_color: PilImage, orig: PilImage) -> PilImage:
|
97 |
+
color_np = np.asarray(raw_color)
|
98 |
+
orig_np = np.asarray(orig)
|
99 |
+
color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV)
|
100 |
+
# do a black and white transform first to get better luminance values
|
101 |
+
orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_BGR2YUV)
|
102 |
+
hires = np.copy(orig_yuv)
|
103 |
+
hires[:, :, 1:3] = color_yuv[:, :, 1:3]
|
104 |
+
final = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)
|
105 |
+
final = PilImage.fromarray(final)
|
106 |
+
return final
|
107 |
+
|
108 |
+
|
109 |
+
class MasterFilter(BaseFilter):
|
110 |
+
def __init__(self, filters: [IFilter], render_factor: int):
|
111 |
+
self.filters = filters
|
112 |
+
self.render_factor = render_factor
|
113 |
+
|
114 |
+
def filter(
|
115 |
+
self, orig_image: PilImage, filtered_image: PilImage, render_factor: int = None, post_process: bool = True) -> PilImage:
|
116 |
+
render_factor = self.render_factor if render_factor is None else render_factor
|
117 |
+
for filter in self.filters:
|
118 |
+
filtered_image = filter.filter(orig_image, filtered_image, render_factor, post_process)
|
119 |
+
|
120 |
+
return filtered_image
|
deoldify/generators.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.basics import F, nn
|
2 |
+
from fastai.basic_data import DataBunch
|
3 |
+
from fastai.basic_train import Learner
|
4 |
+
from fastai.layers import NormType
|
5 |
+
from fastai.torch_core import SplitFuncOrIdxList, to_device, apply_init
|
6 |
+
from fastai.vision import *
|
7 |
+
from fastai.vision.learner import cnn_config, create_body
|
8 |
+
from .unet import DynamicUnetWide, DynamicUnetDeep
|
9 |
+
from .loss import FeatureLoss
|
10 |
+
from .dataset import *
|
11 |
+
|
12 |
+
# Weights are implicitly read from ./models/ folder
|
13 |
+
def gen_inference_wide(
|
14 |
+
root_folder: Path, weights_name: str, nf_factor: int = 2,
|
15 |
+
arch=models.resnet101
|
16 |
+
) -> Learner:
|
17 |
+
|
18 |
+
data = get_dummy_databunch()
|
19 |
+
learn = gen_learner_wide(data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch)
|
20 |
+
learn = get_inference(learn, root_folder, weights_name)
|
21 |
+
return learn
|
22 |
+
|
23 |
+
def gen_inference_deep(root_folder: Path, weights_name: str,
|
24 |
+
arch=models.resnet34, nf_factor: float = 1.5
|
25 |
+
) -> Learner:
|
26 |
+
|
27 |
+
data = get_dummy_databunch()
|
28 |
+
learn = gen_learner_deep(data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor)
|
29 |
+
learn = get_inference(learn, root_folder, weights_name)
|
30 |
+
return learn
|
31 |
+
|
32 |
+
# Weights are implicitly read from ./models/ folder
|
33 |
+
# Load loads weights from os.path.join(learner.path, learner.model_dir, weights_name)
|
34 |
+
def get_inference(learn, root_folder, weights_name) -> Learner:
|
35 |
+
learn.path = root_folder
|
36 |
+
try:
|
37 |
+
learn.load(weights_name)
|
38 |
+
print('Model loaded successfully')
|
39 |
+
except Exception as e:
|
40 |
+
print(e)
|
41 |
+
print('Error while reading the model')
|
42 |
+
learn.model.eval()
|
43 |
+
|
44 |
+
return learn
|
45 |
+
|
46 |
+
|
47 |
+
def gen_learner_wide(
|
48 |
+
data: ImageDataBunch, gen_loss, arch=models.resnet101, nf_factor: int = 2
|
49 |
+
) -> Learner:
|
50 |
+
return unet_learner_wide(
|
51 |
+
data,
|
52 |
+
arch=arch,
|
53 |
+
wd=1e-3,
|
54 |
+
blur=True,
|
55 |
+
norm_type=NormType.Spectral,
|
56 |
+
self_attention=True,
|
57 |
+
y_range=(-3.0, 3.0),
|
58 |
+
loss_func=gen_loss,
|
59 |
+
nf_factor=nf_factor,
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
# The code below is meant to be merged into fastaiv1 ideally
|
64 |
+
def unet_learner_wide(
|
65 |
+
data: DataBunch,
|
66 |
+
arch: Callable,
|
67 |
+
pretrained: bool = True,
|
68 |
+
blur_final: bool = True,
|
69 |
+
norm_type: Optional[NormType] = NormType,
|
70 |
+
split_on: Optional[SplitFuncOrIdxList] = None,
|
71 |
+
blur: bool = False,
|
72 |
+
self_attention: bool = False,
|
73 |
+
y_range: Optional[Tuple[float, float]] = None,
|
74 |
+
last_cross: bool = True,
|
75 |
+
bottle: bool = False,
|
76 |
+
nf_factor: int = 1,
|
77 |
+
**kwargs: Any
|
78 |
+
) -> Learner:
|
79 |
+
"Build Unet learner from `data` and `arch`."
|
80 |
+
meta = cnn_config(arch)
|
81 |
+
body = create_body(arch, pretrained)
|
82 |
+
model = to_device(
|
83 |
+
DynamicUnetWide(
|
84 |
+
body,
|
85 |
+
n_classes=data.c,
|
86 |
+
blur=blur,
|
87 |
+
blur_final=blur_final,
|
88 |
+
self_attention=self_attention,
|
89 |
+
y_range=y_range,
|
90 |
+
norm_type=norm_type,
|
91 |
+
last_cross=last_cross,
|
92 |
+
bottle=bottle,
|
93 |
+
nf_factor=nf_factor,
|
94 |
+
),
|
95 |
+
data.device,
|
96 |
+
)
|
97 |
+
learn = Learner(data, model, **kwargs)
|
98 |
+
learn.split(ifnone(split_on, meta['split']))
|
99 |
+
if pretrained:
|
100 |
+
learn.freeze()
|
101 |
+
apply_init(model[2], nn.init.kaiming_normal_)
|
102 |
+
return learn
|
103 |
+
|
104 |
+
|
105 |
+
# ----------------------------------------------------------------------
|
106 |
+
|
107 |
+
def gen_learner_deep(data: ImageDataBunch, gen_loss, arch=models.resnet34,
|
108 |
+
nf_factor: float = 1.5
|
109 |
+
) -> Learner:
|
110 |
+
|
111 |
+
return unet_learner_deep(
|
112 |
+
data,
|
113 |
+
arch,
|
114 |
+
wd=1e-3,
|
115 |
+
blur=True,
|
116 |
+
norm_type=NormType.Spectral,
|
117 |
+
self_attention=True,
|
118 |
+
y_range=(-3.0, 3.0),
|
119 |
+
loss_func=gen_loss,
|
120 |
+
nf_factor=nf_factor,
|
121 |
+
)
|
122 |
+
|
123 |
+
|
124 |
+
# The code below is meant to be merged into fastaiv1 ideally
|
125 |
+
def unet_learner_deep(
|
126 |
+
data: DataBunch,
|
127 |
+
arch: Callable,
|
128 |
+
pretrained: bool = True,
|
129 |
+
blur_final: bool = True,
|
130 |
+
norm_type: Optional[NormType] = NormType,
|
131 |
+
split_on: Optional[SplitFuncOrIdxList] = None,
|
132 |
+
blur: bool = False,
|
133 |
+
self_attention: bool = False,
|
134 |
+
y_range: Optional[Tuple[float, float]] = None,
|
135 |
+
last_cross: bool = True,
|
136 |
+
bottle: bool = False,
|
137 |
+
nf_factor: float = 1.5,
|
138 |
+
**kwargs: Any
|
139 |
+
) -> Learner:
|
140 |
+
|
141 |
+
"Build Unet learner from `data` and `arch`."
|
142 |
+
meta = cnn_config(arch)
|
143 |
+
body = create_body(arch, pretrained)
|
144 |
+
model = to_device(
|
145 |
+
DynamicUnetDeep(
|
146 |
+
body,
|
147 |
+
n_classes=data.c,
|
148 |
+
blur=blur,
|
149 |
+
blur_final=blur_final,
|
150 |
+
self_attention=self_attention,
|
151 |
+
y_range=y_range,
|
152 |
+
norm_type=norm_type,
|
153 |
+
last_cross=last_cross,
|
154 |
+
bottle=bottle,
|
155 |
+
nf_factor=nf_factor,
|
156 |
+
),
|
157 |
+
data.device,
|
158 |
+
)
|
159 |
+
learn = Learner(data, model, **kwargs)
|
160 |
+
learn.split(ifnone(split_on, meta['split']))
|
161 |
+
if pretrained:
|
162 |
+
learn.freeze()
|
163 |
+
apply_init(model[2], nn.init.kaiming_normal_)
|
164 |
+
return learn
|
165 |
+
|
166 |
+
|
167 |
+
# -----------------------------
|
deoldify/layers.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.layers import *
|
2 |
+
from fastai.torch_core import *
|
3 |
+
from torch.nn.parameter import Parameter
|
4 |
+
from torch.autograd import Variable
|
5 |
+
|
6 |
+
|
7 |
+
# The code below is meant to be merged into fastaiv1 ideally
|
8 |
+
|
9 |
+
|
10 |
+
def custom_conv_layer(
|
11 |
+
ni: int,
|
12 |
+
nf: int,
|
13 |
+
ks: int = 3,
|
14 |
+
stride: int = 1,
|
15 |
+
padding: int = None,
|
16 |
+
bias: bool = None,
|
17 |
+
is_1d: bool = False,
|
18 |
+
norm_type: Optional[NormType] = NormType.Batch,
|
19 |
+
use_activ: bool = True,
|
20 |
+
leaky: float = None,
|
21 |
+
transpose: bool = False,
|
22 |
+
init: Callable = nn.init.kaiming_normal_,
|
23 |
+
self_attention: bool = False,
|
24 |
+
extra_bn: bool = False,
|
25 |
+
):
|
26 |
+
"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
|
27 |
+
if padding is None:
|
28 |
+
padding = (ks - 1) // 2 if not transpose else 0
|
29 |
+
bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn == True
|
30 |
+
if bias is None:
|
31 |
+
bias = not bn
|
32 |
+
conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
|
33 |
+
conv = init_default(
|
34 |
+
conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding),
|
35 |
+
init,
|
36 |
+
)
|
37 |
+
if norm_type == NormType.Weight:
|
38 |
+
conv = weight_norm(conv)
|
39 |
+
elif norm_type == NormType.Spectral:
|
40 |
+
conv = spectral_norm(conv)
|
41 |
+
layers = [conv]
|
42 |
+
if use_activ:
|
43 |
+
layers.append(relu(True, leaky=leaky))
|
44 |
+
if bn:
|
45 |
+
layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
|
46 |
+
if self_attention:
|
47 |
+
layers.append(SelfAttention(nf))
|
48 |
+
return nn.Sequential(*layers)
|
deoldify/loss.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai import *
|
2 |
+
from fastai.core import *
|
3 |
+
from fastai.torch_core import *
|
4 |
+
from fastai.callbacks import hook_outputs
|
5 |
+
import torchvision.models as models
|
6 |
+
|
7 |
+
|
8 |
+
class FeatureLoss(nn.Module):
|
9 |
+
def __init__(self, layer_wgts=[20, 70, 10]):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
self.m_feat = models.vgg16_bn(True).features.cuda().eval()
|
13 |
+
requires_grad(self.m_feat, False)
|
14 |
+
blocks = [
|
15 |
+
i - 1
|
16 |
+
for i, o in enumerate(children(self.m_feat))
|
17 |
+
if isinstance(o, nn.MaxPool2d)
|
18 |
+
]
|
19 |
+
layer_ids = blocks[2:5]
|
20 |
+
self.loss_features = [self.m_feat[i] for i in layer_ids]
|
21 |
+
self.hooks = hook_outputs(self.loss_features, detach=False)
|
22 |
+
self.wgts = layer_wgts
|
23 |
+
self.metric_names = ['pixel'] + [f'feat_{i}' for i in range(len(layer_ids))]
|
24 |
+
self.base_loss = F.l1_loss
|
25 |
+
|
26 |
+
def _make_features(self, x, clone=False):
|
27 |
+
self.m_feat(x)
|
28 |
+
return [(o.clone() if clone else o) for o in self.hooks.stored]
|
29 |
+
|
30 |
+
def forward(self, input, target):
|
31 |
+
out_feat = self._make_features(target, clone=True)
|
32 |
+
in_feat = self._make_features(input)
|
33 |
+
self.feat_losses = [self.base_loss(input, target)]
|
34 |
+
self.feat_losses += [
|
35 |
+
self.base_loss(f_in, f_out) * w
|
36 |
+
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
|
37 |
+
]
|
38 |
+
|
39 |
+
self.metrics = dict(zip(self.metric_names, self.feat_losses))
|
40 |
+
return sum(self.feat_losses)
|
41 |
+
|
42 |
+
def __del__(self):
|
43 |
+
self.hooks.remove()
|
44 |
+
|
45 |
+
|
46 |
+
# Refactored code, originally from https://github.com/VinceMarron/style_transfer
|
47 |
+
class WassFeatureLoss(nn.Module):
|
48 |
+
def __init__(self, layer_wgts=[5, 15, 2], wass_wgts=[3.0, 0.7, 0.01]):
|
49 |
+
super().__init__()
|
50 |
+
self.m_feat = models.vgg16_bn(True).features.cuda().eval()
|
51 |
+
requires_grad(self.m_feat, False)
|
52 |
+
blocks = [
|
53 |
+
i - 1
|
54 |
+
for i, o in enumerate(children(self.m_feat))
|
55 |
+
if isinstance(o, nn.MaxPool2d)
|
56 |
+
]
|
57 |
+
layer_ids = blocks[2:5]
|
58 |
+
self.loss_features = [self.m_feat[i] for i in layer_ids]
|
59 |
+
self.hooks = hook_outputs(self.loss_features, detach=False)
|
60 |
+
self.wgts = layer_wgts
|
61 |
+
self.wass_wgts = wass_wgts
|
62 |
+
self.metric_names = (
|
63 |
+
['pixel']
|
64 |
+
+ [f'feat_{i}' for i in range(len(layer_ids))]
|
65 |
+
+ [f'wass_{i}' for i in range(len(layer_ids))]
|
66 |
+
)
|
67 |
+
self.base_loss = F.l1_loss
|
68 |
+
|
69 |
+
def _make_features(self, x, clone=False):
|
70 |
+
self.m_feat(x)
|
71 |
+
return [(o.clone() if clone else o) for o in self.hooks.stored]
|
72 |
+
|
73 |
+
def _calc_2_moments(self, tensor):
|
74 |
+
chans = tensor.shape[1]
|
75 |
+
tensor = tensor.view(1, chans, -1)
|
76 |
+
n = tensor.shape[2]
|
77 |
+
mu = tensor.mean(2)
|
78 |
+
tensor = (tensor - mu[:, :, None]).squeeze(0)
|
79 |
+
# Prevents nasty bug that happens very occassionally- divide by zero. Why such things happen?
|
80 |
+
if n == 0:
|
81 |
+
return None, None
|
82 |
+
cov = torch.mm(tensor, tensor.t()) / float(n)
|
83 |
+
return mu, cov
|
84 |
+
|
85 |
+
def _get_style_vals(self, tensor):
|
86 |
+
mean, cov = self._calc_2_moments(tensor)
|
87 |
+
if mean is None:
|
88 |
+
return None, None, None
|
89 |
+
eigvals, eigvects = torch.symeig(cov, eigenvectors=True)
|
90 |
+
eigroot_mat = torch.diag(torch.sqrt(eigvals.clamp(min=0)))
|
91 |
+
root_cov = torch.mm(torch.mm(eigvects, eigroot_mat), eigvects.t())
|
92 |
+
tr_cov = eigvals.clamp(min=0).sum()
|
93 |
+
return mean, tr_cov, root_cov
|
94 |
+
|
95 |
+
def _calc_l2wass_dist(
|
96 |
+
self, mean_stl, tr_cov_stl, root_cov_stl, mean_synth, cov_synth
|
97 |
+
):
|
98 |
+
tr_cov_synth = torch.symeig(cov_synth, eigenvectors=True)[0].clamp(min=0).sum()
|
99 |
+
mean_diff_squared = (mean_stl - mean_synth).pow(2).sum()
|
100 |
+
cov_prod = torch.mm(torch.mm(root_cov_stl, cov_synth), root_cov_stl)
|
101 |
+
var_overlap = torch.sqrt(
|
102 |
+
torch.symeig(cov_prod, eigenvectors=True)[0].clamp(min=0) + 1e-8
|
103 |
+
).sum()
|
104 |
+
dist = mean_diff_squared + tr_cov_stl + tr_cov_synth - 2 * var_overlap
|
105 |
+
return dist
|
106 |
+
|
107 |
+
def _single_wass_loss(self, pred, targ):
|
108 |
+
mean_test, tr_cov_test, root_cov_test = targ
|
109 |
+
mean_synth, cov_synth = self._calc_2_moments(pred)
|
110 |
+
loss = self._calc_l2wass_dist(
|
111 |
+
mean_test, tr_cov_test, root_cov_test, mean_synth, cov_synth
|
112 |
+
)
|
113 |
+
return loss
|
114 |
+
|
115 |
+
def forward(self, input, target):
|
116 |
+
out_feat = self._make_features(target, clone=True)
|
117 |
+
in_feat = self._make_features(input)
|
118 |
+
self.feat_losses = [self.base_loss(input, target)]
|
119 |
+
self.feat_losses += [
|
120 |
+
self.base_loss(f_in, f_out) * w
|
121 |
+
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
|
122 |
+
]
|
123 |
+
|
124 |
+
styles = [self._get_style_vals(i) for i in out_feat]
|
125 |
+
|
126 |
+
if styles[0][0] is not None:
|
127 |
+
self.feat_losses += [
|
128 |
+
self._single_wass_loss(f_pred, f_targ) * w
|
129 |
+
for f_pred, f_targ, w in zip(in_feat, styles, self.wass_wgts)
|
130 |
+
]
|
131 |
+
|
132 |
+
self.metrics = dict(zip(self.metric_names, self.feat_losses))
|
133 |
+
return sum(self.feat_losses)
|
134 |
+
|
135 |
+
def __del__(self):
|
136 |
+
self.hooks.remove()
|
deoldify/save.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.basic_train import Learner, LearnerCallback
|
2 |
+
from fastai.vision.gan import GANLearner
|
3 |
+
|
4 |
+
|
5 |
+
class GANSaveCallback(LearnerCallback):
|
6 |
+
"""A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."""
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
learn: GANLearner,
|
11 |
+
learn_gen: Learner,
|
12 |
+
filename: str,
|
13 |
+
save_iters: int = 1000,
|
14 |
+
):
|
15 |
+
super().__init__(learn)
|
16 |
+
self.learn_gen = learn_gen
|
17 |
+
self.filename = filename
|
18 |
+
self.save_iters = save_iters
|
19 |
+
|
20 |
+
def on_batch_end(self, iteration: int, epoch: int, **kwargs) -> None:
|
21 |
+
if iteration == 0:
|
22 |
+
return
|
23 |
+
|
24 |
+
if iteration % self.save_iters == 0:
|
25 |
+
self._save_gen_learner(iteration=iteration, epoch=epoch)
|
26 |
+
|
27 |
+
def _save_gen_learner(self, iteration: int, epoch: int):
|
28 |
+
filename = '{}_{}_{}'.format(self.filename, epoch, iteration)
|
29 |
+
self.learn_gen.save(filename)
|
deoldify/unet.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.layers import *
|
2 |
+
from .layers import *
|
3 |
+
from fastai.torch_core import *
|
4 |
+
from fastai.callbacks.hooks import *
|
5 |
+
from fastai.vision import *
|
6 |
+
|
7 |
+
|
8 |
+
# The code below is meant to be merged into fastaiv1 ideally
|
9 |
+
|
10 |
+
__all__ = ['DynamicUnetDeep', 'DynamicUnetWide']
|
11 |
+
|
12 |
+
|
13 |
+
def _get_sfs_idxs(sizes: Sizes) -> List[int]:
|
14 |
+
"Get the indexes of the layers where the size of the activation changes."
|
15 |
+
feature_szs = [size[-1] for size in sizes]
|
16 |
+
sfs_idxs = list(
|
17 |
+
np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0]
|
18 |
+
)
|
19 |
+
if feature_szs[0] != feature_szs[1]:
|
20 |
+
sfs_idxs = [0] + sfs_idxs
|
21 |
+
return sfs_idxs
|
22 |
+
|
23 |
+
|
24 |
+
class CustomPixelShuffle_ICNR(nn.Module):
|
25 |
+
"Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
ni: int,
|
30 |
+
nf: int = None,
|
31 |
+
scale: int = 2,
|
32 |
+
blur: bool = False,
|
33 |
+
leaky: float = None,
|
34 |
+
**kwargs
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
nf = ifnone(nf, ni)
|
38 |
+
self.conv = custom_conv_layer(
|
39 |
+
ni, nf * (scale ** 2), ks=1, use_activ=False, **kwargs
|
40 |
+
)
|
41 |
+
icnr(self.conv[0].weight)
|
42 |
+
self.shuf = nn.PixelShuffle(scale)
|
43 |
+
# Blurring over (h*w) kernel
|
44 |
+
# "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
|
45 |
+
# - https://arxiv.org/abs/1806.02658
|
46 |
+
self.pad = nn.ReplicationPad2d((1, 0, 1, 0))
|
47 |
+
self.blur = nn.AvgPool2d(2, stride=1)
|
48 |
+
self.relu = relu(True, leaky=leaky)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
x = self.shuf(self.relu(self.conv(x)))
|
52 |
+
return self.blur(self.pad(x)) if self.blur else x
|
53 |
+
|
54 |
+
|
55 |
+
class UnetBlockDeep(nn.Module):
|
56 |
+
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
|
57 |
+
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
up_in_c: int,
|
61 |
+
x_in_c: int,
|
62 |
+
hook: Hook,
|
63 |
+
final_div: bool = True,
|
64 |
+
blur: bool = False,
|
65 |
+
leaky: float = None,
|
66 |
+
self_attention: bool = False,
|
67 |
+
nf_factor: float = 1.0,
|
68 |
+
**kwargs
|
69 |
+
):
|
70 |
+
super().__init__()
|
71 |
+
self.hook = hook
|
72 |
+
self.shuf = CustomPixelShuffle_ICNR(
|
73 |
+
up_in_c, up_in_c // 2, blur=blur, leaky=leaky, **kwargs
|
74 |
+
)
|
75 |
+
self.bn = batchnorm_2d(x_in_c)
|
76 |
+
ni = up_in_c // 2 + x_in_c
|
77 |
+
nf = int((ni if final_div else ni // 2) * nf_factor)
|
78 |
+
self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs)
|
79 |
+
self.conv2 = custom_conv_layer(
|
80 |
+
nf, nf, leaky=leaky, self_attention=self_attention, **kwargs
|
81 |
+
)
|
82 |
+
self.relu = relu(leaky=leaky)
|
83 |
+
|
84 |
+
def forward(self, up_in: Tensor) -> Tensor:
|
85 |
+
s = self.hook.stored
|
86 |
+
up_out = self.shuf(up_in)
|
87 |
+
ssh = s.shape[-2:]
|
88 |
+
if ssh != up_out.shape[-2:]:
|
89 |
+
up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
|
90 |
+
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
|
91 |
+
return self.conv2(self.conv1(cat_x))
|
92 |
+
|
93 |
+
|
94 |
+
class DynamicUnetDeep(SequentialEx):
|
95 |
+
"Create a U-Net from a given architecture."
|
96 |
+
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
encoder: nn.Module,
|
100 |
+
n_classes: int,
|
101 |
+
blur: bool = False,
|
102 |
+
blur_final=True,
|
103 |
+
self_attention: bool = False,
|
104 |
+
y_range: Optional[Tuple[float, float]] = None,
|
105 |
+
last_cross: bool = True,
|
106 |
+
bottle: bool = False,
|
107 |
+
norm_type: Optional[NormType] = NormType.Batch,
|
108 |
+
nf_factor: float = 1.0,
|
109 |
+
**kwargs
|
110 |
+
):
|
111 |
+
extra_bn = norm_type == NormType.Spectral
|
112 |
+
imsize = (256, 256)
|
113 |
+
sfs_szs = model_sizes(encoder, size=imsize)
|
114 |
+
sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
|
115 |
+
self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
|
116 |
+
x = dummy_eval(encoder, imsize).detach()
|
117 |
+
|
118 |
+
ni = sfs_szs[-1][1]
|
119 |
+
middle_conv = nn.Sequential(
|
120 |
+
custom_conv_layer(
|
121 |
+
ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
|
122 |
+
),
|
123 |
+
custom_conv_layer(
|
124 |
+
ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
|
125 |
+
),
|
126 |
+
).eval()
|
127 |
+
x = middle_conv(x)
|
128 |
+
layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
|
129 |
+
|
130 |
+
for i, idx in enumerate(sfs_idxs):
|
131 |
+
not_final = i != len(sfs_idxs) - 1
|
132 |
+
up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
|
133 |
+
do_blur = blur and (not_final or blur_final)
|
134 |
+
sa = self_attention and (i == len(sfs_idxs) - 3)
|
135 |
+
unet_block = UnetBlockDeep(
|
136 |
+
up_in_c,
|
137 |
+
x_in_c,
|
138 |
+
self.sfs[i],
|
139 |
+
final_div=not_final,
|
140 |
+
blur=blur,
|
141 |
+
self_attention=sa,
|
142 |
+
norm_type=norm_type,
|
143 |
+
extra_bn=extra_bn,
|
144 |
+
nf_factor=nf_factor,
|
145 |
+
**kwargs
|
146 |
+
).eval()
|
147 |
+
layers.append(unet_block)
|
148 |
+
x = unet_block(x)
|
149 |
+
|
150 |
+
ni = x.shape[1]
|
151 |
+
if imsize != sfs_szs[0][-2:]:
|
152 |
+
layers.append(PixelShuffle_ICNR(ni, **kwargs))
|
153 |
+
if last_cross:
|
154 |
+
layers.append(MergeLayer(dense=True))
|
155 |
+
ni += in_channels(encoder)
|
156 |
+
layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
|
157 |
+
layers += [
|
158 |
+
custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
|
159 |
+
]
|
160 |
+
if y_range is not None:
|
161 |
+
layers.append(SigmoidRange(*y_range))
|
162 |
+
super().__init__(*layers)
|
163 |
+
|
164 |
+
def __del__(self):
|
165 |
+
if hasattr(self, "sfs"):
|
166 |
+
self.sfs.remove()
|
167 |
+
|
168 |
+
|
169 |
+
# ------------------------------------------------------
|
170 |
+
class UnetBlockWide(nn.Module):
|
171 |
+
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
|
172 |
+
|
173 |
+
def __init__(
|
174 |
+
self,
|
175 |
+
up_in_c: int,
|
176 |
+
x_in_c: int,
|
177 |
+
n_out: int,
|
178 |
+
hook: Hook,
|
179 |
+
final_div: bool = True,
|
180 |
+
blur: bool = False,
|
181 |
+
leaky: float = None,
|
182 |
+
self_attention: bool = False,
|
183 |
+
**kwargs
|
184 |
+
):
|
185 |
+
super().__init__()
|
186 |
+
self.hook = hook
|
187 |
+
up_out = x_out = n_out // 2
|
188 |
+
self.shuf = CustomPixelShuffle_ICNR(
|
189 |
+
up_in_c, up_out, blur=blur, leaky=leaky, **kwargs
|
190 |
+
)
|
191 |
+
self.bn = batchnorm_2d(x_in_c)
|
192 |
+
ni = up_out + x_in_c
|
193 |
+
self.conv = custom_conv_layer(
|
194 |
+
ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs
|
195 |
+
)
|
196 |
+
self.relu = relu(leaky=leaky)
|
197 |
+
|
198 |
+
def forward(self, up_in: Tensor) -> Tensor:
|
199 |
+
s = self.hook.stored
|
200 |
+
up_out = self.shuf(up_in)
|
201 |
+
ssh = s.shape[-2:]
|
202 |
+
if ssh != up_out.shape[-2:]:
|
203 |
+
up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
|
204 |
+
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
|
205 |
+
return self.conv(cat_x)
|
206 |
+
|
207 |
+
|
208 |
+
class DynamicUnetWide(SequentialEx):
|
209 |
+
"Create a U-Net from a given architecture."
|
210 |
+
|
211 |
+
def __init__(
|
212 |
+
self,
|
213 |
+
encoder: nn.Module,
|
214 |
+
n_classes: int,
|
215 |
+
blur: bool = False,
|
216 |
+
blur_final=True,
|
217 |
+
self_attention: bool = False,
|
218 |
+
y_range: Optional[Tuple[float, float]] = None,
|
219 |
+
last_cross: bool = True,
|
220 |
+
bottle: bool = False,
|
221 |
+
norm_type: Optional[NormType] = NormType.Batch,
|
222 |
+
nf_factor: int = 1,
|
223 |
+
**kwargs
|
224 |
+
):
|
225 |
+
|
226 |
+
nf = 512 * nf_factor
|
227 |
+
extra_bn = norm_type == NormType.Spectral
|
228 |
+
imsize = (256, 256)
|
229 |
+
sfs_szs = model_sizes(encoder, size=imsize)
|
230 |
+
sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
|
231 |
+
self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
|
232 |
+
x = dummy_eval(encoder, imsize).detach()
|
233 |
+
|
234 |
+
ni = sfs_szs[-1][1]
|
235 |
+
middle_conv = nn.Sequential(
|
236 |
+
custom_conv_layer(
|
237 |
+
ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
|
238 |
+
),
|
239 |
+
custom_conv_layer(
|
240 |
+
ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
|
241 |
+
),
|
242 |
+
).eval()
|
243 |
+
x = middle_conv(x)
|
244 |
+
layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
|
245 |
+
|
246 |
+
for i, idx in enumerate(sfs_idxs):
|
247 |
+
not_final = i != len(sfs_idxs) - 1
|
248 |
+
up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
|
249 |
+
do_blur = blur and (not_final or blur_final)
|
250 |
+
sa = self_attention and (i == len(sfs_idxs) - 3)
|
251 |
+
|
252 |
+
n_out = nf if not_final else nf // 2
|
253 |
+
|
254 |
+
unet_block = UnetBlockWide(
|
255 |
+
up_in_c,
|
256 |
+
x_in_c,
|
257 |
+
n_out,
|
258 |
+
self.sfs[i],
|
259 |
+
final_div=not_final,
|
260 |
+
blur=blur,
|
261 |
+
self_attention=sa,
|
262 |
+
norm_type=norm_type,
|
263 |
+
extra_bn=extra_bn,
|
264 |
+
**kwargs
|
265 |
+
).eval()
|
266 |
+
layers.append(unet_block)
|
267 |
+
x = unet_block(x)
|
268 |
+
|
269 |
+
ni = x.shape[1]
|
270 |
+
if imsize != sfs_szs[0][-2:]:
|
271 |
+
layers.append(PixelShuffle_ICNR(ni, **kwargs))
|
272 |
+
if last_cross:
|
273 |
+
layers.append(MergeLayer(dense=True))
|
274 |
+
ni += in_channels(encoder)
|
275 |
+
layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
|
276 |
+
layers += [
|
277 |
+
custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
|
278 |
+
]
|
279 |
+
if y_range is not None:
|
280 |
+
layers.append(SigmoidRange(*y_range))
|
281 |
+
super().__init__(*layers)
|
282 |
+
|
283 |
+
def __del__(self):
|
284 |
+
if hasattr(self, "sfs"):
|
285 |
+
self.sfs.remove()
|
deoldify/visualize.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import gc
|
3 |
+
import requests
|
4 |
+
from io import BytesIO
|
5 |
+
import base64
|
6 |
+
from scipy import misc
|
7 |
+
from PIL import Image
|
8 |
+
from matplotlib.axes import Axes
|
9 |
+
from matplotlib.figure import Figure
|
10 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
11 |
+
from typing import Tuple
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from fastai.core import *
|
15 |
+
from fastai.vision import *
|
16 |
+
|
17 |
+
from .filters import IFilter, MasterFilter, ColorizerFilter
|
18 |
+
from .generators import gen_inference_deep, gen_inference_wide
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
class ModelImageVisualizer:
|
23 |
+
def __init__(self, filter: IFilter, results_dir: str = None):
|
24 |
+
self.filter = filter
|
25 |
+
self.results_dir = None if results_dir is None else Path(results_dir)
|
26 |
+
|
27 |
+
if self.results_dir is not None:
|
28 |
+
self.results_dir.mkdir(parents=True, exist_ok=True)
|
29 |
+
|
30 |
+
def _clean_mem(self):
|
31 |
+
torch.cuda.empty_cache()
|
32 |
+
# gc.collect()
|
33 |
+
|
34 |
+
def _open_pil_image(self, path: Path) -> Image:
|
35 |
+
return Image.open(path).convert('RGB')
|
36 |
+
|
37 |
+
def _get_image_from_url(self, url: str) -> Image:
|
38 |
+
response = requests.get(url, timeout=30, headers={'Accept': '*/*;q=0.8'})
|
39 |
+
img = Image.open(BytesIO(response.content)).convert('RGB')
|
40 |
+
return img
|
41 |
+
|
42 |
+
def plot_transformed_image_from_url(
|
43 |
+
self,
|
44 |
+
url: str,
|
45 |
+
path: str = 'test_images/image.png',
|
46 |
+
results_dir:Path = None,
|
47 |
+
figsize: Tuple[int, int] = (20, 20),
|
48 |
+
render_factor: int = None,
|
49 |
+
|
50 |
+
display_render_factor: bool = False,
|
51 |
+
compare: bool = False,
|
52 |
+
post_process: bool = True,
|
53 |
+
watermarked: bool = True,
|
54 |
+
) -> Path:
|
55 |
+
img = self._get_image_from_url(url)
|
56 |
+
img.save(path)
|
57 |
+
return self.plot_transformed_image(
|
58 |
+
path=path,
|
59 |
+
results_dir=results_dir,
|
60 |
+
figsize=figsize,
|
61 |
+
render_factor=render_factor,
|
62 |
+
display_render_factor=display_render_factor,
|
63 |
+
compare=compare,
|
64 |
+
post_process = post_process,
|
65 |
+
watermarked=watermarked,
|
66 |
+
)
|
67 |
+
|
68 |
+
def plot_transformed_image(
|
69 |
+
self,
|
70 |
+
path: str,
|
71 |
+
results_dir:Path = None,
|
72 |
+
figsize: Tuple[int, int] = (20, 20),
|
73 |
+
render_factor: int = None,
|
74 |
+
display_render_factor: bool = False,
|
75 |
+
compare: bool = False,
|
76 |
+
post_process: bool = True,
|
77 |
+
watermarked: bool = True,
|
78 |
+
) -> Path:
|
79 |
+
path = Path(path)
|
80 |
+
if results_dir is None:
|
81 |
+
results_dir = Path(self.results_dir)
|
82 |
+
result = self.get_transformed_image(
|
83 |
+
path, render_factor, post_process=post_process,watermarked=watermarked
|
84 |
+
)
|
85 |
+
orig = self._open_pil_image(path)
|
86 |
+
if compare:
|
87 |
+
self._plot_comparison(
|
88 |
+
figsize, render_factor, display_render_factor, orig, result
|
89 |
+
)
|
90 |
+
else:
|
91 |
+
self._plot_solo(figsize, render_factor, display_render_factor, result)
|
92 |
+
|
93 |
+
orig.close()
|
94 |
+
result_path = self._save_result_image(path, result, results_dir=results_dir)
|
95 |
+
result.close()
|
96 |
+
return result_path
|
97 |
+
|
98 |
+
def plot_transformed_pil_image(
|
99 |
+
self,
|
100 |
+
input_image: Image,
|
101 |
+
figsize: Tuple[int, int] = (20, 20),
|
102 |
+
render_factor: int = None,
|
103 |
+
display_render_factor: bool = False,
|
104 |
+
compare: bool = False,
|
105 |
+
post_process: bool = True,
|
106 |
+
) -> Image:
|
107 |
+
|
108 |
+
result = self.get_transformed_pil_image(
|
109 |
+
input_image, render_factor, post_process=post_process
|
110 |
+
)
|
111 |
+
|
112 |
+
if compare:
|
113 |
+
self._plot_comparison(
|
114 |
+
figsize, render_factor, display_render_factor, input_image, result
|
115 |
+
)
|
116 |
+
else:
|
117 |
+
self._plot_solo(figsize, render_factor, display_render_factor, result)
|
118 |
+
|
119 |
+
return result
|
120 |
+
|
121 |
+
def _plot_comparison(
|
122 |
+
self,
|
123 |
+
figsize: Tuple[int, int],
|
124 |
+
render_factor: int,
|
125 |
+
display_render_factor: bool,
|
126 |
+
orig: Image,
|
127 |
+
result: Image,
|
128 |
+
):
|
129 |
+
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
130 |
+
self._plot_image(
|
131 |
+
orig,
|
132 |
+
axes=axes[0],
|
133 |
+
figsize=figsize,
|
134 |
+
render_factor=render_factor,
|
135 |
+
display_render_factor=False,
|
136 |
+
)
|
137 |
+
self._plot_image(
|
138 |
+
result,
|
139 |
+
axes=axes[1],
|
140 |
+
figsize=figsize,
|
141 |
+
render_factor=render_factor,
|
142 |
+
display_render_factor=display_render_factor,
|
143 |
+
)
|
144 |
+
|
145 |
+
def _plot_solo(
|
146 |
+
self,
|
147 |
+
figsize: Tuple[int, int],
|
148 |
+
render_factor: int,
|
149 |
+
display_render_factor: bool,
|
150 |
+
result: Image,
|
151 |
+
):
|
152 |
+
fig, axes = plt.subplots(1, 1, figsize=figsize)
|
153 |
+
self._plot_image(
|
154 |
+
result,
|
155 |
+
axes=axes,
|
156 |
+
figsize=figsize,
|
157 |
+
render_factor=render_factor,
|
158 |
+
display_render_factor=display_render_factor,
|
159 |
+
)
|
160 |
+
|
161 |
+
def _save_result_image(self, source_path: Path, image: Image, results_dir = None) -> Path:
|
162 |
+
if results_dir is None:
|
163 |
+
results_dir = Path(self.results_dir)
|
164 |
+
result_path = results_dir / source_path.name
|
165 |
+
image.save(result_path)
|
166 |
+
return result_path
|
167 |
+
|
168 |
+
def get_transformed_image(
|
169 |
+
self, path: Path, render_factor: int = None, post_process: bool = True,
|
170 |
+
watermarked: bool = True,
|
171 |
+
) -> Image:
|
172 |
+
self._clean_mem()
|
173 |
+
orig_image = self._open_pil_image(path)
|
174 |
+
filtered_image = self.filter.filter(
|
175 |
+
orig_image, orig_image, render_factor=render_factor,post_process=post_process
|
176 |
+
)
|
177 |
+
|
178 |
+
return filtered_image
|
179 |
+
|
180 |
+
def get_transformed_pil_image(
|
181 |
+
self, input_image: Image, render_factor: int = None, post_process: bool = True,
|
182 |
+
) -> Image:
|
183 |
+
self._clean_mem()
|
184 |
+
filtered_image = self.filter.filter(
|
185 |
+
input_image, input_image, render_factor=render_factor,post_process=post_process
|
186 |
+
)
|
187 |
+
|
188 |
+
return filtered_image
|
189 |
+
|
190 |
+
def _plot_image(
|
191 |
+
self,
|
192 |
+
image: Image,
|
193 |
+
render_factor: int,
|
194 |
+
axes: Axes = None,
|
195 |
+
figsize=(20, 20),
|
196 |
+
display_render_factor = False,
|
197 |
+
):
|
198 |
+
if axes is None:
|
199 |
+
_, axes = plt.subplots(figsize=figsize)
|
200 |
+
axes.imshow(np.asarray(image) / 255)
|
201 |
+
axes.axis('off')
|
202 |
+
if render_factor is not None and display_render_factor:
|
203 |
+
plt.text(
|
204 |
+
10,
|
205 |
+
10,
|
206 |
+
'render_factor: ' + str(render_factor),
|
207 |
+
color='white',
|
208 |
+
backgroundcolor='black',
|
209 |
+
)
|
210 |
+
|
211 |
+
def _get_num_rows_columns(self, num_images: int, max_columns: int) -> Tuple[int, int]:
|
212 |
+
columns = min(num_images, max_columns)
|
213 |
+
rows = num_images // columns
|
214 |
+
rows = rows if rows * columns == num_images else rows + 1
|
215 |
+
return rows, columns
|
216 |
+
|
217 |
+
|
218 |
+
def get_image_colorizer(root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True
|
219 |
+
) -> ModelImageVisualizer:
|
220 |
+
|
221 |
+
if artistic:
|
222 |
+
return get_artistic_image_colorizer(root_folder=root_folder, render_factor=render_factor)
|
223 |
+
else:
|
224 |
+
return get_stable_image_colorizer(root_folder=root_folder, render_factor=render_factor)
|
225 |
+
|
226 |
+
|
227 |
+
def get_stable_image_colorizer(root_folder: Path = Path('./'), weights_name: str = 'ColorizeStable_gen',
|
228 |
+
results_dir='result_images', render_factor: int = 35
|
229 |
+
) -> ModelImageVisualizer:
|
230 |
+
|
231 |
+
learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
|
232 |
+
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
|
233 |
+
# vis = ModelImageVisualizer(filtr, results_dir=results_dir)
|
234 |
+
vis = ModelImageVisualizer(filtr)
|
235 |
+
return vis
|
236 |
+
|
237 |
+
|
238 |
+
def get_artistic_image_colorizer(root_folder: Path = Path('./'), weights_name: str = 'ColorizeArtistic_gen',
|
239 |
+
results_dir='result_images', render_factor: int = 35
|
240 |
+
) -> ModelImageVisualizer:
|
241 |
+
|
242 |
+
learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
|
243 |
+
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
|
244 |
+
# vis = ModelImageVisualizer(filtr, results_dir=results_dir)
|
245 |
+
vis = ModelImageVisualizer(filtr)
|
246 |
+
return vis
|
models/ColorizeArtistic_gen.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3f750246fa220529323b85a8905f9b49c0e5d427099185334d048fb5b5e22477
|
3 |
+
size 255144681
|
packages.txt
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
scipy==1.7.1
|
2 |
+
scikit_image==0.18.3
|
3 |
+
streamlit==0.88.0
|
4 |
+
requests==2.26.0
|
5 |
+
torch==1.9.0
|
6 |
+
torchvision==0.10.0
|
7 |
+
matplotlib==3.4.3
|
8 |
+
numpy==1.21.2
|
9 |
+
fastai==1.0.51
|
10 |
+
Pillow==8.3.2
|
11 |
+
opencv-python-headless
|
12 |
+
|