Haiyu Wu commited on
Commit
94d3e3e
·
1 Parent(s): 2223153
Files changed (1) hide show
  1. app.py +26 -19
app.py CHANGED
@@ -81,7 +81,7 @@ def initialize_models():
81
  return generator, id_model, pose_model, quality_model
82
 
83
  @spaces.GPU
84
- def image_generation(input_image, quality, use_target_pose, pose, dimension, progress=gr.Progress()):
85
  generator, id_model, pose_model, quality_model = initialize_models()
86
 
87
  generated_images = []
@@ -93,7 +93,7 @@ def image_generation(input_image, quality, use_target_pose, pose, dimension, pro
93
  input_image.div_(255).sub_(0.5).div_(0.5)
94
  feature = id_model(input_image).clone().detach().cpu().numpy()
95
 
96
- if not use_target_pose:
97
  features = []
98
  norm = np.linalg.norm(feature, 2, 1, True)
99
  for i in progress.tqdm(np.arange(0, 4.8, 2), desc="Generating images"):
@@ -102,16 +102,15 @@ def image_generation(input_image, quality, use_target_pose, pose, dimension, pro
102
  updated_feature = updated_feature / np.linalg.norm(updated_feature, 2, 1, True) * norm
103
  features.append(updated_feature)
104
  features = torch.tensor(np.vstack(features)).float().to(device)
105
- if quality > 25:
106
  images, _ = generator.gen_image(features, quality_model, id_model, q_target=quality)
107
  else:
108
  _, _, images, *_ = generator(features)
109
  else:
110
  features = torch.repeat_interleave(torch.tensor(feature), 3, dim=0)
111
- features = sample_nearby_vectors(features, [0.7], [1]).float().to(device)
112
- if quality > 25 or pose > 20:
113
- images, _ = generator.gen_image(features, quality_model, id_model, pose_model=pose_model,
114
- q_target=quality, pose=pose, class_rep=features)
115
  else:
116
  _, _, images, *_ = generator(features)
117
 
@@ -122,7 +121,7 @@ def image_generation(input_image, quality, use_target_pose, pose, dimension, pro
122
  return generated_images
123
 
124
  @spaces.GPU
125
- def process_input(image_input, num1, num2, num3, num4, random_seed, target_quality, use_target_pose, target_pose, progress=gr.Progress()):
126
  # Ensure all dimension numbers are within [0, 512)
127
  num1, num2, num3, num4 = [max(0, min(int(n), 511)) for n in [num1, num2, num3, num4]]
128
 
@@ -136,7 +135,7 @@ def process_input(image_input, num1, num2, num3, num4, random_seed, target_quali
136
  input_data = Image.open(image_input)
137
  input_data = np.array(input_data.resize((112, 112)))
138
 
139
- generated_images = image_generation(input_data, target_quality, use_target_pose, target_pose, [num1, num2, num3, num4], progress)
140
 
141
  return generated_images
142
 
@@ -148,11 +147,15 @@ def select_image(value, images):
148
 
149
  def toggle_inputs(use_pose):
150
  return [
151
- gr.update(visible=use_pose, interactive=use_pose), # target_pose
152
  gr.update(interactive=not use_pose), # num1
153
  gr.update(interactive=not use_pose), # num2
154
  gr.update(interactive=not use_pose), # num3
155
  gr.update(interactive=not use_pose), # num4
 
 
 
 
156
  ]
157
 
158
 
@@ -186,16 +189,20 @@ def main():
186
 
187
  with gr.Row():
188
  num1 = gr.Number(label="Dimension 1", value=0, minimum=0, maximum=511, step=1)
189
- num2 = gr.Number(label="Dimension 2", value=50, minimum=0, maximum=511, step=1)
190
- num3 = gr.Number(label="Dimension 3", value=100, minimum=0, maximum=511, step=1)
191
- num4 = gr.Number(label="Dimension 4", value=200, minimum=0, maximum=511, step=1)
 
 
 
 
192
 
193
  random_seed = gr.Number(label="Random Seed", value=42, minimum=0, maximum=MAX_SEED, step=1)
194
  target_quality = gr.Slider(label="Minimum Quality", minimum=22, maximum=35, step=1, value=24)
195
 
196
  with gr.Row():
197
- use_target_pose = gr.Checkbox(label="Use Target Pose")
198
- target_pose = gr.Slider(label="Target Pose", value=0, minimum=0, maximum=90, step=1, visible=False)
199
 
200
  submit = gr.Button("Submit", variant="primary")
201
 
@@ -219,10 +226,10 @@ def main():
219
  - These values are added to the dimensions (before normalization), **please ignore it if pose editing is on**.
220
  """)
221
 
222
- use_target_pose.change(
223
  fn=toggle_inputs,
224
- inputs=[use_target_pose],
225
- outputs=[target_pose, num1, num2, num3, num4]
226
  )
227
 
228
  generated_images = gr.State([])
@@ -237,7 +244,7 @@ def main():
237
  outputs=[generation_time]
238
  ).then(
239
  fn=process_input,
240
- inputs=[image_file, num1, num2, num3, num4, random_seed, target_quality, use_target_pose, target_pose],
241
  outputs=[generated_images]
242
  ).then(
243
  fn=done,
 
81
  return generator, id_model, pose_model, quality_model
82
 
83
  @spaces.GPU
84
+ def image_generation(input_image, quality, random_perturbation, sigma, dimension, progress=gr.Progress()):
85
  generator, id_model, pose_model, quality_model = initialize_models()
86
 
87
  generated_images = []
 
93
  input_image.div_(255).sub_(0.5).div_(0.5)
94
  feature = id_model(input_image).clone().detach().cpu().numpy()
95
 
96
+ if not random_perturbation:
97
  features = []
98
  norm = np.linalg.norm(feature, 2, 1, True)
99
  for i in progress.tqdm(np.arange(0, 4.8, 2), desc="Generating images"):
 
102
  updated_feature = updated_feature / np.linalg.norm(updated_feature, 2, 1, True) * norm
103
  features.append(updated_feature)
104
  features = torch.tensor(np.vstack(features)).float().to(device)
105
+ if quality > 22:
106
  images, _ = generator.gen_image(features, quality_model, id_model, q_target=quality)
107
  else:
108
  _, _, images, *_ = generator(features)
109
  else:
110
  features = torch.repeat_interleave(torch.tensor(feature), 3, dim=0)
111
+ features = sample_nearby_vectors(features, [sigma], [1]).float().to(device)
112
+ if quality > 22:
113
+ images, _ = generator.gen_image(features, quality_model, id_model, q_target=quality, class_rep=features)
 
114
  else:
115
  _, _, images, *_ = generator(features)
116
 
 
121
  return generated_images
122
 
123
  @spaces.GPU
124
+ def process_input(image_input, num1, num2, num3, num4, num5, num6, num7, num8, random_seed, target_quality, random_perturbation, sigma, progress=gr.Progress()):
125
  # Ensure all dimension numbers are within [0, 512)
126
  num1, num2, num3, num4 = [max(0, min(int(n), 511)) for n in [num1, num2, num3, num4]]
127
 
 
135
  input_data = Image.open(image_input)
136
  input_data = np.array(input_data.resize((112, 112)))
137
 
138
+ generated_images = image_generation(input_data, target_quality, random_perturbation, sigma, [num1, num2, num3, num4, num5, num6, num7, num8], progress)
139
 
140
  return generated_images
141
 
 
147
 
148
  def toggle_inputs(use_pose):
149
  return [
150
+ gr.update(visible=use_pose, interactive=use_pose), # sigma
151
  gr.update(interactive=not use_pose), # num1
152
  gr.update(interactive=not use_pose), # num2
153
  gr.update(interactive=not use_pose), # num3
154
  gr.update(interactive=not use_pose), # num4
155
+ gr.update(interactive=not use_pose), # num5
156
+ gr.update(interactive=not use_pose), # num6
157
+ gr.update(interactive=not use_pose), # num7
158
+ gr.update(interactive=not use_pose), # num8
159
  ]
160
 
161
 
 
189
 
190
  with gr.Row():
191
  num1 = gr.Number(label="Dimension 1", value=0, minimum=0, maximum=511, step=1)
192
+ num2 = gr.Number(label="Dimension 2", value=25, minimum=0, maximum=511, step=1)
193
+ num3 = gr.Number(label="Dimension 3", value=50, minimum=0, maximum=511, step=1)
194
+ num4 = gr.Number(label="Dimension 4", value=75, minimum=0, maximum=511, step=1)
195
+ num5 = gr.Number(label="Dimension 5", value=100, minimum=0, maximum=511, step=1)
196
+ num6 = gr.Number(label="Dimension 6", value=125, minimum=0, maximum=511, step=1)
197
+ num7 = gr.Number(label="Dimension 7", value=150, minimum=0, maximum=511, step=1)
198
+ num8 = gr.Number(label="Dimension 8", value=200, minimum=0, maximum=511, step=1)
199
 
200
  random_seed = gr.Number(label="Random Seed", value=42, minimum=0, maximum=MAX_SEED, step=1)
201
  target_quality = gr.Slider(label="Minimum Quality", minimum=22, maximum=35, step=1, value=24)
202
 
203
  with gr.Row():
204
+ random_perturbation = gr.Checkbox(label="Random Perturbation")
205
+ sigma = gr.Slider(label="Sigma value", value=0, minimum=0, maximum=1, step=0.1, visible=False)
206
 
207
  submit = gr.Button("Submit", variant="primary")
208
 
 
226
  - These values are added to the dimensions (before normalization), **please ignore it if pose editing is on**.
227
  """)
228
 
229
+ random_perturbation.change(
230
  fn=toggle_inputs,
231
+ inputs=[random_perturbation],
232
+ outputs=[sigma, num1, num2, num3, num4, num5, num6, num7, num8]
233
  )
234
 
235
  generated_images = gr.State([])
 
244
  outputs=[generation_time]
245
  ).then(
246
  fn=process_input,
247
+ inputs=[image_file, num1, num2, num3, num4, num5, num6, num7, num8, random_seed, target_quality, random_perturbation, sigma],
248
  outputs=[generated_images]
249
  ).then(
250
  fn=done,