import jax import jax.numpy as jnp from flax.jax_utils import replicate from flax.training import train_state import optax from diffusers import FlaxStableDiffusionPipeline, FlaxUNet2DConditionModel from diffusers.schedulers import FlaxPNDMScheduler from datasets import load_dataset from tqdm.auto import tqdm import os import pickle from PIL import Image import numpy as np from inspect import signature # Custom Scheduler class CustomFlaxPNDMScheduler(FlaxPNDMScheduler): def add_noise(self, state, original_samples, noise, timesteps): timesteps = timesteps.astype(jnp.int32) return super().add_noise(state, original_samples, noise, timesteps) # Set up cache directories cache_dir = "/tmp/huggingface_cache" model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model") os.makedirs(model_cache_dir, exist_ok=True) print(f"Cache directory: {cache_dir}") print(f"Model cache directory: {model_cache_dir}") def filter_dict(dict_to_filter, target_callable): """Filter a dictionary to only include keys that are valid parameters for the target callable.""" valid_params = signature(target_callable).parameters.keys() return {k: v for k, v in dict_to_filter.items() if k in valid_params} # Function to load or download the model def get_model(model_id, revision): model_cache_file = os.path.join(model_cache_dir, f"{model_id.replace('/', '_')}_{revision}.pkl") print(f"Model cache file: {model_cache_file}") if os.path.exists(model_cache_file): print("Loading model from cache...") with open(model_cache_file, 'rb') as f: return pickle.load(f) else: print("Downloading model...") pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( model_id, revision=revision, dtype=jnp.float32, ) with open(model_cache_file, 'wb') as f: pickle.dump((pipeline, params), f) return pipeline, params # Load the pre-trained model model_id = "CompVis/stable-diffusion-v1-4" pipeline, params = get_model(model_id, "flax") # Use custom scheduler custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config) pipeline.scheduler = custom_scheduler # Extract UNet from pipeline unet = pipeline.unet # Print UNet configuration print("UNet configuration:") print(unet.config) # Adjust the input layer of the UNet def adjust_unet_input_layer(params): if 'unet' in params: unet_params = params['unet'] else: unet_params = params if 'conv_in' not in unet_params: print("Warning: 'conv_in' not found in UNet params. Skipping input layer adjustment.") return params conv_in_weight = unet_params['conv_in']['kernel'] print(f"Original conv_in weight shape: {conv_in_weight.shape}") if conv_in_weight.shape[2] != 4: new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32) new_conv_in_weight = new_conv_in_weight.at[:, :, :3, :].set(conv_in_weight[:, :, :3, :]) unet_params['conv_in']['kernel'] = new_conv_in_weight print(f"New conv_in weight shape: {unet_params['conv_in']['kernel'].shape}") if 'unet' in params: params['unet'] = unet_params else: params = unet_params return params params = adjust_unet_input_layer(params) # Load and preprocess your dataset def preprocess_images(examples): def process_image(image): if isinstance(image, str): if not image.lower().endswith('.jpg') and not image.lower().endswith('.jpeg'): return None image = Image.open(image) if not isinstance(image, Image.Image): return None image = image.convert("RGB").resize((512, 512)) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(2, 0, 1) # Change to channel-first format return image processed = [process_image(img) for img in examples["image"]] return {"pixel_values": [img for img in processed if img is not None]} # Load dataset from Hugging Face dataset_name = "uruguayai/montevideo" dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl") print(f"Dataset name: {dataset_name}") print(f"Dataset cache file: {dataset_cache_file}") if os.path.exists(dataset_cache_file): print("Loading dataset from cache...") with open(dataset_cache_file, 'rb') as f: processed_dataset = pickle.load(f) else: print("Processing dataset...") dataset = load_dataset(dataset_name) processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names) processed_dataset = processed_dataset.filter(lambda example: len(example['pixel_values']) > 0) with open(dataset_cache_file, 'wb') as f: pickle.dump(processed_dataset, f) print(f"Processed dataset size: {len(processed_dataset)}") # Print sample input shape sample_batch = next(iter(processed_dataset.batch(1))) print(f"Sample batch keys: {sample_batch.keys()}") print(f"Sample pixel_values type: {type(sample_batch['pixel_values'])}") print(f"Sample pixel_values length: {len(sample_batch['pixel_values'])}") if len(sample_batch['pixel_values']) > 0: print(f"Sample pixel_values[0] shape: {np.array(sample_batch['pixel_values'][0]).shape}") # Training function def train_step(state, batch, rng): def compute_loss(unet_params, pixel_values, rng): pixel_values = jnp.array(pixel_values, dtype=jnp.float32) if pixel_values.ndim == 3: pixel_values = jnp.expand_dims(pixel_values, axis=0) # Add batch dimension if needed print(f"pixel_values shape in compute_loss: {pixel_values.shape}") # Use the VAE from params latents = pipeline.vae.apply( {"params": params["vae"]}, pixel_values, method=pipeline.vae.encode ).latent_dist.sample(rng) latents = latents * jnp.float32(0.18215) print(f"latents shape: {latents.shape}") noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32) timesteps = jax.random.randint( rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps ) noisy_latents = pipeline.scheduler.add_noise( pipeline.scheduler.create_state(), original_samples=latents, noise=noise, timesteps=timesteps ) encoder_hidden_states = jax.random.normal( rng, (latents.shape[0], pipeline.text_encoder.config.hidden_size), dtype=jnp.float32 ) print(f"noisy_latents shape: {noisy_latents.shape}") print(f"timesteps shape: {timesteps.shape}") print(f"encoder_hidden_states shape: {encoder_hidden_states.shape}") # Use the state's apply_fn (which should be the adjusted UNet) model_output = state.apply_fn( {"params": unet_params}, noisy_latents, jnp.array(timesteps, dtype=jnp.int32), encoder_hidden_states, train=True, ).sample return jnp.mean((model_output - noise) ** 2) grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True) rng, step_rng = jax.random.split(rng) # Ensure we're passing the correct structure to grad_fn and compute_loss unet_params = state.params["params"] if "params" in state.params else state.params grads = grad_fn(unet_params, batch["pixel_values"], step_rng) loss = compute_loss(unet_params, batch["pixel_values"], step_rng) # Update the state with the correct structure new_params = optax.apply_updates(state.params, grads) state = state.replace(params=new_params) return state, loss # Initialize training state learning_rate = 1e-5 optimizer = optax.adam(learning_rate) float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params) # Create a new UNet with the adjusted parameters unet_config = dict(unet.config) filtered_unet_config = filter_dict(unet_config, FlaxUNet2DConditionModel.__init__) print("Filtered UNet config keys:", filtered_unet_config.keys()) adjusted_unet = FlaxUNet2DConditionModel(**filtered_unet_config) adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768))) adjusted_params = adjust_unet_input_layer(adjusted_params) # Adjust the input layer # Adjust the state creation state = train_state.TrainState.create( apply_fn=adjusted_unet.apply, params={"params": adjusted_params}, # Wrap params in a dict with "params" key tx=optimizer, ) # Training loop num_epochs = 3 batch_size = 1 rng = jax.random.PRNGKey(0) # Training loop num_epochs = 3 batch_size = 1 rng = jax.random.PRNGKey(0) for epoch in range(num_epochs): epoch_loss = 0 num_batches = 0 num_errors = 0 for batch in tqdm(processed_dataset.batch(batch_size)): try: batch['pixel_values'] = jnp.array(batch['pixel_values'][0], dtype=jnp.float32) rng, step_rng = jax.random.split(rng) state, loss = train_step(state, batch, step_rng) epoch_loss += loss num_batches += 1 if num_batches % 10 == 0: jax.clear_caches() print(f"Processed {num_batches} batches. Current loss: {loss}") except Exception as e: num_errors += 1 print(f"Error processing batch: {e}") continue if num_batches > 0: avg_loss = epoch_loss / num_batches print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}, Errors: {num_errors}") else: print(f"Epoch {epoch+1}/{num_epochs}, No valid batches processed, Errors: {num_errors}") jax.clear_caches() # Save the fine-tuned model output_dir = "/tmp/montevideo_fine_tuned_model" os.makedirs(output_dir, exist_ok=True) adjusted_unet.save_pretrained(output_dir, params=state.params["params"]) print(f"Model saved to {output_dir}")