uruguayai commited on
Commit
4a48f70
·
verified ·
1 Parent(s): e9745d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -9
app.py CHANGED
@@ -15,7 +15,6 @@ import numpy as np
15
  # Custom Scheduler
16
  class CustomFlaxPNDMScheduler(FlaxPNDMScheduler):
17
  def add_noise(self, state, original_samples, noise, timesteps):
18
- # Explicitly cast timesteps to int32
19
  timesteps = timesteps.astype(jnp.int32)
20
  return super().add_noise(state, original_samples, noise, timesteps)
21
 
@@ -97,10 +96,8 @@ def train_step(state, batch, rng):
97
  print("params dtypes:", jax.tree_map(lambda x: x.dtype, params))
98
  print("rng dtype:", rng.dtype)
99
 
100
- # Ensure pixel_values are float32
101
  pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
102
 
103
- # Encode images to latent space
104
  latents = pipeline.vae.apply(
105
  {"params": params["vae"]},
106
  pixel_values,
@@ -108,10 +105,8 @@ def train_step(state, batch, rng):
108
  ).latent_dist.sample(rng)
109
  latents = latents * jnp.float32(0.18215)
110
 
111
- # Generate random noise
112
  noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
113
 
114
- # Sample random timesteps
115
  timesteps = jax.random.randint(
116
  rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
117
  )
@@ -119,8 +114,8 @@ def train_step(state, batch, rng):
119
  print("timesteps dtype:", timesteps.dtype)
120
  print("latents dtype:", latents.dtype)
121
  print("noise dtype:", noise.dtype)
 
122
 
123
- # Add noise to latents
124
  noisy_latents = pipeline.scheduler.add_noise(
125
  pipeline.scheduler.create_state(),
126
  original_samples=latents,
@@ -128,14 +123,12 @@ def train_step(state, batch, rng):
128
  timesteps=timesteps
129
  )
130
 
131
- # Generate random encoder hidden states (simulating text embeddings)
132
  encoder_hidden_states = jax.random.normal(
133
  rng,
134
  (latents.shape[0], pipeline.text_encoder.config.hidden_size),
135
  dtype=jnp.float32
136
  )
137
 
138
- # Predict noise
139
  model_output = unet.apply(
140
  {'params': params["unet"]},
141
  noisy_latents,
@@ -144,7 +137,6 @@ def train_step(state, batch, rng):
144
  train=True,
145
  )
146
 
147
- # Compute loss
148
  return jnp.mean((model_output - noise) ** 2)
149
 
150
  grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
 
15
  # Custom Scheduler
16
  class CustomFlaxPNDMScheduler(FlaxPNDMScheduler):
17
  def add_noise(self, state, original_samples, noise, timesteps):
 
18
  timesteps = timesteps.astype(jnp.int32)
19
  return super().add_noise(state, original_samples, noise, timesteps)
20
 
 
96
  print("params dtypes:", jax.tree_map(lambda x: x.dtype, params))
97
  print("rng dtype:", rng.dtype)
98
 
 
99
  pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
100
 
 
101
  latents = pipeline.vae.apply(
102
  {"params": params["vae"]},
103
  pixel_values,
 
105
  ).latent_dist.sample(rng)
106
  latents = latents * jnp.float32(0.18215)
107
 
 
108
  noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
109
 
 
110
  timesteps = jax.random.randint(
111
  rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
112
  )
 
114
  print("timesteps dtype:", timesteps.dtype)
115
  print("latents dtype:", latents.dtype)
116
  print("noise dtype:", noise.dtype)
117
+ print("latents shape:", latents.shape)
118
 
 
119
  noisy_latents = pipeline.scheduler.add_noise(
120
  pipeline.scheduler.create_state(),
121
  original_samples=latents,
 
123
  timesteps=timesteps
124
  )
125
 
 
126
  encoder_hidden_states = jax.random.normal(
127
  rng,
128
  (latents.shape[0], pipeline.text_encoder.config.hidden_size),
129
  dtype=jnp.float32
130
  )
131
 
 
132
  model_output = unet.apply(
133
  {'params': params["unet"]},
134
  noisy_latents,
 
137
  train=True,
138
  )
139
 
 
140
  return jnp.mean((model_output - noise) ** 2)
141
 
142
  grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)