import jax import jax.numpy as jnp from flax.jax_utils import replicate from flax.training import train_state import optax from diffusers import FlaxStableDiffusionPipeline from datasets import load_dataset from tqdm.auto import tqdm import os import pickle from PIL import Image import numpy as np # Set up cache directories cache_dir = "/tmp/huggingface_cache" model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model") os.makedirs(model_cache_dir, exist_ok=True) print(f"Cache directory: {cache_dir}") print(f"Model cache directory: {model_cache_dir}") # Function to load or download the model def get_model(model_id, revision): model_cache_file = os.path.join(model_cache_dir, f"{model_id.replace('/', '_')}_{revision}.pkl") print(f"Model cache file: {model_cache_file}") if os.path.exists(model_cache_file): print("Loading model from cache...") with open(model_cache_file, 'rb') as f: return pickle.load(f) else: print("Downloading model...") pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( model_id, revision=revision, dtype=jnp.float32, ) with open(model_cache_file, 'wb') as f: pickle.dump((pipeline, params), f) return pipeline, params # Load the pre-trained model model_id = "CompVis/stable-diffusion-v1-4" pipeline, params = get_model(model_id, "flax") # Extract UNet from pipeline unet = pipeline.unet # Load and preprocess your dataset def preprocess_images(examples): def process_image(image): if isinstance(image, str): image = Image.open(image) if not isinstance(image, Image.Image): raise ValueError(f"Unexpected image type: {type(image)}") # Resize and convert to RGB image = image.convert("RGB").resize((512, 512)) # Convert to numpy array and normalize image = np.array(image).astype(np.float32) / 255.0 # Ensure the image has the shape (3, height, width) return image.transpose(2, 0, 1) # Change to channel-first format return {"pixel_values": [process_image(img) for img in examples["image"]]} # Load dataset from Hugging Face dataset_name = "uruguayai/montevideo" dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl") print(f"Dataset name: {dataset_name}") print(f"Dataset cache file: {dataset_cache_file}") try: if os.path.exists(dataset_cache_file): print("Loading dataset from cache...") with open(dataset_cache_file, 'rb') as f: processed_dataset = pickle.load(f) else: print("Loading dataset from Hugging Face...") dataset = load_dataset(dataset_name) print("Dataset structure:", dataset) print("Available splits:", dataset.keys()) if "train" not in dataset: raise ValueError("The dataset does not contain a 'train' split.") print("Processing dataset...") processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names) with open(dataset_cache_file, 'wb') as f: pickle.dump(processed_dataset, f) print(f"Processed dataset size: {len(processed_dataset)}") except Exception as e: print(f"Error loading or processing dataset: {str(e)}") print("Attempting to find dataset...") # List contents of current directory and parent directories print("Current directory contents:") print(os.listdir('.')) print("Parent directory contents:") print(os.listdir('..')) print("Root directory contents:") print(os.listdir('/')) # Try to find any directory that might contain the dataset for root, dirs, files in os.walk('/'): if 'montevideo' in dirs: print(f"Found 'montevideo' directory at: {os.path.join(root, 'montevideo')}") print(f"Contents: {os.listdir(os.path.join(root, 'montevideo'))}") raise ValueError("Unable to locate or load the dataset. Please check the dataset path and permissions.") # Training function def train_step(state, batch, rng): def compute_loss(params): # Convert batch to JAX array pixel_values = jnp.array(batch["pixel_values"]) batch_size = pixel_values.shape[0] # Encode images to latent space latents = pipeline.vae.apply( {"params": params["vae"]}, pixel_values, method=pipeline.vae.encode ).latent_dist.sample(rng) latents = latents * 0.18215 # scaling factor # Generate random noise noise_rng, timestep_rng, latents_rng = jax.random.split(rng, 3) noise = jax.random.normal(noise_rng, latents.shape) # Sample random timesteps timesteps = jax.random.randint( timestep_rng, (batch_size,), 0, pipeline.scheduler.config.num_train_timesteps ) # Create scheduler state scheduler_state = pipeline.scheduler.create_state() # Add noise to latents using the scheduler noisy_latents = pipeline.scheduler.add_noise( scheduler_state, original_samples=latents, noise=noise, timesteps=timesteps ) # Generate random latents for text encoder encoder_hidden_states = jax.random.normal(latents_rng, (batch_size, pipeline.text_encoder.config.hidden_size)) # Predict noise model_output = state.apply_fn.apply( {'params': params["unet"]}, jnp.array(noisy_latents), jnp.array(timesteps), encoder_hidden_states=encoder_hidden_states, train=True, ) # Compute loss loss = jnp.mean((model_output - noise) ** 2) return loss loss, grads = jax.value_and_grad(compute_loss)(state.params) state = state.apply_gradients(grads=grads) return state, loss # Initialize training state learning_rate = 1e-5 optimizer = optax.adam(learning_rate) state = train_state.TrainState.create( apply_fn=unet, params={"unet": params["unet"], "vae": params["vae"]}, # Include both UNet and VAE params tx=optimizer, ) # Training loop num_epochs = 10 batch_size = 4 rng = jax.random.PRNGKey(0) for epoch in range(num_epochs): epoch_loss = 0 num_batches = 0 for batch in tqdm(processed_dataset.batch(batch_size)): # Convert the list of pixel values to a numpy array for each batch batch['pixel_values'] = np.array(batch['pixel_values']) rng, step_rng = jax.random.split(rng) state, loss = train_step(state, batch, step_rng) epoch_loss += loss num_batches += 1 avg_loss = epoch_loss / num_batches print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}") # Save the fine-tuned model output_dir = "/tmp/montevideo_fine_tuned_model" os.makedirs(output_dir, exist_ok=True) unet.save_pretrained(output_dir, params=state.params) print(f"Model saved to {output_dir}")