uruguayai commited on
Commit
7b46a28
·
verified ·
1 Parent(s): 6f411d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -113,7 +113,7 @@ def train_step(state, batch, rng):
113
  batch_size = pixel_values.shape[0]
114
 
115
  # Generate random noise
116
- noise_rng, timestep_rng = jax.random.split(rng)
117
  noise = jax.random.normal(noise_rng, pixel_values.shape)
118
 
119
  # Sample random timesteps
@@ -132,11 +132,15 @@ def train_step(state, batch, rng):
132
  timesteps=timesteps
133
  )
134
 
 
 
 
135
  # Predict noise
136
  model_output = state.apply_fn.apply(
137
  {'params': params},
138
  jnp.array(noisy_images),
139
  jnp.array(timesteps),
 
140
  train=True,
141
  )
142
 
@@ -173,7 +177,6 @@ for epoch in range(num_epochs):
173
  avg_loss = epoch_loss / num_batches
174
  print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
175
 
176
-
177
  # Save the fine-tuned model
178
  output_dir = "/tmp/montevideo_fine_tuned_model"
179
  os.makedirs(output_dir, exist_ok=True)
 
113
  batch_size = pixel_values.shape[0]
114
 
115
  # Generate random noise
116
+ noise_rng, timestep_rng, latents_rng = jax.random.split(rng, 3)
117
  noise = jax.random.normal(noise_rng, pixel_values.shape)
118
 
119
  # Sample random timesteps
 
132
  timesteps=timesteps
133
  )
134
 
135
+ # Generate random latents for text encoder
136
+ latents = jax.random.normal(latents_rng, (batch_size, pipeline.text_encoder.config.hidden_size))
137
+
138
  # Predict noise
139
  model_output = state.apply_fn.apply(
140
  {'params': params},
141
  jnp.array(noisy_images),
142
  jnp.array(timesteps),
143
+ encoder_hidden_states=latents,
144
  train=True,
145
  )
146
 
 
177
  avg_loss = epoch_loss / num_batches
178
  print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
179
 
 
180
  # Save the fine-tuned model
181
  output_dir = "/tmp/montevideo_fine_tuned_model"
182
  os.makedirs(output_dir, exist_ok=True)