zwh7 hylee commited on
Commit
aac5eb3
0 Parent(s):

Duplicate from hylee/White-box-Cartoonization

Browse files

Co-authored-by: hylee <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.onnx filter=lfs diff=lfs merge=lfs -text
13
+ *.ot filter=lfs diff=lfs merge=lfs -text
14
+ *.parquet filter=lfs diff=lfs merge=lfs -text
15
+ *.pb filter=lfs diff=lfs merge=lfs -text
16
+ *.pt filter=lfs diff=lfs merge=lfs -text
17
+ *.pth filter=lfs diff=lfs merge=lfs -text
18
+ *.rar filter=lfs diff=lfs merge=lfs -text
19
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
21
+ *.tflite filter=lfs diff=lfs merge=lfs -text
22
+ *.tgz filter=lfs diff=lfs merge=lfs -text
23
+ *.wasm filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ .data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
29
+ *.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ python_version: 3.7
3
+ title: White Box Cartoonization
4
+ emoji: 📚
5
+ colorFrom: purple
6
+ colorTo: green
7
+ sdk: gradio
8
+ sdk_version: 2.9.4
9
+ app_file: app.py
10
+ pinned: false
11
+ license: apache-2.0
12
+ duplicated_from: hylee/White-box-Cartoonization
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+ import argparse
5
+ import functools
6
+ import os
7
+ import pathlib
8
+ import sys
9
+ from typing import Callable
10
+ import uuid
11
+
12
+ import gradio as gr
13
+ import huggingface_hub
14
+ import numpy as np
15
+ import PIL.Image
16
+
17
+ from io import BytesIO
18
+ from wbc.cartoonize import Cartoonize
19
+
20
+ ORIGINAL_REPO_URL = 'https://github.com/SystemErrorWang/White-box-Cartoonization'
21
+ TITLE = 'SystemErrorWang/White-box-Cartoonization'
22
+ DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}.
23
+
24
+ """
25
+ ARTICLE = """
26
+
27
+ """
28
+
29
+ SAFEHASH = [x for x in "0123456789-abcdefghijklmnopqrstuvwxyz_ABCDEFGHIJKLMNOPQRSTUVWXYZ"]
30
+ def compress_UUID():
31
+ '''
32
+ 根据http://www.ietf.org/rfc/rfc1738.txt,由uuid编码扩bai大字符域生成du串
33
+ 包括:[0-9a-zA-Z\-_]共64个
34
+ 长度:(32-2)/3*2=20
35
+ 备注:可在地球上人zhi人都用,使用100年不重复(2^120)
36
+ :return:String
37
+ '''
38
+ row = str(uuid.uuid4()).replace('-', '')
39
+ safe_code = ''
40
+ for i in range(10):
41
+ enbin = "%012d" % int(bin(int(row[i * 3] + row[i * 3 + 1] + row[i * 3 + 2], 16))[2:], 10)
42
+ safe_code += (SAFEHASH[int(enbin[0:6], 2)] + SAFEHASH[int(enbin[6:12], 2)])
43
+ safe_code = safe_code.replace('-', '')
44
+ return safe_code
45
+
46
+
47
+ def parse_args() -> argparse.Namespace:
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument('--device', type=str, default='cpu')
50
+ parser.add_argument('--theme', type=str)
51
+ parser.add_argument('--live', action='store_true')
52
+ parser.add_argument('--share', action='store_true')
53
+ parser.add_argument('--port', type=int)
54
+ parser.add_argument('--disable-queue',
55
+ dest='enable_queue',
56
+ action='store_false')
57
+ parser.add_argument('--allow-flagging', type=str, default='never')
58
+ parser.add_argument('--allow-screenshot', action='store_true')
59
+ return parser.parse_args()
60
+
61
+ def run(
62
+ image,
63
+ cartoonize : Cartoonize
64
+ ) -> tuple[PIL.Image.Image]:
65
+
66
+ out_path = compress_UUID()+'.png'
67
+ cartoonize.run_sigle(image.name, out_path)
68
+
69
+ return PIL.Image.open(out_path)
70
+
71
+
72
+ def main():
73
+ gr.close_all()
74
+
75
+ args = parse_args()
76
+
77
+ cartoonize = Cartoonize(os.path.join(os.path.dirname(os.path.abspath(__file__)),'wbc/saved_models/'))
78
+
79
+ func = functools.partial(run, cartoonize=cartoonize)
80
+ func = functools.update_wrapper(func, run)
81
+
82
+ gr.Interface(
83
+ func,
84
+ [
85
+ gr.inputs.Image(type='file', label='Input Image'),
86
+ ],
87
+ [
88
+ gr.outputs.Image(
89
+ type='pil',
90
+ label='Result'),
91
+ ],
92
+ # examples=examples,
93
+ theme=args.theme,
94
+ title=TITLE,
95
+ description=DESCRIPTION,
96
+ article=ARTICLE,
97
+ allow_screenshot=args.allow_screenshot,
98
+ allow_flagging=args.allow_flagging,
99
+ live=args.live,
100
+ ).launch(
101
+ enable_queue=args.enable_queue,
102
+ server_port=args.port,
103
+ share=args.share,
104
+ )
105
+
106
+
107
+ if __name__ == '__main__':
108
+ main()
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ opencv-python-headless==4.5.5.62
2
+ Pillow==9.0.1
3
+ scipy==1.7.3
4
+ tensorflow-gpu==1.14.0
5
+ scikit-image==0.14.5
wbc/cartoonize.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ import wbc.network as network
6
+ import wbc.guided_filter as guided_filter
7
+ from tqdm import tqdm
8
+
9
+
10
+ def resize_crop(image):
11
+ h, w, c = np.shape(image)
12
+ if min(h, w) > 720:
13
+ if h > w:
14
+ h, w = int(720 * h / w), 720
15
+ else:
16
+ h, w = 720, int(720 * w / h)
17
+ image = cv2.resize(image, (w, h),
18
+ interpolation=cv2.INTER_AREA)
19
+ h, w = (h // 8) * 8, (w // 8) * 8
20
+ image = image[:h, :w, :]
21
+ return image
22
+
23
+
24
+ def cartoonize(load_folder, save_folder, model_path):
25
+ print(model_path)
26
+ input_photo = tf.placeholder(tf.float32, [1, None, None, 3])
27
+ network_out = network.unet_generator(input_photo)
28
+ final_out = guided_filter.guided_filter(input_photo, network_out, r=1, eps=5e-3)
29
+
30
+ all_vars = tf.trainable_variables()
31
+ gene_vars = [var for var in all_vars if 'generator' in var.name]
32
+ saver = tf.train.Saver(var_list=gene_vars)
33
+
34
+ config = tf.ConfigProto()
35
+ config.gpu_options.allow_growth = True
36
+ sess = tf.Session(config=config)
37
+
38
+ sess.run(tf.global_variables_initializer())
39
+ saver.restore(sess, tf.train.latest_checkpoint(model_path))
40
+ name_list = os.listdir(load_folder)
41
+ for name in tqdm(name_list):
42
+ try:
43
+ load_path = os.path.join(load_folder, name)
44
+ save_path = os.path.join(save_folder, name)
45
+ image = cv2.imread(load_path)
46
+ image = resize_crop(image)
47
+ batch_image = image.astype(np.float32) / 127.5 - 1
48
+ batch_image = np.expand_dims(batch_image, axis=0)
49
+ output = sess.run(final_out, feed_dict={input_photo: batch_image})
50
+ output = (np.squeeze(output) + 1) * 127.5
51
+ output = np.clip(output, 0, 255).astype(np.uint8)
52
+ cv2.imwrite(save_path, output)
53
+ except:
54
+ print('cartoonize {} failed'.format(load_path))
55
+
56
+
57
+ class Cartoonize:
58
+ def __init__(self, model_path):
59
+ print(model_path)
60
+ self.input_photo = tf.placeholder(tf.float32, [1, None, None, 3])
61
+ network_out = network.unet_generator(self.input_photo)
62
+ self.final_out = guided_filter.guided_filter(self.input_photo, network_out, r=1, eps=5e-3)
63
+
64
+ all_vars = tf.trainable_variables()
65
+ gene_vars = [var for var in all_vars if 'generator' in var.name]
66
+ saver = tf.train.Saver(var_list=gene_vars)
67
+
68
+ config = tf.ConfigProto()
69
+ config.gpu_options.allow_growth = True
70
+ self.sess = tf.Session(config=config)
71
+
72
+ self.sess.run(tf.global_variables_initializer())
73
+ saver.restore(self.sess, tf.train.latest_checkpoint(model_path))
74
+
75
+ def run(self, load_folder, save_folder):
76
+ name_list = os.listdir(load_folder)
77
+ for name in tqdm(name_list):
78
+ try:
79
+ load_path = os.path.join(load_folder, name)
80
+ save_path = os.path.join(save_folder, name)
81
+ image = cv2.imread(load_path)
82
+ image = resize_crop(image)
83
+ batch_image = image.astype(np.float32) / 127.5 - 1
84
+ batch_image = np.expand_dims(batch_image, axis=0)
85
+ output = self.sess.run(self.final_out, feed_dict={self.input_photo: batch_image})
86
+ output = (np.squeeze(output) + 1) * 127.5
87
+ output = np.clip(output, 0, 255).astype(np.uint8)
88
+ cv2.imwrite(save_path, output)
89
+ except:
90
+ print('cartoonize {} failed'.format(load_path))
91
+
92
+ def run_sigle(self, load_path, save_path):
93
+ try:
94
+ image = cv2.imread(load_path)
95
+ image = resize_crop(image)
96
+ batch_image = image.astype(np.float32) / 127.5 - 1
97
+ batch_image = np.expand_dims(batch_image, axis=0)
98
+ output = self.sess.run(self.final_out, feed_dict={self.input_photo: batch_image})
99
+ output = (np.squeeze(output) + 1) * 127.5
100
+ output = np.clip(output, 0, 255).astype(np.uint8)
101
+ cv2.imwrite(save_path, output)
102
+ except:
103
+ print('cartoonize {} failed'.format(load_path))
104
+
105
+
106
+ if __name__ == '__main__':
107
+ model_path = 'saved_models'
108
+ load_folder = 'test_images'
109
+ save_folder = 'cartoonized_images'
110
+ if not os.path.exists(save_folder):
111
+ os.mkdir(save_folder)
112
+ cartoonize(load_folder, save_folder, model_path)
wbc/guided_filter.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+
4
+
5
+
6
+
7
+ def tf_box_filter(x, r):
8
+ k_size = int(2*r+1)
9
+ ch = x.get_shape().as_list()[-1]
10
+ weight = 1/(k_size**2)
11
+ box_kernel = weight*np.ones((k_size, k_size, ch, 1))
12
+ box_kernel = np.array(box_kernel).astype(np.float32)
13
+ output = tf.nn.depthwise_conv2d(x, box_kernel, [1, 1, 1, 1], 'SAME')
14
+ return output
15
+
16
+
17
+
18
+ def guided_filter(x, y, r, eps=1e-2):
19
+
20
+ x_shape = tf.shape(x)
21
+ #y_shape = tf.shape(y)
22
+
23
+ N = tf_box_filter(tf.ones((1, x_shape[1], x_shape[2], 1), dtype=x.dtype), r)
24
+
25
+ mean_x = tf_box_filter(x, r) / N
26
+ mean_y = tf_box_filter(y, r) / N
27
+ cov_xy = tf_box_filter(x * y, r) / N - mean_x * mean_y
28
+ var_x = tf_box_filter(x * x, r) / N - mean_x * mean_x
29
+
30
+ A = cov_xy / (var_x + eps)
31
+ b = mean_y - A * mean_x
32
+
33
+ mean_A = tf_box_filter(A, r) / N
34
+ mean_b = tf_box_filter(b, r) / N
35
+
36
+ output = mean_A * x + mean_b
37
+
38
+ return output
39
+
40
+
41
+
42
+ def fast_guided_filter(lr_x, lr_y, hr_x, r=1, eps=1e-8):
43
+
44
+ #assert lr_x.shape.ndims == 4 and lr_y.shape.ndims == 4 and hr_x.shape.ndims == 4
45
+
46
+ lr_x_shape = tf.shape(lr_x)
47
+ #lr_y_shape = tf.shape(lr_y)
48
+ hr_x_shape = tf.shape(hr_x)
49
+
50
+ N = tf_box_filter(tf.ones((1, lr_x_shape[1], lr_x_shape[2], 1), dtype=lr_x.dtype), r)
51
+
52
+ mean_x = tf_box_filter(lr_x, r) / N
53
+ mean_y = tf_box_filter(lr_y, r) / N
54
+ cov_xy = tf_box_filter(lr_x * lr_y, r) / N - mean_x * mean_y
55
+ var_x = tf_box_filter(lr_x * lr_x, r) / N - mean_x * mean_x
56
+
57
+ A = cov_xy / (var_x + eps)
58
+ b = mean_y - A * mean_x
59
+
60
+ mean_A = tf.image.resize_images(A, hr_x_shape[1: 3])
61
+ mean_b = tf.image.resize_images(b, hr_x_shape[1: 3])
62
+
63
+ output = mean_A * hr_x + mean_b
64
+
65
+ return output
66
+
67
+
68
+ if __name__ == '__main__':
69
+ import cv2
70
+ from tqdm import tqdm
71
+
72
+ input_photo = tf.placeholder(tf.float32, [1, None, None, 3])
73
+ #input_superpixel = tf.placeholder(tf.float32, [16, 256, 256, 3])
74
+ output = guided_filter(input_photo, input_photo, 5, eps=1)
75
+ image = cv2.imread('output_figure1/cartoon2.jpg')
76
+ image = image/127.5 - 1
77
+ image = np.expand_dims(image, axis=0)
78
+
79
+ config = tf.ConfigProto()
80
+ config.gpu_options.allow_growth = True
81
+ sess = tf.Session(config=config)
82
+ sess.run(tf.global_variables_initializer())
83
+
84
+ out = sess.run(output, feed_dict={input_photo: image})
85
+ out = (np.squeeze(out)+1)*127.5
86
+ out = np.clip(out, 0, 255).astype(np.uint8)
87
+ cv2.imwrite('output_figure1/cartoon2_filter.jpg', out)
wbc/network.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ import tensorflow.contrib.slim as slim
4
+
5
+
6
+
7
+ def resblock(inputs, out_channel=32, name='resblock'):
8
+
9
+ with tf.variable_scope(name):
10
+
11
+ x = slim.convolution2d(inputs, out_channel, [3, 3],
12
+ activation_fn=None, scope='conv1')
13
+ x = tf.nn.leaky_relu(x)
14
+ x = slim.convolution2d(x, out_channel, [3, 3],
15
+ activation_fn=None, scope='conv2')
16
+
17
+ return x + inputs
18
+
19
+
20
+
21
+
22
+ def unet_generator(inputs, channel=32, num_blocks=4, name='generator', reuse=False):
23
+ with tf.variable_scope(name, reuse=reuse):
24
+
25
+ x0 = slim.convolution2d(inputs, channel, [7, 7], activation_fn=None)
26
+ x0 = tf.nn.leaky_relu(x0)
27
+
28
+ x1 = slim.convolution2d(x0, channel, [3, 3], stride=2, activation_fn=None)
29
+ x1 = tf.nn.leaky_relu(x1)
30
+ x1 = slim.convolution2d(x1, channel*2, [3, 3], activation_fn=None)
31
+ x1 = tf.nn.leaky_relu(x1)
32
+
33
+ x2 = slim.convolution2d(x1, channel*2, [3, 3], stride=2, activation_fn=None)
34
+ x2 = tf.nn.leaky_relu(x2)
35
+ x2 = slim.convolution2d(x2, channel*4, [3, 3], activation_fn=None)
36
+ x2 = tf.nn.leaky_relu(x2)
37
+
38
+ for idx in range(num_blocks):
39
+ x2 = resblock(x2, out_channel=channel*4, name='block_{}'.format(idx))
40
+
41
+ x2 = slim.convolution2d(x2, channel*2, [3, 3], activation_fn=None)
42
+ x2 = tf.nn.leaky_relu(x2)
43
+
44
+ h1, w1 = tf.shape(x2)[1], tf.shape(x2)[2]
45
+ x3 = tf.image.resize_bilinear(x2, (h1*2, w1*2))
46
+ x3 = slim.convolution2d(x3+x1, channel*2, [3, 3], activation_fn=None)
47
+ x3 = tf.nn.leaky_relu(x3)
48
+ x3 = slim.convolution2d(x3, channel, [3, 3], activation_fn=None)
49
+ x3 = tf.nn.leaky_relu(x3)
50
+
51
+ h2, w2 = tf.shape(x3)[1], tf.shape(x3)[2]
52
+ x4 = tf.image.resize_bilinear(x3, (h2*2, w2*2))
53
+ x4 = slim.convolution2d(x4+x0, channel, [3, 3], activation_fn=None)
54
+ x4 = tf.nn.leaky_relu(x4)
55
+ x4 = slim.convolution2d(x4, 3, [7, 7], activation_fn=None)
56
+
57
+ return x4
58
+
59
+ if __name__ == '__main__':
60
+
61
+
62
+ pass
wbc/saved_models/checkpoint ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ model_checkpoint_path: "model-33999"
2
+ all_model_checkpoint_paths: "model-33999"
3
+ all_model_checkpoint_paths: "model-37499"
wbc/saved_models/model-33999.data-00000-of-00001 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e2df1a5aa86faa4f979720bfc2436f79333a480876f8d6790b7671cf50fe75b
3
+ size 5868300
wbc/saved_models/model-33999.index ADDED
Binary file (1.56 kB). View file