uruguayai commited on
Commit
7cbe1c1
·
verified ·
1 Parent(s): 920c999

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -57,7 +57,7 @@ def preprocess_images(examples):
57
  image = Image.open(image)
58
  if not isinstance(image, Image.Image):
59
  raise ValueError(f"Unexpected image type: {type(image)}")
60
- image = image.convert("RGB").resize((128, 128)) # Further reduced image size
61
  image = np.array(image).astype(np.float32) / 255.0
62
  return image.transpose(2, 0, 1)
63
 
@@ -97,6 +97,7 @@ def clear_jit_cache():
97
  # Training function
98
  def train_step(state, batch, rng):
99
  def compute_loss(params, pixel_values, rng):
 
100
  latents = pipeline.vae.apply(
101
  {"params": params["vae"]},
102
  pixel_values,
@@ -104,10 +105,15 @@ def train_step(state, batch, rng):
104
  ).latent_dist.sample(rng)
105
  latents = latents * 0.18215
106
 
 
107
  noise = jax.random.normal(rng, latents.shape)
 
 
108
  timesteps = jax.random.randint(
109
  rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
110
  )
 
 
111
  noisy_latents = pipeline.scheduler.add_noise(
112
  pipeline.scheduler.create_state(),
113
  original_samples=latents,
@@ -115,11 +121,13 @@ def train_step(state, batch, rng):
115
  timesteps=timesteps
116
  )
117
 
 
118
  encoder_hidden_states = jax.random.normal(
119
  rng,
120
  (latents.shape[0], pipeline.text_encoder.config.hidden_size)
121
  )
122
 
 
123
  model_output = state.apply_fn.apply(
124
  {'params': params["unet"]},
125
  noisy_latents,
@@ -128,6 +136,7 @@ def train_step(state, batch, rng):
128
  train=True,
129
  )
130
 
 
131
  return jnp.mean((model_output - noise) ** 2)
132
 
133
  grad_fn = jax.value_and_grad(compute_loss)
@@ -136,18 +145,9 @@ def train_step(state, batch, rng):
136
  state = state.apply_gradients(grads=grads)
137
  return state, loss
138
 
139
- # Initialize training state
140
- learning_rate = 1e-5
141
- optimizer = optax.adam(learning_rate)
142
- state = train_state.TrainState.create(
143
- apply_fn=unet,
144
- params={"unet": params["unet"], "vae": params["vae"]},
145
- tx=optimizer,
146
- )
147
-
148
  # Training loop
149
- num_epochs = 3 # Further reduced number of epochs
150
- batch_size = 2 # Reduced batch size for CPU
151
  rng = jax.random.PRNGKey(0)
152
 
153
  for epoch in range(num_epochs):
 
57
  image = Image.open(image)
58
  if not isinstance(image, Image.Image):
59
  raise ValueError(f"Unexpected image type: {type(image)}")
60
+ image = image.convert("RGB").resize((512, 512)) # Keep original size
61
  image = np.array(image).astype(np.float32) / 255.0
62
  return image.transpose(2, 0, 1)
63
 
 
97
  # Training function
98
  def train_step(state, batch, rng):
99
  def compute_loss(params, pixel_values, rng):
100
+ # Encode images to latent space
101
  latents = pipeline.vae.apply(
102
  {"params": params["vae"]},
103
  pixel_values,
 
105
  ).latent_dist.sample(rng)
106
  latents = latents * 0.18215
107
 
108
+ # Generate random noise
109
  noise = jax.random.normal(rng, latents.shape)
110
+
111
+ # Sample random timesteps
112
  timesteps = jax.random.randint(
113
  rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
114
  )
115
+
116
+ # Add noise to latents
117
  noisy_latents = pipeline.scheduler.add_noise(
118
  pipeline.scheduler.create_state(),
119
  original_samples=latents,
 
121
  timesteps=timesteps
122
  )
123
 
124
+ # Generate random encoder hidden states (simulating text embeddings)
125
  encoder_hidden_states = jax.random.normal(
126
  rng,
127
  (latents.shape[0], pipeline.text_encoder.config.hidden_size)
128
  )
129
 
130
+ # Predict noise
131
  model_output = state.apply_fn.apply(
132
  {'params': params["unet"]},
133
  noisy_latents,
 
136
  train=True,
137
  )
138
 
139
+ # Compute loss
140
  return jnp.mean((model_output - noise) ** 2)
141
 
142
  grad_fn = jax.value_and_grad(compute_loss)
 
145
  state = state.apply_gradients(grads=grads)
146
  return state, loss
147
 
 
 
 
 
 
 
 
 
 
148
  # Training loop
149
+ num_epochs = 3
150
+ batch_size = 1 # Reduced batch size due to memory constraints
151
  rng = jax.random.PRNGKey(0)
152
 
153
  for epoch in range(num_epochs):