Haiyu Wu commited on
Commit
3192730
·
1 Parent(s): 5ba379a
Files changed (1) hide show
  1. app.py +15 -19
app.py CHANGED
@@ -25,11 +25,11 @@ def clear_generation_time():
25
 
26
 
27
  def generating():
28
- return "**Generating images...**"
29
 
30
 
31
  def done():
32
- return "**Done!**"
33
 
34
 
35
  def sample_nearby_vectors(base_vector, epsilons=[0.3, 0.5, 0.7], percentages=[0.4, 0.4, 0.2]):
@@ -102,14 +102,14 @@ def image_generation(input_image, quality, random_perturbation, sigma, dimension
102
  updated_feature = updated_feature / np.linalg.norm(updated_feature, 2, 1, True) * norm
103
  features.append(updated_feature)
104
  features = torch.tensor(np.vstack(features)).float().to(device)
105
- if quality > 22:
106
  images, _ = generator.gen_image(features, quality_model, id_model, q_target=quality)
107
  else:
108
  _, _, images, *_ = generator(features)
109
  else:
110
  features = torch.repeat_interleave(torch.tensor(feature), 3, dim=0)
111
  features = sample_nearby_vectors(features, [sigma], [1]).float().to(device)
112
- if quality > 22:
113
  images, _ = generator.gen_image(features, quality_model, id_model, q_target=quality, class_rep=features)
114
  else:
115
  _, _, images, *_ = generator(features)
@@ -121,7 +121,7 @@ def image_generation(input_image, quality, random_perturbation, sigma, dimension
121
  return generated_images
122
 
123
  @spaces.GPU
124
- def process_input(image_input, num1, num2, num3, num4, num5, num6, num7, num8, random_seed, target_quality, random_perturbation, sigma, progress=gr.Progress()):
125
  # Ensure all dimension numbers are within [0, 512)
126
  num1, num2, num3, num4 = [max(0, min(int(n), 511)) for n in [num1, num2, num3, num4]]
127
 
@@ -135,7 +135,7 @@ def process_input(image_input, num1, num2, num3, num4, num5, num6, num7, num8, r
135
  input_data = Image.open(image_input)
136
  input_data = np.array(input_data.resize((112, 112)))
137
 
138
- generated_images = image_generation(input_data, target_quality, random_perturbation, sigma, [num1, num2, num3, num4, num5, num6, num7, num8], progress)
139
 
140
  return generated_images
141
 
@@ -152,10 +152,6 @@ def toggle_inputs(random_perturbation):
152
  gr.update(interactive=not random_perturbation), # num2
153
  gr.update(interactive=not random_perturbation), # num3
154
  gr.update(interactive=not random_perturbation), # num4
155
- gr.update(interactive=not random_perturbation), # num5
156
- gr.update(interactive=not random_perturbation), # num6
157
- gr.update(interactive=not random_perturbation), # num7
158
- gr.update(interactive=not random_perturbation), # num8
159
  ]
160
 
161
 
@@ -189,13 +185,13 @@ def main():
189
 
190
  with gr.Row():
191
  num1 = gr.Number(label="Dimension 1", value=0, minimum=0, maximum=511, step=1)
192
- num2 = gr.Number(label="Dimension 2", value=25, minimum=0, maximum=511, step=1)
193
- num3 = gr.Number(label="Dimension 3", value=50, minimum=0, maximum=511, step=1)
194
- num4 = gr.Number(label="Dimension 4", value=75, minimum=0, maximum=511, step=1)
195
- num5 = gr.Number(label="Dimension 5", value=100, minimum=0, maximum=511, step=1)
196
- num6 = gr.Number(label="Dimension 6", value=125, minimum=0, maximum=511, step=1)
197
- num7 = gr.Number(label="Dimension 7", value=150, minimum=0, maximum=511, step=1)
198
- num8 = gr.Number(label="Dimension 8", value=200, minimum=0, maximum=511, step=1)
199
 
200
  random_seed = gr.Number(label="Random Seed", value=42, minimum=0, maximum=MAX_SEED, step=1)
201
  target_quality = gr.Slider(label="Minimum Quality", minimum=22, maximum=35, step=1, value=24)
@@ -227,7 +223,7 @@ def main():
227
  random_perturbation.change(
228
  fn=toggle_inputs,
229
  inputs=[random_perturbation],
230
- outputs=[sigma, num1, num2, num3, num4, num5, num6, num7, num8]
231
  )
232
 
233
  generated_images = gr.State([])
@@ -242,7 +238,7 @@ def main():
242
  outputs=[generation_time]
243
  ).then(
244
  fn=process_input,
245
- inputs=[image_file, num1, num2, num3, num4, num5, num6, num7, num8, random_seed, target_quality, random_perturbation, sigma],
246
  outputs=[generated_images]
247
  ).then(
248
  fn=done,
 
25
 
26
 
27
  def generating():
28
+ return "Generating images..."
29
 
30
 
31
  def done():
32
+ return "Done!"
33
 
34
 
35
  def sample_nearby_vectors(base_vector, epsilons=[0.3, 0.5, 0.7], percentages=[0.4, 0.4, 0.2]):
 
102
  updated_feature = updated_feature / np.linalg.norm(updated_feature, 2, 1, True) * norm
103
  features.append(updated_feature)
104
  features = torch.tensor(np.vstack(features)).float().to(device)
105
+ if quality > 25:
106
  images, _ = generator.gen_image(features, quality_model, id_model, q_target=quality)
107
  else:
108
  _, _, images, *_ = generator(features)
109
  else:
110
  features = torch.repeat_interleave(torch.tensor(feature), 3, dim=0)
111
  features = sample_nearby_vectors(features, [sigma], [1]).float().to(device)
112
+ if quality > 25:
113
  images, _ = generator.gen_image(features, quality_model, id_model, q_target=quality, class_rep=features)
114
  else:
115
  _, _, images, *_ = generator(features)
 
121
  return generated_images
122
 
123
  @spaces.GPU
124
+ def process_input(image_input, num1, num2, num3, num4, random_seed, target_quality, random_perturbation, sigma, progress=gr.Progress()):
125
  # Ensure all dimension numbers are within [0, 512)
126
  num1, num2, num3, num4 = [max(0, min(int(n), 511)) for n in [num1, num2, num3, num4]]
127
 
 
135
  input_data = Image.open(image_input)
136
  input_data = np.array(input_data.resize((112, 112)))
137
 
138
+ generated_images = image_generation(input_data, target_quality, random_perturbation, sigma, [num1, num2, num3, num4], progress)
139
 
140
  return generated_images
141
 
 
152
  gr.update(interactive=not random_perturbation), # num2
153
  gr.update(interactive=not random_perturbation), # num3
154
  gr.update(interactive=not random_perturbation), # num4
 
 
 
 
155
  ]
156
 
157
 
 
185
 
186
  with gr.Row():
187
  num1 = gr.Number(label="Dimension 1", value=0, minimum=0, maximum=511, step=1)
188
+ num2 = gr.Number(label="Dimension 2", value=0, minimum=0, maximum=511, step=1)
189
+ num3 = gr.Number(label="Dimension 3", value=0, minimum=0, maximum=511, step=1)
190
+ num4 = gr.Number(label="Dimension 4", value=0, minimum=0, maximum=511, step=1)
191
+ # num5 = gr.Number(label="Dimension 5", value=0, minimum=0, maximum=511, step=1)
192
+ # num6 = gr.Number(label="Dimension 6", value=0, minimum=0, maximum=511, step=1)
193
+ # num7 = gr.Number(label="Dimension 7", value=0, minimum=0, maximum=511, step=1)
194
+ # num8 = gr.Number(label="Dimension 8", value=0, minimum=0, maximum=511, step=1)
195
 
196
  random_seed = gr.Number(label="Random Seed", value=42, minimum=0, maximum=MAX_SEED, step=1)
197
  target_quality = gr.Slider(label="Minimum Quality", minimum=22, maximum=35, step=1, value=24)
 
223
  random_perturbation.change(
224
  fn=toggle_inputs,
225
  inputs=[random_perturbation],
226
+ outputs=[sigma, num1, num2, num3, num4]
227
  )
228
 
229
  generated_images = gr.State([])
 
238
  outputs=[generation_time]
239
  ).then(
240
  fn=process_input,
241
+ inputs=[image_file, num1, num2, num3, num4, random_seed, target_quality, random_perturbation, sigma],
242
  outputs=[generated_images]
243
  ).then(
244
  fn=done,