Update app.py
Browse files
app.py
CHANGED
@@ -53,12 +53,9 @@ def preprocess_images(examples):
|
|
53 |
image = Image.open(image)
|
54 |
if not isinstance(image, Image.Image):
|
55 |
raise ValueError(f"Unexpected image type: {type(image)}")
|
56 |
-
|
57 |
-
image = image.convert("RGB").resize((512, 512))
|
58 |
-
# Convert to numpy array and normalize
|
59 |
image = np.array(image).astype(np.float16) / 255.0
|
60 |
-
|
61 |
-
return image.transpose(2, 0, 1) # Change to channel-first format
|
62 |
|
63 |
return {"pixel_values": [process_image(img) for img in examples["image"]]}
|
64 |
|
@@ -76,15 +73,9 @@ try:
|
|
76 |
processed_dataset = pickle.load(f)
|
77 |
else:
|
78 |
print("Loading dataset from Hugging Face...")
|
79 |
-
dataset = load_dataset(dataset_name)
|
80 |
-
print("Dataset structure:", dataset)
|
81 |
-
print("Available splits:", dataset.keys())
|
82 |
-
|
83 |
-
if "train" not in dataset:
|
84 |
-
raise ValueError("The dataset does not contain a 'train' split.")
|
85 |
-
|
86 |
print("Processing dataset...")
|
87 |
-
processed_dataset = dataset
|
88 |
with open(dataset_cache_file, 'wb') as f:
|
89 |
pickle.dump(processed_dataset, f)
|
90 |
|
@@ -92,23 +83,7 @@ try:
|
|
92 |
|
93 |
except Exception as e:
|
94 |
print(f"Error loading or processing dataset: {str(e)}")
|
95 |
-
|
96 |
-
|
97 |
-
# List contents of current directory and parent directories
|
98 |
-
print("Current directory contents:")
|
99 |
-
print(os.listdir('.'))
|
100 |
-
print("Parent directory contents:")
|
101 |
-
print(os.listdir('..'))
|
102 |
-
print("Root directory contents:")
|
103 |
-
print(os.listdir('/'))
|
104 |
-
|
105 |
-
# Try to find any directory that might contain the dataset
|
106 |
-
for root, dirs, files in os.walk('/'):
|
107 |
-
if 'montevideo' in dirs:
|
108 |
-
print(f"Found 'montevideo' directory at: {os.path.join(root, 'montevideo')}")
|
109 |
-
print(f"Contents: {os.listdir(os.path.join(root, 'montevideo'))}")
|
110 |
-
|
111 |
-
raise ValueError("Unable to locate or load the dataset. Please check the dataset path and permissions.")
|
112 |
|
113 |
# Function to clear JIT cache
|
114 |
def clear_jit_cache():
|
@@ -116,122 +91,78 @@ def clear_jit_cache():
|
|
116 |
gc.collect()
|
117 |
|
118 |
# Training function with gradient accumulation
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
pixel_values = jnp.array(batch_slice["pixel_values"], dtype=jnp.float16)
|
123 |
-
batch_size = pixel_values.shape[0]
|
124 |
-
|
125 |
-
# Encode images to latent space
|
126 |
latents = pipeline.vae.apply(
|
127 |
{"params": params["vae"]},
|
128 |
pixel_values,
|
129 |
method=pipeline.vae.encode
|
130 |
).latent_dist.sample(rng)
|
131 |
-
latents = latents * jnp.float16(0.18215)
|
132 |
|
133 |
-
|
134 |
-
noise_rng, timestep_rng, latents_rng = jax.random.split(rng, 3)
|
135 |
-
noise = jax.random.normal(noise_rng, latents.shape, dtype=jnp.float16)
|
136 |
-
|
137 |
-
# Sample random timesteps
|
138 |
timesteps = jax.random.randint(
|
139 |
-
|
140 |
)
|
141 |
-
|
142 |
-
# Create scheduler state
|
143 |
-
scheduler_state = pipeline.scheduler.create_state()
|
144 |
-
|
145 |
-
# Add noise to latents using the scheduler
|
146 |
noisy_latents = pipeline.scheduler.add_noise(
|
147 |
-
|
148 |
original_samples=latents,
|
149 |
noise=noise,
|
150 |
timesteps=timesteps
|
151 |
)
|
152 |
|
153 |
-
# Generate random latents for text encoder
|
154 |
encoder_hidden_states = jax.random.normal(
|
155 |
-
|
156 |
-
(
|
157 |
dtype=jnp.float16
|
158 |
)
|
159 |
|
160 |
-
# Predict noise
|
161 |
model_output = state.apply_fn.apply(
|
162 |
{'params': params["unet"]},
|
163 |
-
|
164 |
-
|
165 |
encoder_hidden_states=encoder_hidden_states,
|
166 |
train=True,
|
167 |
)
|
168 |
|
169 |
-
|
170 |
-
loss = jnp.mean((model_output - noise) ** 2)
|
171 |
-
return loss
|
172 |
|
173 |
grad_fn = jax.value_and_grad(compute_loss)
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
# Initialize accumulated gradients
|
180 |
-
acc_grads = jax.tree_map(jnp.zeros_like, state.params)
|
181 |
-
acc_loss = jnp.float16(0.0)
|
182 |
-
|
183 |
-
for i in range(grad_accumulation_steps):
|
184 |
-
start_idx = i * chunk_size
|
185 |
-
end_idx = start_idx + chunk_size if i < grad_accumulation_steps - 1 else batch_size
|
186 |
-
|
187 |
-
batch_slice = {
|
188 |
-
'pixel_values': batch['pixel_values'][start_idx:end_idx]
|
189 |
-
}
|
190 |
-
|
191 |
-
rng, step_rng = jax.random.split(rng)
|
192 |
-
loss, grads = grad_fn(state.params, batch_slice, step_rng)
|
193 |
-
|
194 |
-
# Accumulate gradients and loss
|
195 |
-
acc_grads = jax.tree_map(lambda acc, g: acc + g / grad_accumulation_steps, acc_grads, grads)
|
196 |
-
acc_loss += loss / grad_accumulation_steps
|
197 |
-
|
198 |
-
# Update state with accumulated gradients
|
199 |
-
state = state.apply_gradients(grads=acc_grads)
|
200 |
-
return state, acc_loss
|
201 |
|
202 |
# Initialize training state
|
203 |
learning_rate = jnp.float16(1e-5)
|
204 |
optimizer = optax.adam(learning_rate)
|
205 |
state = train_state.TrainState.create(
|
206 |
apply_fn=unet,
|
207 |
-
params={"unet": params["unet"], "vae": params["vae"]},
|
208 |
tx=optimizer,
|
209 |
)
|
210 |
|
211 |
# Training loop
|
212 |
-
num_epochs =
|
213 |
-
batch_size =
|
214 |
rng = jax.random.PRNGKey(0)
|
215 |
|
216 |
for epoch in range(num_epochs):
|
217 |
epoch_loss = 0
|
218 |
num_batches = 0
|
219 |
for batch in tqdm(processed_dataset.batch(batch_size)):
|
220 |
-
|
221 |
-
batch['pixel_values'] = np.array(batch['pixel_values'], dtype=np.float16)
|
222 |
rng, step_rng = jax.random.split(rng)
|
223 |
state, loss = train_step(state, batch, step_rng)
|
224 |
epoch_loss += loss
|
225 |
num_batches += 1
|
226 |
|
227 |
-
# Clear JIT cache every 10 batches
|
228 |
if num_batches % 10 == 0:
|
229 |
clear_jit_cache()
|
230 |
|
231 |
avg_loss = epoch_loss / num_batches
|
232 |
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
233 |
-
|
234 |
-
# Clear JIT cache after each epoch
|
235 |
clear_jit_cache()
|
236 |
|
237 |
# Save the fine-tuned model
|
|
|
53 |
image = Image.open(image)
|
54 |
if not isinstance(image, Image.Image):
|
55 |
raise ValueError(f"Unexpected image type: {type(image)}")
|
56 |
+
image = image.convert("RGB").resize((256, 256)) # Reduced image size
|
|
|
|
|
57 |
image = np.array(image).astype(np.float16) / 255.0
|
58 |
+
return image.transpose(2, 0, 1)
|
|
|
59 |
|
60 |
return {"pixel_values": [process_image(img) for img in examples["image"]]}
|
61 |
|
|
|
73 |
processed_dataset = pickle.load(f)
|
74 |
else:
|
75 |
print("Loading dataset from Hugging Face...")
|
76 |
+
dataset = load_dataset(dataset_name, split="train[:1000]") # Load only first 1000 samples
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
print("Processing dataset...")
|
78 |
+
processed_dataset = dataset.map(preprocess_images, batched=True, remove_columns=dataset.column_names)
|
79 |
with open(dataset_cache_file, 'wb') as f:
|
80 |
pickle.dump(processed_dataset, f)
|
81 |
|
|
|
83 |
|
84 |
except Exception as e:
|
85 |
print(f"Error loading or processing dataset: {str(e)}")
|
86 |
+
raise ValueError("Unable to load or process the dataset.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
# Function to clear JIT cache
|
89 |
def clear_jit_cache():
|
|
|
91 |
gc.collect()
|
92 |
|
93 |
# Training function with gradient accumulation
|
94 |
+
@jax.jit
|
95 |
+
def train_step(state, batch, rng):
|
96 |
+
def compute_loss(params, pixel_values, rng):
|
|
|
|
|
|
|
|
|
97 |
latents = pipeline.vae.apply(
|
98 |
{"params": params["vae"]},
|
99 |
pixel_values,
|
100 |
method=pipeline.vae.encode
|
101 |
).latent_dist.sample(rng)
|
102 |
+
latents = latents * jnp.float16(0.18215)
|
103 |
|
104 |
+
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float16)
|
|
|
|
|
|
|
|
|
105 |
timesteps = jax.random.randint(
|
106 |
+
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
107 |
)
|
|
|
|
|
|
|
|
|
|
|
108 |
noisy_latents = pipeline.scheduler.add_noise(
|
109 |
+
pipeline.scheduler.create_state(),
|
110 |
original_samples=latents,
|
111 |
noise=noise,
|
112 |
timesteps=timesteps
|
113 |
)
|
114 |
|
|
|
115 |
encoder_hidden_states = jax.random.normal(
|
116 |
+
rng,
|
117 |
+
(latents.shape[0], pipeline.text_encoder.config.hidden_size),
|
118 |
dtype=jnp.float16
|
119 |
)
|
120 |
|
|
|
121 |
model_output = state.apply_fn.apply(
|
122 |
{'params': params["unet"]},
|
123 |
+
noisy_latents,
|
124 |
+
timesteps,
|
125 |
encoder_hidden_states=encoder_hidden_states,
|
126 |
train=True,
|
127 |
)
|
128 |
|
129 |
+
return jnp.mean((model_output - noise) ** 2)
|
|
|
|
|
130 |
|
131 |
grad_fn = jax.value_and_grad(compute_loss)
|
132 |
+
rng, step_rng = jax.random.split(rng)
|
133 |
+
loss, grads = grad_fn(state.params, batch["pixel_values"], step_rng)
|
134 |
+
state = state.apply_gradients(grads=grads)
|
135 |
+
return state, loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
# Initialize training state
|
138 |
learning_rate = jnp.float16(1e-5)
|
139 |
optimizer = optax.adam(learning_rate)
|
140 |
state = train_state.TrainState.create(
|
141 |
apply_fn=unet,
|
142 |
+
params={"unet": params["unet"], "vae": params["vae"]},
|
143 |
tx=optimizer,
|
144 |
)
|
145 |
|
146 |
# Training loop
|
147 |
+
num_epochs = 5 # Reduced number of epochs
|
148 |
+
batch_size = 4
|
149 |
rng = jax.random.PRNGKey(0)
|
150 |
|
151 |
for epoch in range(num_epochs):
|
152 |
epoch_loss = 0
|
153 |
num_batches = 0
|
154 |
for batch in tqdm(processed_dataset.batch(batch_size)):
|
155 |
+
batch['pixel_values'] = jnp.array(batch['pixel_values'], dtype=jnp.float16)
|
|
|
156 |
rng, step_rng = jax.random.split(rng)
|
157 |
state, loss = train_step(state, batch, step_rng)
|
158 |
epoch_loss += loss
|
159 |
num_batches += 1
|
160 |
|
|
|
161 |
if num_batches % 10 == 0:
|
162 |
clear_jit_cache()
|
163 |
|
164 |
avg_loss = epoch_loss / num_batches
|
165 |
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
|
|
|
|
166 |
clear_jit_cache()
|
167 |
|
168 |
# Save the fine-tuned model
|