WarpWingHF commited on
Commit
9c31576
·
verified ·
1 Parent(s): 325cbea

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -0
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import tensorflow as tf
4
+ import numpy as np
5
+ from PIL import Image
6
+ import logging
7
+ import time
8
+ from tqdm import tqdm
9
+
10
+ # Initialize logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger("style_transfer_app")
13
+
14
+ # Set TensorFlow threading options
15
+ tf.config.threading.set_inter_op_parallelism_threads(8)
16
+ tf.config.threading.set_intra_op_parallelism_threads(8)
17
+
18
+ def load_img(image):
19
+ """Load and preprocess image for style transfer"""
20
+ max_dim = 256
21
+ # Convert PIL Image to tensor
22
+ img = tf.convert_to_tensor(np.array(image))
23
+ img = tf.image.convert_image_dtype(img, tf.float32)
24
+ shape = tf.cast(tf.shape(img)[:-1], tf.float32)
25
+ long_dim = max(shape)
26
+ scale = max_dim / long_dim
27
+ new_shape = tf.cast(shape * scale, tf.int32)
28
+ img = tf.image.resize(img, new_shape)
29
+ img = img[tf.newaxis, :]
30
+ return img
31
+
32
+ def vgg_layers(layer_names):
33
+ """Create VGG model with specified layers"""
34
+ vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
35
+ vgg.trainable = False
36
+ outputs = [vgg.get_layer(name).output for name in layer_names]
37
+ model = tf.keras.Model([vgg.input], outputs)
38
+ return model
39
+
40
+ def gram_matrix(input_tensor):
41
+ """Calculate Gram matrix"""
42
+ result = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor)
43
+ input_shape = tf.shape(input_tensor)
44
+ num_locations = tf.cast(input_shape[1]*input_shape[2], tf.float32)
45
+ return result / num_locations
46
+
47
+ class StyleContentModel(tf.keras.models.Model):
48
+ def __init__(self, style_layers, content_layers):
49
+ super(StyleContentModel, self).__init__()
50
+ self.vgg = vgg_layers(style_layers + content_layers)
51
+ self.style_layers = style_layers
52
+ self.content_layers = content_layers
53
+ self.num_style_layers = len(style_layers)
54
+ self.vgg.trainable = False
55
+
56
+ def call(self, inputs):
57
+ inputs = inputs * 255.0
58
+ preprocessed_input = tf.keras.applications.vgg19.preprocess_input(inputs)
59
+ outputs = self.vgg(preprocessed_input)
60
+ style_outputs, content_outputs = (outputs[:self.num_style_layers],
61
+ outputs[self.num_style_layers:])
62
+ style_outputs = [gram_matrix(style_output)
63
+ for style_output in style_outputs]
64
+ content_dict = {content_name: value
65
+ for content_name, value
66
+ in zip(self.content_layers, content_outputs)}
67
+ style_dict = {style_name: value
68
+ for style_name, value
69
+ in zip(self.style_layers, style_outputs)}
70
+ return {'content': content_dict, 'style': style_dict}
71
+
72
+ def clip_0_1(image):
73
+ """Clip tensor values between 0 and 1"""
74
+ return tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)
75
+
76
+ def style_content_loss(outputs, style_targets, content_targets, style_weight, content_weight):
77
+ """Calculate style and content loss"""
78
+ style_outputs = outputs['style']
79
+ content_outputs = outputs['content']
80
+ style_loss = tf.add_n([tf.reduce_mean((style_outputs[name]-style_targets[name])**2)
81
+ for name in style_outputs.keys()])
82
+ style_loss *= style_weight / len(style_outputs)
83
+ content_loss = tf.add_n([tf.reduce_mean((content_outputs[name]-content_targets[name])**2)
84
+ for name in content_outputs.keys()])
85
+ content_loss *= content_weight / len(content_outputs)
86
+ loss = style_loss + content_loss
87
+ return loss
88
+
89
+ @tf.function
90
+ def train_step(image, extractor, style_targets, content_targets, opt, style_weight, content_weight, total_variation_weight):
91
+ """Perform one training step"""
92
+ with tf.GradientTape() as tape:
93
+ outputs = extractor(image)
94
+ loss = style_content_loss(outputs, style_targets, content_targets, style_weight, content_weight)
95
+ loss += total_variation_weight * tf.image.total_variation(image)
96
+ grad = tape.gradient(loss, image)
97
+ opt.apply_gradients([(grad, image)])
98
+ image.assign(clip_0_1(image))
99
+ return loss
100
+
101
+ def tensor_to_image(tensor):
102
+ """Convert tensor to PIL Image"""
103
+ tensor = tensor * 255
104
+ tensor = np.array(tensor, dtype=np.uint8)
105
+ if np.ndim(tensor) > 3:
106
+ tensor = tensor[0]
107
+ return Image.fromarray(tensor)
108
+
109
+ # Initialize the style-content model
110
+ style_layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1']
111
+ content_layers = ['block5_conv2']
112
+ extractor = StyleContentModel(style_layers, content_layers)
113
+
114
+ @spaces.GPU(duration=120) # Style transfer typically needs more than 60s
115
+ def style_transfer_fn(content_image, style_image, progress=gr.Progress(track_tqdm=True)):
116
+ """Main style transfer function for Gradio interface"""
117
+ try:
118
+ # Preprocess images
119
+ content_img = load_img(content_image)
120
+ style_img = load_img(style_image)
121
+
122
+ # Extract style and content features
123
+ style_targets = extractor(style_img)['style']
124
+ content_targets = extractor(content_img)['content']
125
+ image = tf.Variable(content_img)
126
+
127
+ # Set optimization parameters
128
+ opt = tf.keras.optimizers.Adam(learning_rate=0.02, beta_1=0.99, epsilon=1e-1)
129
+ style_weight = 1e-2
130
+ content_weight = 1e4
131
+ total_variation_weight = 30
132
+
133
+ epochs = 10
134
+ steps_per_epoch = 100
135
+
136
+ start_time = time.time()
137
+
138
+ # Training loop
139
+ for n in tqdm(range(epochs), desc="Epochs"):
140
+ for m in tqdm(range(steps_per_epoch), desc="Steps", leave=False):
141
+ loss = train_step(image, extractor, style_targets, content_targets,
142
+ opt, style_weight, content_weight, total_variation_weight)
143
+
144
+ # Convert result to image
145
+ result_image = tensor_to_image(image)
146
+
147
+ return result_image
148
+
149
+ except Exception as e:
150
+ logger.error(f"Error during style transfer: {e}")
151
+ raise gr.Error("An error occurred during style transfer.")
152
+
153
+ # Create Gradio interface
154
+ iface = gr.Interface(
155
+ fn=style_transfer_fn,
156
+ inputs=[
157
+ gr.Image(label="Content Image", type="pil"),
158
+ gr.Image(label="Style Image", type="pil")
159
+ ],
160
+ outputs=gr.Image(label="Stylized Image"),
161
+ title="Neural Style Transfer",
162
+ description="Upload a content image and a style image to create a stylized combination.",
163
+ examples=[
164
+ ["examples/content.jpg", "examples/style.jpg"]
165
+ ],
166
+ cache_examples=True
167
+ )
168
+
169
+ # Launch the interface
170
+ if __name__ == "__main__":
171
+ iface.launch()