Update app.py
Browse files
app.py
CHANGED
@@ -111,12 +111,17 @@ print(f"Processed dataset size: {len(processed_dataset)}")
|
|
111 |
|
112 |
# Print sample input shape
|
113 |
sample_batch = next(iter(processed_dataset.batch(1)))
|
114 |
-
print(f"Sample
|
|
|
|
|
|
|
|
|
115 |
|
116 |
# Training function
|
117 |
def train_step(state, batch, rng):
|
118 |
def compute_loss(params, pixel_values, rng):
|
119 |
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
|
|
120 |
|
121 |
latents = pipeline.vae.apply(
|
122 |
{"params": params["vae"]},
|
@@ -124,6 +129,7 @@ def train_step(state, batch, rng):
|
|
124 |
method=pipeline.vae.encode
|
125 |
).latent_dist.sample(rng)
|
126 |
latents = latents * jnp.float32(0.18215)
|
|
|
127 |
|
128 |
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
|
129 |
|
@@ -186,7 +192,7 @@ for epoch in range(num_epochs):
|
|
186 |
epoch_loss = 0
|
187 |
num_batches = 0
|
188 |
for batch in tqdm(processed_dataset.batch(batch_size)):
|
189 |
-
batch['pixel_values'] = jnp.array(batch['pixel_values'], dtype=jnp.float32)
|
190 |
rng, step_rng = jax.random.split(rng)
|
191 |
state, loss = train_step(state, batch, step_rng)
|
192 |
epoch_loss += loss
|
|
|
111 |
|
112 |
# Print sample input shape
|
113 |
sample_batch = next(iter(processed_dataset.batch(1)))
|
114 |
+
print(f"Sample batch keys: {sample_batch.keys()}")
|
115 |
+
print(f"Sample pixel_values type: {type(sample_batch['pixel_values'])}")
|
116 |
+
print(f"Sample pixel_values length: {len(sample_batch['pixel_values'])}")
|
117 |
+
if len(sample_batch['pixel_values']) > 0:
|
118 |
+
print(f"Sample pixel_values[0] shape: {np.array(sample_batch['pixel_values'][0]).shape}")
|
119 |
|
120 |
# Training function
|
121 |
def train_step(state, batch, rng):
|
122 |
def compute_loss(params, pixel_values, rng):
|
123 |
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
124 |
+
print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
|
125 |
|
126 |
latents = pipeline.vae.apply(
|
127 |
{"params": params["vae"]},
|
|
|
129 |
method=pipeline.vae.encode
|
130 |
).latent_dist.sample(rng)
|
131 |
latents = latents * jnp.float32(0.18215)
|
132 |
+
print(f"latents shape: {latents.shape}")
|
133 |
|
134 |
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
|
135 |
|
|
|
192 |
epoch_loss = 0
|
193 |
num_batches = 0
|
194 |
for batch in tqdm(processed_dataset.batch(batch_size)):
|
195 |
+
batch['pixel_values'] = jnp.array(batch['pixel_values'][0], dtype=jnp.float32)
|
196 |
rng, step_rng = jax.random.split(rng)
|
197 |
state, loss = train_step(state, batch, step_rng)
|
198 |
epoch_loss += loss
|