TuringsSolutions commited on
Commit
61497c3
·
verified ·
1 Parent(s): 98e577b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +199 -0
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
5
+ from keras.models import Model
6
+ import matplotlib.pyplot as plt
7
+ import logging
8
+ from skimage.transform import resize
9
+ from PIL import Image, ImageEnhance
10
+ from tqdm import tqdm
11
+
12
+ class SwarmAgent:
13
+ def __init__(self, position, velocity):
14
+ self.position = position
15
+ self.velocity = velocity
16
+ self.m = np.zeros_like(position)
17
+ self.v = np.zeros_like(position)
18
+
19
+ class SwarmNeuralNetwork:
20
+ def __init__(self, num_agents, image_shape, target_image):
21
+ self.image_shape = image_shape
22
+ self.resized_shape = (64, 64, 3)
23
+ self.agents = [SwarmAgent(self.random_position(), self.random_velocity()) for _ in range(num_agents)]
24
+ self.target_image = self.load_target_image(target_image)
25
+ self.generated_image = np.random.randn(*image_shape) # Start with noise
26
+ self.mobilenet = self.load_mobilenet_model()
27
+ self.current_epoch = 0
28
+ self.noise_schedule = np.linspace(0.1, 0.002, 1000) # Noise schedule
29
+
30
+ def random_position(self):
31
+ return np.random.randn(*self.image_shape) # Use Gaussian noise
32
+
33
+ def random_velocity(self):
34
+ return np.random.randn(*self.image_shape) * 0.01
35
+
36
+ def load_target_image(self, img):
37
+ img = img.resize((self.image_shape[1], self.image_shape[0]))
38
+ img_array = np.array(img) / 127.5 - 1 # Normalize to [-1, 1]
39
+ plt.imshow((img_array + 1) / 2) # Convert back to [0, 1] for display
40
+ plt.title('Target Image')
41
+ plt.show()
42
+ return img_array
43
+
44
+ def resize_image(self, image):
45
+ return resize(image, self.resized_shape, anti_aliasing=True)
46
+
47
+ def load_mobilenet_model(self):
48
+ mobilenet = MobileNetV2(weights='imagenet', include_top=False, input_shape=self.resized_shape)
49
+ return Model(inputs=mobilenet.input, outputs=mobilenet.get_layer('block_13_expand_relu').output)
50
+
51
+ def add_positional_encoding(self, image):
52
+ h, w, c = image.shape
53
+ pos_enc = np.zeros_like(image)
54
+ for i in range(h):
55
+ for j in range(w):
56
+ pos_enc[i, j, :] = [i/h, j/w, 0]
57
+ return image + pos_enc
58
+
59
+ def multi_head_attention(self, agent, num_heads=4):
60
+ attention_scores = []
61
+ for _ in range(num_heads):
62
+ similarity = np.exp(-np.sum((agent.position - self.target_image)**2, axis=-1))
63
+ attention_score = similarity / np.sum(similarity)
64
+ attention_scores.append(attention_score)
65
+ attention = np.mean(attention_scores, axis=0)
66
+ return np.expand_dims(attention, axis=-1)
67
+
68
+ def multi_scale_perceptual_loss(self, agent_positions):
69
+ target_image_resized = self.resize_image((self.target_image + 1) / 2) # Convert to [0, 1] for MobileNet
70
+ target_image_preprocessed = preprocess_input(target_image_resized[np.newaxis, ...] * 255) # MobileNet expects [0, 255]
71
+ target_features = self.mobilenet.predict(target_image_preprocessed)
72
+
73
+ losses = []
74
+ for agent_position in agent_positions:
75
+ agent_image_resized = self.resize_image((agent_position + 1) / 2)
76
+ agent_image_preprocessed = preprocess_input(agent_image_resized[np.newaxis, ...] * 255)
77
+ agent_features = self.mobilenet.predict(agent_image_preprocessed)
78
+
79
+ loss = np.mean((target_features - agent_features)**2)
80
+ losses.append(1 / (1 + loss))
81
+
82
+ return np.array(losses)
83
+
84
+ def update_agents(self, timestep):
85
+ noise_level = self.noise_schedule[min(timestep, len(self.noise_schedule) - 1)]
86
+
87
+ for agent in self.agents:
88
+ # Predict noise
89
+ predicted_noise = agent.position - self.target_image
90
+
91
+ # Denoise
92
+ denoised = (agent.position - noise_level * predicted_noise) / (1 - noise_level)
93
+
94
+ # Add scaled noise for next step
95
+ agent.position = denoised + np.random.randn(*self.image_shape) * np.sqrt(noise_level)
96
+
97
+ # Clip values
98
+ agent.position = np.clip(agent.position, -1, 1)
99
+
100
+ def generate_image(self):
101
+ self.generated_image = np.mean([agent.position for agent in self.agents], axis=0)
102
+ # Normalize to [0, 1] range for display
103
+ self.generated_image = (self.generated_image + 1) / 2
104
+ self.generated_image = np.clip(self.generated_image, 0, 1)
105
+
106
+ # Apply sharpening filter
107
+ image_pil = Image.fromarray((self.generated_image * 255).astype(np.uint8))
108
+ enhancer = ImageEnhance.Sharpness(image_pil)
109
+ self.generated_image = np.array(enhancer.enhance(2.0)) / 255.0
110
+
111
+ def train(self, epochs):
112
+ logging.basicConfig(filename='training.log', level=logging.INFO)
113
+
114
+ for epoch in tqdm(range(epochs), desc="Training Epochs"):
115
+ self.update_agents(epoch)
116
+ self.generate_image()
117
+
118
+ mse = np.mean(((self.generated_image * 2 - 1) - self.target_image)**2)
119
+ logging.info(f"Epoch {epoch}, MSE: {mse}")
120
+
121
+ if epoch % 10 == 0:
122
+ print(f"Epoch {epoch}, MSE: {mse}")
123
+ self.display_image(self.generated_image, title=f'Epoch {epoch}')
124
+ self.current_epoch += 1
125
+
126
+ def display_image(self, image, title=''):
127
+ plt.imshow(image)
128
+ plt.title(title)
129
+ plt.axis('off')
130
+ plt.show()
131
+
132
+ def display_agent_positions(self, epoch):
133
+ fig, ax = plt.subplots()
134
+ positions = np.array([agent.position for agent in self.agents])
135
+ ax.imshow(self.generated_image, extent=[0, self.image_shape[1], 0, self.image_shape[0]])
136
+ ax.scatter(positions[:, :, 0].flatten(), positions[:, :, 1].flatten(), s=1, c='red')
137
+ plt.title(f'Agent Positions at Epoch {epoch}')
138
+ plt.show()
139
+
140
+ def save_model(self, filename):
141
+ model_state = {
142
+ 'agents': self.agents,
143
+ 'generated_image': self.generated_image,
144
+ 'current_epoch': self.current_epoch
145
+ }
146
+ np.save(filename, model_state)
147
+
148
+ def load_model(self, filename):
149
+ model_state = np.load(filename, allow_pickle=True).item()
150
+ self.agents = model_state['agents']
151
+ self.generated_image = model_state['generated_image']
152
+ self.current_epoch = model_state['current_epoch']
153
+
154
+ def generate_new_image(self, num_steps=1000):
155
+ for agent in self.agents:
156
+ agent.position = np.random.randn(*self.image_shape)
157
+
158
+ for step in tqdm(range(num_steps), desc="Generating Image"):
159
+ self.update_agents(num_steps - step - 1) # Reverse order
160
+
161
+ self.generate_image()
162
+ return self.generated_image
163
+
164
+ # Gradio Interface
165
+ def train_snn(image, num_agents, epochs, brightness, contrast, color):
166
+ snn = SwarmNeuralNetwork(num_agents=num_agents, image_shape=(64, 64, 3), target_image=image)
167
+
168
+ # Apply user-specified adjustments to the target image
169
+ image = ImageEnhance.Brightness(image).enhance(brightness)
170
+ image = ImageEnhance.Contrast(image).enhance(contrast)
171
+ image = ImageEnhance.Color(image).enhance(color)
172
+
173
+ snn.target_image = snn.load_target_image(image)
174
+ snn.train(epochs=epochs)
175
+ snn.save_model('snn_model.npy')
176
+ return snn.generated_image
177
+
178
+ def generate_new_image():
179
+ snn = SwarmNeuralNetwork(num_agents=2000, image_shape=(64, 64, 3), target_image=None)
180
+ snn.load_model('snn_model.npy')
181
+ new_image = snn.generate_new_image()
182
+ return new_image
183
+
184
+ interface = gr.Interface(
185
+ fn=train_snn,
186
+ inputs=[
187
+ gr.Image(type="pil", label="Upload Target Image"),
188
+ gr.Slider(minimum=500, maximum=3000, value=2000, label="Number of Agents"),
189
+ gr.Slider(minimum=10, maximum=200, value=100, label="Number of Epochs"),
190
+ gr.Slider(minimum=0.5, maximum=2.0, value=1.0, label="Brightness"),
191
+ gr.Slider(minimum=0.5, maximum=2.0, value=1.0, label="Contrast"),
192
+ gr.Slider(minimum=0.5, maximum=2.0, value=1.0, label="Color Balance")
193
+ ],
194
+ outputs=gr.Image(type="numpy", label="Generated Image"),
195
+ title="Swarm Neural Network Image Generation",
196
+ description="Upload an image and set the number of agents and epochs to train the Swarm Neural Network to generate a new image. Adjust brightness, contrast, and color balance for personalization."
197
+ )
198
+
199
+ interface.launch()