Spaces:
Runtime error
Runtime error
Contrebande Labs
commited on
Commit
·
552cad7
1
Parent(s):
b50d751
debugging silent crash
Browse files
app.py
CHANGED
@@ -68,6 +68,8 @@ def get_inference_lambda(seed):
|
|
68 |
|
69 |
image_width = image_height = 256
|
70 |
|
|
|
|
|
71 |
def __tokenize_prompt(prompt: str):
|
72 |
|
73 |
return tokenizer(
|
@@ -79,10 +81,12 @@ def get_inference_lambda(seed):
|
|
79 |
).input_ids.astype(jnp.float32)
|
80 |
|
81 |
def __convert_image(vae_output):
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
86 |
|
87 |
def __predict_image(tokenized_prompt: jnp.array):
|
88 |
|
@@ -92,10 +96,10 @@ def get_inference_lambda(seed):
|
|
92 |
params=text_encoder_params,
|
93 |
train=False,
|
94 |
)[0]
|
95 |
-
|
96 |
context = jnp.concatenate(
|
97 |
[negative_prompt_text_encoder_hidden_states, text_encoder_hidden_states]
|
98 |
)
|
|
|
99 |
|
100 |
latent_shape = (
|
101 |
tokenized_prompt.shape[0],
|
@@ -152,6 +156,7 @@ def get_inference_lambda(seed):
|
|
152 |
initial_scheduler_state = scheduler.set_timesteps(
|
153 |
scheduler.create_state(), num_inference_steps=timesteps, shape=latent_shape
|
154 |
)
|
|
|
155 |
|
156 |
# initialize latents
|
157 |
initial_latents = (
|
@@ -160,11 +165,11 @@ def get_inference_lambda(seed):
|
|
160 |
)
|
161 |
* initial_scheduler_state.init_noise_sigma
|
162 |
)
|
|
|
163 |
|
164 |
final_latents, _ = jax.lax.fori_loop(
|
165 |
0, timesteps, ___timestep, (initial_latents, initial_scheduler_state)
|
166 |
)
|
167 |
-
|
168 |
jax.debug.print("got final latents...")
|
169 |
|
170 |
# scale and decode the image latents with vae
|
@@ -181,8 +186,7 @@ def get_inference_lambda(seed):
|
|
181 |
.clip(0, 1)
|
182 |
.transpose(0, 2, 3, 1)
|
183 |
)
|
184 |
-
|
185 |
-
jax.debug.print("got vae decoded image output...")
|
186 |
|
187 |
# return reshaped vae outputs
|
188 |
return image
|
@@ -212,7 +216,7 @@ with gr.Blocks(theme="gradio/soft") as demo:
|
|
212 |
with gr.Tab("Journal"):
|
213 |
gr.Markdown(
|
214 |
"""
|
215 |
-
## On How Four Crazy Fellows Embarked on Training a U-Net from Scratch in Five Days
|
216 |
|
217 |
Lorem ipsum dolor sit amet, consectetur adipiscing elit. Mauris vitae varius libero. Nullam laoreet eget sapien quis tristique. Cras odio odio, consequat sed cursus quis, dignissim hendrerit ligula. Curabitur non lorem tellus. Nam bibendum malesuada mi sed faucibus. Sed euismod enim metus, sit amet venenatis elit elementum vel. Duis nec rhoncus tellus, rhoncus auctor justo. Proin id gravida dolor. Sed nulla lectus, finibus non fringilla ac, fermentum in sapien. Cras lobortis est augue, vel posuere justo pretium vitae. Aliquam lorem dolor, condimentum et finibus rutrum, rhoncus eget nunc.
|
218 |
|
|
|
68 |
|
69 |
image_width = image_height = 256
|
70 |
|
71 |
+
print("all models setup")
|
72 |
+
|
73 |
def __tokenize_prompt(prompt: str):
|
74 |
|
75 |
return tokenizer(
|
|
|
81 |
).input_ids.astype(jnp.float32)
|
82 |
|
83 |
def __convert_image(vae_output):
|
84 |
+
print("skipping image conversion...")
|
85 |
+
return None
|
86 |
+
# return [
|
87 |
+
# Image.fromarray(image)
|
88 |
+
# for image in (np.asarray(vae_output) * 255).round().astype(np.uint8)
|
89 |
+
# ]
|
90 |
|
91 |
def __predict_image(tokenized_prompt: jnp.array):
|
92 |
|
|
|
96 |
params=text_encoder_params,
|
97 |
train=False,
|
98 |
)[0]
|
|
|
99 |
context = jnp.concatenate(
|
100 |
[negative_prompt_text_encoder_hidden_states, text_encoder_hidden_states]
|
101 |
)
|
102 |
+
jax.debug.print("got text encoding...")
|
103 |
|
104 |
latent_shape = (
|
105 |
tokenized_prompt.shape[0],
|
|
|
156 |
initial_scheduler_state = scheduler.set_timesteps(
|
157 |
scheduler.create_state(), num_inference_steps=timesteps, shape=latent_shape
|
158 |
)
|
159 |
+
jax.debug.print("initialized scheduler state...")
|
160 |
|
161 |
# initialize latents
|
162 |
initial_latents = (
|
|
|
165 |
)
|
166 |
* initial_scheduler_state.init_noise_sigma
|
167 |
)
|
168 |
+
jax.debug.print("initialized latents...")
|
169 |
|
170 |
final_latents, _ = jax.lax.fori_loop(
|
171 |
0, timesteps, ___timestep, (initial_latents, initial_scheduler_state)
|
172 |
)
|
|
|
173 |
jax.debug.print("got final latents...")
|
174 |
|
175 |
# scale and decode the image latents with vae
|
|
|
186 |
.clip(0, 1)
|
187 |
.transpose(0, 2, 3, 1)
|
188 |
)
|
189 |
+
jax.debug.print("got vae processed image output...")
|
|
|
190 |
|
191 |
# return reshaped vae outputs
|
192 |
return image
|
|
|
216 |
with gr.Tab("Journal"):
|
217 |
gr.Markdown(
|
218 |
"""
|
219 |
+
## On How Four Crazy Fellows Embarked on Training a JAX U-Net from Scratch in Five Days and Almost Died in the End
|
220 |
|
221 |
Lorem ipsum dolor sit amet, consectetur adipiscing elit. Mauris vitae varius libero. Nullam laoreet eget sapien quis tristique. Cras odio odio, consequat sed cursus quis, dignissim hendrerit ligula. Curabitur non lorem tellus. Nam bibendum malesuada mi sed faucibus. Sed euismod enim metus, sit amet venenatis elit elementum vel. Duis nec rhoncus tellus, rhoncus auctor justo. Proin id gravida dolor. Sed nulla lectus, finibus non fringilla ac, fermentum in sapien. Cras lobortis est augue, vel posuere justo pretium vitae. Aliquam lorem dolor, condimentum et finibus rutrum, rhoncus eget nunc.
|
222 |
|