Update app.py
Browse files
app.py
CHANGED
@@ -145,6 +145,15 @@ def train_step(state, batch, rng):
|
|
145 |
state = state.apply_gradients(grads=grads)
|
146 |
return state, loss
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
# Training loop
|
149 |
num_epochs = 3
|
150 |
batch_size = 1 # Reduced batch size due to memory constraints
|
@@ -167,6 +176,8 @@ for epoch in range(num_epochs):
|
|
167 |
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
168 |
clear_jit_cache()
|
169 |
|
|
|
|
|
170 |
# Save the fine-tuned model
|
171 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
172 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
145 |
state = state.apply_gradients(grads=grads)
|
146 |
return state, loss
|
147 |
|
148 |
+
# Initialize training state
|
149 |
+
learning_rate = 1e-5
|
150 |
+
optimizer = optax.adam(learning_rate)
|
151 |
+
state = train_state.TrainState.create(
|
152 |
+
apply_fn=unet.__call__,
|
153 |
+
params={"unet": params["unet"], "vae": params["vae"]},
|
154 |
+
tx=optimizer,
|
155 |
+
)
|
156 |
+
|
157 |
# Training loop
|
158 |
num_epochs = 3
|
159 |
batch_size = 1 # Reduced batch size due to memory constraints
|
|
|
176 |
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
177 |
clear_jit_cache()
|
178 |
|
179 |
+
|
180 |
+
|
181 |
# Save the fine-tuned model
|
182 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
183 |
os.makedirs(output_dir, exist_ok=True)
|