uruguayai commited on
Commit
571b479
·
verified ·
1 Parent(s): bec6160

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -97,18 +97,21 @@ def clear_jit_cache():
97
  # Training function
98
  def train_step(state, batch, rng):
99
  def compute_loss(params, pixel_values, rng):
 
 
 
100
  # Encode images to latent space
101
  latents = pipeline.vae.apply(
102
  {"params": params["vae"]},
103
  pixel_values,
104
  method=pipeline.vae.encode
105
  ).latent_dist.sample(rng)
106
- latents = latents * 0.18215
107
 
108
  # Generate random noise
109
- noise = jax.random.normal(rng, latents.shape)
110
 
111
- # Sample random timesteps
112
  timesteps = jax.random.randint(
113
  rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
114
  )
@@ -124,11 +127,12 @@ def train_step(state, batch, rng):
124
  # Generate random encoder hidden states (simulating text embeddings)
125
  encoder_hidden_states = jax.random.normal(
126
  rng,
127
- (latents.shape[0], pipeline.text_encoder.config.hidden_size)
 
128
  )
129
 
130
  # Predict noise
131
- model_output = state.apply_fn.apply(
132
  {'params': params["unet"]},
133
  noisy_latents,
134
  timesteps,
@@ -139,9 +143,11 @@ def train_step(state, batch, rng):
139
  # Compute loss
140
  return jnp.mean((model_output - noise) ** 2)
141
 
142
- grad_fn = jax.value_and_grad(compute_loss)
143
  rng, step_rng = jax.random.split(rng)
144
- loss, grads = grad_fn(state.params, batch["pixel_values"], step_rng)
 
 
145
  state = state.apply_gradients(grads=grads)
146
  return state, loss
147
 
@@ -236,7 +242,6 @@ for epoch in range(num_epochs):
236
  print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
237
  clear_jit_cache()
238
 
239
-
240
 
241
  # Save the fine-tuned model
242
  output_dir = "/tmp/montevideo_fine_tuned_model"
 
97
  # Training function
98
  def train_step(state, batch, rng):
99
  def compute_loss(params, pixel_values, rng):
100
+ # Ensure pixel_values are float32
101
+ pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
102
+
103
  # Encode images to latent space
104
  latents = pipeline.vae.apply(
105
  {"params": params["vae"]},
106
  pixel_values,
107
  method=pipeline.vae.encode
108
  ).latent_dist.sample(rng)
109
+ latents = latents * jnp.float32(0.18215)
110
 
111
  # Generate random noise
112
+ noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
113
 
114
+ # Sample random timesteps (keep as integers)
115
  timesteps = jax.random.randint(
116
  rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
117
  )
 
127
  # Generate random encoder hidden states (simulating text embeddings)
128
  encoder_hidden_states = jax.random.normal(
129
  rng,
130
+ (latents.shape[0], pipeline.text_encoder.config.hidden_size),
131
+ dtype=jnp.float32
132
  )
133
 
134
  # Predict noise
135
+ model_output = state.apply_fn(
136
  {'params': params["unet"]},
137
  noisy_latents,
138
  timesteps,
 
143
  # Compute loss
144
  return jnp.mean((model_output - noise) ** 2)
145
 
146
+ grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
147
  rng, step_rng = jax.random.split(rng)
148
+
149
+ grads = grad_fn(state.params, batch["pixel_values"], step_rng)
150
+ loss = compute_loss(state.params, batch["pixel_values"], step_rng)
151
  state = state.apply_gradients(grads=grads)
152
  return state, loss
153
 
 
242
  print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
243
  clear_jit_cache()
244
 
 
245
 
246
  # Save the fine-tuned model
247
  output_dir = "/tmp/montevideo_fine_tuned_model"