uruguayai commited on
Commit
35bc545
·
verified ·
1 Parent(s): 5fadcb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -10,6 +10,9 @@ import os
10
  import pickle
11
  from PIL import Image
12
  import numpy as np
 
 
 
13
 
14
  # Set up cache directories
15
  cache_dir = "/tmp/huggingface_cache"
@@ -120,9 +123,13 @@ def train_step(state, batch, rng):
120
  batch_size = pixel_values.shape[0]
121
 
122
  # Encode images to latent space
123
- latents = pipeline.vae.encode(pixel_values).latent_dist.sample(rng)
124
- latents = latents * pipeline.vae.config.scaling_factor
125
-
 
 
 
 
126
  # Generate random noise
127
  noise_rng, timestep_rng, latents_rng = jax.random.split(rng, 3)
128
  noise = jax.random.normal(noise_rng, latents.shape)
@@ -148,7 +155,7 @@ def train_step(state, batch, rng):
148
 
149
  # Predict noise
150
  model_output = state.apply_fn.apply(
151
- {'params': params},
152
  jnp.array(noisy_latents),
153
  jnp.array(timesteps),
154
  encoder_hidden_states=encoder_hidden_states,
@@ -168,7 +175,7 @@ learning_rate = 1e-5
168
  optimizer = optax.adam(learning_rate)
169
  state = train_state.TrainState.create(
170
  apply_fn=unet,
171
- params=params["unet"], # Use only UNet params
172
  tx=optimizer,
173
  )
174
 
@@ -191,6 +198,7 @@ for epoch in range(num_epochs):
191
  print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
192
 
193
 
 
194
  # Save the fine-tuned model
195
  output_dir = "/tmp/montevideo_fine_tuned_model"
196
  os.makedirs(output_dir, exist_ok=True)
 
10
  import pickle
11
  from PIL import Image
12
  import numpy as np
13
+
14
+
15
+
16
 
17
  # Set up cache directories
18
  cache_dir = "/tmp/huggingface_cache"
 
123
  batch_size = pixel_values.shape[0]
124
 
125
  # Encode images to latent space
126
+ latents = pipeline.vae.apply(
127
+ {"params": params["vae"]},
128
+ pixel_values,
129
+ method=pipeline.vae.encode
130
+ ).latent_dist.sample(rng)
131
+ latents = latents * 0.18215 # scaling factor
132
+
133
  # Generate random noise
134
  noise_rng, timestep_rng, latents_rng = jax.random.split(rng, 3)
135
  noise = jax.random.normal(noise_rng, latents.shape)
 
155
 
156
  # Predict noise
157
  model_output = state.apply_fn.apply(
158
+ {'params': params["unet"]},
159
  jnp.array(noisy_latents),
160
  jnp.array(timesteps),
161
  encoder_hidden_states=encoder_hidden_states,
 
175
  optimizer = optax.adam(learning_rate)
176
  state = train_state.TrainState.create(
177
  apply_fn=unet,
178
+ params={"unet": params["unet"], "vae": params["vae"]}, # Include both UNet and VAE params
179
  tx=optimizer,
180
  )
181
 
 
198
  print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
199
 
200
 
201
+
202
  # Save the fine-tuned model
203
  output_dir = "/tmp/montevideo_fine_tuned_model"
204
  os.makedirs(output_dir, exist_ok=True)