Spaces:
Runtime error
Runtime error
init
Browse files
app.py
CHANGED
@@ -54,9 +54,9 @@ def run(
|
|
54 |
paprika: ImportGraph,
|
55 |
) -> tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
|
56 |
|
57 |
-
im1 = shinkai.test(image.name, True)
|
58 |
-
im2 = hayao.test(image.name, True)
|
59 |
-
im3 = paprika.test(image.name, True)
|
60 |
|
61 |
return PIL.Image.open(im1),PIL.Image.open(im2),PIL.Image.open(im3)
|
62 |
|
|
|
54 |
paprika: ImportGraph,
|
55 |
) -> tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
|
56 |
|
57 |
+
im1 = shinkai.test('shinkai', image.name, True)
|
58 |
+
im2 = hayao.test('hayao', image.name, True)
|
59 |
+
im3 = paprika.test('paprika', image.name, True)
|
60 |
|
61 |
return PIL.Image.open(im1),PIL.Image.open(im2),PIL.Image.open(im3)
|
62 |
|
test1.py
CHANGED
@@ -15,7 +15,7 @@ class ImportGraph:
|
|
15 |
with self.graph.as_default():
|
16 |
test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')
|
17 |
with tf.variable_scope("generator", reuse=False):
|
18 |
-
test_generated = generator.G_net(test_real).fake
|
19 |
saver = tf.train.Saver()
|
20 |
|
21 |
ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information
|
@@ -26,10 +26,13 @@ class ImportGraph:
|
|
26 |
else:
|
27 |
print(" [*] Failed to find a checkpoint")
|
28 |
|
29 |
-
def test(self, sample_file, if_adjust_brightness, img_size=[256,256]):
|
|
|
|
|
|
|
30 |
sample_image = np.asarray(load_test_data(sample_file, img_size))
|
31 |
image_path = os.path.join(result_dir, '{0}'.format(os.path.basename(sample_file)))
|
32 |
-
fake_img = sess.run(test_generated, feed_dict={test_real: sample_image})
|
33 |
if if_adjust_brightness:
|
34 |
save_images(fake_img, image_path, sample_file)
|
35 |
else:
|
|
|
15 |
with self.graph.as_default():
|
16 |
test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')
|
17 |
with tf.variable_scope("generator", reuse=False):
|
18 |
+
self.test_generated = generator.G_net(test_real).fake
|
19 |
saver = tf.train.Saver()
|
20 |
|
21 |
ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information
|
|
|
26 |
else:
|
27 |
print(" [*] Failed to find a checkpoint")
|
28 |
|
29 |
+
def test(self, style_name, sample_file, if_adjust_brightness, img_size=[256,256]):
|
30 |
+
result_dir = 'results/' + style_name
|
31 |
+
check_folder(result_dir)
|
32 |
+
|
33 |
sample_image = np.asarray(load_test_data(sample_file, img_size))
|
34 |
image_path = os.path.join(result_dir, '{0}'.format(os.path.basename(sample_file)))
|
35 |
+
fake_img = self.sess.run(self.test_generated, feed_dict={test_real: sample_image})
|
36 |
if if_adjust_brightness:
|
37 |
save_images(fake_img, image_path, sample_file)
|
38 |
else:
|