uruguayai commited on
Commit
1f8900f
·
verified ·
1 Parent(s): 77248af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -4
app.py CHANGED
@@ -29,12 +29,11 @@ def get_model(model_id, revision):
29
  return pickle.load(f)
30
  else:
31
  print("Downloading model...")
32
- pipeline = FlaxStableDiffusionPipeline.from_pretrained(
33
  model_id,
34
  revision=revision,
35
  dtype=jnp.float32,
36
  )
37
- params = pipeline.params
38
  with open(model_cache_file, 'wb') as f:
39
  pickle.dump((pipeline, params), f)
40
  return pipeline, params
@@ -102,5 +101,73 @@ except Exception as e:
102
  else:
103
  raise ValueError(f"Local path {local_path} does not exist.")
104
 
105
- # Rest of your code (training loop, etc.) remains the same
106
- ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  return pickle.load(f)
30
  else:
31
  print("Downloading model...")
32
+ pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
33
  model_id,
34
  revision=revision,
35
  dtype=jnp.float32,
36
  )
 
37
  with open(model_cache_file, 'wb') as f:
38
  pickle.dump((pipeline, params), f)
39
  return pipeline, params
 
101
  else:
102
  raise ValueError(f"Local path {local_path} does not exist.")
103
 
104
+ # Training function
105
+ def train_step(state, batch, rng):
106
+ def compute_loss(params):
107
+ # Convert batch to JAX array
108
+ pixel_values = jnp.array(batch["pixel_values"])
109
+ batch_size = pixel_values.shape[0]
110
+
111
+ # Generate random noise
112
+ noise_rng, timestep_rng = jax.random.split(rng)
113
+ noise = jax.random.normal(noise_rng, pixel_values.shape)
114
+
115
+ # Sample random timesteps
116
+ timesteps = jax.random.randint(
117
+ timestep_rng, (batch_size,), 0, pipeline.scheduler.config.num_train_timesteps
118
+ )
119
+
120
+ # Add noise to images using the scheduler
121
+ noisy_images = pipeline.scheduler.add_noise(
122
+ original_samples=pixel_values,
123
+ noise=noise,
124
+ timesteps=timesteps
125
+ )
126
+
127
+ # Predict noise
128
+ model_output = state.apply_fn.apply(
129
+ {'params': params},
130
+ jnp.array(noisy_images),
131
+ jnp.array(timesteps),
132
+ train=True,
133
+ )
134
+
135
+ # Compute loss
136
+ loss = jnp.mean((model_output - noise) ** 2)
137
+ return loss
138
+
139
+ loss, grads = jax.value_and_grad(compute_loss)(state.params)
140
+ state = state.apply_gradients(grads=grads)
141
+ return state, loss
142
+
143
+ # Initialize training state
144
+ learning_rate = 1e-5
145
+ optimizer = optax.adam(learning_rate)
146
+ state = train_state.TrainState.create(
147
+ apply_fn=unet,
148
+ params=params["unet"], # Use only UNet params
149
+ tx=optimizer,
150
+ )
151
+
152
+ # Training loop
153
+ num_epochs = 10
154
+ batch_size = 4
155
+ rng = jax.random.PRNGKey(0)
156
+
157
+ for epoch in range(num_epochs):
158
+ epoch_loss = 0
159
+ num_batches = 0
160
+ for batch in tqdm(processed_dataset.batch(batch_size)):
161
+ rng, step_rng = jax.random.split(rng)
162
+ state, loss = train_step(state, batch, step_rng)
163
+ epoch_loss += loss
164
+ num_batches += 1
165
+ avg_loss = epoch_loss / num_batches
166
+ print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
167
+
168
+ # Save the fine-tuned model
169
+ output_dir = "/tmp/montevideo_fine_tuned_model"
170
+ os.makedirs(output_dir, exist_ok=True)
171
+ unet.save_pretrained(output_dir, params=state.params)
172
+
173
+ print(f"Model saved to {output_dir}")