Spaces:
Runtime error
Runtime error
merged changes
Browse files- .gitignore +1 -0
- __pycache__/app_utils.cpython-37.pyc +0 -0
- app.py +147 -75
- app_utils.py +14 -0
- deoldify/__pycache__/__init__.cpython-37.pyc +0 -0
- deoldify/__pycache__/_device.cpython-37.pyc +0 -0
- deoldify/__pycache__/augs.cpython-37.pyc +0 -0
- deoldify/__pycache__/critics.cpython-37.pyc +0 -0
- deoldify/__pycache__/dataset.cpython-37.pyc +0 -0
- deoldify/__pycache__/device_id.cpython-37.pyc +0 -0
- deoldify/__pycache__/filters.cpython-37.pyc +0 -0
- deoldify/__pycache__/generators.cpython-37.pyc +0 -0
- deoldify/__pycache__/layers.cpython-37.pyc +0 -0
- deoldify/__pycache__/loss.cpython-37.pyc +0 -0
- deoldify/__pycache__/unet.cpython-37.pyc +0 -0
- deoldify/__pycache__/visualize.cpython-37.pyc +0 -0
- deoldify/generators.py +52 -0
- deoldify/visualize.py +38 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__/
|
__pycache__/app_utils.cpython-37.pyc
ADDED
Binary file (3.35 kB). View file
|
|
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
#
|
2 |
import os, sys, re
|
3 |
import streamlit as st
|
4 |
import PIL
|
@@ -6,21 +6,8 @@ from PIL import Image
|
|
6 |
import cv2
|
7 |
import numpy as np
|
8 |
import uuid
|
9 |
-
|
10 |
-
import
|
11 |
-
ssl._create_default_https_context = ssl._create_unverified_context
|
12 |
-
|
13 |
-
# Import torch libraries
|
14 |
-
import fastai
|
15 |
-
import torch
|
16 |
-
|
17 |
-
# Import util functions from app_utils
|
18 |
-
from app_utils import download
|
19 |
-
from app_utils import generate_random_filename
|
20 |
-
from app_utils import clean_me
|
21 |
-
from app_utils import clean_all
|
22 |
-
from app_utils import get_model_bin
|
23 |
-
from app_utils import convertToJPG
|
24 |
|
25 |
# Import util functions from deoldify
|
26 |
# NOTE: This must be the first call in order to work properly!
|
@@ -30,13 +17,17 @@ from deoldify.device_id import DeviceId
|
|
30 |
device.set(device=DeviceId.CPU)
|
31 |
from deoldify.visualize import *
|
32 |
|
|
|
|
|
|
|
|
|
33 |
|
34 |
####### INPUT PARAMS ###########
|
35 |
model_folder = 'models/'
|
36 |
max_img_size = 800
|
37 |
################################
|
38 |
|
39 |
-
@st.cache(allow_output_mutation=True)
|
40 |
def load_model(model_dir, option):
|
41 |
if option.lower() == 'artistic':
|
42 |
model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
|
@@ -68,42 +59,132 @@ def resize_img(input_img, max_size):
|
|
68 |
|
69 |
return img
|
70 |
|
71 |
-
def get_image_download_link(img,filename,
|
72 |
button_uuid = str(uuid.uuid4()).replace('-', '')
|
73 |
button_id = re.sub('\d+', '', button_uuid)
|
74 |
-
|
75 |
-
custom_css = f"""
|
76 |
-
<style>
|
77 |
-
#{button_id} {{
|
78 |
-
background-color: rgb(255, 255, 255);
|
79 |
-
color: rgb(38, 39, 48);
|
80 |
-
padding: 0.25em 0.38em;
|
81 |
-
position: relative;
|
82 |
-
text-decoration: none;
|
83 |
-
border-radius: 4px;
|
84 |
-
border-width: 1px;
|
85 |
-
border-style: solid;
|
86 |
-
border-color: rgb(230, 234, 241);
|
87 |
-
border-image: initial;
|
88 |
-
|
89 |
-
}}
|
90 |
-
#{button_id}:hover {{
|
91 |
-
border-color: rgb(246, 51, 102);
|
92 |
-
color: rgb(246, 51, 102);
|
93 |
-
}}
|
94 |
-
#{button_id}:active {{
|
95 |
-
box-shadow: none;
|
96 |
-
background-color: rgb(246, 51, 102);
|
97 |
-
color: white;
|
98 |
-
}}
|
99 |
-
</style> """
|
100 |
|
101 |
buffered = BytesIO()
|
102 |
img.save(buffered, format="JPEG")
|
103 |
img_str = base64.b64encode(buffered.getvalue()).decode()
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
return href
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
# General configuration
|
109 |
# st.set_page_config(layout="centered")
|
@@ -118,12 +199,16 @@ unsafe_allow_html=True)
|
|
118 |
# Main window configuration
|
119 |
st.title("Black and white colorizer")
|
120 |
st.markdown("This app puts color into your black and white pictures")
|
121 |
-
|
|
|
|
|
|
|
|
|
122 |
|
123 |
-
|
124 |
|
125 |
# # Sidebar
|
126 |
-
|
127 |
('Artistic', 'Stable'))
|
128 |
|
129 |
# st.sidebar.title('Model parameters')
|
@@ -132,40 +217,27 @@ color_option = st.sidebar.selectbox('Select colorizer mode',
|
|
132 |
|
133 |
# Load models
|
134 |
try:
|
135 |
-
|
|
|
|
|
|
|
136 |
except Exception as e:
|
137 |
-
print(e)
|
138 |
colorizer = None
|
139 |
print('Error while loading the model. Please refresh the page')
|
|
|
|
|
140 |
|
141 |
if colorizer is not None:
|
142 |
-
|
143 |
-
title_message.markdown("**To begin, please upload an image** 👇")
|
144 |
|
145 |
#Choose your own image
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
input_img_pos = st.empty()
|
150 |
-
output_img_pos = st.empty()
|
151 |
|
152 |
-
if
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
img_rgb = np.array(pil_img)
|
157 |
-
|
158 |
-
resized_img_rgb = resize_img(img_rgb, max_img_size)
|
159 |
-
resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
|
160 |
-
|
161 |
-
title_message.markdown("**Processing your image, please wait** ⌛")
|
162 |
-
|
163 |
-
output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
|
164 |
-
|
165 |
-
title_message.markdown("**To begin, please upload an image** 👇")
|
166 |
|
167 |
-
# Plot images
|
168 |
-
input_img_pos.image(resized_pil_img, 'Input image', use_column_width=True)
|
169 |
-
output_img_pos.image(output_pil_img, 'Output image', use_column_width=True)
|
170 |
|
171 |
-
st.markdown(get_image_download_link(output_pil_img, img_name, 'Download '+img_name), unsafe_allow_html=True)
|
|
|
1 |
+
# Import general purpose libraries
|
2 |
import os, sys, re
|
3 |
import streamlit as st
|
4 |
import PIL
|
|
|
6 |
import cv2
|
7 |
import numpy as np
|
8 |
import uuid
|
9 |
+
from zipfile import ZipFile, ZIP_DEFLATED
|
10 |
+
from io import BytesIO
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# Import util functions from deoldify
|
13 |
# NOTE: This must be the first call in order to work properly!
|
|
|
17 |
device.set(device=DeviceId.CPU)
|
18 |
from deoldify.visualize import *
|
19 |
|
20 |
+
# Import util functions from app_utils
|
21 |
+
from app_utils import get_model_bin
|
22 |
+
|
23 |
+
|
24 |
|
25 |
####### INPUT PARAMS ###########
|
26 |
model_folder = 'models/'
|
27 |
max_img_size = 800
|
28 |
################################
|
29 |
|
30 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
31 |
def load_model(model_dir, option):
|
32 |
if option.lower() == 'artistic':
|
33 |
model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
|
|
|
59 |
|
60 |
return img
|
61 |
|
62 |
+
def get_image_download_link(img, filename, button_text):
|
63 |
button_uuid = str(uuid.uuid4()).replace('-', '')
|
64 |
button_id = re.sub('\d+', '', button_uuid)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
buffered = BytesIO()
|
67 |
img.save(buffered, format="JPEG")
|
68 |
img_str = base64.b64encode(buffered.getvalue()).decode()
|
69 |
+
|
70 |
+
return get_button_html_code(img_str, filename, 'txt', button_id, button_text)
|
71 |
+
|
72 |
+
def get_button_html_code(data_str, filename, filetype, button_id, button_txt='Download file'):
|
73 |
+
custom_css = f"""
|
74 |
+
<style>
|
75 |
+
#{button_id} {{
|
76 |
+
background-color: rgb(255, 255, 255);
|
77 |
+
color: rgb(38, 39, 48);
|
78 |
+
padding: 0.25em 0.38em;
|
79 |
+
position: relative;
|
80 |
+
text-decoration: none;
|
81 |
+
border-radius: 4px;
|
82 |
+
border-width: 1px;
|
83 |
+
border-style: solid;
|
84 |
+
border-color: rgb(230, 234, 241);
|
85 |
+
border-image: initial;
|
86 |
+
|
87 |
+
}}
|
88 |
+
#{button_id}:hover {{
|
89 |
+
border-color: rgb(246, 51, 102);
|
90 |
+
color: rgb(246, 51, 102);
|
91 |
+
}}
|
92 |
+
#{button_id}:active {{
|
93 |
+
box-shadow: none;
|
94 |
+
background-color: rgb(246, 51, 102);
|
95 |
+
color: white;
|
96 |
+
}}
|
97 |
+
</style> """
|
98 |
+
|
99 |
+
href = custom_css + f'<a href="data:file/{filetype};base64,{data_str}" id="{button_id}" download="{filename}">{button_txt}</a>'
|
100 |
return href
|
101 |
|
102 |
+
def display_single_image(uploaded_file, img_size=800):
|
103 |
+
print('Type: ', type(uploaded_file))
|
104 |
+
st_title_message.markdown("**Processing your image, please wait** ⌛")
|
105 |
+
img_name = uploaded_file.name
|
106 |
+
|
107 |
+
# Open the image
|
108 |
+
pil_img = PIL.Image.open(uploaded_file)
|
109 |
+
img_rgb = np.array(pil_img)
|
110 |
+
resized_img_rgb = resize_img(img_rgb, img_size)
|
111 |
+
resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
|
112 |
+
|
113 |
+
# Send the image to the model
|
114 |
+
output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
|
115 |
+
|
116 |
+
# Plot images
|
117 |
+
st_input_img.image(resized_pil_img, 'Input image', use_column_width=True)
|
118 |
+
st_output_img.image(output_pil_img, 'Output image', use_column_width=True)
|
119 |
+
|
120 |
+
# Show download button
|
121 |
+
st_download_button.markdown(get_image_download_link(output_pil_img, img_name, 'Download Image'), unsafe_allow_html=True)
|
122 |
+
|
123 |
+
# Reset the message
|
124 |
+
st_title_message.markdown("**To begin, please upload an image** 👇")
|
125 |
+
|
126 |
+
def process_multiple_images(uploaded_files, img_size=800):
|
127 |
+
num_imgs = len(uploaded_files)
|
128 |
+
|
129 |
+
output_images_list = []
|
130 |
+
img_names_list = []
|
131 |
+
idx = 1
|
132 |
+
for idx, uploaded_file in enumerate(uploaded_files, start=1):
|
133 |
+
st_title_message.markdown("**Processing image {}/{}. Please wait** ⌛".format(idx,
|
134 |
+
num_imgs))
|
135 |
+
|
136 |
+
img_name = uploaded_file.name
|
137 |
+
img_type = uploaded_file.type
|
138 |
+
|
139 |
+
# Open the image
|
140 |
+
pil_img = PIL.Image.open(uploaded_file)
|
141 |
+
img_rgb = np.array(pil_img)
|
142 |
+
resized_img_rgb = resize_img(img_rgb, img_size)
|
143 |
+
resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
|
144 |
+
|
145 |
+
# Send the image to the model
|
146 |
+
output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
|
147 |
+
|
148 |
+
output_images_list.append(output_pil_img)
|
149 |
+
img_names_list.append(img_name.split('.')[0])
|
150 |
+
|
151 |
+
# Zip output files
|
152 |
+
zip_path = 'processed_images.zip'
|
153 |
+
zip_buf = zip_multiple_images(output_images_list, img_names_list, zip_path)
|
154 |
+
|
155 |
+
st_download_button.download_button(
|
156 |
+
label='Download ZIP file',
|
157 |
+
data=zip_buf.read(),
|
158 |
+
file_name=zip_path,
|
159 |
+
mime="application/zip"
|
160 |
+
)
|
161 |
+
|
162 |
+
# Show message
|
163 |
+
st_title_message.markdown("**Images are ready for download** 💾")
|
164 |
+
|
165 |
+
def zip_multiple_images(pil_images_list, img_names_list, dest_path):
|
166 |
+
# Create zip file on memory
|
167 |
+
zip_buf = BytesIO()
|
168 |
+
|
169 |
+
with ZipFile(zip_buf, 'w', ZIP_DEFLATED) as zipObj:
|
170 |
+
for pil_img, img_name in zip(pil_images_list, img_names_list):
|
171 |
+
with BytesIO() as output:
|
172 |
+
# Save image in memory
|
173 |
+
pil_img.save(output, format="PNG")
|
174 |
+
|
175 |
+
# Read data
|
176 |
+
contents = output.getvalue()
|
177 |
+
|
178 |
+
# Write it to zip file
|
179 |
+
zipObj.writestr(img_name+".png", contents)
|
180 |
+
zip_buf.seek(0)
|
181 |
+
return zip_buf
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
###########################
|
186 |
+
###### STREAMLIT CODE #####
|
187 |
+
###########################
|
188 |
|
189 |
# General configuration
|
190 |
# st.set_page_config(layout="centered")
|
|
|
199 |
# Main window configuration
|
200 |
st.title("Black and white colorizer")
|
201 |
st.markdown("This app puts color into your black and white pictures")
|
202 |
+
st_title_message = st.empty()
|
203 |
+
st_file_uploader = st.empty()
|
204 |
+
st_input_img = st.empty()
|
205 |
+
st_output_img = st.empty()
|
206 |
+
st_download_button = st.empty()
|
207 |
|
208 |
+
st_title_message.markdown("**Model loading, please wait** ⌛")
|
209 |
|
210 |
# # Sidebar
|
211 |
+
st_color_option = st.sidebar.selectbox('Select colorizer mode',
|
212 |
('Artistic', 'Stable'))
|
213 |
|
214 |
# st.sidebar.title('Model parameters')
|
|
|
217 |
|
218 |
# Load models
|
219 |
try:
|
220 |
+
print('before loading the model')
|
221 |
+
colorizer = load_model(model_folder, st_color_option)
|
222 |
+
print('after loading the model')
|
223 |
+
|
224 |
except Exception as e:
|
|
|
225 |
colorizer = None
|
226 |
print('Error while loading the model. Please refresh the page')
|
227 |
+
print(e)
|
228 |
+
st_title_message.markdown("**Error while loading the model. Please refresh the page**")
|
229 |
|
230 |
if colorizer is not None:
|
231 |
+
st_title_message.markdown("**To begin, please upload an image** 👇")
|
|
|
232 |
|
233 |
#Choose your own image
|
234 |
+
uploaded_files = st_file_uploader.file_uploader("Upload a black and white photo",
|
235 |
+
type=['png', 'jpg', 'jpeg'],
|
236 |
+
accept_multiple_files=True)
|
|
|
|
|
237 |
|
238 |
+
if len(uploaded_files) == 1:
|
239 |
+
display_single_image(uploaded_files[0], max_img_size)
|
240 |
+
elif len(uploaded_files) > 1:
|
241 |
+
process_multiple_images(uploaded_files, max_img_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
|
|
|
|
|
|
|
243 |
|
|
app_utils.py
CHANGED
@@ -106,6 +106,7 @@ def clean_all(files):
|
|
106 |
clean_me(me)
|
107 |
|
108 |
|
|
|
109 |
def get_model_bin(url, output_path):
|
110 |
# print('Getting model dir: ', output_path)
|
111 |
if not os.path.exists(output_path):
|
@@ -115,14 +116,27 @@ def get_model_bin(url, output_path):
|
|
115 |
output_folder = output_path.replace('\\','/').split('/')[0]
|
116 |
if not os.path.exists(output_folder):
|
117 |
os.makedirs(output_folder, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
urllib.request.urlretrieve(url, output_path)
|
120 |
|
121 |
# cmd = "wget -O %s %s" % (output_path, url)
|
122 |
# print(cmd)
|
123 |
# os.system(cmd)
|
|
|
124 |
else:
|
125 |
print('Model exists')
|
|
|
|
|
126 |
|
127 |
return output_path
|
128 |
|
|
|
106 |
clean_me(me)
|
107 |
|
108 |
|
109 |
+
<<<<<<< HEAD
|
110 |
def get_model_bin(url, output_path):
|
111 |
# print('Getting model dir: ', output_path)
|
112 |
if not os.path.exists(output_path):
|
|
|
116 |
output_folder = output_path.replace('\\','/').split('/')[0]
|
117 |
if not os.path.exists(output_folder):
|
118 |
os.makedirs(output_folder, exist_ok=True)
|
119 |
+
=======
|
120 |
+
def create_directory(path):
|
121 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
122 |
+
|
123 |
+
|
124 |
+
def get_model_bin(url, output_path):
|
125 |
+
# print('Getting model dir: ', output_path)
|
126 |
+
if not os.path.exists(output_path):
|
127 |
+
create_directory(output_path)
|
128 |
+
>>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
|
129 |
|
130 |
urllib.request.urlretrieve(url, output_path)
|
131 |
|
132 |
# cmd = "wget -O %s %s" % (output_path, url)
|
133 |
# print(cmd)
|
134 |
# os.system(cmd)
|
135 |
+
<<<<<<< HEAD
|
136 |
else:
|
137 |
print('Model exists')
|
138 |
+
=======
|
139 |
+
>>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
|
140 |
|
141 |
return output_path
|
142 |
|
deoldify/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (218 Bytes). View file
|
|
deoldify/__pycache__/_device.cpython-37.pyc
ADDED
Binary file (1.32 kB). View file
|
|
deoldify/__pycache__/augs.cpython-37.pyc
ADDED
Binary file (867 Bytes). View file
|
|
deoldify/__pycache__/critics.cpython-37.pyc
ADDED
Binary file (1.52 kB). View file
|
|
deoldify/__pycache__/dataset.cpython-37.pyc
ADDED
Binary file (1.53 kB). View file
|
|
deoldify/__pycache__/device_id.cpython-37.pyc
ADDED
Binary file (510 Bytes). View file
|
|
deoldify/__pycache__/filters.cpython-37.pyc
ADDED
Binary file (4.9 kB). View file
|
|
deoldify/__pycache__/generators.cpython-37.pyc
ADDED
Binary file (3.15 kB). View file
|
|
deoldify/__pycache__/layers.cpython-37.pyc
ADDED
Binary file (1.43 kB). View file
|
|
deoldify/__pycache__/loss.cpython-37.pyc
ADDED
Binary file (6.47 kB). View file
|
|
deoldify/__pycache__/unet.cpython-37.pyc
ADDED
Binary file (8.21 kB). View file
|
|
deoldify/__pycache__/visualize.cpython-37.pyc
ADDED
Binary file (6.62 kB). View file
|
|
deoldify/generators.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from fastai.basics import F, nn
|
2 |
from fastai.basic_data import DataBunch
|
3 |
from fastai.basic_train import Learner
|
@@ -5,12 +6,17 @@ from fastai.layers import NormType
|
|
5 |
from fastai.torch_core import SplitFuncOrIdxList, to_device, apply_init
|
6 |
from fastai.vision import *
|
7 |
from fastai.vision.learner import cnn_config, create_body
|
|
|
|
|
|
|
|
|
8 |
from .unet import DynamicUnetWide, DynamicUnetDeep
|
9 |
from .loss import FeatureLoss
|
10 |
from .dataset import *
|
11 |
|
12 |
# Weights are implicitly read from ./models/ folder
|
13 |
def gen_inference_wide(
|
|
|
14 |
root_folder: Path, weights_name: str, nf_factor: int = 2,
|
15 |
arch=models.resnet101
|
16 |
) -> Learner:
|
@@ -41,6 +47,16 @@ def get_inference(learn, root_folder, weights_name) -> Learner:
|
|
41 |
print('Error while reading the model')
|
42 |
learn.model.eval()
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
return learn
|
45 |
|
46 |
|
@@ -104,10 +120,29 @@ def unet_learner_wide(
|
|
104 |
|
105 |
# ----------------------------------------------------------------------
|
106 |
|
|
|
107 |
def gen_learner_deep(data: ImageDataBunch, gen_loss, arch=models.resnet34,
|
108 |
nf_factor: float = 1.5
|
109 |
) -> Learner:
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
return unet_learner_deep(
|
112 |
data,
|
113 |
arch,
|
@@ -123,6 +158,7 @@ def gen_learner_deep(data: ImageDataBunch, gen_loss, arch=models.resnet34,
|
|
123 |
|
124 |
# The code below is meant to be merged into fastaiv1 ideally
|
125 |
def unet_learner_deep(
|
|
|
126 |
data: DataBunch,
|
127 |
arch: Callable,
|
128 |
pretrained: bool = True,
|
@@ -138,6 +174,22 @@ def unet_learner_deep(
|
|
138 |
**kwargs: Any
|
139 |
) -> Learner:
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
"Build Unet learner from `data` and `arch`."
|
142 |
meta = cnn_config(arch)
|
143 |
body = create_body(arch, pretrained)
|
|
|
1 |
+
<<<<<<< HEAD
|
2 |
from fastai.basics import F, nn
|
3 |
from fastai.basic_data import DataBunch
|
4 |
from fastai.basic_train import Learner
|
|
|
6 |
from fastai.torch_core import SplitFuncOrIdxList, to_device, apply_init
|
7 |
from fastai.vision import *
|
8 |
from fastai.vision.learner import cnn_config, create_body
|
9 |
+
=======
|
10 |
+
from fastai.vision import *
|
11 |
+
from fastai.vision.learner import cnn_config
|
12 |
+
>>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
|
13 |
from .unet import DynamicUnetWide, DynamicUnetDeep
|
14 |
from .loss import FeatureLoss
|
15 |
from .dataset import *
|
16 |
|
17 |
# Weights are implicitly read from ./models/ folder
|
18 |
def gen_inference_wide(
|
19 |
+
<<<<<<< HEAD
|
20 |
root_folder: Path, weights_name: str, nf_factor: int = 2,
|
21 |
arch=models.resnet101
|
22 |
) -> Learner:
|
|
|
47 |
print('Error while reading the model')
|
48 |
learn.model.eval()
|
49 |
|
50 |
+
=======
|
51 |
+
root_folder: Path, weights_name: str, nf_factor: int = 2, arch=models.resnet101) -> Learner:
|
52 |
+
data = get_dummy_databunch()
|
53 |
+
learn = gen_learner_wide(
|
54 |
+
data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch
|
55 |
+
)
|
56 |
+
learn.path = root_folder
|
57 |
+
learn.load(weights_name)
|
58 |
+
learn.model.eval()
|
59 |
+
>>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
|
60 |
return learn
|
61 |
|
62 |
|
|
|
120 |
|
121 |
# ----------------------------------------------------------------------
|
122 |
|
123 |
+
<<<<<<< HEAD
|
124 |
def gen_learner_deep(data: ImageDataBunch, gen_loss, arch=models.resnet34,
|
125 |
nf_factor: float = 1.5
|
126 |
) -> Learner:
|
127 |
|
128 |
+
=======
|
129 |
+
# Weights are implicitly read from ./models/ folder
|
130 |
+
def gen_inference_deep(
|
131 |
+
root_folder: Path, weights_name: str, arch=models.resnet34, nf_factor: float = 1.5) -> Learner:
|
132 |
+
data = get_dummy_databunch()
|
133 |
+
learn = gen_learner_deep(
|
134 |
+
data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor
|
135 |
+
)
|
136 |
+
learn.path = root_folder
|
137 |
+
learn.load(weights_name)
|
138 |
+
learn.model.eval()
|
139 |
+
return learn
|
140 |
+
|
141 |
+
|
142 |
+
def gen_learner_deep(
|
143 |
+
data: ImageDataBunch, gen_loss, arch=models.resnet34, nf_factor: float = 1.5
|
144 |
+
) -> Learner:
|
145 |
+
>>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
|
146 |
return unet_learner_deep(
|
147 |
data,
|
148 |
arch,
|
|
|
158 |
|
159 |
# The code below is meant to be merged into fastaiv1 ideally
|
160 |
def unet_learner_deep(
|
161 |
+
<<<<<<< HEAD
|
162 |
data: DataBunch,
|
163 |
arch: Callable,
|
164 |
pretrained: bool = True,
|
|
|
174 |
**kwargs: Any
|
175 |
) -> Learner:
|
176 |
|
177 |
+
=======
|
178 |
+
data: DataBunch,
|
179 |
+
arch: Callable,
|
180 |
+
pretrained: bool = True,
|
181 |
+
blur_final: bool = True,
|
182 |
+
norm_type: Optional[NormType] = NormType,
|
183 |
+
split_on: Optional[SplitFuncOrIdxList] = None,
|
184 |
+
blur: bool = False,
|
185 |
+
self_attention: bool = False,
|
186 |
+
y_range: Optional[Tuple[float, float]] = None,
|
187 |
+
last_cross: bool = True,
|
188 |
+
bottle: bool = False,
|
189 |
+
nf_factor: float = 1.5,
|
190 |
+
**kwargs: Any
|
191 |
+
) -> Learner:
|
192 |
+
>>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
|
193 |
"Build Unet learner from `data` and `arch`."
|
194 |
meta = cnn_config(arch)
|
195 |
body = create_body(arch, pretrained)
|
deoldify/visualize.py
CHANGED
@@ -19,13 +19,21 @@ from .generators import gen_inference_deep, gen_inference_wide
|
|
19 |
|
20 |
|
21 |
|
|
|
|
|
|
|
|
|
22 |
class ModelImageVisualizer:
|
23 |
def __init__(self, filter: IFilter, results_dir: str = None):
|
24 |
self.filter = filter
|
25 |
self.results_dir = None if results_dir is None else Path(results_dir)
|
|
|
26 |
|
27 |
if self.results_dir is not None:
|
28 |
self.results_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
29 |
|
30 |
def _clean_mem(self):
|
31 |
torch.cuda.empty_cache()
|
@@ -215,15 +223,22 @@ class ModelImageVisualizer:
|
|
215 |
return rows, columns
|
216 |
|
217 |
|
|
|
218 |
def get_image_colorizer(root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True
|
219 |
) -> ModelImageVisualizer:
|
220 |
|
|
|
|
|
|
|
|
|
|
|
221 |
if artistic:
|
222 |
return get_artistic_image_colorizer(root_folder=root_folder, render_factor=render_factor)
|
223 |
else:
|
224 |
return get_stable_image_colorizer(root_folder=root_folder, render_factor=render_factor)
|
225 |
|
226 |
|
|
|
227 |
def get_stable_image_colorizer(root_folder: Path = Path('./'), weights_name: str = 'ColorizeStable_gen',
|
228 |
results_dir='result_images', render_factor: int = 35
|
229 |
) -> ModelImageVisualizer:
|
@@ -243,4 +258,27 @@ def get_artistic_image_colorizer(root_folder: Path = Path('./'), weights_name: s
|
|
243 |
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
|
244 |
# vis = ModelImageVisualizer(filtr, results_dir=results_dir)
|
245 |
vis = ModelImageVisualizer(filtr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
return vis
|
|
|
19 |
|
20 |
|
21 |
|
22 |
+
<<<<<<< HEAD
|
23 |
+
=======
|
24 |
+
# class LoadedModel
|
25 |
+
>>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
|
26 |
class ModelImageVisualizer:
|
27 |
def __init__(self, filter: IFilter, results_dir: str = None):
|
28 |
self.filter = filter
|
29 |
self.results_dir = None if results_dir is None else Path(results_dir)
|
30 |
+
<<<<<<< HEAD
|
31 |
|
32 |
if self.results_dir is not None:
|
33 |
self.results_dir.mkdir(parents=True, exist_ok=True)
|
34 |
+
=======
|
35 |
+
self.results_dir.mkdir(parents=True, exist_ok=True)
|
36 |
+
>>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
|
37 |
|
38 |
def _clean_mem(self):
|
39 |
torch.cuda.empty_cache()
|
|
|
223 |
return rows, columns
|
224 |
|
225 |
|
226 |
+
<<<<<<< HEAD
|
227 |
def get_image_colorizer(root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True
|
228 |
) -> ModelImageVisualizer:
|
229 |
|
230 |
+
=======
|
231 |
+
def get_image_colorizer(
|
232 |
+
root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True
|
233 |
+
) -> ModelImageVisualizer:
|
234 |
+
>>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
|
235 |
if artistic:
|
236 |
return get_artistic_image_colorizer(root_folder=root_folder, render_factor=render_factor)
|
237 |
else:
|
238 |
return get_stable_image_colorizer(root_folder=root_folder, render_factor=render_factor)
|
239 |
|
240 |
|
241 |
+
<<<<<<< HEAD
|
242 |
def get_stable_image_colorizer(root_folder: Path = Path('./'), weights_name: str = 'ColorizeStable_gen',
|
243 |
results_dir='result_images', render_factor: int = 35
|
244 |
) -> ModelImageVisualizer:
|
|
|
258 |
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
|
259 |
# vis = ModelImageVisualizer(filtr, results_dir=results_dir)
|
260 |
vis = ModelImageVisualizer(filtr)
|
261 |
+
=======
|
262 |
+
def get_stable_image_colorizer(
|
263 |
+
root_folder: Path = Path('./'),
|
264 |
+
weights_name: str = 'ColorizeStable_gen',
|
265 |
+
results_dir='result_images',
|
266 |
+
render_factor: int = 35
|
267 |
+
) -> ModelImageVisualizer:
|
268 |
+
learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
|
269 |
+
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
|
270 |
+
vis = ModelImageVisualizer(filtr, results_dir=results_dir)
|
271 |
+
return vis
|
272 |
+
|
273 |
+
|
274 |
+
def get_artistic_image_colorizer(
|
275 |
+
root_folder: Path = Path('./'),
|
276 |
+
weights_name: str = 'ColorizeArtistic_gen',
|
277 |
+
results_dir='result_images',
|
278 |
+
render_factor: int = 35
|
279 |
+
) -> ModelImageVisualizer:
|
280 |
+
learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
|
281 |
+
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
|
282 |
+
vis = ModelImageVisualizer(filtr, results_dir=results_dir)
|
283 |
+
>>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
|
284 |
return vis
|