Spaces:
Runtime error
Runtime error
Update gen_demo.py
Browse files- gen_demo.py +21 -8
gen_demo.py
CHANGED
@@ -103,7 +103,7 @@ gpt_model = GPT_models[params.gpt_model.target](**params.gpt_model.params).to(de
|
|
103 |
load_checkpoint(ckpt_path, gpt_model)
|
104 |
gpt_model.eval();
|
105 |
|
106 |
-
@spaces.GPU(duration=120)
|
107 |
def viz_diff_slots(model, slots, nums, cfg=1.0, return_figs=False):
|
108 |
n_slots_inf = []
|
109 |
for num_slots_to_inference in nums:
|
@@ -120,7 +120,7 @@ num_slots_to_gen = 32
|
|
120 |
|
121 |
# Function to generate image from class
|
122 |
@spaces.GPU(duration=120)
|
123 |
-
def
|
124 |
with torch.no_grad():
|
125 |
dtype = torch.float
|
126 |
num_slots_to_gen = 32
|
@@ -136,7 +136,18 @@ def generate_from_class(class_id, cfg_scale):
|
|
136 |
null_slots = ae_model.dit.null_cond.expand(slots_gen.shape[0], -1, -1)
|
137 |
null_slots = null_slots[:, num_slots_to_gen:, :]
|
138 |
slots_gen = torch.cat([slots_gen, null_slots], dim=1)
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
with gr.Blocks() as demo:
|
142 |
with gr.Row():
|
@@ -213,13 +224,15 @@ with gr.Blocks() as demo:
|
|
213 |
while len(labels) < 3:
|
214 |
labels.append(256)
|
215 |
|
216 |
-
# Generate the image based on the selected class
|
217 |
-
slots_gen = generate_from_class(selected_class_id, cfg_scale=slider_value)
|
218 |
|
219 |
-
recon = viz_diff_slots(ae_model, slots_gen, [32], cfg=slider_value)[0]
|
220 |
|
221 |
-
# Always generate the model decomposition for potential gallery display
|
222 |
-
model_decompose = viz_diff_slots(ae_model, slots_gen, labels, cfg=slider_value)
|
|
|
|
|
223 |
|
224 |
if not show_gallery_value:
|
225 |
# If only the image should be shown, return just the processed image
|
|
|
103 |
load_checkpoint(ckpt_path, gpt_model)
|
104 |
gpt_model.eval();
|
105 |
|
106 |
+
# @spaces.GPU(duration=120)
|
107 |
def viz_diff_slots(model, slots, nums, cfg=1.0, return_figs=False):
|
108 |
n_slots_inf = []
|
109 |
for num_slots_to_inference in nums:
|
|
|
120 |
|
121 |
# Function to generate image from class
|
122 |
@spaces.GPU(duration=120)
|
123 |
+
def generate_from_class_and_viz_diff_slots(class_id, cfg_scale, nums, cfg=1.0, return_figs=False):
|
124 |
with torch.no_grad():
|
125 |
dtype = torch.float
|
126 |
num_slots_to_gen = 32
|
|
|
136 |
null_slots = ae_model.dit.null_cond.expand(slots_gen.shape[0], -1, -1)
|
137 |
null_slots = null_slots[:, num_slots_to_gen:, :]
|
138 |
slots_gen = torch.cat([slots_gen, null_slots], dim=1)
|
139 |
+
|
140 |
+
slots = slots_gen
|
141 |
+
n_slots_inf = []
|
142 |
+
for num_slots_to_inference in nums:
|
143 |
+
drop_mask = ae_model.nested_sampler(slots.shape[0], device, num_slots_to_inference)
|
144 |
+
recon_n = ae_model.sample(slots, drop_mask=drop_mask, cfg=cfg)
|
145 |
+
n_slots_inf.append(recon_n)
|
146 |
+
drop_mask = ae_model.nested_sampler(slots.shape[0], device, 32)
|
147 |
+
recon_n = ae_model.sample(slots, drop_mask=drop_mask, cfg=cfg)
|
148 |
+
recon = recon_n
|
149 |
+
|
150 |
+
return slots_gen, [convert_np(n_slots_inf[i][0]) for i in range(len(n_slots_inf))], convert_np(recon[0])
|
151 |
|
152 |
with gr.Blocks() as demo:
|
153 |
with gr.Row():
|
|
|
224 |
while len(labels) < 3:
|
225 |
labels.append(256)
|
226 |
|
227 |
+
# # Generate the image based on the selected class
|
228 |
+
# slots_gen = generate_from_class(selected_class_id, cfg_scale=slider_value)
|
229 |
|
230 |
+
# recon = viz_diff_slots(ae_model, slots_gen, [32], cfg=slider_value)[0]
|
231 |
|
232 |
+
# # Always generate the model decomposition for potential gallery display
|
233 |
+
# model_decompose = viz_diff_slots(ae_model, slots_gen, labels, cfg=slider_value)
|
234 |
+
slots_gen, model_decompose, recon = \
|
235 |
+
generate_from_class_and_viz_diff_slots(selected_class_id, cfg_scale=slider_value, nums=labels, cfg=slider_value)
|
236 |
|
237 |
if not show_gallery_value:
|
238 |
# If only the image should be shown, return just the processed image
|