Haiyu Wu commited on
Commit
89e5981
·
1 Parent(s): d48b521
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -102,14 +102,14 @@ def image_generation(input_image, quality, random_perturbation, sigma, dimension
102
  updated_feature = updated_feature / np.linalg.norm(updated_feature, 2, 1, True) * norm
103
  features.append(updated_feature)
104
  features = torch.tensor(np.vstack(features)).float().to(device)
105
- if quality > 25:
106
  images, _ = generator.gen_image(features, quality_model, id_model, q_target=quality)
107
  else:
108
  _, _, images, *_ = generator(features)
109
  else:
110
  features = torch.repeat_interleave(torch.tensor(feature), 3, dim=0)
111
  features = sample_nearby_vectors(features, [sigma], [1]).float().to(device)
112
- if quality > 25:
113
  images, _ = generator.gen_image(features, quality_model, id_model, q_target=quality, class_rep=features)
114
  else:
115
  _, _, images, *_ = generator(features)
@@ -194,7 +194,7 @@ def main():
194
  # num8 = gr.Number(label="Dimension 8", value=0, minimum=0, maximum=511, step=1)
195
 
196
  random_seed = gr.Number(label="Random Seed", value=42, minimum=0, maximum=MAX_SEED, step=1)
197
- target_quality = gr.Slider(label="Minimum Quality", minimum=22, maximum=30, step=1, value=24)
198
 
199
  with gr.Row():
200
  random_perturbation = gr.Checkbox(label="Random Perturbation")
 
102
  updated_feature = updated_feature / np.linalg.norm(updated_feature, 2, 1, True) * norm
103
  features.append(updated_feature)
104
  features = torch.tensor(np.vstack(features)).float().to(device)
105
+ if quality > 22:
106
  images, _ = generator.gen_image(features, quality_model, id_model, q_target=quality)
107
  else:
108
  _, _, images, *_ = generator(features)
109
  else:
110
  features = torch.repeat_interleave(torch.tensor(feature), 3, dim=0)
111
  features = sample_nearby_vectors(features, [sigma], [1]).float().to(device)
112
+ if quality > 22:
113
  images, _ = generator.gen_image(features, quality_model, id_model, q_target=quality, class_rep=features)
114
  else:
115
  _, _, images, *_ = generator(features)
 
194
  # num8 = gr.Number(label="Dimension 8", value=0, minimum=0, maximum=511, step=1)
195
 
196
  random_seed = gr.Number(label="Random Seed", value=42, minimum=0, maximum=MAX_SEED, step=1)
197
+ target_quality = gr.Slider(label="Minimum Quality", minimum=22, maximum=30, step=1, value=22)
198
 
199
  with gr.Row():
200
  random_perturbation = gr.Checkbox(label="Random Perturbation")