Contrebande Labs commited on
Commit
552cad7
·
1 Parent(s): b50d751

debugging silent crash

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -68,6 +68,8 @@ def get_inference_lambda(seed):
68
 
69
  image_width = image_height = 256
70
 
 
 
71
  def __tokenize_prompt(prompt: str):
72
 
73
  return tokenizer(
@@ -79,10 +81,12 @@ def get_inference_lambda(seed):
79
  ).input_ids.astype(jnp.float32)
80
 
81
  def __convert_image(vae_output):
82
- return [
83
- Image.fromarray(image)
84
- for image in (np.asarray(vae_output) * 255).round().astype(np.uint8)
85
- ]
 
 
86
 
87
  def __predict_image(tokenized_prompt: jnp.array):
88
 
@@ -92,10 +96,10 @@ def get_inference_lambda(seed):
92
  params=text_encoder_params,
93
  train=False,
94
  )[0]
95
-
96
  context = jnp.concatenate(
97
  [negative_prompt_text_encoder_hidden_states, text_encoder_hidden_states]
98
  )
 
99
 
100
  latent_shape = (
101
  tokenized_prompt.shape[0],
@@ -152,6 +156,7 @@ def get_inference_lambda(seed):
152
  initial_scheduler_state = scheduler.set_timesteps(
153
  scheduler.create_state(), num_inference_steps=timesteps, shape=latent_shape
154
  )
 
155
 
156
  # initialize latents
157
  initial_latents = (
@@ -160,11 +165,11 @@ def get_inference_lambda(seed):
160
  )
161
  * initial_scheduler_state.init_noise_sigma
162
  )
 
163
 
164
  final_latents, _ = jax.lax.fori_loop(
165
  0, timesteps, ___timestep, (initial_latents, initial_scheduler_state)
166
  )
167
-
168
  jax.debug.print("got final latents...")
169
 
170
  # scale and decode the image latents with vae
@@ -181,8 +186,7 @@ def get_inference_lambda(seed):
181
  .clip(0, 1)
182
  .transpose(0, 2, 3, 1)
183
  )
184
-
185
- jax.debug.print("got vae decoded image output...")
186
 
187
  # return reshaped vae outputs
188
  return image
@@ -212,7 +216,7 @@ with gr.Blocks(theme="gradio/soft") as demo:
212
  with gr.Tab("Journal"):
213
  gr.Markdown(
214
  """
215
- ## On How Four Crazy Fellows Embarked on Training a U-Net from Scratch in Five Days with JAX and Almost Died in the End
216
 
217
  Lorem ipsum dolor sit amet, consectetur adipiscing elit. Mauris vitae varius libero. Nullam laoreet eget sapien quis tristique. Cras odio odio, consequat sed cursus quis, dignissim hendrerit ligula. Curabitur non lorem tellus. Nam bibendum malesuada mi sed faucibus. Sed euismod enim metus, sit amet venenatis elit elementum vel. Duis nec rhoncus tellus, rhoncus auctor justo. Proin id gravida dolor. Sed nulla lectus, finibus non fringilla ac, fermentum in sapien. Cras lobortis est augue, vel posuere justo pretium vitae. Aliquam lorem dolor, condimentum et finibus rutrum, rhoncus eget nunc.
218
 
 
68
 
69
  image_width = image_height = 256
70
 
71
+ print("all models setup")
72
+
73
  def __tokenize_prompt(prompt: str):
74
 
75
  return tokenizer(
 
81
  ).input_ids.astype(jnp.float32)
82
 
83
  def __convert_image(vae_output):
84
+ print("skipping image conversion...")
85
+ return None
86
+ # return [
87
+ # Image.fromarray(image)
88
+ # for image in (np.asarray(vae_output) * 255).round().astype(np.uint8)
89
+ # ]
90
 
91
  def __predict_image(tokenized_prompt: jnp.array):
92
 
 
96
  params=text_encoder_params,
97
  train=False,
98
  )[0]
 
99
  context = jnp.concatenate(
100
  [negative_prompt_text_encoder_hidden_states, text_encoder_hidden_states]
101
  )
102
+ jax.debug.print("got text encoding...")
103
 
104
  latent_shape = (
105
  tokenized_prompt.shape[0],
 
156
  initial_scheduler_state = scheduler.set_timesteps(
157
  scheduler.create_state(), num_inference_steps=timesteps, shape=latent_shape
158
  )
159
+ jax.debug.print("initialized scheduler state...")
160
 
161
  # initialize latents
162
  initial_latents = (
 
165
  )
166
  * initial_scheduler_state.init_noise_sigma
167
  )
168
+ jax.debug.print("initialized latents...")
169
 
170
  final_latents, _ = jax.lax.fori_loop(
171
  0, timesteps, ___timestep, (initial_latents, initial_scheduler_state)
172
  )
 
173
  jax.debug.print("got final latents...")
174
 
175
  # scale and decode the image latents with vae
 
186
  .clip(0, 1)
187
  .transpose(0, 2, 3, 1)
188
  )
189
+ jax.debug.print("got vae processed image output...")
 
190
 
191
  # return reshaped vae outputs
192
  return image
 
216
  with gr.Tab("Journal"):
217
  gr.Markdown(
218
  """
219
+ ## On How Four Crazy Fellows Embarked on Training a JAX U-Net from Scratch in Five Days and Almost Died in the End
220
 
221
  Lorem ipsum dolor sit amet, consectetur adipiscing elit. Mauris vitae varius libero. Nullam laoreet eget sapien quis tristique. Cras odio odio, consequat sed cursus quis, dignissim hendrerit ligula. Curabitur non lorem tellus. Nam bibendum malesuada mi sed faucibus. Sed euismod enim metus, sit amet venenatis elit elementum vel. Duis nec rhoncus tellus, rhoncus auctor justo. Proin id gravida dolor. Sed nulla lectus, finibus non fringilla ac, fermentum in sapien. Cras lobortis est augue, vel posuere justo pretium vitae. Aliquam lorem dolor, condimentum et finibus rutrum, rhoncus eget nunc.
222