Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import spaces
|
3 |
+
import tensorflow as tf
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import logging
|
7 |
+
import time
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
# Initialize logging
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
logger = logging.getLogger("style_transfer_app")
|
13 |
+
|
14 |
+
# Set TensorFlow threading options
|
15 |
+
tf.config.threading.set_inter_op_parallelism_threads(8)
|
16 |
+
tf.config.threading.set_intra_op_parallelism_threads(8)
|
17 |
+
|
18 |
+
def load_img(image):
|
19 |
+
"""Load and preprocess image for style transfer"""
|
20 |
+
max_dim = 256
|
21 |
+
# Convert PIL Image to tensor
|
22 |
+
img = tf.convert_to_tensor(np.array(image))
|
23 |
+
img = tf.image.convert_image_dtype(img, tf.float32)
|
24 |
+
shape = tf.cast(tf.shape(img)[:-1], tf.float32)
|
25 |
+
long_dim = max(shape)
|
26 |
+
scale = max_dim / long_dim
|
27 |
+
new_shape = tf.cast(shape * scale, tf.int32)
|
28 |
+
img = tf.image.resize(img, new_shape)
|
29 |
+
img = img[tf.newaxis, :]
|
30 |
+
return img
|
31 |
+
|
32 |
+
def vgg_layers(layer_names):
|
33 |
+
"""Create VGG model with specified layers"""
|
34 |
+
vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
|
35 |
+
vgg.trainable = False
|
36 |
+
outputs = [vgg.get_layer(name).output for name in layer_names]
|
37 |
+
model = tf.keras.Model([vgg.input], outputs)
|
38 |
+
return model
|
39 |
+
|
40 |
+
def gram_matrix(input_tensor):
|
41 |
+
"""Calculate Gram matrix"""
|
42 |
+
result = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor)
|
43 |
+
input_shape = tf.shape(input_tensor)
|
44 |
+
num_locations = tf.cast(input_shape[1]*input_shape[2], tf.float32)
|
45 |
+
return result / num_locations
|
46 |
+
|
47 |
+
class StyleContentModel(tf.keras.models.Model):
|
48 |
+
def __init__(self, style_layers, content_layers):
|
49 |
+
super(StyleContentModel, self).__init__()
|
50 |
+
self.vgg = vgg_layers(style_layers + content_layers)
|
51 |
+
self.style_layers = style_layers
|
52 |
+
self.content_layers = content_layers
|
53 |
+
self.num_style_layers = len(style_layers)
|
54 |
+
self.vgg.trainable = False
|
55 |
+
|
56 |
+
def call(self, inputs):
|
57 |
+
inputs = inputs * 255.0
|
58 |
+
preprocessed_input = tf.keras.applications.vgg19.preprocess_input(inputs)
|
59 |
+
outputs = self.vgg(preprocessed_input)
|
60 |
+
style_outputs, content_outputs = (outputs[:self.num_style_layers],
|
61 |
+
outputs[self.num_style_layers:])
|
62 |
+
style_outputs = [gram_matrix(style_output)
|
63 |
+
for style_output in style_outputs]
|
64 |
+
content_dict = {content_name: value
|
65 |
+
for content_name, value
|
66 |
+
in zip(self.content_layers, content_outputs)}
|
67 |
+
style_dict = {style_name: value
|
68 |
+
for style_name, value
|
69 |
+
in zip(self.style_layers, style_outputs)}
|
70 |
+
return {'content': content_dict, 'style': style_dict}
|
71 |
+
|
72 |
+
def clip_0_1(image):
|
73 |
+
"""Clip tensor values between 0 and 1"""
|
74 |
+
return tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)
|
75 |
+
|
76 |
+
def style_content_loss(outputs, style_targets, content_targets, style_weight, content_weight):
|
77 |
+
"""Calculate style and content loss"""
|
78 |
+
style_outputs = outputs['style']
|
79 |
+
content_outputs = outputs['content']
|
80 |
+
style_loss = tf.add_n([tf.reduce_mean((style_outputs[name]-style_targets[name])**2)
|
81 |
+
for name in style_outputs.keys()])
|
82 |
+
style_loss *= style_weight / len(style_outputs)
|
83 |
+
content_loss = tf.add_n([tf.reduce_mean((content_outputs[name]-content_targets[name])**2)
|
84 |
+
for name in content_outputs.keys()])
|
85 |
+
content_loss *= content_weight / len(content_outputs)
|
86 |
+
loss = style_loss + content_loss
|
87 |
+
return loss
|
88 |
+
|
89 |
+
@tf.function
|
90 |
+
def train_step(image, extractor, style_targets, content_targets, opt, style_weight, content_weight, total_variation_weight):
|
91 |
+
"""Perform one training step"""
|
92 |
+
with tf.GradientTape() as tape:
|
93 |
+
outputs = extractor(image)
|
94 |
+
loss = style_content_loss(outputs, style_targets, content_targets, style_weight, content_weight)
|
95 |
+
loss += total_variation_weight * tf.image.total_variation(image)
|
96 |
+
grad = tape.gradient(loss, image)
|
97 |
+
opt.apply_gradients([(grad, image)])
|
98 |
+
image.assign(clip_0_1(image))
|
99 |
+
return loss
|
100 |
+
|
101 |
+
def tensor_to_image(tensor):
|
102 |
+
"""Convert tensor to PIL Image"""
|
103 |
+
tensor = tensor * 255
|
104 |
+
tensor = np.array(tensor, dtype=np.uint8)
|
105 |
+
if np.ndim(tensor) > 3:
|
106 |
+
tensor = tensor[0]
|
107 |
+
return Image.fromarray(tensor)
|
108 |
+
|
109 |
+
# Initialize the style-content model
|
110 |
+
style_layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1']
|
111 |
+
content_layers = ['block5_conv2']
|
112 |
+
extractor = StyleContentModel(style_layers, content_layers)
|
113 |
+
|
114 |
+
@spaces.GPU(duration=120) # Style transfer typically needs more than 60s
|
115 |
+
def style_transfer_fn(content_image, style_image, progress=gr.Progress(track_tqdm=True)):
|
116 |
+
"""Main style transfer function for Gradio interface"""
|
117 |
+
try:
|
118 |
+
# Preprocess images
|
119 |
+
content_img = load_img(content_image)
|
120 |
+
style_img = load_img(style_image)
|
121 |
+
|
122 |
+
# Extract style and content features
|
123 |
+
style_targets = extractor(style_img)['style']
|
124 |
+
content_targets = extractor(content_img)['content']
|
125 |
+
image = tf.Variable(content_img)
|
126 |
+
|
127 |
+
# Set optimization parameters
|
128 |
+
opt = tf.keras.optimizers.Adam(learning_rate=0.02, beta_1=0.99, epsilon=1e-1)
|
129 |
+
style_weight = 1e-2
|
130 |
+
content_weight = 1e4
|
131 |
+
total_variation_weight = 30
|
132 |
+
|
133 |
+
epochs = 10
|
134 |
+
steps_per_epoch = 100
|
135 |
+
|
136 |
+
start_time = time.time()
|
137 |
+
|
138 |
+
# Training loop
|
139 |
+
for n in tqdm(range(epochs), desc="Epochs"):
|
140 |
+
for m in tqdm(range(steps_per_epoch), desc="Steps", leave=False):
|
141 |
+
loss = train_step(image, extractor, style_targets, content_targets,
|
142 |
+
opt, style_weight, content_weight, total_variation_weight)
|
143 |
+
|
144 |
+
# Convert result to image
|
145 |
+
result_image = tensor_to_image(image)
|
146 |
+
|
147 |
+
return result_image
|
148 |
+
|
149 |
+
except Exception as e:
|
150 |
+
logger.error(f"Error during style transfer: {e}")
|
151 |
+
raise gr.Error("An error occurred during style transfer.")
|
152 |
+
|
153 |
+
# Create Gradio interface
|
154 |
+
iface = gr.Interface(
|
155 |
+
fn=style_transfer_fn,
|
156 |
+
inputs=[
|
157 |
+
gr.Image(label="Content Image", type="pil"),
|
158 |
+
gr.Image(label="Style Image", type="pil")
|
159 |
+
],
|
160 |
+
outputs=gr.Image(label="Stylized Image"),
|
161 |
+
title="Neural Style Transfer",
|
162 |
+
description="Upload a content image and a style image to create a stylized combination.",
|
163 |
+
examples=[
|
164 |
+
["examples/content.jpg", "examples/style.jpg"]
|
165 |
+
],
|
166 |
+
cache_examples=True
|
167 |
+
)
|
168 |
+
|
169 |
+
# Launch the interface
|
170 |
+
if __name__ == "__main__":
|
171 |
+
iface.launch()
|