uruguayai commited on
Commit
8835824
·
verified ·
1 Parent(s): 7cbe1c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -0
app.py CHANGED
@@ -145,6 +145,15 @@ def train_step(state, batch, rng):
145
  state = state.apply_gradients(grads=grads)
146
  return state, loss
147
 
 
 
 
 
 
 
 
 
 
148
  # Training loop
149
  num_epochs = 3
150
  batch_size = 1 # Reduced batch size due to memory constraints
@@ -167,6 +176,8 @@ for epoch in range(num_epochs):
167
  print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
168
  clear_jit_cache()
169
 
 
 
170
  # Save the fine-tuned model
171
  output_dir = "/tmp/montevideo_fine_tuned_model"
172
  os.makedirs(output_dir, exist_ok=True)
 
145
  state = state.apply_gradients(grads=grads)
146
  return state, loss
147
 
148
+ # Initialize training state
149
+ learning_rate = 1e-5
150
+ optimizer = optax.adam(learning_rate)
151
+ state = train_state.TrainState.create(
152
+ apply_fn=unet.__call__,
153
+ params={"unet": params["unet"], "vae": params["vae"]},
154
+ tx=optimizer,
155
+ )
156
+
157
  # Training loop
158
  num_epochs = 3
159
  batch_size = 1 # Reduced batch size due to memory constraints
 
176
  print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
177
  clear_jit_cache()
178
 
179
+
180
+
181
  # Save the fine-tuned model
182
  output_dir = "/tmp/montevideo_fine_tuned_model"
183
  os.makedirs(output_dir, exist_ok=True)