ariG23498 HF staff commited on
Commit
a7f8f41
·
1 Parent(s): 6b85f2c
Files changed (2) hide show
  1. app.py +49 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import wget
4
+
5
+ enc_url = 'https://huggingface.co/ariG23498/nst/blob/main/nst-encoder.h5'
6
+ enc_filename = wget.download(enc_url)
7
+
8
+ dec_url = 'https://huggingface.co/ariG23498/nst/blob/main/nst-decoder.h5'
9
+ dec_filename = wget.download(dec_url)
10
+
11
+ encoder = tf.keras.models.load_model(enc_filename, compile=False)
12
+ decoder = tf.keras.models.load_model(dec_filename, compile=False)
13
+
14
+ def get_mean_std(tensor, epsilon=1e-5):
15
+ axes = [1, 2]
16
+ tensor_mean, tensor_var = tf.nn.moments(tensor, axes=axes, keepdims=True)
17
+ tensor_std = tf.sqrt(tensor_var + epsilon)
18
+ return tensor_mean, tensor_std
19
+
20
+ def ada_in(style, content, epsilon=1e-5):
21
+ c_mean, c_std = get_mean_std(content)
22
+ s_mean, s_std = get_mean_std(style)
23
+ t = s_std * (content - c_mean) / c_std + s_mean
24
+ return t
25
+
26
+ def load_resize(image):
27
+ image = tf.image.convert_image_dtype(image, dtype="float32")
28
+ image = tf.image.resize(image, (224, 224))
29
+ return image
30
+
31
+ def infer(style, content):
32
+ style = load_resize(style)
33
+ style = style[tf.newaxis, ...]
34
+ content = load_resize(content)
35
+ content = content[tf.newaxis, ...]
36
+
37
+ style_enc = encoder(style)
38
+ content_enc = encoder(content)
39
+
40
+ t = ada_in(style=style_enc, content=content_enc)
41
+
42
+ recons_image = decoder(t)
43
+ return recons_image[0].numpy()
44
+
45
+ iface = gr.Interface(
46
+ fn=infer,
47
+ inputs=[gr.inputs.Image(label="style"),
48
+ gr.inputs.Image(label="content")],
49
+ outputs="image").launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tensorflow>2.4
2
+ wget