Update app.py
Browse files
app.py
CHANGED
@@ -29,12 +29,11 @@ def get_model(model_id, revision):
|
|
29 |
return pickle.load(f)
|
30 |
else:
|
31 |
print("Downloading model...")
|
32 |
-
pipeline = FlaxStableDiffusionPipeline.from_pretrained(
|
33 |
model_id,
|
34 |
revision=revision,
|
35 |
dtype=jnp.float32,
|
36 |
)
|
37 |
-
params = pipeline.params
|
38 |
with open(model_cache_file, 'wb') as f:
|
39 |
pickle.dump((pipeline, params), f)
|
40 |
return pipeline, params
|
@@ -102,5 +101,73 @@ except Exception as e:
|
|
102 |
else:
|
103 |
raise ValueError(f"Local path {local_path} does not exist.")
|
104 |
|
105 |
-
#
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
return pickle.load(f)
|
30 |
else:
|
31 |
print("Downloading model...")
|
32 |
+
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
33 |
model_id,
|
34 |
revision=revision,
|
35 |
dtype=jnp.float32,
|
36 |
)
|
|
|
37 |
with open(model_cache_file, 'wb') as f:
|
38 |
pickle.dump((pipeline, params), f)
|
39 |
return pipeline, params
|
|
|
101 |
else:
|
102 |
raise ValueError(f"Local path {local_path} does not exist.")
|
103 |
|
104 |
+
# Training function
|
105 |
+
def train_step(state, batch, rng):
|
106 |
+
def compute_loss(params):
|
107 |
+
# Convert batch to JAX array
|
108 |
+
pixel_values = jnp.array(batch["pixel_values"])
|
109 |
+
batch_size = pixel_values.shape[0]
|
110 |
+
|
111 |
+
# Generate random noise
|
112 |
+
noise_rng, timestep_rng = jax.random.split(rng)
|
113 |
+
noise = jax.random.normal(noise_rng, pixel_values.shape)
|
114 |
+
|
115 |
+
# Sample random timesteps
|
116 |
+
timesteps = jax.random.randint(
|
117 |
+
timestep_rng, (batch_size,), 0, pipeline.scheduler.config.num_train_timesteps
|
118 |
+
)
|
119 |
+
|
120 |
+
# Add noise to images using the scheduler
|
121 |
+
noisy_images = pipeline.scheduler.add_noise(
|
122 |
+
original_samples=pixel_values,
|
123 |
+
noise=noise,
|
124 |
+
timesteps=timesteps
|
125 |
+
)
|
126 |
+
|
127 |
+
# Predict noise
|
128 |
+
model_output = state.apply_fn.apply(
|
129 |
+
{'params': params},
|
130 |
+
jnp.array(noisy_images),
|
131 |
+
jnp.array(timesteps),
|
132 |
+
train=True,
|
133 |
+
)
|
134 |
+
|
135 |
+
# Compute loss
|
136 |
+
loss = jnp.mean((model_output - noise) ** 2)
|
137 |
+
return loss
|
138 |
+
|
139 |
+
loss, grads = jax.value_and_grad(compute_loss)(state.params)
|
140 |
+
state = state.apply_gradients(grads=grads)
|
141 |
+
return state, loss
|
142 |
+
|
143 |
+
# Initialize training state
|
144 |
+
learning_rate = 1e-5
|
145 |
+
optimizer = optax.adam(learning_rate)
|
146 |
+
state = train_state.TrainState.create(
|
147 |
+
apply_fn=unet,
|
148 |
+
params=params["unet"], # Use only UNet params
|
149 |
+
tx=optimizer,
|
150 |
+
)
|
151 |
+
|
152 |
+
# Training loop
|
153 |
+
num_epochs = 10
|
154 |
+
batch_size = 4
|
155 |
+
rng = jax.random.PRNGKey(0)
|
156 |
+
|
157 |
+
for epoch in range(num_epochs):
|
158 |
+
epoch_loss = 0
|
159 |
+
num_batches = 0
|
160 |
+
for batch in tqdm(processed_dataset.batch(batch_size)):
|
161 |
+
rng, step_rng = jax.random.split(rng)
|
162 |
+
state, loss = train_step(state, batch, step_rng)
|
163 |
+
epoch_loss += loss
|
164 |
+
num_batches += 1
|
165 |
+
avg_loss = epoch_loss / num_batches
|
166 |
+
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
167 |
+
|
168 |
+
# Save the fine-tuned model
|
169 |
+
output_dir = "/tmp/montevideo_fine_tuned_model"
|
170 |
+
os.makedirs(output_dir, exist_ok=True)
|
171 |
+
unet.save_pretrained(output_dir, params=state.params)
|
172 |
+
|
173 |
+
print(f"Model saved to {output_dir}")
|