tennant commited on
Commit
cdc4110
·
verified ·
1 Parent(s): 931a18b

Update gen_demo.py

Browse files
Files changed (1) hide show
  1. gen_demo.py +21 -8
gen_demo.py CHANGED
@@ -103,7 +103,7 @@ gpt_model = GPT_models[params.gpt_model.target](**params.gpt_model.params).to(de
103
  load_checkpoint(ckpt_path, gpt_model)
104
  gpt_model.eval();
105
 
106
- @spaces.GPU(duration=120)
107
  def viz_diff_slots(model, slots, nums, cfg=1.0, return_figs=False):
108
  n_slots_inf = []
109
  for num_slots_to_inference in nums:
@@ -120,7 +120,7 @@ num_slots_to_gen = 32
120
 
121
  # Function to generate image from class
122
  @spaces.GPU(duration=120)
123
- def generate_from_class(class_id, cfg_scale):
124
  with torch.no_grad():
125
  dtype = torch.float
126
  num_slots_to_gen = 32
@@ -136,7 +136,18 @@ def generate_from_class(class_id, cfg_scale):
136
  null_slots = ae_model.dit.null_cond.expand(slots_gen.shape[0], -1, -1)
137
  null_slots = null_slots[:, num_slots_to_gen:, :]
138
  slots_gen = torch.cat([slots_gen, null_slots], dim=1)
139
- return slots_gen
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  with gr.Blocks() as demo:
142
  with gr.Row():
@@ -213,13 +224,15 @@ with gr.Blocks() as demo:
213
  while len(labels) < 3:
214
  labels.append(256)
215
 
216
- # Generate the image based on the selected class
217
- slots_gen = generate_from_class(selected_class_id, cfg_scale=slider_value)
218
 
219
- recon = viz_diff_slots(ae_model, slots_gen, [32], cfg=slider_value)[0]
220
 
221
- # Always generate the model decomposition for potential gallery display
222
- model_decompose = viz_diff_slots(ae_model, slots_gen, labels, cfg=slider_value)
 
 
223
 
224
  if not show_gallery_value:
225
  # If only the image should be shown, return just the processed image
 
103
  load_checkpoint(ckpt_path, gpt_model)
104
  gpt_model.eval();
105
 
106
+ # @spaces.GPU(duration=120)
107
  def viz_diff_slots(model, slots, nums, cfg=1.0, return_figs=False):
108
  n_slots_inf = []
109
  for num_slots_to_inference in nums:
 
120
 
121
  # Function to generate image from class
122
  @spaces.GPU(duration=120)
123
+ def generate_from_class_and_viz_diff_slots(class_id, cfg_scale, nums, cfg=1.0, return_figs=False):
124
  with torch.no_grad():
125
  dtype = torch.float
126
  num_slots_to_gen = 32
 
136
  null_slots = ae_model.dit.null_cond.expand(slots_gen.shape[0], -1, -1)
137
  null_slots = null_slots[:, num_slots_to_gen:, :]
138
  slots_gen = torch.cat([slots_gen, null_slots], dim=1)
139
+
140
+ slots = slots_gen
141
+ n_slots_inf = []
142
+ for num_slots_to_inference in nums:
143
+ drop_mask = ae_model.nested_sampler(slots.shape[0], device, num_slots_to_inference)
144
+ recon_n = ae_model.sample(slots, drop_mask=drop_mask, cfg=cfg)
145
+ n_slots_inf.append(recon_n)
146
+ drop_mask = ae_model.nested_sampler(slots.shape[0], device, 32)
147
+ recon_n = ae_model.sample(slots, drop_mask=drop_mask, cfg=cfg)
148
+ recon = recon_n
149
+
150
+ return slots_gen, [convert_np(n_slots_inf[i][0]) for i in range(len(n_slots_inf))], convert_np(recon[0])
151
 
152
  with gr.Blocks() as demo:
153
  with gr.Row():
 
224
  while len(labels) < 3:
225
  labels.append(256)
226
 
227
+ # # Generate the image based on the selected class
228
+ # slots_gen = generate_from_class(selected_class_id, cfg_scale=slider_value)
229
 
230
+ # recon = viz_diff_slots(ae_model, slots_gen, [32], cfg=slider_value)[0]
231
 
232
+ # # Always generate the model decomposition for potential gallery display
233
+ # model_decompose = viz_diff_slots(ae_model, slots_gen, labels, cfg=slider_value)
234
+ slots_gen, model_decompose, recon = \
235
+ generate_from_class_and_viz_diff_slots(selected_class_id, cfg_scale=slider_value, nums=labels, cfg=slider_value)
236
 
237
  if not show_gallery_value:
238
  # If only the image should be shown, return just the processed image