uruguayai commited on
Commit
ad3bf4c
·
verified ·
1 Parent(s): 66bb520

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -111,12 +111,17 @@ print(f"Processed dataset size: {len(processed_dataset)}")
111
 
112
  # Print sample input shape
113
  sample_batch = next(iter(processed_dataset.batch(1)))
114
- print(f"Sample input shape: {sample_batch['pixel_values'].shape}")
 
 
 
 
115
 
116
  # Training function
117
  def train_step(state, batch, rng):
118
  def compute_loss(params, pixel_values, rng):
119
  pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
 
120
 
121
  latents = pipeline.vae.apply(
122
  {"params": params["vae"]},
@@ -124,6 +129,7 @@ def train_step(state, batch, rng):
124
  method=pipeline.vae.encode
125
  ).latent_dist.sample(rng)
126
  latents = latents * jnp.float32(0.18215)
 
127
 
128
  noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
129
 
@@ -186,7 +192,7 @@ for epoch in range(num_epochs):
186
  epoch_loss = 0
187
  num_batches = 0
188
  for batch in tqdm(processed_dataset.batch(batch_size)):
189
- batch['pixel_values'] = jnp.array(batch['pixel_values'], dtype=jnp.float32)
190
  rng, step_rng = jax.random.split(rng)
191
  state, loss = train_step(state, batch, step_rng)
192
  epoch_loss += loss
 
111
 
112
  # Print sample input shape
113
  sample_batch = next(iter(processed_dataset.batch(1)))
114
+ print(f"Sample batch keys: {sample_batch.keys()}")
115
+ print(f"Sample pixel_values type: {type(sample_batch['pixel_values'])}")
116
+ print(f"Sample pixel_values length: {len(sample_batch['pixel_values'])}")
117
+ if len(sample_batch['pixel_values']) > 0:
118
+ print(f"Sample pixel_values[0] shape: {np.array(sample_batch['pixel_values'][0]).shape}")
119
 
120
  # Training function
121
  def train_step(state, batch, rng):
122
  def compute_loss(params, pixel_values, rng):
123
  pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
124
+ print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
125
 
126
  latents = pipeline.vae.apply(
127
  {"params": params["vae"]},
 
129
  method=pipeline.vae.encode
130
  ).latent_dist.sample(rng)
131
  latents = latents * jnp.float32(0.18215)
132
+ print(f"latents shape: {latents.shape}")
133
 
134
  noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
135
 
 
192
  epoch_loss = 0
193
  num_batches = 0
194
  for batch in tqdm(processed_dataset.batch(batch_size)):
195
+ batch['pixel_values'] = jnp.array(batch['pixel_values'][0], dtype=jnp.float32)
196
  rng, step_rng = jax.random.split(rng)
197
  state, loss = train_step(state, batch, step_rng)
198
  epoch_loss += loss